|
|
|
|
|
|
|
|
|
license_statement = "NLP demo software by HyperbeeAI. Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. [email protected]" |
|
print("imported utils.py") |
|
print(license_statement) |
|
print("") |
|
|
|
import torch |
|
import layers |
|
from tokenizers import Tokenizer |
|
import time, torch, datasets |
|
from tqdm import tqdm |
|
|
|
tokenizer_en = None |
|
tokenizer_es = None |
|
|
|
def tokenize_es(text): |
|
return tokenizer_es.encode(text).ids[:48 - 2] |
|
|
|
def tokenize_en(text): |
|
return tokenizer_en.encode(text).ids[:48 - 1] |
|
|
|
def translate_sentence(sentence, src_field, trg_field, model, device): |
|
|
|
model.eval() |
|
if isinstance(sentence, str): |
|
tokens = tokenize_es(sentence) |
|
else: |
|
tokens = sentence |
|
|
|
tokens = [src_field.init_token] + tokens + [src_field.eos_token] + [src_field.pad_token] * (48 - 2 - len(tokens)) |
|
src_tensor = torch.LongTensor(tokens).unsqueeze(0).to(device) |
|
|
|
with torch.no_grad(): |
|
enc_out, _ = model.encoder(src_tensor) |
|
|
|
trg_indexes = [trg_field.init_token, ] + [trg_field.pad_token] * (48 - 1) |
|
|
|
for i in range(48 - 1): |
|
start_idx = max(0, i - 7) |
|
|
|
trg_tensor = torch.LongTensor(trg_indexes[start_idx:start_idx + 8]).unsqueeze(0).to(device) |
|
|
|
with torch.no_grad(): |
|
output, _, _ = model.decoder(trg_tensor, enc_out, max(0, i - 7)) |
|
|
|
pred_token = output.argmax(2)[:, min(i, 7)].item() |
|
trg_indexes[i + 1] = pred_token |
|
if pred_token == trg_field.eos_token: |
|
break |
|
|
|
try: |
|
trg_indexes = trg_indexes[1:trg_indexes.index(trg_field.eos_token)] |
|
except ValueError: |
|
trg_indexes = trg_indexes[1:] |
|
|
|
trg_tokens = tokenizer_en.decode(trg_indexes, skip_special_tokens=False) |
|
|
|
return trg_tokens |
|
|
|
|
|
def calculate_bleu(data, src_field, trg_field, model, device, spiece=False, output_file = f"test.{time.time()}.out"): |
|
|
|
if spiece: |
|
from tokenizers import pre_tokenizers |
|
pre_tokenizer = pre_tokenizers.Digits(individual_digits=True) |
|
else: |
|
pre_tokenizer = tokenizer_en.pre_tokenizer |
|
|
|
trgs = [] |
|
pred_trgs = [] |
|
print('Evaluate on bleu:') |
|
for src, trg in tqdm(zip(open("news-comm-v15/news-comm-v15-all-test.es"), open("news-comm-v15/news-comm-v15-all-test.en"))): |
|
|
|
if len(src) < 3 or len(trg) < 3: |
|
continue |
|
|
|
normalized = pre_tokenizer.pre_tokenize_str(tokenizer_en.normalizer.normalize_str(trg)) |
|
|
|
if len(normalized) > 48: |
|
continue |
|
|
|
trgs.append([ " ".join(map(lambda x: x[0], normalized)) ]) |
|
|
|
pred_trg = translate_sentence(src, src_field, trg_field, model, device) |
|
pred_trgs.append(pred_trg) |
|
|
|
|
|
with open(output_file, "w") as fo: |
|
fo.write("\n".join(pred_trgs)) |
|
|
|
sacrebleu = datasets.load_metric('sacrebleu') |
|
return sacrebleu.compute(predictions=pred_trgs, references=trgs) |
|
|
|
tokenizer_es = Tokenizer.from_file(f"assets/es.json") |
|
tokenizer_en = Tokenizer.from_file(f"assets/en.json") |
|
TRG_PAD_IDX = tokenizer_en.token_to_id("<PAD>") |
|
|