KoichiYasuoka
commited on
Commit
•
b201666
1
Parent(s):
5f198ab
support transformers>=4.28
Browse files
ud.py
CHANGED
@@ -16,6 +16,8 @@ class UniversalDependenciesPipeline(TokenClassificationPipeline):
|
|
16 |
return {"logits":e.logits[:,1:-2,:],**model_inputs}
|
17 |
def postprocess(self,model_outputs,**kwargs):
|
18 |
import numpy
|
|
|
|
|
19 |
e=model_outputs["logits"].numpy()
|
20 |
r=[1 if i==0 else -1 if j.endswith("|root") else 0 for i,j in sorted(self.model.config.id2label.items())]
|
21 |
e+=numpy.where(numpy.add.outer(numpy.identity(e.shape[0]),r)==0,0,numpy.nan)
|
@@ -73,11 +75,8 @@ class MecabPreTokenizer(MecabTokenizer):
|
|
73 |
e=0
|
74 |
for c in self.tokenize(t):
|
75 |
s=t.find(c,e)
|
76 |
-
if s<0
|
77 |
-
|
78 |
-
else:
|
79 |
-
e=s+len(c)
|
80 |
-
z.append((s,e))
|
81 |
return [normalized_string[s:e] for s,e in z if e>0]
|
82 |
def pre_tokenize(self,pretok):
|
83 |
pretok.split(self.mecab_split)
|
|
|
16 |
return {"logits":e.logits[:,1:-2,:],**model_inputs}
|
17 |
def postprocess(self,model_outputs,**kwargs):
|
18 |
import numpy
|
19 |
+
if "logits" not in model_outputs:
|
20 |
+
return "".join(self.postprocess(x,**kwargs) for x in model_outputs)
|
21 |
e=model_outputs["logits"].numpy()
|
22 |
r=[1 if i==0 else -1 if j.endswith("|root") else 0 for i,j in sorted(self.model.config.id2label.items())]
|
23 |
e+=numpy.where(numpy.add.outer(numpy.identity(e.shape[0]),r)==0,0,numpy.nan)
|
|
|
75 |
e=0
|
76 |
for c in self.tokenize(t):
|
77 |
s=t.find(c,e)
|
78 |
+
e=e if s<0 else s+len(c)
|
79 |
+
z.append((0,0) if s<0 else (s,e))
|
|
|
|
|
|
|
80 |
return [normalized_string[s:e] for s,e in z if e>0]
|
81 |
def pre_tokenize(self,pretok):
|
82 |
pretok.split(self.mecab_split)
|