radu.mutilica commited on
Commit
3e7aea0
·
1 Parent(s): 13e7c0f

removed mp

Browse files
Files changed (2) hide show
  1. handler.py +2 -18
  2. test_endpoint_handler.py +1 -1
handler.py CHANGED
@@ -2,13 +2,12 @@ import logging
2
  from typing import List, Dict
3
 
4
  import torch.multiprocessing as mp
5
- from pydantic import BaseModel
6
  from sentence_transformers import CrossEncoder
7
 
8
  import config
9
 
10
  # Used for CUDA multiprocessing
11
- mp.set_start_method('spawn', force=True)
12
 
13
  formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
14
  console_handler = logging.StreamHandler()
@@ -19,21 +18,6 @@ logger.addHandler(console_handler)
19
  logger.setLevel(logging.DEBUG)
20
 
21
 
22
- class DocsToRerank(BaseModel):
23
- query: str
24
- documents: List[str]
25
-
26
-
27
- class Payload(BaseModel):
28
- inputs: DocsToRerank
29
- parameters: Dict[str, str] = None
30
-
31
-
32
- class Rank(BaseModel):
33
- corpus_id: int
34
- score: float
35
-
36
-
37
  class EndpointHandler:
38
  def __init__(self, path):
39
  self.top_k = config.CROSSENCODER_TOP_K
@@ -55,7 +39,7 @@ class EndpointHandler:
55
  inputs['query'],
56
  inputs['documents'],
57
  top_k=self.top_k,
58
- num_workers=self.max_workers
59
  )
60
  logger.info(f'New ranks: {ranks}')
61
 
 
2
  from typing import List, Dict
3
 
4
  import torch.multiprocessing as mp
 
5
  from sentence_transformers import CrossEncoder
6
 
7
  import config
8
 
9
  # Used for CUDA multiprocessing
10
+ # mp.set_start_method('spawn', force=True)
11
 
12
  formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
13
  console_handler = logging.StreamHandler()
 
18
  logger.setLevel(logging.DEBUG)
19
 
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  class EndpointHandler:
22
  def __init__(self, path):
23
  self.top_k = config.CROSSENCODER_TOP_K
 
39
  inputs['query'],
40
  inputs['documents'],
41
  top_k=self.top_k,
42
+ # num_workers=self.max_workers
43
  )
44
  logger.info(f'New ranks: {ranks}')
45
 
test_endpoint_handler.py CHANGED
@@ -1,6 +1,6 @@
1
  import pytest
2
 
3
- from handler import EndpointHandler, Payload
4
 
5
 
6
  @pytest.mark.parametrize('payload', [
 
1
  import pytest
2
 
3
+ from handler import EndpointHandler
4
 
5
 
6
  @pytest.mark.parametrize('payload', [