File size: 5,837 Bytes
c490f2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135

import torch

class DecodeAndEvaluate:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        self.sentiment2id = {'negative': 3, 'neutral': 4, 'positive': 5}
        self.id2sentiment = {v:k for k, v in self.sentiment2id.items()}
        
    def get_span_from_tags(self, tags, token_range, tok_type): ## tok_type 1=aspect, 2 for opinions
        sel_spans = []
        end_ind = -1
        has_prev = False
        start_ind = -1
        for i in range(len(token_range)):
            l,r = token_range[i]
            if tags[l][l]!= tok_type:
                if has_prev:
                    sel_spans.append([start_ind, end_ind])
                    start_ind = -1
                    end_ind= -1
                has_prev = False
            if tags[l][l] == tok_type and not has_prev:
                start_ind = l
                end_ind = r
                has_prev = True
            if tags[l][l] == tok_type and has_prev:
                end_ind = r
                has_prev = True
        if has_prev:
            sel_spans.append([start_ind, end_ind])
            
        return sel_spans

    ## Corner cases where one sentiment span expresses over multiple sentiments 
    # and one aspect has multiple sentiments expressed on it
    def find_triplet(self, tags, aspect_spans, opinion_spans):
            triplets = []
            for al, ar in aspect_spans:
                for pl, pr in opinion_spans:
                    ## get the overlapping indices
                    # we select such that tag[aspect_l :aspect_r+1, opi_l: opi_r] 
                    # if opi>asp then lower triangular matrix starts being selected that is not annotated
                    # print(al, ar, pl, pr)
                    if al<=pl:                   
                        sent_tags = tags[al:ar+1, pl:pr+1]
                        flat_tags = sent_tags.reshape([-1])
                        flat_tags = torch.tensor([v.item() for v in flat_tags if v.item()>=0])
                        val = torch.mode(flat_tags).values.item()
                        if val > 0:
                            triplets.append([al, ar, pl, pr, val])
                    else: # In this case the aspect becomes column and sentiment becomes  the row
                        # print(al, pl)
                        sent_tags = tags[pl:pr+1, al: ar+1]
                        # print(sent_tags)
                        flat_tags = sent_tags.reshape([-1])
                        flat_tags = torch.tensor([v.item() for v in flat_tags if v.item()>=0])
                        val = torch.mode(flat_tags).values.item()
                        if val>0:
                            triplets.append([al, ar, pl, pr, val])
            return triplets

    def decode_triplets(self, triplets, sent_tokens):
        triplet_list = []
        for alt, art, olt, ort, pol in triplets:
            asp_toks = sent_tokens[alt:art+1]
            op_toks = sent_tokens[olt: ort+1]
            asp_string = self.tokenizer.decode(asp_toks)
            op_string = self.tokenizer.decode(op_toks)
            if pol in [3, 4, 5]:
                sentiment_pol = self.id2sentiment[pol] #.get(pol, "inconsistent")
                triplet_list.append([asp_string, op_string, sentiment_pol])
        return triplet_list
    
    def decode_predict_one(self, tags, token_range, sent_tokens):
        aspect_spans = self.get_span_from_tags(tags, token_range, 1)
        opinion_spans = self.get_span_from_tags(tags, token_range, 2)
        triplets = self.find_triplet(tags, aspect_spans, opinion_spans)
        return self.decode_triplets(triplets, sent_tokens)
    

    def decode_pred_batch(self, tags_batch, token_range_batch, sent_tokens):
        decoded_batch_results = []
        for i in range(tags_batch.shape[0]):
            res = self.decode_predict_one(tags_batch[i], token_range_batch[i], sent_tokens[i])
            decoded_batch_results.append(res)
        return decoded_batch_results
    
    def decode_predict_string_one(self, text_sent, model, max_len=64):
        token_range = []
        words = text_sent.strip().split()
        bert_tokens_padding = torch.zeros(max_len).long()
        bert_tokens = self.tokenizer.encode(text_sent) # tokenization (in sub-words)

        tok_length = len(bert_tokens)
        if tok_length>max_len:
            raise Exception(f'Sub word length exceeded `maxlen` (>{max_len})')
        # this maps (token_start, token_end)
        #
        token_start=1
        for i, w, in enumerate(words):
            token_end = token_start + len(self.tokenizer.encode(w, add_special_tokens=False))
            token_range.append([token_start, token_end-1])
            token_start = token_end
        
        bert_tokens_padding[:tok_length] = torch.tensor(bert_tokens).long()
        attention_mask = torch.zeros(max_len).long()
        attention_mask[:tok_length]=1

        tags_pred = model(bert_tokens_padding.unsqueeze(0), 
                           attention_masks=attention_mask.unsqueeze(0))
        
        tags = tags_pred['logits'][0].argmax(dim=-1)
        return self.decode_predict_one(tags, token_range, bert_tokens)



    def get_batch_tp_fp_tn(self, tags_batch, token_range_batch, sent_tokens, gold_labels):

        batch_results = self.decode_pred_batch(tags_batch, token_range_batch, sent_tokens)
        flat_gold, flat_pred = [], []

        for preds, golds in list(zip(batch_results, gold_labels)):
            for pred in preds:
                flat_pred.append("-".join(pred))
            for gold in golds:
                flat_gold.append("-".join(gold))
        gold_set = set(flat_gold)
        pred_set = set(flat_pred)
        tp = len(gold_set & pred_set)
        fp = len(pred_set - gold_set)
        fn = len(gold_set - pred_set)

        return tp, fp, fn