File size: 2,978 Bytes
a0939c3
 
e054408
a0939c3
e054408
 
 
 
 
a0939c3
 
 
e054408
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a0939c3
 
e054408
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a0939c3
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from typing import Dict, List
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import os
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class EndpointHandler:
    def __init__(self, path: str):
        try:
            logger.info("Loading base model...")
            # Load base model with 8-bit quantization
            base_model = AutoModelForCausalLM.from_pretrained(
                "EleutherAI/gpt-j-6B",
                load_in_8bit=True,
                device_map="auto",
                torch_dtype=torch.float16
            )
            
            logger.info("Loading adapter weights...")
            # Load the adapter weights
            self.model = PeftModel.from_pretrained(
                base_model,
                path
            )
            
            # Set up tokenizer
            self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
            self.tokenizer.pad_token = self.tokenizer.eos_token
            
            logger.info("Model loaded successfully!")
            
        except Exception as e:
            logger.error(f"Error initializing model: {str(e)}")
            raise

    def __call__(self, data: Dict) -> List[str]:
        try:
            # Get the question from the input
            question = data.pop("inputs", data)
            if isinstance(question, list):
                question = question[0]
            
            # Format prompt exactly as in your test file
            prompt = f"Question: {question}\nAnswer:"
            
            # Tokenize exactly as in your test file
            inputs = self.tokenizer(
                prompt,
                return_tensors="pt",
                truncation=True,
                max_length=512
            ).to(self.model.device)
            
            # Generate with exact same parameters as your test file
            with torch.inference_mode(), torch.cuda.amp.autocast():
                outputs = self.model.generate(
                    **inputs,
                    max_length=512,
                    num_return_sequences=1,
                    temperature=0.7,
                    do_sample=True,
                    use_cache=True
                )
            
            # Decode exactly as in your test file
            response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            
            return [response]
            
        except Exception as e:
            logger.error(f"Error generating response: {str(e)}")
            return [f"Error generating response: {str(e)}"]

    def preprocess(self, request):
        """Pre-process request for API compatibility"""
        if request.content_type == "application/json":
            return request.json
        return request

    def postprocess(self, response):
        """Post-process response for API compatibility"""
        return response