Spaces:
Runtime error
Runtime error
| import abc | |
| from typing import Any | |
| import numpy as np | |
| import numpy.typing as npt | |
| class LlamaDraftModel(abc.ABC): | |
| def __call__( | |
| self, input_ids: npt.NDArray[np.intc], /, **kwargs: Any | |
| ) -> npt.NDArray[np.intc]: | |
| raise NotImplementedError() | |
| class LlamaPromptLookupDecoding(LlamaDraftModel): | |
| """Based on https://github.com/apoorvumang/prompt-lookup-decoding""" | |
| def __init__(self, max_ngram_size: int = 2, num_pred_tokens: int = 10): | |
| self.max_ngram_size = max_ngram_size | |
| self.num_pred_tokens = num_pred_tokens | |
| def find_candidate_pred_tokens( | |
| input_ids: npt.NDArray[np.intc], | |
| max_ngram_size: int, | |
| num_pred_tokens: int, | |
| ): | |
| input_length = input_ids.shape[0] | |
| for ngram_size in range(min(max_ngram_size, input_length - 1), 0, -1): | |
| # Create sliding windows of size ngram_size | |
| windows = np.lib.stride_tricks.sliding_window_view(input_ids, (ngram_size,)) | |
| # Convert ngram to an array for comparison | |
| ngram_array = input_ids[-ngram_size:] | |
| # Find where the windows match the ngram | |
| matches = np.all(windows == ngram_array, axis=1) | |
| # Get the indices of matches | |
| match_indices = np.nonzero(matches)[0] | |
| # Iterate through match indices to find a valid continuation | |
| for idx in match_indices: | |
| start_idx = idx + ngram_size | |
| end_idx = start_idx + num_pred_tokens | |
| end_idx = min(end_idx, input_length) | |
| if start_idx < end_idx: | |
| return input_ids[start_idx:end_idx] | |
| # If no match is found, return an empty array | |
| return np.array([], dtype=np.intc) | |
| def __call__( | |
| self, input_ids: npt.NDArray[np.intc], /, **kwargs: Any | |
| ) -> npt.NDArray[np.intc]: | |
| return self.find_candidate_pred_tokens( | |
| input_ids=input_ids, | |
| max_ngram_size=self.max_ngram_size, | |
| num_pred_tokens=self.num_pred_tokens, | |
| ) | |