File size: 1,814 Bytes
45b837c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from typing import Dict, List, Any
import torch
from accelerate import Accelerator
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import numpy as np


def softmax(x):
    z = x - max(x)
    numerator = np.exp(z)
    denominator = np.sum(numerator)
    softmax = numerator/denominator
    return softmax

class EndpointHandler():
    def __init__(self, path=""):
        self.accelerator = Accelerator()
        self.device = self.accelerator.device
        self.model = AutoModelForSeq2SeqLM.from_pretrained(path, trust_remote_code=True, device_map="auto")
        self.model = self.accelerator.prepare(self.model)
        self.tokenizer = AutoTokenizer.from_pretrained(path)
        self.options_tokens = [self.tokenizer.encode(choice)[0] for choice in ["A", "B", "C", "D"]]

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
       data args:
            inputs (:obj: `str` | `PIL.Image` | `np.array`)
            kwargss
      Return:
            A :obj:`list` | `dict`: will be serialized and returned
        """
        with torch.no_grad():
            prompt = data.pop("prompt")
            inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
            input_size = inputs['input_ids'].size(1)
            input_ids = inputs["input_ids"].to(self.device)
            start_token = self.tokenizer('<pad>', return_tensors="pt").to(self.device)
            outputs = self.model(**inputs, decoder_input_ids=start_token['input_ids'])
            last_token_logits = outputs.logits[:, -1, :]
            options_tokens_logits = last_token_logits[:, self.options_tokens].detach().cpu().numpy()
            conf = softmax(options_tokens_logits[0])
            pred = np.argmax(options_tokens_logits[0])
        return [{"pred": pred, "conf":conf}]