File size: 4,553 Bytes
9f86c43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from transformers import Pipeline
from typing import Dict, Any, Union
from transformers.pipelines.base import GenericTensor
from transformers.modeling_outputs import ModelOutput
import torch

class NERPredictorPipe(Pipeline):

    def _sanitize_parameters(self, **kwargs):
        return {},{},{}
    
    def __token_preprocess(self, input, tokenizer, max_length=512):
        tokenized = tokenizer(input, 
                  padding="max_length", 
                  max_length=max_length, 
                  truncation=True, 
                  return_tensors="pt"
        )
        return tokenized

    def preprocess(self, sentence: Union[str,list], max_length=512) -> Dict[str, GenericTensor]:
        input_tensors = self.__token_preprocess(
            sentence, 
            self.tokenizer, 
            max_length=max_length
        )
        input_tensors["input_mask"] = (~(input_tensors["input_ids"]>0)).long()
        for key in input_tensors:
            if input_tensors[key] is not None:
                input_tensors[key] = input_tensors[key].to(self.device)
        return input_tensors
    
    def _forward(self, input_tensors: Dict[str, GenericTensor]) -> ModelOutput:
        self.model.eval()
        with torch.no_grad():
            _,(best_path,_) = self.model(**input_tensors)
        return (input_tensors["input_ids"].tolist(),best_path)

    def __format_output(self, start, end, text, label):
        return {
            "text": text,
            "start": start,
            "end": end,
            "label": label
        }
    
    def postprocess(self, model_outputs: ModelOutput) -> Any:
        batch_slices = []
        input_ids_list = model_outputs[0]
        label_ids_list = model_outputs[1]
        for input_ids,label_ids in zip(input_ids_list,label_ids_list):
            slices = []
            labels = list(self.model.config.id2tag[str(id)] for id in label_ids)
            # get slice
            past = "O"
            start = -1
            end = -1
            for i,label in enumerate(labels):
                if label.startswith("B-"):
                    if start!=-1 and end!=-1:
                        slices.append(
                            self.__format_output(
                                start, end, 
                                ''.join(self.tokenizer.convert_ids_to_tokens(
                                input_ids[start+1:end+2])), past
                            )      
                        )
                    start = i
                    end = i
                    past = "-".join(label.split("-")[1:])
                elif label.startswith("I-") or label.startswith("M-") or label.startswith("E-"):
                    cur = "-".join(label.split("-")[1:])
                    if cur!=past:
                        # cut and skip to next entity
                        if start!=-1 and end!=-1:
                            slices.append(
                                self.__format_output(
                                    start, end, 
                                    ''.join(self.tokenizer.convert_ids_to_tokens(
                                    input_ids[start+1:end+2])), past
                                )      
                            )
                        start = i
                        past = cur
                    end = i
                elif label.startswith("S-"):
                    if start!=-1 and end!=-1:
                        slices.append(
                            self.__format_output(
                                start, end, 
                                ''.join(self.tokenizer.convert_ids_to_tokens(
                                input_ids[start+1:end+2])), past
                            )      
                        )
                    slices.append(
                        self.__format_output(
                            i, i, 
                            ''.join(self.tokenizer.convert_ids_to_tokens(
                            input_ids[i+1:i+2])), past
                        )      
                    )
                    start = -1
                    end = -1
                    past = "O"
            if start!=-1 and end!=-1:
                slices.append(
                    self.__format_output(
                        start, end, 
                        ''.join(self.tokenizer.convert_ids_to_tokens(
                        input_ids[start+1:end+2])), past
                    )      
                )
            batch_slices.append(slices)
        return batch_slices