emanuelaboros commited on
Commit
0162bcf
·
verified ·
1 Parent(s): c2349bd

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +68 -63
handler.py CHANGED
@@ -11,69 +11,6 @@ nltk.download("averaged_perceptron_tagger_eng")
11
  # Define your model name
12
  NEL_MODEL = "nel-mgenre-multilingual"
13
 
14
- class NelPipeline:
15
- def __init__(self, model_dir: str = "."):
16
- self.model_name = model_name
17
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
18
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
19
- self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(self.device)
20
-
21
- def preprocess(self, text: str):
22
- start_token = "[START]"
23
- end_token = "[END]"
24
-
25
- if start_token in text and end_token in text:
26
- start_idx = text.index(start_token) + len(start_token)
27
- end_idx = text.index(end_token)
28
- enclosed_entity = text[start_idx:end_idx].strip()
29
- lOffset = start_idx
30
- rOffset = end_idx
31
- else:
32
- enclosed_entity = None
33
- lOffset = None
34
- rOffset = None
35
-
36
- outputs = self.model.generate(
37
- **self.tokenizer([text], return_tensors="pt").to(self.device),
38
- num_beams=1,
39
- num_return_sequences=1,
40
- max_new_tokens=30,
41
- return_dict_in_generate=True,
42
- output_scores=True,
43
- )
44
- wikipedia_prediction = self.tokenizer.batch_decode(
45
- outputs.sequences, skip_special_tokens=True
46
- )[0]
47
-
48
- transition_scores = self.model.compute_transition_scores(
49
- outputs.sequences, outputs.scores, normalize_logits=True
50
- )
51
- log_prob_sum = sum(transition_scores[0])
52
- sequence_confidence = torch.exp(log_prob_sum)
53
- percentage = sequence_confidence.cpu().numpy() * 100.0
54
-
55
- return wikipedia_prediction, enclosed_entity, lOffset, rOffset, percentage
56
-
57
- def postprocess(self, outputs):
58
- wikipedia_prediction, enclosed_entity, lOffset, rOffset, percentage = outputs
59
-
60
- qid, language = get_wikipedia_page_props(wikipedia_prediction)
61
- title, url = get_wikipedia_title(qid, language=language)
62
-
63
- results = [
64
- {
65
- "surface": enclosed_entity,
66
- "wkd_id": qid,
67
- "wkpedia_pagename": title,
68
- "wkpedia_url": url,
69
- "type": "UNK",
70
- "confidence_nel": round(percentage, 2),
71
- "lOffset": lOffset,
72
- "rOffset": rOffset,
73
- }
74
- ]
75
- return results
76
-
77
 
78
  def get_wikipedia_page_props(input_str: str):
79
  if ">>" not in input_str:
@@ -146,6 +83,74 @@ def get_wikipedia_title(qid, language="en"):
146
  return "NIL", "None"
147
 
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  class EndpointHandler:
150
  def __init__(self, path: str = None):
151
  # Initialize the NelPipeline with the specified model
 
11
  # Define your model name
12
  NEL_MODEL = "nel-mgenre-multilingual"
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  def get_wikipedia_page_props(input_str: str):
16
  if ">>" not in input_str:
 
83
  return "NIL", "None"
84
 
85
 
86
+ class NelPipeline:
87
+ def __init__(self, model_dir: str = "."):
88
+ self.model_name = NEL_MODEL
89
+ print(f"Loading {model_dir}")
90
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
91
+ self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
92
+ self.model = pipeline("generic-nel", model=NEL_MODEL_NAME,
93
+ tokenizer=nel_tokenizer,
94
+ trust_remote_code=True,
95
+ device=self.device)
96
+
97
+ def preprocess(self, text: str):
98
+ start_token = "[START]"
99
+ end_token = "[END]"
100
+
101
+ if start_token in text and end_token in text:
102
+ start_idx = text.index(start_token) + len(start_token)
103
+ end_idx = text.index(end_token)
104
+ enclosed_entity = text[start_idx:end_idx].strip()
105
+ lOffset = start_idx
106
+ rOffset = end_idx
107
+ else:
108
+ enclosed_entity = None
109
+ lOffset = None
110
+ rOffset = None
111
+
112
+ outputs = self.model.generate(
113
+ **self.tokenizer([text], return_tensors="pt").to(self.device),
114
+ num_beams=1,
115
+ num_return_sequences=1,
116
+ max_new_tokens=30,
117
+ return_dict_in_generate=True,
118
+ output_scores=True,
119
+ )
120
+ wikipedia_prediction = self.tokenizer.batch_decode(
121
+ outputs.sequences, skip_special_tokens=True
122
+ )[0]
123
+
124
+ transition_scores = self.model.compute_transition_scores(
125
+ outputs.sequences, outputs.scores, normalize_logits=True
126
+ )
127
+ log_prob_sum = sum(transition_scores[0])
128
+ sequence_confidence = torch.exp(log_prob_sum)
129
+ percentage = sequence_confidence.cpu().numpy() * 100.0
130
+
131
+ return wikipedia_prediction, enclosed_entity, lOffset, rOffset, percentage
132
+
133
+ def postprocess(self, outputs):
134
+ wikipedia_prediction, enclosed_entity, lOffset, rOffset, percentage = outputs
135
+
136
+ qid, language = get_wikipedia_page_props(wikipedia_prediction)
137
+ title, url = get_wikipedia_title(qid, language=language)
138
+
139
+ results = [
140
+ {
141
+ "surface": enclosed_entity,
142
+ "wkd_id": qid,
143
+ "wkpedia_pagename": title,
144
+ "wkpedia_url": url,
145
+ "type": "UNK",
146
+ "confidence_nel": round(percentage, 2),
147
+ "lOffset": lOffset,
148
+ "rOffset": rOffset,
149
+ }
150
+ ]
151
+ return results
152
+
153
+
154
  class EndpointHandler:
155
  def __init__(self, path: str = None):
156
  # Initialize the NelPipeline with the specified model