radu.mutilica commited on
Commit
6bd2130
·
1 Parent(s): bdf9478

add custom handler

Browse files
Files changed (3) hide show
  1. handler.py +63 -0
  2. requirements.txt +3 -0
  3. test_endpoint_handler.py +19 -0
handler.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import List, Dict
3
+
4
+ from pydantic import BaseModel
5
+ from sentence_transformers import CrossEncoder
6
+
7
+ import config
8
+
9
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
10
+ console_handler = logging.StreamHandler()
11
+ console_handler.setLevel(logging.DEBUG)
12
+ console_handler.setFormatter(formatter)
13
+ logger = logging.getLogger()
14
+ logger.addHandler(console_handler)
15
+ logger.setLevel(logging.DEBUG)
16
+
17
+
18
+ class DocsToRerank(BaseModel):
19
+ query: str
20
+ documents: List[str]
21
+
22
+
23
+ class Payload(BaseModel):
24
+ inputs: DocsToRerank
25
+ parameters: Dict[str, str] = None
26
+
27
+
28
+ class Rank(BaseModel):
29
+ corpus_id: str
30
+ score: float
31
+
32
+
33
+ class EndpointHandler:
34
+ def __init__(self, path=""):
35
+ crossencoder_model = config.CROSSENCODER_MODEL
36
+ self.top_k = config.CROSSENCODER_TOP_K
37
+ self.max_length = config.CROSSENCODER_MAX_LENGTH
38
+ self.max_workers = config.CROSSENCODER_MAX_WORKERS
39
+
40
+ self.model = CrossEncoder(
41
+ crossencoder_model,
42
+ max_length=self.max_length
43
+ )
44
+ logger.info(
45
+ f'Loaded {path}:top_k={self.top_k}:length={self.max_length}:workers={self.max_workers}')
46
+
47
+ def __call__(self, data: Payload) -> List[Rank]:
48
+ data = data.inputs
49
+
50
+ logger.info(f'Received new docs to rerank: {data}')
51
+ ranks = self.model.rank(
52
+ data.query,
53
+ data.documents,
54
+ top_k=self.top_k,
55
+ num_workers=self.max_workers
56
+ )
57
+
58
+ ranks = [Rank(
59
+ corpus_id=r['corpus_id'],
60
+ score=r['score']) for r in ranks
61
+ ]
62
+ logger.info(f'Computed ranks for each document: {ranks}')
63
+ return ranks
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ sentence-transformers~=3.0.0
2
+ pydantic
3
+ pytest
test_endpoint_handler.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+ from handler import EndpointHandler, Payload
4
+
5
+
6
+ @pytest.mark.parametrize('payload', [
7
+ {
8
+ 'inputs': {
9
+ 'query': 'I like apples',
10
+ 'documents': [
11
+ 'But I only eat the red ones',
12
+ 'Asia is a continent',
13
+ 'You can add them to smoothies'
14
+ ]
15
+ }
16
+ }
17
+ ])
18
+ def test_sanity(payload):
19
+ assert EndpointHandler(path=".")(Payload.model_validate(payload))