from typing import Dict, List import torch from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel import os import logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class EndpointHandler: def __init__(self, path: str): try: logger.info("Loading base model...") # Load base model with 8-bit quantization base_model = AutoModelForCausalLM.from_pretrained( "EleutherAI/gpt-j-6B", load_in_8bit=True, device_map="auto", torch_dtype=torch.float16 ) logger.info("Loading adapter weights...") # Load the adapter weights self.model = PeftModel.from_pretrained( base_model, path ) # Set up tokenizer self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") self.tokenizer.pad_token = self.tokenizer.eos_token logger.info("Model loaded successfully!") except Exception as e: logger.error(f"Error initializing model: {str(e)}") raise def __call__(self, data: Dict) -> List[str]: try: # Get the question from the input question = data.pop("inputs", data) if isinstance(question, list): question = question[0] # Format prompt exactly as in your test file prompt = f"Question: {question}\nAnswer:" # Tokenize exactly as in your test file inputs = self.tokenizer( prompt, return_tensors="pt", truncation=True, max_length=512 ).to(self.model.device) # Generate with exact same parameters as your test file with torch.inference_mode(), torch.cuda.amp.autocast(): outputs = self.model.generate( **inputs, max_length=512, num_return_sequences=1, temperature=0.7, do_sample=True, use_cache=True ) # Decode exactly as in your test file response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) return [response] except Exception as e: logger.error(f"Error generating response: {str(e)}") return [f"Error generating response: {str(e)}"] def preprocess(self, request): """Pre-process request for API compatibility""" if request.content_type == "application/json": return request.json return request def postprocess(self, response): """Post-process response for API compatibility""" return response