emanuelaboros commited on
Commit
8dd79b5
·
verified ·
1 Parent(s): 2b89904

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +2 -2
handler.py CHANGED
@@ -12,7 +12,7 @@ nltk.download("averaged_perceptron_tagger_eng")
12
  NEL_MODEL = "nel-mgenre-multilingual"
13
 
14
  class NelPipeline:
15
- def __init__(self, model_name: str):
16
  self.model_name = model_name
17
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
18
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
@@ -149,7 +149,7 @@ def get_wikipedia_title(qid, language="en"):
149
  class EndpointHandler:
150
  def __init__(self, path: str = None):
151
  # Initialize the NelPipeline with the specified model
152
- self.pipeline = NelPipeline(NEL_MODEL)
153
 
154
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
155
  # Process incoming data
 
12
  NEL_MODEL = "nel-mgenre-multilingual"
13
 
14
  class NelPipeline:
15
+ def __init__(self, model_dir: str = "."):
16
  self.model_name = model_name
17
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
18
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
 
149
  class EndpointHandler:
150
  def __init__(self, path: str = None):
151
  # Initialize the NelPipeline with the specified model
152
+ self.pipeline = NelPipeline()
153
 
154
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
155
  # Process incoming data