|
from typing import Dict, List, Any |
|
import time |
|
|
|
import torch |
|
from transformers import BertModel, BertTokenizerFast |
|
|
|
|
|
class EndpointHandler(): |
|
def __init__(self, path_to_model: str = ".", max_cache_entries: int = 10000): |
|
|
|
|
|
self.tokenizer = BertTokenizerFast.from_pretrained(path_to_model) |
|
self.model = BertModel.from_pretrained(path_to_model) |
|
self.model = self.model.eval() |
|
self.cache = {} |
|
self.last_cache_cleanup = time.time() |
|
self.max_cache_entries = max_cache_entries |
|
|
|
def _lookup_cache(self, inputs): |
|
cached_results = {} |
|
uncached_inputs = [] |
|
|
|
for index, input_string in enumerate(inputs): |
|
if input_string in self.cache: |
|
cached_results[index] = self.cache[input_string]["pooler_output"] |
|
else: |
|
uncached_inputs.append((index, input_string)) |
|
|
|
return uncached_inputs, cached_results |
|
|
|
def _store_to_cache(self, index, result): |
|
|
|
self.cache[index] = { |
|
"pooler_output": result, |
|
"last_access": time.time() |
|
} |
|
|
|
def _cleanup_cache(self): |
|
current_time = time.time() |
|
if current_time - self.last_cache_cleanup > 60 and len(self.cache) > self.max_cache_entries: |
|
|
|
sorted_cache = sorted(self.cache.items(), key = lambda x: x[1]["last_access"]) |
|
|
|
|
|
for i in range(len(self.cache) - self.max_cache_entries): |
|
del self.cache[sorted_cache[i][0]] |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
""" |
|
This method is called whenever a request is made to the endpoint. |
|
:param data: { inputs [str]: list of strings to be encoded } |
|
:return: A :obj:`list` | `dict`: will be serialized and returned |
|
""" |
|
inputs = data['inputs'] |
|
|
|
|
|
uncached_inputs, cached_results = self._lookup_cache(inputs) |
|
|
|
output_results = {} |
|
|
|
|
|
if len(uncached_inputs) != 0: |
|
model_inputs = [input_string for _, input_string in uncached_inputs] |
|
uncached_inputs_tokenized = self.tokenizer(model_inputs, return_tensors = "pt", padding = True) |
|
|
|
with torch.no_grad(): |
|
uncached_output_tensor = self.model(**uncached_inputs_tokenized) |
|
|
|
uncached_output_list = uncached_output_tensor.pooler_output.tolist() |
|
|
|
|
|
for (index, input_string), result in zip(uncached_inputs, uncached_output_list): |
|
self._store_to_cache(input_string, result) |
|
output_results[index] = result |
|
|
|
self._cleanup_cache() |
|
|
|
output_results.update(cached_results) |
|
|
|
return [output_results[i] for i in range(len(inputs))] |
|
|