chadboyda commited on
Commit
643b8b8
·
1 Parent(s): b4840c5

Add handler

Browse files
Files changed (2) hide show
  1. handler.py +23 -0
  2. requirements.txt +2 -0
handler.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pylate import models
2
+ from transformers import AutoTokenizer
3
+ import torch
4
+
5
+ class EndpointHandler:
6
+ def __init__(self, path=""):
7
+ self.tokenizer = AutoTokenizer.from_pretrained(path, use_fast=True)
8
+ self.model = models.ColBERT(model_name_or_path=path)
9
+ self.model.eval()
10
+
11
+ def __call__(self, data):
12
+ texts = data.get("inputs") or data.get("text") or data
13
+ if isinstance(texts, str):
14
+ texts = [texts]
15
+
16
+ with torch.no_grad():
17
+ emb = self.model.encode(
18
+ texts,
19
+ is_query=True, # query-style encoding
20
+ batch_size=32,
21
+ )
22
+ # TEI expects JSON-serialisable output
23
+ return emb.cpu().tolist()
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ pylate>=0.4.0
2
+ torch>=2.2