radu.mutilica
commited on
Commit
·
3e7aea0
1
Parent(s):
13e7c0f
removed mp
Browse files- handler.py +2 -18
- test_endpoint_handler.py +1 -1
handler.py
CHANGED
@@ -2,13 +2,12 @@ import logging
|
|
2 |
from typing import List, Dict
|
3 |
|
4 |
import torch.multiprocessing as mp
|
5 |
-
from pydantic import BaseModel
|
6 |
from sentence_transformers import CrossEncoder
|
7 |
|
8 |
import config
|
9 |
|
10 |
# Used for CUDA multiprocessing
|
11 |
-
mp.set_start_method('spawn', force=True)
|
12 |
|
13 |
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
14 |
console_handler = logging.StreamHandler()
|
@@ -19,21 +18,6 @@ logger.addHandler(console_handler)
|
|
19 |
logger.setLevel(logging.DEBUG)
|
20 |
|
21 |
|
22 |
-
class DocsToRerank(BaseModel):
|
23 |
-
query: str
|
24 |
-
documents: List[str]
|
25 |
-
|
26 |
-
|
27 |
-
class Payload(BaseModel):
|
28 |
-
inputs: DocsToRerank
|
29 |
-
parameters: Dict[str, str] = None
|
30 |
-
|
31 |
-
|
32 |
-
class Rank(BaseModel):
|
33 |
-
corpus_id: int
|
34 |
-
score: float
|
35 |
-
|
36 |
-
|
37 |
class EndpointHandler:
|
38 |
def __init__(self, path):
|
39 |
self.top_k = config.CROSSENCODER_TOP_K
|
@@ -55,7 +39,7 @@ class EndpointHandler:
|
|
55 |
inputs['query'],
|
56 |
inputs['documents'],
|
57 |
top_k=self.top_k,
|
58 |
-
num_workers=self.max_workers
|
59 |
)
|
60 |
logger.info(f'New ranks: {ranks}')
|
61 |
|
|
|
2 |
from typing import List, Dict
|
3 |
|
4 |
import torch.multiprocessing as mp
|
|
|
5 |
from sentence_transformers import CrossEncoder
|
6 |
|
7 |
import config
|
8 |
|
9 |
# Used for CUDA multiprocessing
|
10 |
+
# mp.set_start_method('spawn', force=True)
|
11 |
|
12 |
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
13 |
console_handler = logging.StreamHandler()
|
|
|
18 |
logger.setLevel(logging.DEBUG)
|
19 |
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
class EndpointHandler:
|
22 |
def __init__(self, path):
|
23 |
self.top_k = config.CROSSENCODER_TOP_K
|
|
|
39 |
inputs['query'],
|
40 |
inputs['documents'],
|
41 |
top_k=self.top_k,
|
42 |
+
# num_workers=self.max_workers
|
43 |
)
|
44 |
logger.info(f'New ranks: {ranks}')
|
45 |
|
test_endpoint_handler.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import pytest
|
2 |
|
3 |
-
from handler import EndpointHandler
|
4 |
|
5 |
|
6 |
@pytest.mark.parametrize('payload', [
|
|
|
1 |
import pytest
|
2 |
|
3 |
+
from handler import EndpointHandler
|
4 |
|
5 |
|
6 |
@pytest.mark.parametrize('payload', [
|