{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Setup & Installation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!apt install -y tesseract-ocr\n",
    "pip install pytesseract"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Create Custom Handler for Inference Endpoints\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting handler.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile handler.py\n",
    "from typing import Dict, List, Any\n",
    "from transformers import LayoutLMForTokenClassification, LayoutLMv2Processor\n",
    "import torch\n",
    "from subprocess import run\n",
    "\n",
    "# install tesseract-ocr and pytesseract\n",
    "run(\"apt install -y tesseract-ocr\", shell=True, check=True)\n",
    "run(\"pip install pytesseract\", shell=True, check=True)\n",
    "\n",
    "# helper function to unnormalize bboxes for drawing onto the image\n",
    "def unnormalize_box(bbox, width, height):\n",
    "    return [\n",
    "        width * (bbox[0] / 1000),\n",
    "        height * (bbox[1] / 1000),\n",
    "        width * (bbox[2] / 1000),\n",
    "        height * (bbox[3] / 1000),\n",
    "    ]\n",
    "\n",
    "\n",
    "# set device\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "\n",
    "class EndpointHandler:\n",
    "    def __init__(self, path=\"\"):\n",
    "        # load model and processor from path\n",
    "        self.model = LayoutLMForTokenClassification.from_pretrained(\"philschmid/layoutlm-funsd\").to(device)\n",
    "        self.processor = LayoutLMv2Processor.from_pretrained(\"philschmid/layoutlm-funsd\")\n",
    "\n",
    "    def __call__(self, data: Dict[str, bytes]) -> Dict[str, List[Any]]:\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            data (:obj:):\n",
    "                includes the deserialized image file as PIL.Image\n",
    "        \"\"\"\n",
    "        # process input\n",
    "        image = data.pop(\"inputs\", data)\n",
    "\n",
    "        # process image\n",
    "        encoding = self.processor(image, return_tensors=\"pt\")\n",
    "\n",
    "        # run prediction\n",
    "        with torch.inference_mode():\n",
    "            outputs = self.model(\n",
    "                input_ids=encoding.input_ids.to(device),\n",
    "                bbox=encoding.bbox.to(device),\n",
    "                attention_mask=encoding.attention_mask.to(device),\n",
    "                token_type_ids=encoding.token_type_ids.to(device),\n",
    "            )\n",
    "            predictions = outputs.logits.softmax(-1)\n",
    "\n",
    "        # post process output\n",
    "        result = []\n",
    "        for item, inp_ids, bbox in zip(\n",
    "            predictions.squeeze(0).cpu(), encoding.input_ids.squeeze(0).cpu(), encoding.bbox.squeeze(0).cpu()\n",
    "        ):\n",
    "            label = self.model.config.id2label[int(item.argmax().cpu())]\n",
    "            if label == \"O\":\n",
    "                continue\n",
    "            score = item.max().item()\n",
    "            text = self.processor.tokenizer.decode(inp_ids)\n",
    "            bbox = unnormalize_box(bbox.tolist(), image.width, image.height)\n",
    "            result.append({\"label\": label, \"score\": score, \"text\": text, \"bbox\": bbox})\n",
    "        return {\"predictions\": result}\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "test custom pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from handler import EndpointHandler\n",
    "\n",
    "my_handler = EndpointHandler(\".\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
     ]
    }
   ],
   "source": [
    "import base64\n",
    "from PIL import Image\n",
    "from io import BytesIO\n",
    "import json\n",
    "\n",
    "# read image from disk\n",
    "image = Image.open(\"invoice_example.png\")\n",
    "request = {\"inputs\":image }\n",
    "\n",
    "# test the handler\n",
    "pred = my_handler(request)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image, ImageDraw, ImageFont\n",
    "\n",
    "\n",
    "def draw_result(image,result):\n",
    "    label2color = {\n",
    "        \"B-HEADER\": \"blue\",\n",
    "        \"B-QUESTION\": \"red\",\n",
    "        \"B-ANSWER\": \"green\",\n",
    "        \"I-HEADER\": \"blue\",\n",
    "        \"I-QUESTION\": \"red\",\n",
    "        \"I-ANSWER\": \"green\",\n",
    "    }\n",
    "\n",
    "\n",
    "    # draw predictions over the image\n",
    "    draw = ImageDraw.Draw(image)\n",
    "    font = ImageFont.load_default()\n",
    "    for res in result:\n",
    "        draw.rectangle(res[\"bbox\"], outline=\"black\")\n",
    "        draw.rectangle(res[\"bbox\"], outline=label2color[res[\"label\"]])\n",
    "        draw.text((res[\"bbox\"][0] + 10, res[\"bbox\"][1] - 10), text=res[\"label\"], fill=label2color[res[\"label\"]], font=font)\n",
    "    return image\n",
    "\n",
    "draw_result(image,pred[\"predictions\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.13 ('dev': conda)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "f6dd96c16031089903d5a31ec148b80aeb0d39c32affb1a1080393235fbfa2fc"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}