chadboyda commited on
Commit
5280e8d
·
1 Parent(s): 643b8b8

fix cpu error

Browse files
Files changed (1) hide show
  1. handler.py +19 -3
handler.py CHANGED
@@ -1,6 +1,7 @@
1
  from pylate import models
2
  from transformers import AutoTokenizer
3
  import torch
 
4
 
5
  class EndpointHandler:
6
  def __init__(self, path=""):
@@ -8,6 +9,21 @@ class EndpointHandler:
8
  self.model = models.ColBERT(model_name_or_path=path)
9
  self.model.eval()
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def __call__(self, data):
12
  texts = data.get("inputs") or data.get("text") or data
13
  if isinstance(texts, str):
@@ -16,8 +32,8 @@ class EndpointHandler:
16
  with torch.no_grad():
17
  emb = self.model.encode(
18
  texts,
19
- is_query=True, # query-style encoding
20
  batch_size=32,
21
  )
22
- # TEI expects JSON-serialisable output
23
- return emb.cpu().tolist()
 
1
  from pylate import models
2
  from transformers import AutoTokenizer
3
  import torch
4
+ import numpy as np
5
 
6
  class EndpointHandler:
7
  def __init__(self, path=""):
 
9
  self.model = models.ColBERT(model_name_or_path=path)
10
  self.model.eval()
11
 
12
+ def _to_list(self, emb):
13
+ """
14
+ Make the output JSON-serialisable:
15
+ – torch.Tensor ➜ emb.cpu().tolist()
16
+ – np.ndarray ➜ emb.tolist()
17
+ – list[...] ➜ recurse
18
+ """
19
+ if isinstance(emb, torch.Tensor):
20
+ return emb.cpu().tolist()
21
+ if isinstance(emb, np.ndarray):
22
+ return emb.tolist()
23
+ if isinstance(emb, list):
24
+ return [self._to_list(e) for e in emb]
25
+ return emb # already plain Python
26
+
27
  def __call__(self, data):
28
  texts = data.get("inputs") or data.get("text") or data
29
  if isinstance(texts, str):
 
32
  with torch.no_grad():
33
  emb = self.model.encode(
34
  texts,
35
+ is_query=True,
36
  batch_size=32,
37
  )
38
+
39
+ return self._to_list(emb)