nvedant07 commited on
Commit
0ba5e52
·
verified ·
1 Parent(s): 857cd41

minor bugfix

Browse files
Files changed (1) hide show
  1. model.py +207 -51
model.py CHANGED
@@ -37,9 +37,10 @@ def sample_argmax(logits: torch.Tensor) -> torch.Tensor:
37
  return torch.argmax(logits, dim=-1)[:, -1]
38
 
39
 
40
- LLAMA_TEMPLATE = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
41
- You are a helpful assistant. You give engaging, well-structured answers to user inquiries.<|eot_id|><|start_header_id|>user<|end_header_id|>
42
- {input}<|eot_id|><|start_header_id|>assistant<|end_header_id|>"""
 
43
 
44
 
45
  class HATCache(Cache):
@@ -488,6 +489,7 @@ class HATEncoderConnector(nn.Module):
488
  device=self.latent_query.device,
489
  dtype=torch.int32,
490
  )
 
491
  word_embeddings = self.cross_attention_encoder_connector.forward(
492
  q_activations=latent_query_repeated,
493
  kv_activations=hidden_states,
@@ -607,7 +609,7 @@ class HATForCausalLM(PreTrainedModel):
607
  backbone_past_key_values = past_key_values.get_backbone_cache() if past_key_values is not None else None
608
  decoder_past_key_values = past_key_values.get_decoder_cache() if past_key_values is not None else None
609
 
610
- encoder_output: BaseModelOutputWithPast = self.encoder(
611
  input_ids=input_ids,
612
  cumulative_seq_lengths_per_word=cumulative_seq_lengths_per_word,
613
  byte_position_ids=byte_position_ids,
@@ -617,13 +619,13 @@ class HATForCausalLM(PreTrainedModel):
617
  )
618
  byte_level_activations = encoder_output.hidden_states
619
 
620
- encoder_connector_output = self.encoder_connector(
621
  byte_level_activations,
622
  cumulative_seq_lengths_per_word,
623
  word_position_ids,
624
  byte_position_ids,
625
  )
626
- backbone_output: CausalLMOutputWithPast = self.backbone(
627
  hidden_states=encoder_connector_output,
628
  position_ids=word_position_ids,
629
  past_key_values=backbone_past_key_values,
@@ -658,7 +660,7 @@ class HATForCausalLM(PreTrainedModel):
658
  def _append_byte(self, words: list[list[int]], token: int) -> list[list[int]]:
659
  extended_last_word = words.pop() + [token]
660
  try:
661
- text = self.splitter.decode(extended_last_word, errors='strict', skip_special_tokens=False)
662
  list_of_bytes = self.splitter.encode(text)
663
  words.extend([list(word_in_bytes) for word_in_bytes in list_of_bytes])
664
  except UnicodeDecodeError:
@@ -667,20 +669,70 @@ class HATForCausalLM(PreTrainedModel):
667
  words.append(extended_last_word)
668
  return words
669
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
670
  def _complete_word(
671
  self,
672
  input_ids: torch.Tensor,
673
  byte_position_ids: torch.Tensor,
674
- backbone_word_prediction: torch.Tensor,
675
  word_position_id: torch.Tensor,
676
  encoder_cache: DynamicCache,
677
  decoder_cache: DynamicCache,
678
  sample_fn: Callable[[torch.Tensor], torch.Tensor] = sample_argmax,
 
679
  ):
680
  """Generate byte tokens until we hit the first byte of a new word."""
681
- words = [input_ids.squeeze(0).tolist()]
682
- byte_encoder_activations = []
683
- completion_logits = []
 
 
 
 
 
 
684
 
685
  while True:
686
  encoder_output = self.encoder.forward(
@@ -692,7 +744,7 @@ class HATForCausalLM(PreTrainedModel):
692
  )
693
  byte_encoder_activations.append(encoder_output.hidden_states)
694
  decoder_output = self.decoder.forward(
695
- backbone_word_prediction,
696
  encoder_output.hidden_states,
697
  byte_position_ids=None,
698
  word_position_ids=word_position_id,
@@ -705,22 +757,112 @@ class HATForCausalLM(PreTrainedModel):
705
  next_byte = int(sample_fn(logits).item())
706
  words = self._append_byte(words, next_byte)
707
  if len(words) > 1 or next_byte == self.eos_token_id:
 
 
 
 
 
 
708
  break
709
  input_ids = torch.tensor([[next_byte]], dtype=input_ids.dtype, device=input_ids.device)
710
 
711
- byte_encoder_activations = torch.cat(byte_encoder_activations, dim=1)
712
  num_kv = encoder_cache.get_seq_length()
713
- byte_position_ids = torch.arange(num_kv + 1 - byte_encoder_activations.shape[1], num_kv + 1, device=input_ids.device, dtype=torch.long).unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
714
  completed_word_embedding = self.encoder_connector.forward(
715
- byte_encoder_activations,
716
- cumulative_seq_lengths_per_word=torch.tensor([0, byte_encoder_activations.size(1)], dtype=torch.int32, device=input_ids.device),
717
  word_position_ids=word_position_id,
718
  byte_position_ids=byte_position_ids,
719
  )
720
 
721
- completion = sum(words, [])[-len(completion_logits) :]
722
- first_byte_of_next_word = words[1]
723
- return completion, completed_word_embedding, first_byte_of_next_word, byte_position_ids[:, -1].item() + 1, completion_logits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
724
 
725
  def generate(
726
  self,
@@ -756,6 +898,20 @@ class HATForCausalLM(PreTrainedModel):
756
  completion_logits=completion_logits,
757
  )
758
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
759
  @torch.no_grad()
760
  def _generate_cached(
761
  self,
@@ -767,43 +923,35 @@ class HATForCausalLM(PreTrainedModel):
767
  sample_fn: Callable[[torch.Tensor], torch.Tensor] = sample_argmax,
768
  stop_sequences: Sequence[str] | None = None,
769
  ):
770
- max_total_bytes = max_new_tokens + input_ids.shape[1]
771
- if byte_position_ids is None:
772
- byte_position_ids = torch.arange(0, cumulative_seq_lengths_per_word[-1].item(), device=input_ids.device, dtype=torch.int32).unsqueeze(0)
773
-
774
- if word_position_ids is None:
775
- word_position_ids = torch.arange(0, cumulative_seq_lengths_per_word.shape[0] - 1, device=input_ids.device, dtype=torch.int32).unsqueeze(0)
776
-
777
- last_word_start, last_word_end = (
778
- cumulative_seq_lengths_per_word[-2],
779
- cumulative_seq_lengths_per_word[-1],
780
- )
781
- # Populate cache with everything except last word
782
- initial_forward_output = self.forward(
783
- input_ids=input_ids[:, :last_word_start],
784
- cumulative_seq_lengths_per_word=cumulative_seq_lengths_per_word[:-1],
785
- byte_position_ids=byte_position_ids[:, :last_word_start],
786
- word_position_ids=word_position_ids[:, :-1],
787
- past_key_values=None,
788
- use_cache=True,
789
  )
790
 
791
- completion_bytes = []
792
- completion_logits = []
793
- input_ids = input_ids[:, last_word_start:last_word_end]
794
- next_byte_id = last_word_end
795
- byte_position_ids = byte_position_ids[:, last_word_start:last_word_end]
796
- word_position_id = word_position_ids[:, -1].unsqueeze(-1)
797
- backbone_last_hidden_state = initial_forward_output.hidden_states[:, -1:, :]
798
  while next_byte_id < max_total_bytes:
799
- completion, completed_word_embedding, first_byte_of_next_word, next_byte_id, next_completion_logits = self._complete_word(
800
  input_ids=input_ids,
801
  byte_position_ids=byte_position_ids,
802
- backbone_word_prediction=backbone_last_hidden_state,
803
  word_position_id=word_position_id,
804
  encoder_cache=initial_forward_output.past_key_values.get_encoder_cache(),
805
  decoder_cache=initial_forward_output.past_key_values.get_decoder_cache(),
806
  sample_fn=sample_fn,
 
807
  )
808
  completion_logits.extend(next_completion_logits)
809
  completion_bytes.extend(completion)
@@ -828,11 +976,19 @@ class HATForCausalLM(PreTrainedModel):
828
  )
829
  backbone_last_hidden_state = backbone_output.hidden_states[:, -1, :].unsqueeze(1)
830
 
831
- input_ids = torch.tensor([first_byte_of_next_word], dtype=input_ids.dtype, device=input_ids.device)
832
- byte_position_ids = torch.tensor([[next_byte_id]], dtype=input_ids.dtype, device=input_ids.device)
833
  word_position_id = word_position_id + 1
 
 
 
 
 
 
 
 
 
 
 
834
 
835
- completion_bytes.extend(first_byte_of_next_word)
836
  completion_bytes = completion_bytes[:max_new_tokens]
837
  completion_logits = torch.cat(completion_logits[:max_new_tokens], dim=0)
838
  completion_text = self.splitter.decode(completion_bytes)
@@ -847,7 +1003,7 @@ class HATForCausalLM(PreTrainedModel):
847
  cumulative_seq_lengths_per_word: torch.Tensor,
848
  byte_position_ids: torch.Tensor | None = None,
849
  word_position_ids: torch.Tensor | None = None,
850
- sample_fn=sample_argmax,
851
  stop_sequences: Sequence[str] | None = None,
852
  ):
853
  if byte_position_ids is None:
 
37
  return torch.argmax(logits, dim=-1)[:, -1]
38
 
39
 
40
+ LLAMA_TEMPLATE = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n
41
+ You are a helpful assistant. You give engaging, well-structured answers to user inquiries.<|eot_id|><|start_header_id|>user<|end_header_id|>\n
42
+ {input}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"""
43
+
44
 
