KoichiYasuoka commited on
Commit
a63bb49
·
1 Parent(s): f321a01

algorithm improved

Browse files
Files changed (1) hide show
  1. ud.py +11 -1
ud.py CHANGED
@@ -7,10 +7,13 @@ class BellmanFordTokenClassificationPipeline(TokenClassificationPipeline):
7
  x=self.model.config.label2id
8
  y=[k for k in x if k.find("|")<0 and not k.startswith("I-")]
9
  self.transition=numpy.full((len(x),len(x)),-numpy.inf)
 
10
  for k,v in x.items():
11
  if k.find("|")<0:
12
  for j in ["I-"+k[2:]] if k.startswith("B-") else [k]+y if k.startswith("I-") else y:
13
  self.transition[v,x[j]]=0
 
 
14
  def check_model_type(self,supported_models):
15
  pass
16
  def postprocess(self,model_outputs,**kwargs):
@@ -19,6 +22,10 @@ class BellmanFordTokenClassificationPipeline(TokenClassificationPipeline):
19
  return self.bellman_ford_token_classification(model_outputs,**kwargs)
20
  def bellman_ford_token_classification(self,model_outputs,**kwargs):
21
  m=model_outputs["logits"][0].numpy()
 
 
 
 
22
  e=numpy.exp(m-numpy.max(m,axis=-1,keepdims=True))
23
  z=e/e.sum(axis=-1,keepdims=True)
24
  for i in range(m.shape[0]-1,0,-1):
@@ -26,13 +33,16 @@ class BellmanFordTokenClassificationPipeline(TokenClassificationPipeline):
26
  k=[numpy.argmax(m[0]+self.transition[0])]
27
  for i in range(1,m.shape[0]):
28
  k.append(numpy.argmax(m[i]+self.transition[k[-1]]))
29
- w=[{"entity":self.model.config.id2label[j],"start":s,"end":e,"score":z[i,j]} for i,((s,e),j) in enumerate(zip(model_outputs["offset_mapping"][0].tolist(),k)) if s<e]
30
  if "aggregation_strategy" in kwargs and kwargs["aggregation_strategy"]!="none":
31
  for i,t in reversed(list(enumerate(w))):
32
  p=t.pop("entity")
33
  if p.startswith("I-"):
34
  w[i-1]["score"]=min(w[i-1]["score"],t["score"])
35
  w[i-1]["end"]=w.pop(i)["end"]
 
 
 
36
  elif p.startswith("B-"):
37
  t["entity_group"]=p[2:]
38
  else:
 
7
  x=self.model.config.label2id
8
  y=[k for k in x if k.find("|")<0 and not k.startswith("I-")]
9
  self.transition=numpy.full((len(x),len(x)),-numpy.inf)
10
+ self.ilabel=numpy.full(len(x),-numpy.inf)
11
  for k,v in x.items():
12
  if k.find("|")<0:
13
  for j in ["I-"+k[2:]] if k.startswith("B-") else [k]+y if k.startswith("I-") else y:
14
  self.transition[v,x[j]]=0
15
+ if k.startswith("I-"):
16
+ self.ilabel[v]=0
17
  def check_model_type(self,supported_models):
18
  pass
19
  def postprocess(self,model_outputs,**kwargs):
 
22
  return self.bellman_ford_token_classification(model_outputs,**kwargs)
23
  def bellman_ford_token_classification(self,model_outputs,**kwargs):
24
  m=model_outputs["logits"][0].numpy()
25
+ x=model_outputs["offset_mapping"][0].tolist()
26
+ for i,(s,e) in enumerate(x):
27
+ if i>0 and s<e and x[i-1][1]>s:
28
+ m[i]+=self.ilabel
29
  e=numpy.exp(m-numpy.max(m,axis=-1,keepdims=True))
30
  z=e/e.sum(axis=-1,keepdims=True)
31
  for i in range(m.shape[0]-1,0,-1):
 
33
  k=[numpy.argmax(m[0]+self.transition[0])]
34
  for i in range(1,m.shape[0]):
35
  k.append(numpy.argmax(m[i]+self.transition[k[-1]]))
36
+ w=[{"entity":self.model.config.id2label[j],"start":s,"end":e,"score":z[i,j]} for i,((s,e),j) in enumerate(zip(x,k)) if s<e]
37
  if "aggregation_strategy" in kwargs and kwargs["aggregation_strategy"]!="none":
38
  for i,t in reversed(list(enumerate(w))):
39
  p=t.pop("entity")
40
  if p.startswith("I-"):
41
  w[i-1]["score"]=min(w[i-1]["score"],t["score"])
42
  w[i-1]["end"]=w.pop(i)["end"]
43
+ elif i>0 and w[i-1]["end"]>t["start"]:
44
+ w[i-1]["score"]=min(w[i-1]["score"],t["score"])
45
+ w[i-1]["end"]=w.pop(i)["end"]
46
  elif p.startswith("B-"):
47
  t["entity_group"]=p[2:]
48
  else: