ALLaM-GSM8K-Thaqib / handler.py
Reemalsahli's picture
update handler.py to support chat completions
29df1c0 verified
# handler.py
from typing import Dict
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel, PeftConfig
class EndpointHandler:
def __init__(self, path=""):
# Load LoRA config to get base model
self.peft_config = PeftConfig.from_pretrained(path)
# Load tokenizer from base model
self.tokenizer = AutoTokenizer.from_pretrained(
self.peft_config.base_model_name_or_path,
use_fast=False
)
# Load base model
base_model = AutoModelForCausalLM.from_pretrained(
self.peft_config.base_model_name_or_path,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto"
)
# Load LoRA adapter
self.model = PeftModel.from_pretrained(base_model, path)
self.model.eval()
def __call__(self, data: Dict[str, str]) -> Dict[str, str]:
# Handle both chat-style and plain input
if "messages" in data:
messages = data["messages"]
prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
else:
user_input = data.get("inputs", "")
if not user_input:
return {"error": "No input provided."}
messages = [{"role": "user", "content": user_input}]
prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Tokenize input
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
# Generate output
with torch.no_grad():
output_ids = self.model.generate(
**inputs,
max_new_tokens=512,
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id
)
# Decode and return result
output_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
return {"generated_text": output_text}