File size: 2,978 Bytes
a0939c3 e054408 a0939c3 e054408 a0939c3 e054408 a0939c3 e054408 a0939c3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
from typing import Dict, List
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import os
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EndpointHandler:
def __init__(self, path: str):
try:
logger.info("Loading base model...")
# Load base model with 8-bit quantization
base_model = AutoModelForCausalLM.from_pretrained(
"EleutherAI/gpt-j-6B",
load_in_8bit=True,
device_map="auto",
torch_dtype=torch.float16
)
logger.info("Loading adapter weights...")
# Load the adapter weights
self.model = PeftModel.from_pretrained(
base_model,
path
)
# Set up tokenizer
self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
self.tokenizer.pad_token = self.tokenizer.eos_token
logger.info("Model loaded successfully!")
except Exception as e:
logger.error(f"Error initializing model: {str(e)}")
raise
def __call__(self, data: Dict) -> List[str]:
try:
# Get the question from the input
question = data.pop("inputs", data)
if isinstance(question, list):
question = question[0]
# Format prompt exactly as in your test file
prompt = f"Question: {question}\nAnswer:"
# Tokenize exactly as in your test file
inputs = self.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=512
).to(self.model.device)
# Generate with exact same parameters as your test file
with torch.inference_mode(), torch.cuda.amp.autocast():
outputs = self.model.generate(
**inputs,
max_length=512,
num_return_sequences=1,
temperature=0.7,
do_sample=True,
use_cache=True
)
# Decode exactly as in your test file
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return [response]
except Exception as e:
logger.error(f"Error generating response: {str(e)}")
return [f"Error generating response: {str(e)}"]
def preprocess(self, request):
"""Pre-process request for API compatibility"""
if request.content_type == "application/json":
return request.json
return request
def postprocess(self, response):
"""Post-process response for API compatibility"""
return response |