Update handler.py
Browse files- handler.py +2 -2
handler.py
CHANGED
@@ -12,7 +12,7 @@ nltk.download("averaged_perceptron_tagger_eng")
|
|
12 |
NEL_MODEL = "nel-mgenre-multilingual"
|
13 |
|
14 |
class NelPipeline:
|
15 |
-
def __init__(self,
|
16 |
self.model_name = model_name
|
17 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
18 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
@@ -149,7 +149,7 @@ def get_wikipedia_title(qid, language="en"):
|
|
149 |
class EndpointHandler:
|
150 |
def __init__(self, path: str = None):
|
151 |
# Initialize the NelPipeline with the specified model
|
152 |
-
self.pipeline = NelPipeline(
|
153 |
|
154 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
155 |
# Process incoming data
|
|
|
12 |
NEL_MODEL = "nel-mgenre-multilingual"
|
13 |
|
14 |
class NelPipeline:
|
15 |
+
def __init__(self, model_dir: str = "."):
|
16 |
self.model_name = model_name
|
17 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
18 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
149 |
class EndpointHandler:
|
150 |
def __init__(self, path: str = None):
|
151 |
# Initialize the NelPipeline with the specified model
|
152 |
+
self.pipeline = NelPipeline()
|
153 |
|
154 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
155 |
# Process incoming data
|