AgentisLabs commited on
Commit
867fc5e
·
verified ·
1 Parent(s): 6902106

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +34 -0
handler.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
+
5
+ class EndpointHandler:
6
+ def __init__(self, path=""):
7
+ # Load the tokenizer and model
8
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
9
+ self.model = AutoModelForCausalLM.from_pretrained(path)
10
+ self.model.eval()
11
+
12
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
13
+ """
14
+ Args:
15
+ data: A dictionary with the key 'inputs' containing the input text.
16
+ Returns:
17
+ A dictionary with the generated text under the key 'generated_text'.
18
+ """
19
+ # Extract input text
20
+ input_text = data.get("inputs", "")
21
+ if not input_text:
22
+ return {"error": "No input provided"}
23
+
24
+ # Tokenize the input
25
+ inputs = self.tokenizer(input_text, return_tensors="pt")
26
+
27
+ # Generate text
28
+ with torch.no_grad():
29
+ outputs = self.model.generate(**inputs, max_length=100)
30
+
31
+ # Decode the generated tokens
32
+ generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
33
+
34
+ return {"generated_text": generated_text}