File size: 5,837 Bytes
c490f2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import torch
class DecodeAndEvaluate:
def __init__(self, tokenizer):
self.tokenizer = tokenizer
self.sentiment2id = {'negative': 3, 'neutral': 4, 'positive': 5}
self.id2sentiment = {v:k for k, v in self.sentiment2id.items()}
def get_span_from_tags(self, tags, token_range, tok_type): ## tok_type 1=aspect, 2 for opinions
sel_spans = []
end_ind = -1
has_prev = False
start_ind = -1
for i in range(len(token_range)):
l,r = token_range[i]
if tags[l][l]!= tok_type:
if has_prev:
sel_spans.append([start_ind, end_ind])
start_ind = -1
end_ind= -1
has_prev = False
if tags[l][l] == tok_type and not has_prev:
start_ind = l
end_ind = r
has_prev = True
if tags[l][l] == tok_type and has_prev:
end_ind = r
has_prev = True
if has_prev:
sel_spans.append([start_ind, end_ind])
return sel_spans
## Corner cases where one sentiment span expresses over multiple sentiments
# and one aspect has multiple sentiments expressed on it
def find_triplet(self, tags, aspect_spans, opinion_spans):
triplets = []
for al, ar in aspect_spans:
for pl, pr in opinion_spans:
## get the overlapping indices
# we select such that tag[aspect_l :aspect_r+1, opi_l: opi_r]
# if opi>asp then lower triangular matrix starts being selected that is not annotated
# print(al, ar, pl, pr)
if al<=pl:
sent_tags = tags[al:ar+1, pl:pr+1]
flat_tags = sent_tags.reshape([-1])
flat_tags = torch.tensor([v.item() for v in flat_tags if v.item()>=0])
val = torch.mode(flat_tags).values.item()
if val > 0:
triplets.append([al, ar, pl, pr, val])
else: # In this case the aspect becomes column and sentiment becomes the row
# print(al, pl)
sent_tags = tags[pl:pr+1, al: ar+1]
# print(sent_tags)
flat_tags = sent_tags.reshape([-1])
flat_tags = torch.tensor([v.item() for v in flat_tags if v.item()>=0])
val = torch.mode(flat_tags).values.item()
if val>0:
triplets.append([al, ar, pl, pr, val])
return triplets
def decode_triplets(self, triplets, sent_tokens):
triplet_list = []
for alt, art, olt, ort, pol in triplets:
asp_toks = sent_tokens[alt:art+1]
op_toks = sent_tokens[olt: ort+1]
asp_string = self.tokenizer.decode(asp_toks)
op_string = self.tokenizer.decode(op_toks)
if pol in [3, 4, 5]:
sentiment_pol = self.id2sentiment[pol] #.get(pol, "inconsistent")
triplet_list.append([asp_string, op_string, sentiment_pol])
return triplet_list
def decode_predict_one(self, tags, token_range, sent_tokens):
aspect_spans = self.get_span_from_tags(tags, token_range, 1)
opinion_spans = self.get_span_from_tags(tags, token_range, 2)
triplets = self.find_triplet(tags, aspect_spans, opinion_spans)
return self.decode_triplets(triplets, sent_tokens)
def decode_pred_batch(self, tags_batch, token_range_batch, sent_tokens):
decoded_batch_results = []
for i in range(tags_batch.shape[0]):
res = self.decode_predict_one(tags_batch[i], token_range_batch[i], sent_tokens[i])
decoded_batch_results.append(res)
return decoded_batch_results
def decode_predict_string_one(self, text_sent, model, max_len=64):
token_range = []
words = text_sent.strip().split()
bert_tokens_padding = torch.zeros(max_len).long()
bert_tokens = self.tokenizer.encode(text_sent) # tokenization (in sub-words)
tok_length = len(bert_tokens)
if tok_length>max_len:
raise Exception(f'Sub word length exceeded `maxlen` (>{max_len})')
# this maps (token_start, token_end)
#
token_start=1
for i, w, in enumerate(words):
token_end = token_start + len(self.tokenizer.encode(w, add_special_tokens=False))
token_range.append([token_start, token_end-1])
token_start = token_end
bert_tokens_padding[:tok_length] = torch.tensor(bert_tokens).long()
attention_mask = torch.zeros(max_len).long()
attention_mask[:tok_length]=1
tags_pred = model(bert_tokens_padding.unsqueeze(0),
attention_masks=attention_mask.unsqueeze(0))
tags = tags_pred['logits'][0].argmax(dim=-1)
return self.decode_predict_one(tags, token_range, bert_tokens)
def get_batch_tp_fp_tn(self, tags_batch, token_range_batch, sent_tokens, gold_labels):
batch_results = self.decode_pred_batch(tags_batch, token_range_batch, sent_tokens)
flat_gold, flat_pred = [], []
for preds, golds in list(zip(batch_results, gold_labels)):
for pred in preds:
flat_pred.append("-".join(pred))
for gold in golds:
flat_gold.append("-".join(gold))
gold_set = set(flat_gold)
pred_set = set(flat_pred)
tp = len(gold_set & pred_set)
fp = len(pred_set - gold_set)
fn = len(gold_set - pred_set)
return tp, fp, fn
|