nanotranslator-hf / utils.py
sonebu
update email
fd48f4d
###########################################################################
# NLP demo software by HyperbeeAI. #
# Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. [email protected] #
###########################################################################
license_statement = "NLP demo software by HyperbeeAI. Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. [email protected]"
print("imported utils.py")
print(license_statement)
print("")
import torch
import layers
from tokenizers import Tokenizer
import time, torch, datasets
from tqdm import tqdm
tokenizer_en = None
tokenizer_es = None
def tokenize_es(text):
return tokenizer_es.encode(text).ids[:48 - 2]
def tokenize_en(text):
return tokenizer_en.encode(text).ids[:48 - 1]
def translate_sentence(sentence, src_field, trg_field, model, device):
model.eval()
if isinstance(sentence, str):
tokens = tokenize_es(sentence)
else:
tokens = sentence
tokens = [src_field.init_token] + tokens + [src_field.eos_token] + [src_field.pad_token] * (48 - 2 - len(tokens))
src_tensor = torch.LongTensor(tokens).unsqueeze(0).to(device)
with torch.no_grad():
enc_out, _ = model.encoder(src_tensor)
trg_indexes = [trg_field.init_token, ] + [trg_field.pad_token] * (48 - 1)
for i in range(48 - 1):
start_idx = max(0, i - 7)
trg_tensor = torch.LongTensor(trg_indexes[start_idx:start_idx + 8]).unsqueeze(0).to(device)
with torch.no_grad():
output, _, _ = model.decoder(trg_tensor, enc_out, max(0, i - 7))
pred_token = output.argmax(2)[:, min(i, 7)].item()
trg_indexes[i + 1] = pred_token
if pred_token == trg_field.eos_token:
break
try:
trg_indexes = trg_indexes[1:trg_indexes.index(trg_field.eos_token)]
except ValueError:
trg_indexes = trg_indexes[1:]
trg_tokens = tokenizer_en.decode(trg_indexes, skip_special_tokens=False)
return trg_tokens
def calculate_bleu(data, src_field, trg_field, model, device, spiece=False, output_file = f"test.{time.time()}.out"):
if spiece:
from tokenizers import pre_tokenizers
pre_tokenizer = pre_tokenizers.Digits(individual_digits=True)
else:
pre_tokenizer = tokenizer_en.pre_tokenizer
trgs = []
pred_trgs = []
print('Evaluate on bleu:')
for src, trg in tqdm(zip(open("news-comm-v15/news-comm-v15-all-test.es"), open("news-comm-v15/news-comm-v15-all-test.en"))):
if len(src) < 3 or len(trg) < 3:
continue
normalized = pre_tokenizer.pre_tokenize_str(tokenizer_en.normalizer.normalize_str(trg))
if len(normalized) > 48:
continue
trgs.append([ " ".join(map(lambda x: x[0], normalized)) ])
pred_trg = translate_sentence(src, src_field, trg_field, model, device)
pred_trgs.append(pred_trg)
with open(output_file, "w") as fo:
fo.write("\n".join(pred_trgs))
sacrebleu = datasets.load_metric('sacrebleu')
return sacrebleu.compute(predictions=pred_trgs, references=trgs)
tokenizer_es = Tokenizer.from_file(f"assets/es.json")
tokenizer_en = Tokenizer.from_file(f"assets/en.json")
TRG_PAD_IDX = tokenizer_en.token_to_id("<PAD>")