emanuelaboros commited on
Commit
090afab
·
verified ·
1 Parent(s): 76805b5

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +7 -52
handler.py CHANGED
@@ -96,60 +96,15 @@ class NelPipeline:
96
  device=self.device)
97
 
98
  def preprocess(self, text: str):
99
- start_token = "[START]"
100
- end_token = "[END]"
101
-
102
- if start_token in text and end_token in text:
103
- start_idx = text.index(start_token) + len(start_token)
104
- end_idx = text.index(end_token)
105
- enclosed_entity = text[start_idx:end_idx].strip()
106
- lOffset = start_idx
107
- rOffset = end_idx
108
- else:
109
- enclosed_entity = None
110
- lOffset = None
111
- rOffset = None
112
-
113
- outputs = self.model.generate(
114
- **self.tokenizer([text], return_tensors="pt").to(self.device),
115
- num_beams=1,
116
- num_return_sequences=1,
117
- max_new_tokens=30,
118
- return_dict_in_generate=True,
119
- output_scores=True,
120
- )
121
- wikipedia_prediction = self.tokenizer.batch_decode(
122
- outputs.sequences, skip_special_tokens=True
123
- )[0]
124
-
125
- transition_scores = self.model.compute_transition_scores(
126
- outputs.sequences, outputs.scores, normalize_logits=True
127
- )
128
- log_prob_sum = sum(transition_scores[0])
129
- sequence_confidence = torch.exp(log_prob_sum)
130
- percentage = sequence_confidence.cpu().numpy() * 100.0
131
-
132
- return wikipedia_prediction, enclosed_entity, lOffset, rOffset, percentage
133
 
134
  def postprocess(self, outputs):
135
- wikipedia_prediction, enclosed_entity, lOffset, rOffset, percentage = outputs
136
-
137
- qid, language = get_wikipedia_page_props(wikipedia_prediction)
138
- title, url = get_wikipedia_title(qid, language=language)
139
-
140
- results = [
141
- {
142
- "surface": enclosed_entity,
143
- "wkd_id": qid,
144
- "wkpedia_pagename": title,
145
- "wkpedia_url": url,
146
- "type": "UNK",
147
- "confidence_nel": round(percentage, 2),
148
- "lOffset": lOffset,
149
- "rOffset": rOffset,
150
- }
151
- ]
152
- return results
153
 
154
 
155
  class EndpointHandler:
 
96
  device=self.device)
97
 
98
  def preprocess(self, text: str):
99
+
100
+ linked_entity = nel_pipeline(text)
101
+
102
+ return linked_entity
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  def postprocess(self, outputs):
105
+ linked_entity = outputs
106
+
107
+ return linked_entity
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
 
110
  class EndpointHandler: