File size: 4,803 Bytes
92a6f43
 
 
7e6964a
75a65be
7e6964a
 
75a65be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e6964a
 
 
 
 
92a6f43
081d311
7e6964a
 
 
1709ba8
75a65be
7e6964a
 
 
 
92a6f43
 
7e6964a
92a6f43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e6964a
 
ce2493d
7e6964a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f3898ef
 
 
 
ce2493d
 
 
 
 
 
 
 
f3898ef
ce2493d
 
9d2f4c9
 
ce2493d
 
 
 
 
9d2f4c9
ce2493d
 
 
f3898ef
 
ce2493d
 
f3898ef
ce2493d
 
 
9d2f4c9
ce2493d
 
f3898ef
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import json 
import copy

import torch
from simplemma import lemmatize
from transformers import AutoTokenizer

from extended_embeddings.extended_embedding_token_classification import ExtendedEmbeddigsRobertaForTokenClassification
from data_manipulation.dataset_funcions import gazetteer_matching, align_gazetteers_with_tokens

# code originaly from data_manipulation.creation_gazetteers
def lemmatizing(x):
    if x == "":
        return ""
    return lemmatize(x, lang="cs")

# code originaly from data_manipulation.creation_gazetteers
def build_reverse_dictionary(dictionary, apply_lemmatizing=False):
    reverse_dictionary = {}
    for key, values in dictionary.items():
        for value in values:
            reverse_dictionary[value] = key
            if apply_lemmatizing:
                temp = lemmatizing(value)
                if temp != value:
                    reverse_dictionary[temp] = key
    return reverse_dictionary

def load_json(path):
    """
    Load gazetteers from a file
    :param path: path to the gazetteer file
    :return: a dict of gazetteers
    """
    with open(path, 'r') as file:
        data = json.load(file)
    return data


def load():
    model_name = "ufal/robeczech-base"
    model_path = "bettystr/NerRoB-czech"
    gazetteers_path = "gazz2.json"

    model = ExtendedEmbeddigsRobertaForTokenClassification.from_pretrained(model_path).to("cpu")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model.eval()

    gazetteers_for_matching = load_json(gazetteers_path)
    temp = []
    for i in gazetteers_for_matching.keys():
        temp.append(build_reverse_dictionary({i: gazetteers_for_matching[i]}))
    gazetteers_for_matching = temp
    return tokenizer, model, gazetteers_for_matching


def add_additional_gazetteers(gazetteers_for_matching, file_names):
    if file_names is None or file_names == []:
        return gazetteers_for_matching
    temp = []
    for l1 in gazetteers_for_matching:
        d2 = copy.deepcopy(l1)
        temp.append(d2)
    for file_name in file_names:
        with open(file_name, 'r') as file:
            data = json.load(file)
        for key, value_lst in data.items():
            key = key.upper()
            for dictionary in temp:
                if key in dictionary.values():
                    for value in value_lst:
                        dictionary[value] = key
    return temp


def run(tokenizer, model, gazetteers, text, file_names=None):
    gazetteers_for_matching = add_additional_gazetteers(gazetteers, file_names)

    tokenized_inputs = tokenizer(
        text, truncation=True, is_split_into_words=False, return_offsets_mapping=True
    )
    matches = gazetteer_matching(text, gazetteers_for_matching)
    new_g = []
    word_ids = tokenized_inputs.word_ids()
    new_g.append(align_gazetteers_with_tokens(matches, word_ids))
    p, o, l = [], [], []
    for i in new_g:
        p.append([x[0] for x in i])
        o.append([x[1] for x in i])
        l.append([x[2] for x in i])

    input_ids = torch.tensor(tokenized_inputs["input_ids"], device="cpu").unsqueeze(0)
    attention_mask = torch.tensor(tokenized_inputs["attention_mask"], device="cpu").unsqueeze(0)
    per = torch.tensor(p, device="cpu")
    org = torch.tensor(o, device="cpu")
    loc = torch.tensor(l, device="cpu")
    output = model(input_ids=input_ids, attention_mask=attention_mask, per=per, org=org, loc=loc).logits
    predictions = torch.argmax(output, dim=2).tolist()
    predicted_tags = [[model.config.id2label[idx] for idx in sentence] for sentence in predictions]
    
    softmax = torch.nn.Softmax(dim=2)
    scores = softmax(output).squeeze(0).tolist()
    result = []
    temp = {
            "start": 0,
            "end": 0,
            "entity": "O",
            "score": 0,
            "word": "",
            "count": 0
        }
    for pos, entity, score in zip(tokenized_inputs.offset_mapping, predicted_tags[0], scores):
        if pos[0] == pos[1] or entity == "O":
            continue
        if "I-" + temp["entity"] == entity:  # same entity
            temp["word"] += text[temp["end"]:pos[0]] + text[pos[0]:pos[1]]
            temp["end"] = pos[1]
            temp["count"] += 1
            temp["score"] += max(score)
        else:  # new entity
            if temp["count"] > 0:
                temp["score"] += max(score)
                temp["score"] /= temp.pop("count")
                result.append(temp)
            temp = {
            "start": pos[0],
            "end": pos[1],
            "entity": entity[2:],
            "score": 0,
            "word": text[pos[0]:pos[1]],
            "count": 1
            }
    if temp["count"] > 0:
        temp["score"] += max(score)
        temp["score"] /= temp.pop("count")
        result.append(temp)
    return result