File size: 3,302 Bytes
2f6628d
 
fd48f4d
2f6628d
fd48f4d
2f6628d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
###########################################################################
# NLP demo software by HyperbeeAI.                                        #
# Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. [email protected] #
###########################################################################
license_statement = "NLP demo software by HyperbeeAI. Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. [email protected]"
print("imported utils.py")
print(license_statement)
print("")

import torch
import layers
from tokenizers import Tokenizer
import time, torch, datasets
from tqdm import tqdm 

tokenizer_en = None 
tokenizer_es = None

def tokenize_es(text):
    return tokenizer_es.encode(text).ids[:48 - 2]

def tokenize_en(text):
    return tokenizer_en.encode(text).ids[:48 - 1]

def translate_sentence(sentence, src_field, trg_field, model, device):

    model.eval()
    if isinstance(sentence, str):
        tokens = tokenize_es(sentence)
    else:
        tokens = sentence

    tokens = [src_field.init_token] + tokens + [src_field.eos_token] + [src_field.pad_token] * (48 - 2 - len(tokens)) 
    src_tensor = torch.LongTensor(tokens).unsqueeze(0).to(device)

    with torch.no_grad():
        enc_out, _ = model.encoder(src_tensor)

    trg_indexes = [trg_field.init_token, ] + [trg_field.pad_token] * (48 - 1) 

    for i in range(48 - 1):
        start_idx = max(0, i - 7)
        
        trg_tensor = torch.LongTensor(trg_indexes[start_idx:start_idx + 8]).unsqueeze(0).to(device)

        with torch.no_grad():
            output, _, _ = model.decoder(trg_tensor, enc_out, max(0, i - 7))

        pred_token = output.argmax(2)[:, min(i, 7)].item()
        trg_indexes[i + 1] = pred_token
        if pred_token == trg_field.eos_token:
            break

    try:
        trg_indexes = trg_indexes[1:trg_indexes.index(trg_field.eos_token)]
    except ValueError: 
        trg_indexes = trg_indexes[1:]

    trg_tokens = tokenizer_en.decode(trg_indexes, skip_special_tokens=False)

    return trg_tokens


def calculate_bleu(data, src_field, trg_field, model, device, spiece=False, output_file = f"test.{time.time()}.out"):
    
    if spiece:
        from tokenizers import pre_tokenizers
        pre_tokenizer = pre_tokenizers.Digits(individual_digits=True)
    else:
        pre_tokenizer = tokenizer_en.pre_tokenizer

    trgs = []
    pred_trgs = []
    print('Evaluate on bleu:')
    for src, trg in tqdm(zip(open("news-comm-v15/news-comm-v15-all-test.es"), open("news-comm-v15/news-comm-v15-all-test.en"))):

        if len(src) < 3 or len(trg) < 3:
            continue

        normalized = pre_tokenizer.pre_tokenize_str(tokenizer_en.normalizer.normalize_str(trg))

        if len(normalized) > 48:
            continue

        trgs.append([ " ".join(map(lambda x: x[0], normalized)) ])

        pred_trg = translate_sentence(src, src_field, trg_field, model, device)
        pred_trgs.append(pred_trg)


    with open(output_file, "w") as fo:
        fo.write("\n".join(pred_trgs))

    sacrebleu = datasets.load_metric('sacrebleu')
    return sacrebleu.compute(predictions=pred_trgs, references=trgs)

tokenizer_es = Tokenizer.from_file(f"assets/es.json")
tokenizer_en = Tokenizer.from_file(f"assets/en.json")
TRG_PAD_IDX  = tokenizer_en.token_to_id("<PAD>")