File size: 2,410 Bytes
18ca2f4
 
 
 
 
da70fe9
18ca2f4
 
 
da70fe9
 
 
 
18ca2f4
 
 
da70fe9
 
 
 
18ca2f4
 
 
 
da70fe9
 
 
 
18ca2f4
 
 
 
da70fe9
 
 
 
 
 
 
 
 
 
 
 
 
 
18ca2f4
 
 
 
 
 
da70fe9
18ca2f4
 
 
 
 
 
 
 
 
 
 
 
6181390
 
 
 
 
18ca2f4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from typing import Dict, List, Any, Tuple
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
from subprocess import run


# set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# set path
query_emb_model_path = "/splade_query"
doc_emb_model_path = "/splade_doc"


class EndpointHandler():
    def __init__(self, path=""):
        self.query_model = AutoModelForMaskedLM.from_pretrained(path+query_emb_model_path).to(device)
        self.query_tokenizer = AutoTokenizer.from_pretrained(path+query_emb_model_path)
        self.doc_model = AutoModelForMaskedLM.from_pretrained(path+doc_emb_model_path).to(device)
        self.doc_tokenizer = AutoTokenizer.from_pretrained(path+doc_emb_model_path)


    def __call__(self, data: Dict[str, Any]) -> Tuple[List[List[int]], List[List[float]]]:
        """
       data args:
            inputs (:obj: `List[str]`)
            task (:obj: `str`)
      Return:
            A :obj:`list` | `dict`: will be serialized and returned
        """
        # get inputs
        texts = data.pop("inputs", data)
        task = data.pop("task", data)
        emb_model = None
        tokenizer = None
        
        if task == "query_emb":
            emb_model = self.query_model
            tokenizer = self.query_tokenizer
        elif task == "doc_emb":
            emb_model = self.doc_model
            tokenizer = self.doc_tokenizer
        else:
            raise ValueError("task must be either 'query_emb' or 'doc_emb'")
        
        tokens = tokenizer(
            texts, truncation=True, padding=True, return_tensors="pt"
        )

        if torch.cuda.is_available():
            tokens = tokens.to("cuda")

        output = emb_model(**tokens)
        logits, attention_mask = output.logits, tokens.attention_mask
        relu_log = torch.log(1 + torch.relu(logits))
        weighted_log = relu_log * attention_mask.unsqueeze(-1)
        tvecs, _ = torch.max(weighted_log, dim=1)

        # extract the vectors that are non-zero and their indices
        indices = []
        vecs = []
        for batch in tvecs:
            indices.append(batch.nonzero(as_tuple=True)[0].tolist())
            vecs.append(batch[indices[-1]].tolist())

        # release all the GPU memory cache that can be freed
        if torch.cuda.is_available():
            torch.cuda.empty_cache()


        return [indices, vecs]