|
|
from typing import Dict, List |
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
from peft import PeftModel |
|
|
import os |
|
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path: str): |
|
|
try: |
|
|
logger.info("Loading base model...") |
|
|
|
|
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
|
"EleutherAI/gpt-j-6B", |
|
|
load_in_8bit=True, |
|
|
device_map="auto", |
|
|
torch_dtype=torch.float16 |
|
|
) |
|
|
|
|
|
logger.info("Loading adapter weights...") |
|
|
|
|
|
self.model = PeftModel.from_pretrained( |
|
|
base_model, |
|
|
path |
|
|
) |
|
|
|
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") |
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
|
|
logger.info("Model loaded successfully!") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error initializing model: {str(e)}") |
|
|
raise |
|
|
|
|
|
def __call__(self, data: Dict) -> List[str]: |
|
|
try: |
|
|
|
|
|
question = data.pop("inputs", data) |
|
|
if isinstance(question, list): |
|
|
question = question[0] |
|
|
|
|
|
|
|
|
prompt = f"Question: {question}\nAnswer:" |
|
|
|
|
|
|
|
|
inputs = self.tokenizer( |
|
|
prompt, |
|
|
return_tensors="pt", |
|
|
truncation=True, |
|
|
max_length=512 |
|
|
).to(self.model.device) |
|
|
|
|
|
|
|
|
with torch.inference_mode(), torch.cuda.amp.autocast(): |
|
|
outputs = self.model.generate( |
|
|
**inputs, |
|
|
max_length=512, |
|
|
num_return_sequences=1, |
|
|
temperature=0.7, |
|
|
do_sample=True, |
|
|
use_cache=True |
|
|
) |
|
|
|
|
|
|
|
|
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
return [response] |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error generating response: {str(e)}") |
|
|
return [f"Error generating response: {str(e)}"] |
|
|
|
|
|
def preprocess(self, request): |
|
|
"""Pre-process request for API compatibility""" |
|
|
if request.content_type == "application/json": |
|
|
return request.json |
|
|
return request |
|
|
|
|
|
def postprocess(self, response): |
|
|
"""Post-process response for API compatibility""" |
|
|
return response |