45
 
46
  class HATCache(Cache):
 
489
  device=self.latent_query.device,
490
  dtype=torch.int32,
491
  )
492
+
493
  word_embeddings = self.cross_attention_encoder_connector.forward(
494
  q_activations=latent_query_repeated,
495
  kv_activations=hidden_states,
 
609
  backbone_past_key_values = past_key_values.get_backbone_cache() if past_key_values is not None else None
610
  decoder_past_key_values = past_key_values.get_decoder_cache() if past_key_values is not None else None
611
 
612
+ encoder_output: BaseModelOutputWithPast = self.encoder.forward(
613
  input_ids=input_ids,
614
  cumulative_seq_lengths_per_word=cumulative_seq_lengths_per_word,
615
  byte_position_ids=byte_position_ids,
 
619
  )
620
  byte_level_activations = encoder_output.hidden_states
621
 
622
+ encoder_connector_output = self.encoder_connector.forward(
623
  byte_level_activations,
624
  cumulative_seq_lengths_per_word,
625
  word_position_ids,
626
  byte_position_ids,
627
  )
628
+ backbone_output: CausalLMOutputWithPast = self.backbone.forward(
629
  hidden_states=encoder_connector_output,
630
  position_ids=word_position_ids,
631
  past_key_values=backbone_past_key_values,
 
660
  def _append_byte(self, words: list[list[int]], token: int) -> list[list[int]]:
661
  extended_last_word = words.pop() + [token]
662
  try:
663
+ text = self.splitter.decode(extended_last_word, errors="strict", skip_special_tokens=False)
664
  list_of_bytes = self.splitter.encode(text)
665
  words.extend([list(word_in_bytes) for word_in_bytes in list_of_bytes])
666
  except UnicodeDecodeError:
 
669
  words.append(extended_last_word)
670
  return words
671
 
672
+ def _split_encoder_activations(
673
+ self,
674
+ byte_encoder_activations: torch.Tensor,
675
+ words: list[list[int]],
676
+ previous_encoder_activations: torch.Tensor | None = None,
677
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
678
+ """Split encoder activations between first word and next word.
679
+
680
+ Args:
681
+ byte_encoder_activations: Tensor of shape [batch_size, seq_len, hidden_size] containing all encoder activations which were computed in the current iteration
682
+ words: List of word byte sequences which were completed in previous iteration and current iteration
683
+ previous_encoder_activations: Optional tensor of shape [batch_size, prev_seq_len, hidden_size] containing precomputed activations from the previous iteration
684
+
685
+ Returns:
686
+ tuple containing:
687
+ - first_word_encoder_activations: Tensor of shape [batch_size, first_word_len, hidden_size]
688
+ - next_word_encoder_activations: Tensor of shape [batch_size, remaining_len, hidden_size]
689
+ """
690
+
691
+ assert sum(len(word) for word in words) - 1 == byte_encoder_activations.shape[1] + (previous_encoder_activations.shape[1] if previous_encoder_activations is not None else 0), "Length of (words - 1) must match the sum of byte_encoder_activations and previous_encoder_activations dimensions"
692
+
693
+ next_word_encoder_activations = None
694
+ if previous_encoder_activations is not None:
695
+ # We have already precomputed first word's encoder activations partially in the previous iteration
696
+ new_bytes_of_first_words = len(words[0]) - previous_encoder_activations.shape[1]
697
+ # Concatenate the precomputed activations with the new activations that still belong to the first word
698
+ first_word_encoder_activations = torch.cat([previous_encoder_activations, byte_encoder_activations[:, :new_bytes_of_first_words]], dim=1)
699
+ if len(words[1]) > 1:
700
+ # The remaining activations that belong to the next word
701
+ next_word_encoder_activations = byte_encoder_activations[:, new_bytes_of_first_words:]
702
+ else:
703
+ next_word_encoder_activations = None
704
+ else:
705
+ # We have not precomputed any activations for the first word previously
706
+ first_word_encoder_activations = byte_encoder_activations[:, : len(words[0])]
707
+
708
+ if len(words[1]) > 1:
709
+ next_word_encoder_activations = byte_encoder_activations[:, len(words[0]) :]
710
+ else:
711
+ next_word_encoder_activations = None
712
+
713
+ return first_word_encoder_activations, next_word_encoder_activations
714
+
715
  def _complete_word(
716
  self,
717
  input_ids: torch.Tensor,
718
  byte_position_ids: torch.Tensor,
719
+ predictive_word_embeddings: torch.Tensor,
720
  word_position_id: torch.Tensor,
721
  encoder_cache: DynamicCache,
722
  decoder_cache: DynamicCache,
723
  sample_fn: Callable[[torch.Tensor], torch.Tensor] = sample_argmax,
724
+ previous_encoder_activations: torch.Tensor | None = None,
725
  ):
726
  """Generate byte tokens until we hit the first byte of a new word."""
727
+ words: list[list[int]] = [input_ids.squeeze(0).tolist()]
728
+ byte_encoder_activations: list[torch.Tensor] = []
729
+ completion_logits: list[torch.Tensor] = []
730
+
731
+ if previous_encoder_activations is not None:
732
+ # we need to pass all inputs in order to get the correct encoding/decoding by the splitter
733
+ # but only the last byte is used for the generation
734
+ # since the cache is already populated with the first word's activations
735
+ input_ids = input_ids[:, -1:]
736
 
737
  while True:
738
  encoder_output = self.encoder.forward(
 
744
  )
745
  byte_encoder_activations.append(encoder_output.hidden_states)
746
  decoder_output = self.decoder.forward(
747
+ predictive_word_embeddings,
748
  encoder_output.hidden_states,
749
  byte_position_ids=None,
750
  word_position_ids=word_position_id,
 
757
  next_byte = int(sample_fn(logits).item())
758
  words = self._append_byte(words, next_byte)
759
  if len(words) > 1 or next_byte == self.eos_token_id:
760
+ byte_encoder_activations = torch.cat(byte_encoder_activations, dim=1)
761
+ first_word_encoder_activations, next_word_encoder_activations = self._split_encoder_activations(
762
+ byte_encoder_activations,
763
+ words,
764
+ previous_encoder_activations,
765
+ )
766
  break
767
  input_ids = torch.tensor([[next_byte]], dtype=input_ids.dtype, device=input_ids.device)
768
 
 
769
  num_kv = encoder_cache.get_seq_length()
770
+
771
+ completion = sum(words, [])[-len(completion_logits) :]
772
+ if next_word_encoder_activations is not None:
773
+ start_idx = num_kv - first_word_encoder_activations.shape[1] - next_word_encoder_activations.shape[1]
774
+ end_idx = num_kv - next_word_encoder_activations.shape[1]
775
+ # We do not want to return the logits for the second word went into the mulitbyte starting character case
776
+ # When that happens we remove the logits and post-hoc fix the decoder cache and compute new logits
777
+ # This is breaking causality but we want to imitate uncached generation/training behavior
778
+ completion_logits = completion_logits[:-next_word_encoder_activations.shape[1]]
779
+ else:
780
+ start_idx = num_kv - first_word_encoder_activations.shape[1]
781
+ end_idx = num_kv
782
+
783
+ byte_position_ids = torch.arange(start_idx, end_idx, device=input_ids.device, dtype=torch.long).unsqueeze(0)
784
  completed_word_embedding = self.encoder_connector.forward(
785
+ first_word_encoder_activations,
786
+ cumulative_seq_lengths_per_word=torch.tensor([0, first_word_encoder_activations.size(1)], dtype=torch.int32, device=input_ids.device),
787
  word_position_ids=word_position_id,
788
  byte_position_ids=byte_position_ids,
789
  )
790
 
791
+ bytes_of_next_word = words[1]
792
+
793
+ return (
794
+ completion,
795
+ completed_word_embedding,
796
+ bytes_of_next_word,
797
+ byte_position_ids[:, -1].item() + 1,
798
+ completion_logits,
799
+ next_word_encoder_activations,
800
+ )
801
+
802
+ def _populate_cache(
803
+ self,
804
+ input_ids: torch.Tensor,
805
+ cumulative_seq_lengths_per_word: torch.Tensor,
806
+ byte_position_ids: torch.Tensor,
807
+ word_position_ids: torch.Tensor,
808
+ ):
809
+ last_word_start = cumulative_seq_lengths_per_word[-2]
810
+ last_word_end = cumulative_seq_lengths_per_word[-1]
811
+
812
+ # Populate cache with everything except last word
813
+ initial_forward_output = self.forward(
814
+ input_ids=input_ids[:, :last_word_start],
815
+ cumulative_seq_lengths_per_word=cumulative_seq_lengths_per_word[:-1],
816
+ byte_position_ids=byte_position_ids[:, :last_word_start],
817
+ word_position_ids=word_position_ids[:, :-1],
818
+ past_key_values=None,
819
+ use_cache=True,
820
+ )
821
+ return initial_forward_output, last_word_start, last_word_end
822
+
823
+ def _initialize_generation_state(
824
+ self,
825
+ input_ids: torch.Tensor,
826
+ max_new_tokens: int,
827
+ cumulative_seq_lengths_per_word: torch.Tensor,
828
+ byte_position_ids: torch.Tensor | None = None,
829
+ word_position_ids: torch.Tensor | None = None,
830
+ ):
831
+ max_total_bytes = max_new_tokens + input_ids.shape[1]
832
+ if byte_position_ids is None:
833
+ byte_position_ids = torch.arange(0, cumulative_seq_lengths_per_word[-1].item(), device=input_ids.device, dtype=torch.int32).unsqueeze(0)
834
+
835
+ if word_position_ids is None:
836
+ word_position_ids = torch.arange(0, cumulative_seq_lengths_per_word.shape[0] - 1, device=input_ids.device, dtype=torch.int32).unsqueeze(0)
837
+
838
+ initial_forward_output, last_word_start, last_word_end = self._populate_cache(
839
+ input_ids=input_ids,
840
+ cumulative_seq_lengths_per_word=cumulative_seq_lengths_per_word,
841
+ byte_position_ids=byte_position_ids,
842
+ word_position_ids=word_position_ids,
843
+ )
844
+
845
+ completion_bytes: list[int] = []
846
+ completion_logits: list[torch.Tensor] = []
847
+ # Slice input_ids and byte_position_ids to only contain the last word for the generation loop
848
+ current_input_ids = input_ids[:, last_word_start:last_word_end]
849
+ next_byte_id = last_word_end.item() # Ensure this is an int
850
+ current_byte_position_ids = byte_position_ids[:, last_word_start:last_word_end]
851
+ current_word_position_id = word_position_ids[:, -1].unsqueeze(-1)
852
+ backbone_last_hidden_state = initial_forward_output.hidden_states[:, -1:, :]
853
+ next_word_encoder_activations = None
854
+ return (
855
+ initial_forward_output,
856
+ completion_bytes,
857
+ completion_logits,
858
+ current_input_ids,
859
+ next_byte_id,
860
+ current_byte_position_ids,
861
+ current_word_position_id,
862
+ backbone_last_hidden_state,
863
+ next_word_encoder_activations,
864
+ max_total_bytes,
865
+ )
866
 
867
  def generate(
868
  self,
 
898
  completion_logits=completion_logits,
899
  )
900
 
901
+ def _fix_decoder_cache(self, predictive_word_embeddings: torch.Tensor, encoder_activions: torch.Tensor, decoder_cache: DynamicCache, word_position_id: torch.Tensor):
902
+ decoder_cache.crop(decoder_cache.get_seq_length() - encoder_activions.shape[1])
903
+ real_decoder_logits = self.decoder.forward(
904
+ predictive_word_embeddings,
905
+ encoder_activions,
906
+ byte_position_ids=None,
907
+ word_position_ids=word_position_id,
908
+ past_key_values=decoder_cache,
909
+ ).last_hidden_state
910
+
911
+ decoder_output = self.layer_norm(real_decoder_logits)
912
+ logits = self.lm_head(decoder_output)
913
+ return logits
914
+
915
  @torch.no_grad()
916
  def _generate_cached(
917
  self,
 
923
  sample_fn: Callable[[torch.Tensor], torch.Tensor] = sample_argmax,
924
  stop_sequences: Sequence[str] | None = None,
925
  ):
926
+ (
927
+ initial_forward_output,
928
+ completion_bytes, # empty list
929
+ completion_logits, # empty list
930
+ input_ids, # This is now the sliced input_ids for the last word
931
+ next_byte_id,
932
+ byte_position_ids, # This is now the sliced byte_position_ids for the last word
933
+ word_position_id,
934
+ backbone_last_hidden_state,
935
+ next_word_encoder_activations, # None for the first iteration
936
+ max_total_bytes,
937
+ ) = self._initialize_generation_state(
938
+ input_ids=input_ids,
939
+ max_new_tokens=max_new_tokens,
940
+ cumulative_seq_lengths_per_word=cumulative_seq_lengths_per_word,
941
+ byte_position_ids=byte_position_ids,
942
+ word_position_ids=word_position_ids,
 
 
943
  )
944
 
 
 
 
 
 
 
 
945
  while next_byte_id < max_total_bytes:
946
+ completion, completed_word_embedding, bytes_of_next_word, next_byte_id, next_completion_logits, next_word_encoder_activations = self._complete_word(
947
  input_ids=input_ids,
948
  byte_position_ids=byte_position_ids,
949
+ predictive_word_embeddings=backbone_last_hidden_state,
950
  word_position_id=word_position_id,
951
  encoder_cache=initial_forward_output.past_key_values.get_encoder_cache(),
952
  decoder_cache=initial_forward_output.past_key_values.get_decoder_cache(),
953
  sample_fn=sample_fn,
954
+ previous_encoder_activations=next_word_encoder_activations,
955
  )
956
  completion_logits.extend(next_completion_logits)
957
  completion_bytes.extend(completion)
 
976
  )
977
  backbone_last_hidden_state = backbone_output.hidden_states[:, -1, :].unsqueeze(1)
978
 
 
 
979
  word_position_id = word_position_id + 1
980
+ if len(bytes_of_next_word) > 1:
981
+ real_decoder_logits = self._fix_decoder_cache(
982
+ predictive_word_embeddings=backbone_last_hidden_state,
983
+ encoder_activions=next_word_encoder_activations,
984
+ decoder_cache=initial_forward_output.past_key_values.get_decoder_cache(),
985
+ word_position_id=word_position_id,
986
+ )
987
+ completion_logits.extend(real_decoder_logits)
988
+
989
+ input_ids = torch.tensor([bytes_of_next_word], dtype=input_ids.dtype, device=input_ids.device)
990
+ byte_position_ids = torch.tensor([[next_byte_id]], dtype=input_ids.dtype, device=input_ids.device)
991
 
 
992
  completion_bytes = completion_bytes[:max_new_tokens]
993
  completion_logits = torch.cat(completion_logits[:max_new_tokens], dim=0)
994
  completion_text = self.splitter.decode(completion_bytes)
 
1003
  cumulative_seq_lengths_per_word: torch.Tensor,
1004
  byte_position_ids: torch.Tensor | None = None,
1005
  word_position_ids: torch.Tensor | None = None,
1006
+ sample_fn: Callable[[torch.Tensor], torch.Tensor] = sample_argmax,
1007
  stop_sequences: Sequence[str] | None = None,
1008
  ):
1009
  if byte_position_ids is None: