from typing import Dict, List, Any, Tuple from transformers import AutoTokenizer, AutoModelForMaskedLM import torch from subprocess import run # set device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # set path query_emb_model_path = "/splade_query" doc_emb_model_path = "/splade_doc" class EndpointHandler(): def __init__(self, path=""): self.query_model = AutoModelForMaskedLM.from_pretrained(path+query_emb_model_path).to(device) self.query_tokenizer = AutoTokenizer.from_pretrained(path+query_emb_model_path) self.doc_model = AutoModelForMaskedLM.from_pretrained(path+doc_emb_model_path).to(device) self.doc_tokenizer = AutoTokenizer.from_pretrained(path+doc_emb_model_path) def __call__(self, data: Dict[str, Any]) -> Tuple[List[List[int]], List[List[float]]]: """ data args: inputs (:obj: `List[str]`) task (:obj: `str`) Return: A :obj:`list` | `dict`: will be serialized and returned """ # get inputs texts = data.pop("inputs", data) task = data.pop("task", data) emb_model = None tokenizer = None if task == "query_emb": emb_model = self.query_model tokenizer = self.query_tokenizer elif task == "doc_emb": emb_model = self.doc_model tokenizer = self.doc_tokenizer else: raise ValueError("task must be either 'query_emb' or 'doc_emb'") tokens = tokenizer( texts, truncation=True, padding=True, return_tensors="pt" ) if torch.cuda.is_available(): tokens = tokens.to("cuda") output = emb_model(**tokens) logits, attention_mask = output.logits, tokens.attention_mask relu_log = torch.log(1 + torch.relu(logits)) weighted_log = relu_log * attention_mask.unsqueeze(-1) tvecs, _ = torch.max(weighted_log, dim=1) # extract the vectors that are non-zero and their indices indices = [] vecs = [] for batch in tvecs: indices.append(batch.nonzero(as_tuple=True)[0].tolist()) vecs.append(batch[indices[-1]].tolist()) # release all the GPU memory cache that can be freed if torch.cuda.is_available(): torch.cuda.empty_cache() return [indices, vecs]