minor bugfix
Browse files
model.py
CHANGED
@@ -37,9 +37,10 @@ def sample_argmax(logits: torch.Tensor) -> torch.Tensor:
|
|
37 |
return torch.argmax(logits, dim=-1)[:, -1]
|
38 |
|
39 |
|
40 |
-
LLAMA_TEMPLATE = """<|begin_of_text|><|start_header_id|>system<|end_header_id
|
41 |
-
You are a helpful assistant. You give engaging, well-structured answers to user inquiries.<|eot_id|><|start_header_id|>user<|end_header_id
|
42 |
-
{input}<|eot_id|><|start_header_id|>assistant<|end_header_id
|
|
|
43 |
|
44 |
|
45 |
class HATCache(Cache):
|
@@ -488,6 +489,7 @@ class HATEncoderConnector(nn.Module):
|
|
488 |
device=self.latent_query.device,
|
489 |
dtype=torch.int32,
|
490 |
)
|
|
|
491 |
word_embeddings = self.cross_attention_encoder_connector.forward(
|
492 |
q_activations=latent_query_repeated,
|
493 |
kv_activations=hidden_states,
|
@@ -607,7 +609,7 @@ class HATForCausalLM(PreTrainedModel):
|
|
607 |
backbone_past_key_values = past_key_values.get_backbone_cache() if past_key_values is not None else None
|
608 |
decoder_past_key_values = past_key_values.get_decoder_cache() if past_key_values is not None else None
|
609 |
|
610 |
-
encoder_output: BaseModelOutputWithPast = self.encoder(
|
611 |
input_ids=input_ids,
|
612 |
cumulative_seq_lengths_per_word=cumulative_seq_lengths_per_word,
|
613 |
byte_position_ids=byte_position_ids,
|
@@ -617,13 +619,13 @@ class HATForCausalLM(PreTrainedModel):
|
|
617 |
)
|
618 |
byte_level_activations = encoder_output.hidden_states
|
619 |
|
620 |
-
encoder_connector_output = self.encoder_connector(
|
621 |
byte_level_activations,
|
622 |
cumulative_seq_lengths_per_word,
|
623 |
word_position_ids,
|
624 |
byte_position_ids,
|
625 |
)
|
626 |
-
backbone_output: CausalLMOutputWithPast = self.backbone(
|
627 |
hidden_states=encoder_connector_output,
|
628 |
position_ids=word_position_ids,
|
629 |
past_key_values=backbone_past_key_values,
|
@@ -658,7 +660,7 @@ class HATForCausalLM(PreTrainedModel):
|
|
658 |
def _append_byte(self, words: list[list[int]], token: int) -> list[list[int]]:
|
659 |
extended_last_word = words.pop() + [token]
|
660 |
try:
|
661 |
-
text = self.splitter.decode(extended_last_word, errors=
|
662 |
list_of_bytes = self.splitter.encode(text)
|
663 |
words.extend([list(word_in_bytes) for word_in_bytes in list_of_bytes])
|
664 |
except UnicodeDecodeError:
|
@@ -667,20 +669,70 @@ class HATForCausalLM(PreTrainedModel):
|
|
667 |
words.append(extended_last_word)
|
668 |
return words
|
669 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
670 |
def _complete_word(
|
671 |
self,
|
672 |
input_ids: torch.Tensor,
|
673 |
byte_position_ids: torch.Tensor,
|
674 |
-
|
675 |
word_position_id: torch.Tensor,
|
676 |
encoder_cache: DynamicCache,
|
677 |
decoder_cache: DynamicCache,
|
678 |
sample_fn: Callable[[torch.Tensor], torch.Tensor] = sample_argmax,
|
|
|
679 |
):
|
680 |
"""Generate byte tokens until we hit the first byte of a new word."""
|
681 |
-
words = [input_ids.squeeze(0).tolist()]
|
682 |
-
byte_encoder_activations = []
|
683 |
-
completion_logits = []
|
|
|
|
|
|
|
|
|
|
|
|
|
684 |
|
685 |
while True:
|
686 |
encoder_output = self.encoder.forward(
|
@@ -692,7 +744,7 @@ class HATForCausalLM(PreTrainedModel):
|
|
692 |
)
|
693 |
byte_encoder_activations.append(encoder_output.hidden_states)
|
694 |
decoder_output = self.decoder.forward(
|
695 |
-
|
696 |
encoder_output.hidden_states,
|
697 |
byte_position_ids=None,
|
698 |
word_position_ids=word_position_id,
|
@@ -705,22 +757,112 @@ class HATForCausalLM(PreTrainedModel):
|
|
705 |
next_byte = int(sample_fn(logits).item())
|
706 |
words = self._append_byte(words, next_byte)
|
707 |
if len(words) > 1 or next_byte == self.eos_token_id:
|
|
|
|
|
|
|
|
|
|
|
|
|
708 |
break
|
709 |
input_ids = torch.tensor([[next_byte]], dtype=input_ids.dtype, device=input_ids.device)
|
710 |
|
711 |
-
byte_encoder_activations = torch.cat(byte_encoder_activations, dim=1)
|
712 |
num_kv = encoder_cache.get_seq_length()
|
713 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
714 |
completed_word_embedding = self.encoder_connector.forward(
|
715 |
-
|
716 |
-
cumulative_seq_lengths_per_word=torch.tensor([0,
|
717 |
word_position_ids=word_position_id,
|
718 |
byte_position_ids=byte_position_ids,
|
719 |
)
|
720 |
|
721 |
-
|
722 |
-
|
723 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
724 |
|
725 |
def generate(
|
726 |
self,
|
@@ -756,6 +898,20 @@ class HATForCausalLM(PreTrainedModel):
|
|
756 |
completion_logits=completion_logits,
|
757 |
)
|
758 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
759 |
@torch.no_grad()
|
760 |
def _generate_cached(
|
761 |
self,
|
@@ -767,43 +923,35 @@ class HATForCausalLM(PreTrainedModel):
|
|
767 |
sample_fn: Callable[[torch.Tensor], torch.Tensor] = sample_argmax,
|
768 |
stop_sequences: Sequence[str] | None = None,
|
769 |
):
|
770 |
-
|
771 |
-
|
772 |
-
|
773 |
-
|
774 |
-
|
775 |
-
|
776 |
-
|
777 |
-
|
778 |
-
|
779 |
-
|
780 |
-
|
781 |
-
|
782 |
-
|
783 |
-
|
784 |
-
cumulative_seq_lengths_per_word=cumulative_seq_lengths_per_word
|
785 |
-
byte_position_ids=byte_position_ids
|
786 |
-
word_position_ids=word_position_ids
|
787 |
-
past_key_values=None,
|
788 |
-
use_cache=True,
|
789 |
)
|
790 |
|
791 |
-
completion_bytes = []
|
792 |
-
completion_logits = []
|
793 |
-
input_ids = input_ids[:, last_word_start:last_word_end]
|
794 |
-
next_byte_id = last_word_end
|
795 |
-
byte_position_ids = byte_position_ids[:, last_word_start:last_word_end]
|
796 |
-
word_position_id = word_position_ids[:, -1].unsqueeze(-1)
|
797 |
-
backbone_last_hidden_state = initial_forward_output.hidden_states[:, -1:, :]
|
798 |
while next_byte_id < max_total_bytes:
|
799 |
-
completion, completed_word_embedding,
|
800 |
input_ids=input_ids,
|
801 |
byte_position_ids=byte_position_ids,
|
802 |
-
|
803 |
word_position_id=word_position_id,
|
804 |
encoder_cache=initial_forward_output.past_key_values.get_encoder_cache(),
|
805 |
decoder_cache=initial_forward_output.past_key_values.get_decoder_cache(),
|
806 |
sample_fn=sample_fn,
|
|
|
807 |
)
|
808 |
completion_logits.extend(next_completion_logits)
|
809 |
completion_bytes.extend(completion)
|
@@ -828,11 +976,19 @@ class HATForCausalLM(PreTrainedModel):
|
|
828 |
)
|
829 |
backbone_last_hidden_state = backbone_output.hidden_states[:, -1, :].unsqueeze(1)
|
830 |
|
831 |
-
input_ids = torch.tensor([first_byte_of_next_word], dtype=input_ids.dtype, device=input_ids.device)
|
832 |
-
byte_position_ids = torch.tensor([[next_byte_id]], dtype=input_ids.dtype, device=input_ids.device)
|
833 |
word_position_id = word_position_id + 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
834 |
|
835 |
-
completion_bytes.extend(first_byte_of_next_word)
|
836 |
completion_bytes = completion_bytes[:max_new_tokens]
|
837 |
completion_logits = torch.cat(completion_logits[:max_new_tokens], dim=0)
|
838 |
completion_text = self.splitter.decode(completion_bytes)
|
@@ -847,7 +1003,7 @@ class HATForCausalLM(PreTrainedModel):
|
|
847 |
cumulative_seq_lengths_per_word: torch.Tensor,
|
848 |
byte_position_ids: torch.Tensor | None = None,
|
849 |
word_position_ids: torch.Tensor | None = None,
|
850 |
-
sample_fn=sample_argmax,
|
851 |
stop_sequences: Sequence[str] | None = None,
|
852 |
):
|
853 |
if byte_position_ids is None:
|
|
|
37 |
return torch.argmax(logits, dim=-1)[:, -1]
|
38 |
|
39 |
|
40 |
+
LLAMA_TEMPLATE = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n
|
41 |
+
You are a helpful assistant. You give engaging, well-structured answers to user inquiries.<|eot_id|><|start_header_id|>user<|end_header_id|>\n
|
42 |
+
{input}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"""
|
43 |
+
|
44 |
|
45 |
|
46 |
class HATCache(Cache):
|
|
|
489 |
device=self.latent_query.device,
|
490 |
dtype=torch.int32,
|
491 |
)
|
492 |
+
|
493 |
word_embeddings = self.cross_attention_encoder_connector.forward(
|
494 |
q_activations=latent_query_repeated,
|
495 |
kv_activations=hidden_states,
|
|
|
609 |
backbone_past_key_values = past_key_values.get_backbone_cache() if past_key_values is not None else None
|
610 |
decoder_past_key_values = past_key_values.get_decoder_cache() if past_key_values is not None else None
|
611 |
|
612 |
+
encoder_output: BaseModelOutputWithPast = self.encoder.forward(
|
613 |
input_ids=input_ids,
|
614 |
cumulative_seq_lengths_per_word=cumulative_seq_lengths_per_word,
|
615 |
byte_position_ids=byte_position_ids,
|
|
|
619 |
)
|
620 |
byte_level_activations = encoder_output.hidden_states
|
621 |
|
622 |
+
encoder_connector_output = self.encoder_connector.forward(
|
623 |
byte_level_activations,
|
624 |
cumulative_seq_lengths_per_word,
|
625 |
word_position_ids,
|
626 |
byte_position_ids,
|
627 |
)
|
628 |
+
backbone_output: CausalLMOutputWithPast = self.backbone.forward(
|
629 |
hidden_states=encoder_connector_output,
|
630 |
position_ids=word_position_ids,
|
631 |
past_key_values=backbone_past_key_values,
|
|
|
660 |
def _append_byte(self, words: list[list[int]], token: int) -> list[list[int]]:
|
661 |
extended_last_word = words.pop() + [token]
|
662 |
try:
|
663 |
+
text = self.splitter.decode(extended_last_word, errors="strict", skip_special_tokens=False)
|
664 |
list_of_bytes = self.splitter.encode(text)
|
665 |
words.extend([list(word_in_bytes) for word_in_bytes in list_of_bytes])
|
666 |
except UnicodeDecodeError:
|
|
|
669 |
words.append(extended_last_word)
|
670 |
return words
|
671 |
|
672 |
+
def _split_encoder_activations(
|
673 |
+
self,
|
674 |
+
byte_encoder_activations: torch.Tensor,
|
675 |
+
words: list[list[int]],
|
676 |
+
previous_encoder_activations: torch.Tensor | None = None,
|
677 |
+
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
678 |
+
"""Split encoder activations between first word and next word.
|
679 |
+
|
680 |
+
Args:
|
681 |
+
byte_encoder_activations: Tensor of shape [batch_size, seq_len, hidden_size] containing all encoder activations which were computed in the current iteration
|
682 |
+
words: List of word byte sequences which were completed in previous iteration and current iteration
|
683 |
+
previous_encoder_activations: Optional tensor of shape [batch_size, prev_seq_len, hidden_size] containing precomputed activations from the previous iteration
|
684 |
+
|
685 |
+
Returns:
|
686 |
+
tuple containing:
|
687 |
+
- first_word_encoder_activations: Tensor of shape [batch_size, first_word_len, hidden_size]
|
688 |
+
- next_word_encoder_activations: Tensor of shape [batch_size, remaining_len, hidden_size]
|
689 |
+
"""
|
690 |
+
|
691 |
+
assert sum(len(word) for word in words) - 1 == byte_encoder_activations.shape[1] + (previous_encoder_activations.shape[1] if previous_encoder_activations is not None else 0), "Length of (words - 1) must match the sum of byte_encoder_activations and previous_encoder_activations dimensions"
|
692 |
+
|
693 |
+
next_word_encoder_activations = None
|
694 |
+
if previous_encoder_activations is not None:
|
695 |
+
# We have already precomputed first word's encoder activations partially in the previous iteration
|
696 |
+
new_bytes_of_first_words = len(words[0]) - previous_encoder_activations.shape[1]
|
697 |
+
# Concatenate the precomputed activations with the new activations that still belong to the first word
|
698 |
+
first_word_encoder_activations = torch.cat([previous_encoder_activations, byte_encoder_activations[:, :new_bytes_of_first_words]], dim=1)
|
699 |
+
if len(words[1]) > 1:
|
700 |
+
# The remaining activations that belong to the next word
|
701 |
+
next_word_encoder_activations = byte_encoder_activations[:, new_bytes_of_first_words:]
|
702 |
+
else:
|
703 |
+
next_word_encoder_activations = None
|
704 |
+
else:
|
705 |
+
# We have not precomputed any activations for the first word previously
|
706 |
+
first_word_encoder_activations = byte_encoder_activations[:, : len(words[0])]
|
707 |
+
|
708 |
+
if len(words[1]) > 1:
|
709 |
+
next_word_encoder_activations = byte_encoder_activations[:, len(words[0]) :]
|
710 |
+
else:
|
711 |
+
next_word_encoder_activations = None
|
712 |
+
|
713 |
+
return first_word_encoder_activations, next_word_encoder_activations
|
714 |
+
|
715 |
def _complete_word(
|
716 |
self,
|
717 |
input_ids: torch.Tensor,
|
718 |
byte_position_ids: torch.Tensor,
|
719 |
+
predictive_word_embeddings: torch.Tensor,
|
720 |
word_position_id: torch.Tensor,
|
721 |
encoder_cache: DynamicCache,
|
722 |
decoder_cache: DynamicCache,
|
723 |
sample_fn: Callable[[torch.Tensor], torch.Tensor] = sample_argmax,
|
724 |
+
previous_encoder_activations: torch.Tensor | None = None,
|
725 |
):
|
726 |
"""Generate byte tokens until we hit the first byte of a new word."""
|
727 |
+
words: list[list[int]] = [input_ids.squeeze(0).tolist()]
|
728 |
+
byte_encoder_activations: list[torch.Tensor] = []
|
729 |
+
completion_logits: list[torch.Tensor] = []
|
730 |
+
|
731 |
+
if previous_encoder_activations is not None:
|
732 |
+
# we need to pass all inputs in order to get the correct encoding/decoding by the splitter
|
733 |
+
# but only the last byte is used for the generation
|
734 |
+
# since the cache is already populated with the first word's activations
|
735 |
+
input_ids = input_ids[:, -1:]
|
736 |
|
737 |
while True:
|
738 |
encoder_output = self.encoder.forward(
|
|
|
744 |
)
|
745 |
byte_encoder_activations.append(encoder_output.hidden_states)
|
746 |
decoder_output = self.decoder.forward(
|
747 |
+
predictive_word_embeddings,
|
748 |
encoder_output.hidden_states,
|
749 |
byte_position_ids=None,
|
750 |
word_position_ids=word_position_id,
|
|
|
757 |
next_byte = int(sample_fn(logits).item())
|
758 |
words = self._append_byte(words, next_byte)
|
759 |
if len(words) > 1 or next_byte == self.eos_token_id:
|
760 |
+
byte_encoder_activations = torch.cat(byte_encoder_activations, dim=1)
|
761 |
+
first_word_encoder_activations, next_word_encoder_activations = self._split_encoder_activations(
|
762 |
+
byte_encoder_activations,
|
763 |
+
words,
|
764 |
+
previous_encoder_activations,
|
765 |
+
)
|
766 |
break
|
767 |
input_ids = torch.tensor([[next_byte]], dtype=input_ids.dtype, device=input_ids.device)
|
768 |
|
|
|
769 |
num_kv = encoder_cache.get_seq_length()
|
770 |
+
|
771 |
+
completion = sum(words, [])[-len(completion_logits) :]
|
772 |
+
if next_word_encoder_activations is not None:
|
773 |
+
start_idx = num_kv - first_word_encoder_activations.shape[1] - next_word_encoder_activations.shape[1]
|
774 |
+
end_idx = num_kv - next_word_encoder_activations.shape[1]
|
775 |
+
# We do not want to return the logits for the second word went into the mulitbyte starting character case
|
776 |
+
# When that happens we remove the logits and post-hoc fix the decoder cache and compute new logits
|
777 |
+
# This is breaking causality but we want to imitate uncached generation/training behavior
|
778 |
+
completion_logits = completion_logits[:-next_word_encoder_activations.shape[1]]
|
779 |
+
else:
|
780 |
+
start_idx = num_kv - first_word_encoder_activations.shape[1]
|
781 |
+
end_idx = num_kv
|
782 |
+
|
783 |
+
byte_position_ids = torch.arange(start_idx, end_idx, device=input_ids.device, dtype=torch.long).unsqueeze(0)
|
784 |
completed_word_embedding = self.encoder_connector.forward(
|
785 |
+
first_word_encoder_activations,
|
786 |
+
cumulative_seq_lengths_per_word=torch.tensor([0, first_word_encoder_activations.size(1)], dtype=torch.int32, device=input_ids.device),
|
787 |
word_position_ids=word_position_id,
|
788 |
byte_position_ids=byte_position_ids,
|
789 |
)
|
790 |
|
791 |
+
bytes_of_next_word = words[1]
|
792 |
+
|
793 |
+
return (
|
794 |
+
completion,
|
795 |
+
completed_word_embedding,
|
796 |
+
bytes_of_next_word,
|
797 |
+
byte_position_ids[:, -1].item() + 1,
|
798 |
+
completion_logits,
|
799 |
+
next_word_encoder_activations,
|
800 |
+
)
|
801 |
+
|
802 |
+
def _populate_cache(
|
803 |
+
self,
|
804 |
+
input_ids: torch.Tensor,
|
805 |
+
cumulative_seq_lengths_per_word: torch.Tensor,
|
806 |
+
byte_position_ids: torch.Tensor,
|
807 |
+
word_position_ids: torch.Tensor,
|
808 |
+
):
|
809 |
+
last_word_start = cumulative_seq_lengths_per_word[-2]
|
810 |
+
last_word_end = cumulative_seq_lengths_per_word[-1]
|
811 |
+
|
812 |
+
# Populate cache with everything except last word
|
813 |
+
initial_forward_output = self.forward(
|
814 |
+
input_ids=input_ids[:, :last_word_start],
|
815 |
+
cumulative_seq_lengths_per_word=cumulative_seq_lengths_per_word[:-1],
|
816 |
+
byte_position_ids=byte_position_ids[:, :last_word_start],
|
817 |
+
word_position_ids=word_position_ids[:, :-1],
|
818 |
+
past_key_values=None,
|
819 |
+
use_cache=True,
|
820 |
+
)
|
821 |
+
return initial_forward_output, last_word_start, last_word_end
|
822 |
+
|
823 |
+
def _initialize_generation_state(
|
824 |
+
self,
|
825 |
+
input_ids: torch.Tensor,
|
826 |
+
max_new_tokens: int,
|
827 |
+
cumulative_seq_lengths_per_word: torch.Tensor,
|
828 |
+
byte_position_ids: torch.Tensor | None = None,
|
829 |
+
word_position_ids: torch.Tensor | None = None,
|
830 |
+
):
|
831 |
+
max_total_bytes = max_new_tokens + input_ids.shape[1]
|
832 |
+
if byte_position_ids is None:
|
833 |
+
byte_position_ids = torch.arange(0, cumulative_seq_lengths_per_word[-1].item(), device=input_ids.device, dtype=torch.int32).unsqueeze(0)
|
834 |
+
|
835 |
+
if word_position_ids is None:
|
836 |
+
word_position_ids = torch.arange(0, cumulative_seq_lengths_per_word.shape[0] - 1, device=input_ids.device, dtype=torch.int32).unsqueeze(0)
|
837 |
+
|
838 |
+
initial_forward_output, last_word_start, last_word_end = self._populate_cache(
|
839 |
+
input_ids=input_ids,
|
840 |
+
cumulative_seq_lengths_per_word=cumulative_seq_lengths_per_word,
|
841 |
+
byte_position_ids=byte_position_ids,
|
842 |
+
word_position_ids=word_position_ids,
|
843 |
+
)
|
844 |
+
|
845 |
+
completion_bytes: list[int] = []
|
846 |
+
completion_logits: list[torch.Tensor] = []
|
847 |
+
# Slice input_ids and byte_position_ids to only contain the last word for the generation loop
|
848 |
+
current_input_ids = input_ids[:, last_word_start:last_word_end]
|
849 |
+
next_byte_id = last_word_end.item() # Ensure this is an int
|
850 |
+
current_byte_position_ids = byte_position_ids[:, last_word_start:last_word_end]
|
851 |
+
current_word_position_id = word_position_ids[:, -1].unsqueeze(-1)
|
852 |
+
backbone_last_hidden_state = initial_forward_output.hidden_states[:, -1:, :]
|
853 |
+
next_word_encoder_activations = None
|
854 |
+
return (
|
855 |
+
initial_forward_output,
|
856 |
+
completion_bytes,
|
857 |
+
completion_logits,
|
858 |
+
current_input_ids,
|
859 |
+
next_byte_id,
|
860 |
+
current_byte_position_ids,
|
861 |
+
current_word_position_id,
|
862 |
+
backbone_last_hidden_state,
|
863 |
+
next_word_encoder_activations,
|
864 |
+
max_total_bytes,
|
865 |
+
)
|
866 |
|
867 |
def generate(
|
868 |
self,
|
|
|
898 |
completion_logits=completion_logits,
|
899 |
)
|
900 |
|
901 |
+
def _fix_decoder_cache(self, predictive_word_embeddings: torch.Tensor, encoder_activions: torch.Tensor, decoder_cache: DynamicCache, word_position_id: torch.Tensor):
|
902 |
+
decoder_cache.crop(decoder_cache.get_seq_length() - encoder_activions.shape[1])
|
903 |
+
real_decoder_logits = self.decoder.forward(
|
904 |
+
predictive_word_embeddings,
|
905 |
+
encoder_activions,
|
906 |
+
byte_position_ids=None,
|
907 |
+
word_position_ids=word_position_id,
|
908 |
+
past_key_values=decoder_cache,
|
909 |
+
).last_hidden_state
|
910 |
+
|
911 |
+
decoder_output = self.layer_norm(real_decoder_logits)
|
912 |
+
logits = self.lm_head(decoder_output)
|
913 |
+
return logits
|
914 |
+
|
915 |
@torch.no_grad()
|
916 |
def _generate_cached(
|
917 |
self,
|
|
|
923 |
sample_fn: Callable[[torch.Tensor], torch.Tensor] = sample_argmax,
|
924 |
stop_sequences: Sequence[str] | None = None,
|
925 |
):
|
926 |
+
(
|
927 |
+
initial_forward_output,
|
928 |
+
completion_bytes, # empty list
|
929 |
+
completion_logits, # empty list
|
930 |
+
input_ids, # This is now the sliced input_ids for the last word
|
931 |
+
next_byte_id,
|
932 |
+
byte_position_ids, # This is now the sliced byte_position_ids for the last word
|
933 |
+
word_position_id,
|
934 |
+
backbone_last_hidden_state,
|
935 |
+
next_word_encoder_activations, # None for the first iteration
|
936 |
+
max_total_bytes,
|
937 |
+
) = self._initialize_generation_state(
|
938 |
+
input_ids=input_ids,
|
939 |
+
max_new_tokens=max_new_tokens,
|
940 |
+
cumulative_seq_lengths_per_word=cumulative_seq_lengths_per_word,
|
941 |
+
byte_position_ids=byte_position_ids,
|
942 |
+
word_position_ids=word_position_ids,
|
|
|
|
|
943 |
)
|
944 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
945 |
while next_byte_id < max_total_bytes:
|
946 |
+
completion, completed_word_embedding, bytes_of_next_word, next_byte_id, next_completion_logits, next_word_encoder_activations = self._complete_word(
|
947 |
input_ids=input_ids,
|
948 |
byte_position_ids=byte_position_ids,
|
949 |
+
predictive_word_embeddings=backbone_last_hidden_state,
|
950 |
word_position_id=word_position_id,
|
951 |
encoder_cache=initial_forward_output.past_key_values.get_encoder_cache(),
|
952 |
decoder_cache=initial_forward_output.past_key_values.get_decoder_cache(),
|
953 |
sample_fn=sample_fn,
|
954 |
+
previous_encoder_activations=next_word_encoder_activations,
|
955 |
)
|
956 |
completion_logits.extend(next_completion_logits)
|
957 |
completion_bytes.extend(completion)
|
|
|
976 |
)
|
977 |
backbone_last_hidden_state = backbone_output.hidden_states[:, -1, :].unsqueeze(1)
|
978 |
|
|
|
|
|
979 |
word_position_id = word_position_id + 1
|
980 |
+
if len(bytes_of_next_word) > 1:
|
981 |
+
real_decoder_logits = self._fix_decoder_cache(
|
982 |
+
predictive_word_embeddings=backbone_last_hidden_state,
|
983 |
+
encoder_activions=next_word_encoder_activations,
|
984 |
+
decoder_cache=initial_forward_output.past_key_values.get_decoder_cache(),
|
985 |
+
word_position_id=word_position_id,
|
986 |
+
)
|
987 |
+
completion_logits.extend(real_decoder_logits)
|
988 |
+
|
989 |
+
input_ids = torch.tensor([bytes_of_next_word], dtype=input_ids.dtype, device=input_ids.device)
|
990 |
+
byte_position_ids = torch.tensor([[next_byte_id]], dtype=input_ids.dtype, device=input_ids.device)
|
991 |
|
|
|
992 |
completion_bytes = completion_bytes[:max_new_tokens]
|
993 |
completion_logits = torch.cat(completion_logits[:max_new_tokens], dim=0)
|
994 |
completion_text = self.splitter.decode(completion_bytes)
|
|
|
1003 |
cumulative_seq_lengths_per_word: torch.Tensor,
|
1004 |
byte_position_ids: torch.Tensor | None = None,
|
1005 |
word_position_ids: torch.Tensor | None = None,
|
1006 |
+
sample_fn: Callable[[torch.Tensor], torch.Tensor] = sample_argmax,
|
1007 |
stop_sequences: Sequence[str] | None = None,
|
1008 |
):
|
1009 |
if byte_position_ids is None:
|