File size: 3,258 Bytes
7e6964a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce2493d
7e6964a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f3898ef
 
 
 
ce2493d
 
 
 
 
 
 
 
f3898ef
ce2493d
 
9d2f4c9
 
ce2493d
 
 
 
 
9d2f4c9
ce2493d
 
 
f3898ef
 
ce2493d
 
f3898ef
ce2493d
 
 
9d2f4c9
ce2493d
 
f3898ef
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
from transformers import AutoTokenizer

from extended_embeddings.token_classification import ExtendedEmbeddigsRobertaForTokenClassification
from data_manipulation.dataset_funcions import load_gazetteers, gazetteer_matching, align_gazetteers_with_tokens
from data_manipulation.preprocess_gazetteers import build_reverse_dictionary


def load():
    model_name = "ufal/robeczech-base"
    model_path = "bettystr/NerRoB-czech"
    model = ExtendedEmbeddigsRobertaForTokenClassification.from_pretrained(model_path).to("cpu")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model.eval()
    gazetteers_path = "gazz2.json"
    gazetteers_for_matching = load_gazetteers(gazetteers_path)
    temp = []
    for i in gazetteers_for_matching.keys():
        temp.append(build_reverse_dictionary({i: gazetteers_for_matching[i]}))
    gazetteers_for_matching = temp
    return tokenizer, model, gazetteers_for_matching


def run(tokenizer, model, gazetteers_for_matching, text):

    tokenized_inputs = tokenizer(
        text, truncation=True, is_split_into_words=False, return_offsets_mapping=True
    )
    matches = gazetteer_matching(text, gazetteers_for_matching)
    new_g = []
    word_ids = tokenized_inputs.word_ids()
    new_g.append(align_gazetteers_with_tokens(matches, word_ids))
    p, o, l = [], [], []
    for i in new_g:
        p.append([x[0] for x in i])
        o.append([x[1] for x in i])
        l.append([x[2] for x in i])

    input_ids = torch.tensor(tokenized_inputs["input_ids"], device="cpu").unsqueeze(0)
    attention_mask = torch.tensor(tokenized_inputs["attention_mask"], device="cpu").unsqueeze(0)
    per = torch.tensor(p, device="cpu")
    org = torch.tensor(o, device="cpu")
    loc = torch.tensor(l, device="cpu")
    output = model(input_ids=input_ids, attention_mask=attention_mask, per=per, org=org, loc=loc).logits
    predictions = torch.argmax(output, dim=2).tolist()
    predicted_tags = [[model.config.id2label[idx] for idx in sentence] for sentence in predictions]
    
    softmax = torch.nn.Softmax(dim=2)
    scores = softmax(output).squeeze(0).tolist()
    result = []
    temp = {
            "start": 0,
            "end": 0,
            "entity": "O",
            "score": 0,
            "word": "",
            "count": 0
        }
    for pos, entity, score in zip(tokenized_inputs.offset_mapping, predicted_tags[0], scores):
        if pos[0] == pos[1] or entity == "O":
            continue
        if "I-" + temp["entity"] == entity:  # same entity
            temp["word"] += text[temp["end"]:pos[0]] + text[pos[0]:pos[1]]
            temp["end"] = pos[1]
            temp["count"] += 1
            temp["score"] += max(score)
        else:  # new entity
            if temp["count"] > 0:
                temp["score"] += max(score)
                temp["score"] /= temp.pop("count")
                result.append(temp)
            temp = {
            "start": pos[0],
            "end": pos[1],
            "entity": entity[2:],
            "score": 0,
            "word": text[pos[0]:pos[1]],
            "count": 1
            }
    if temp["count"] > 0:
        temp["score"] += max(score)
        temp["score"] /= temp.pop("count")
        result.append(temp)
    return result