layoutlm-funsd-2 / handler.py
Alexander Slessor
added more files
8a5956b
raw
history blame contribute delete
2.59 kB
from typing import Dict, List, Any
from transformers import LayoutLMForTokenClassification, LayoutLMv2Processor
import torch
from subprocess import run
run("apt install -y tesseract-ocr", shell=True, check=True)
class HugEndpointException(Exception):
def __init__(self, e):
self.e = e
def __str__(self):
return f'Custom Endpoint Exception: {self.e}'
# helper function to unnormalize bboxes for drawing onto the image
def unnormalize_box(bbox, width, height):
return [
width * (bbox[0] / 1000),
height * (bbox[1] / 1000),
width * (bbox[2] / 1000),
height * (bbox[3] / 1000),
]
# set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class EndpointHandler:
def __init__(self, path=""):
# load model and processor from path
self.model = LayoutLMForTokenClassification.from_pretrained(path).to(device)
self.processor = LayoutLMv2Processor.from_pretrained(path)
def __call__(self, data: Dict[str, bytes]) -> Dict[str, List[Any]]:
"""
Args:
data (:obj:):
includes the deserialized image file as PIL.Image
"""
# process input
image = data.pop("inputs", data)
# process image
encoding = self.processor(image, return_tensors="pt")
try:
# run prediction
with torch.inference_mode():
outputs = self.model(
input_ids=encoding.input_ids.to(device),
bbox=encoding.bbox.to(device),
attention_mask=encoding.attention_mask.to(device),
token_type_ids=encoding.token_type_ids.to(device),
)
predictions = outputs.logits.softmax(-1)
# post process output
result = []
for item, inp_ids, bbox in zip(
predictions.squeeze(0).cpu(),
encoding.input_ids.squeeze(0).cpu(),
encoding.bbox.squeeze(0).cpu()
):
label = self.model.config.id2label[int(item.argmax().cpu())]
if label == "O":
continue
score = item.max().item()
text = self.processor.tokenizer.decode(inp_ids)
bbox = unnormalize_box(bbox.tolist(), image.width, image.height)
result.append({"label": label, "score": score, "text": text, "bbox": bbox})
return {"predictions": result}
except Exception as e:
raise HugEndpointException(e)