|
from transformers import Pipeline |
|
import nltk |
|
import numpy as np |
|
import torch |
|
|
|
nltk.download("averaged_perceptron_tagger") |
|
nltk.download("averaged_perceptron_tagger_eng") |
|
|
|
NEL_MODEL = "nel-mgenre-multilingual" |
|
|
|
|
|
class NelPipeline(Pipeline): |
|
|
|
def _sanitize_parameters(self, **kwargs): |
|
preprocess_kwargs = {} |
|
if "text" in kwargs: |
|
preprocess_kwargs["text"] = kwargs["text"] |
|
|
|
return preprocess_kwargs, {}, {} |
|
|
|
def preprocess(self, text, **kwargs): |
|
|
|
start_token = "[START]" |
|
end_token = "[END]" |
|
|
|
if start_token in text and end_token in text: |
|
start_idx = text.index(start_token) + len(start_token) |
|
end_idx = text.index(end_token) |
|
enclosed_entity = text[start_idx:end_idx].strip() |
|
lOffset = start_idx |
|
rOffset = end_idx |
|
else: |
|
enclosed_entity = None |
|
lOffset = None |
|
rOffset = None |
|
|
|
|
|
outputs = self.model.generate( |
|
**self.tokenizer(text, return_tensors="pt", truncation=True, max_length=30).to(self.device), |
|
num_beams=5, |
|
num_return_sequences=5, |
|
return_dict_in_generate=True, |
|
output_scores=True |
|
) |
|
|
|
|
|
wikipedia_predictions = self.tokenizer.batch_decode( |
|
outputs.sequences, skip_special_tokens=True |
|
) |
|
token_ids, scores = outputs['sequences'], outputs['sequences_scores'].cpu().numpy() |
|
scores = np.exp(scores) |
|
surfaces = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True) |
|
|
|
final_outputs = [] |
|
|
|
for wikipedia_prediction, percentage in zip(surfaces, scores): |
|
|
|
final_outputs.append((wikipedia_prediction, enclosed_entity, lOffset, rOffset, round(percentage/100.0, 2))) |
|
return final_outputs |
|
|
|
def _forward(self, inputs): |
|
return inputs |
|
|
|
def postprocess(self, outputs, **kwargs): |
|
""" |
|
Postprocess the outputs of the model |
|
:param outputs: |
|
:param kwargs: |
|
:return: |
|
""" |
|
final_results = [] |
|
for output in outputs: |
|
wikipedia_prediction, enclosed_entity, lOffset, rOffset, percentage = output |
|
|
|
percentage = round(percentage, 2) |
|
|
|
final_results.append( |
|
{ |
|
|
|
"surface": enclosed_entity, |
|
"wkd_pred": wikipedia_prediction, |
|
"type": "UNK", |
|
"confidence_nel": round(percentage * 100.0, 2), |
|
"lOffset": lOffset, |
|
"rOffset": rOffset, |
|
} |
|
) |
|
return final_results |
|
|