Commit
·
a63bb49
1
Parent(s):
f321a01
algorithm improved
Browse files
ud.py
CHANGED
@@ -7,10 +7,13 @@ class BellmanFordTokenClassificationPipeline(TokenClassificationPipeline):
|
|
7 |
x=self.model.config.label2id
|
8 |
y=[k for k in x if k.find("|")<0 and not k.startswith("I-")]
|
9 |
self.transition=numpy.full((len(x),len(x)),-numpy.inf)
|
|
|
10 |
for k,v in x.items():
|
11 |
if k.find("|")<0:
|
12 |
for j in ["I-"+k[2:]] if k.startswith("B-") else [k]+y if k.startswith("I-") else y:
|
13 |
self.transition[v,x[j]]=0
|
|
|
|
|
14 |
def check_model_type(self,supported_models):
|
15 |
pass
|
16 |
def postprocess(self,model_outputs,**kwargs):
|
@@ -19,6 +22,10 @@ class BellmanFordTokenClassificationPipeline(TokenClassificationPipeline):
|
|
19 |
return self.bellman_ford_token_classification(model_outputs,**kwargs)
|
20 |
def bellman_ford_token_classification(self,model_outputs,**kwargs):
|
21 |
m=model_outputs["logits"][0].numpy()
|
|
|
|
|
|
|
|
|
22 |
e=numpy.exp(m-numpy.max(m,axis=-1,keepdims=True))
|
23 |
z=e/e.sum(axis=-1,keepdims=True)
|
24 |
for i in range(m.shape[0]-1,0,-1):
|
@@ -26,13 +33,16 @@ class BellmanFordTokenClassificationPipeline(TokenClassificationPipeline):
|
|
26 |
k=[numpy.argmax(m[0]+self.transition[0])]
|
27 |
for i in range(1,m.shape[0]):
|
28 |
k.append(numpy.argmax(m[i]+self.transition[k[-1]]))
|
29 |
-
w=[{"entity":self.model.config.id2label[j],"start":s,"end":e,"score":z[i,j]} for i,((s,e),j) in enumerate(zip(
|
30 |
if "aggregation_strategy" in kwargs and kwargs["aggregation_strategy"]!="none":
|
31 |
for i,t in reversed(list(enumerate(w))):
|
32 |
p=t.pop("entity")
|
33 |
if p.startswith("I-"):
|
34 |
w[i-1]["score"]=min(w[i-1]["score"],t["score"])
|
35 |
w[i-1]["end"]=w.pop(i)["end"]
|
|
|
|
|
|
|
36 |
elif p.startswith("B-"):
|
37 |
t["entity_group"]=p[2:]
|
38 |
else:
|
|
|
7 |
x=self.model.config.label2id
|
8 |
y=[k for k in x if k.find("|")<0 and not k.startswith("I-")]
|
9 |
self.transition=numpy.full((len(x),len(x)),-numpy.inf)
|
10 |
+
self.ilabel=numpy.full(len(x),-numpy.inf)
|
11 |
for k,v in x.items():
|
12 |
if k.find("|")<0:
|
13 |
for j in ["I-"+k[2:]] if k.startswith("B-") else [k]+y if k.startswith("I-") else y:
|
14 |
self.transition[v,x[j]]=0
|
15 |
+
if k.startswith("I-"):
|
16 |
+
self.ilabel[v]=0
|
17 |
def check_model_type(self,supported_models):
|
18 |
pass
|
19 |
def postprocess(self,model_outputs,**kwargs):
|
|
|
22 |
return self.bellman_ford_token_classification(model_outputs,**kwargs)
|
23 |
def bellman_ford_token_classification(self,model_outputs,**kwargs):
|
24 |
m=model_outputs["logits"][0].numpy()
|
25 |
+
x=model_outputs["offset_mapping"][0].tolist()
|
26 |
+
for i,(s,e) in enumerate(x):
|
27 |
+
if i>0 and s<e and x[i-1][1]>s:
|
28 |
+
m[i]+=self.ilabel
|
29 |
e=numpy.exp(m-numpy.max(m,axis=-1,keepdims=True))
|
30 |
z=e/e.sum(axis=-1,keepdims=True)
|
31 |
for i in range(m.shape[0]-1,0,-1):
|
|
|
33 |
k=[numpy.argmax(m[0]+self.transition[0])]
|
34 |
for i in range(1,m.shape[0]):
|
35 |
k.append(numpy.argmax(m[i]+self.transition[k[-1]]))
|
36 |
+
w=[{"entity":self.model.config.id2label[j],"start":s,"end":e,"score":z[i,j]} for i,((s,e),j) in enumerate(zip(x,k)) if s<e]
|
37 |
if "aggregation_strategy" in kwargs and kwargs["aggregation_strategy"]!="none":
|
38 |
for i,t in reversed(list(enumerate(w))):
|
39 |
p=t.pop("entity")
|
40 |
if p.startswith("I-"):
|
41 |
w[i-1]["score"]=min(w[i-1]["score"],t["score"])
|
42 |
w[i-1]["end"]=w.pop(i)["end"]
|
43 |
+
elif i>0 and w[i-1]["end"]>t["start"]:
|
44 |
+
w[i-1]["score"]=min(w[i-1]["score"],t["score"])
|
45 |
+
w[i-1]["end"]=w.pop(i)["end"]
|
46 |
elif p.startswith("B-"):
|
47 |
t["entity_group"]=p[2:]
|
48 |
else:
|