HenriAI / handler.py
henriceriocain's picture
Update handler.py
e054408 verified
raw
history blame
2.98 kB
from typing import Dict, List
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import os
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EndpointHandler:
def __init__(self, path: str):
try:
logger.info("Loading base model...")
# Load base model with 8-bit quantization
base_model = AutoModelForCausalLM.from_pretrained(
"EleutherAI/gpt-j-6B",
load_in_8bit=True,
device_map="auto",
torch_dtype=torch.float16
)
logger.info("Loading adapter weights...")
# Load the adapter weights
self.model = PeftModel.from_pretrained(
base_model,
path
)
# Set up tokenizer
self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
self.tokenizer.pad_token = self.tokenizer.eos_token
logger.info("Model loaded successfully!")
except Exception as e:
logger.error(f"Error initializing model: {str(e)}")
raise
def __call__(self, data: Dict) -> List[str]:
try:
# Get the question from the input
question = data.pop("inputs", data)
if isinstance(question, list):
question = question[0]
# Format prompt exactly as in your test file
prompt = f"Question: {question}\nAnswer:"
# Tokenize exactly as in your test file
inputs = self.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=512
).to(self.model.device)
# Generate with exact same parameters as your test file
with torch.inference_mode(), torch.cuda.amp.autocast():
outputs = self.model.generate(
**inputs,
max_length=512,
num_return_sequences=1,
temperature=0.7,
do_sample=True,
use_cache=True
)
# Decode exactly as in your test file
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return [response]
except Exception as e:
logger.error(f"Error generating response: {str(e)}")
return [f"Error generating response: {str(e)}"]
def preprocess(self, request):
"""Pre-process request for API compatibility"""
if request.content_type == "application/json":
return request.json
return request
def postprocess(self, response):
"""Post-process response for API compatibility"""
return response