File size: 3,767 Bytes
6343a22 85f56d8 35876e5 85f56d8 35876e5 85f56d8 0a0f890 35876e5 6343a22 35876e5 7f016dc 35876e5 85f56d8 6343a22 853e97f 6343a22 3556302 6343a22 3556302 35876e5 3556302 ae8c845 3556302 b2e4b11 3556302 b2e4b11 6343a22 35876e5 6343a22 85f56d8 3556302 85f56d8 3556302 6343a22 3556302 6343a22 3556302 6343a22 3556302 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import numpy
from transformers import TokenClassificationPipeline
class UniversalDependenciesPipeline(TokenClassificationPipeline):
def _forward(self,model_inputs):
import torch
v=model_inputs["input_ids"][0].tolist()
with torch.no_grad():
e=self.model(input_ids=torch.tensor([v[0:i]+[self.tokenizer.mask_token_id]+v[i+1:]+[j] for i,j in enumerate(v[1:-1],1)],device=self.device))
return {"logits":e.logits[:,1:-2,:],**model_inputs}
def check_model_type(self,supported_models):
pass
def postprocess(self,model_outputs,**kwargs):
if "logits" not in model_outputs:
return "".join(self.postprocess(x,**kwargs) for x in model_outputs)
e=model_outputs["logits"].numpy()
r=[1 if i==0 else -1 if j.endswith("|root") else 0 for i,j in sorted(self.model.config.id2label.items())]
e+=numpy.where(numpy.add.outer(numpy.identity(e.shape[0]),r)==0,0,-numpy.inf)
g=self.model.config.label2id["X|_|goeswith"]
r=numpy.tri(e.shape[0])
for i in range(e.shape[0]):
for j in range(i+2,e.shape[1]):
r[i,j]=r[i,j-1] if numpy.argmax(e[i,j-1])==g else 1
e[:,:,g]+=numpy.where(r==0,0,-numpy.inf)
m,p=numpy.max(e,axis=2),numpy.argmax(e,axis=2)
h=self.chu_liu_edmonds(m)
z=[i for i,j in enumerate(h) if i==j]
if len(z)>1:
k,h=z[numpy.argmax(m[z,z])],numpy.min(m)-numpy.max(m)
m[:,z]+=[[0 if j in z and (i!=j or i==k) else h for i in z] for j in range(m.shape[0])]
h=self.chu_liu_edmonds(m)
v=[(s,e) for s,e in model_outputs["offset_mapping"][0].tolist() if s<e]
q=[self.model.config.id2label[p[j,i]].split("|") for i,j in enumerate(h)]
if "aggregation_strategy" in kwargs and kwargs["aggregation_strategy"]!="none":
for i,j in reversed(list(enumerate(q[1:],1))):
if j[-1]=="goeswith" and set([t[-1] for t in q[h[i]+1:i+1]])=={"goeswith"}:
h=[b if i>b else b-1 for a,b in enumerate(h) if i!=a]
v[i-1]=(v[i-1][0],v.pop(i)[1])
q.pop(i)
elif v[i-1][1]>v[i][0]:
h=[b if i>b else b-1 for a,b in enumerate(h) if i!=a]
v[i-1]=(v[i-1][0],v.pop(i)[1])
q.pop(i)
t=model_outputs["sentence"].replace("\n"," ")
for i,(s,e) in reversed(list(enumerate(v))):
w=t[s:e]
if w.startswith(" "):
j=len(w)-len(w.lstrip())
w=w.lstrip()
v[i]=(v[i][0]+j,v[i][1])
if w.endswith(" "):
j=len(w)-len(w.rstrip())
w=w.rstrip()
v[i]=(v[i][0],v[i][1]-j)
if w.strip()=="":
h=[b if i>b else b-1 for a,b in enumerate(h) if i!=a]
v.pop(i)
q.pop(i)
u="# text = "+t+"\n"
for i,(s,e) in enumerate(v):
u+="\t".join([str(i+1),t[s:e],"_",q[i][0],"_","|".join(q[i][1:-1]),str(0 if h[i]==i else h[i]+1),q[i][-1],"_","_" if i+1<len(v) and e<v[i+1][0] else "SpaceAfter=No"])+"\n"
return u+"\n"
def chu_liu_edmonds(self,matrix):
h=numpy.argmax(matrix,axis=0)
x=[-1 if i==j else j for i,j in enumerate(h)]
for b in [lambda x,i,j:-1 if i not in x else x[i],lambda x,i,j:-1 if j<0 else x[j]]:
y=[]
while x!=y:
y=list(x)
for i,j in enumerate(x):
x[i]=b(x,i,j)
if max(x)<0:
return h
y,x=[i for i,j in enumerate(x) if j==max(x)],[i for i,j in enumerate(x) if j<max(x)]
z=matrix-numpy.max(matrix,axis=0)
m=numpy.block([[z[x,:][:,x],numpy.max(z[x,:][:,y],axis=1).reshape(len(x),1)],[numpy.max(z[y,:][:,x],axis=0),numpy.max(z[y,y])]])
k=[j if i==len(x) else x[j] if j<len(x) else y[numpy.argmax(z[y,x[i]])] for i,j in enumerate(self.chu_liu_edmonds(m))]
h=[j if i in y else k[x.index(i)] for i,j in enumerate(h)]
i=y[numpy.argmax(z[x[k[-1]],y] if k[-1]<len(x) else z[y,y])]
h[i]=x[k[-1]] if k[-1]<len(x) else i
return h
|