henriceriocain commited on
Commit
a0939c3
·
verified ·
1 Parent(s): 2f8418c

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +79 -0
handler.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
4
+ from peft import PeftModel
5
+
6
+ class EndpointHandler:
7
+ def __init__(self, path: str):
8
+ print("Loading base model...")
9
+ # Configure 4-bit quantization
10
+ self.bnb_config = BitsAndBytesConfig(
11
+ load_in_4bit=True,
12
+ bnb_4bit_quant_type="nf4",
13
+ bnb_4bit_compute_dtype=torch.float16,
14
+ bnb_4bit_use_double_quant=True,
15
+ )
16
+
17
+ # Load base model with 4-bit quantization
18
+ base_model = AutoModelForCausalLM.from_pretrained(
19
+ "EleutherAI/gpt-j-6B",
20
+ quantization_config=self.bnb_config,
21
+ device_map="auto",
22
+ torch_dtype=torch.float16
23
+ )
24
+
25
+ print("Loading adapter weights...")
26
+ # Load the adapter weights
27
+ self.model = PeftModel.from_pretrained(
28
+ base_model,
29
+ path
30
+ )
31
+
32
+ # Set up tokenizer
33
+ self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
34
+ self.tokenizer.pad_token = self.tokenizer.eos_token
35
+
36
+ def __call__(self, data: Dict) -> List[str]:
37
+ """Matches your generate_response function exactly"""
38
+ # Get the question from the input
39
+ question = data.pop("inputs", data)
40
+ if isinstance(question, list):
41
+ question = question[0]
42
+
43
+ # Format prompt
44
+ prompt = f"Question: {question}\nAnswer:"
45
+
46
+ # Tokenize
47
+ inputs = self.tokenizer(
48
+ prompt,
49
+ return_tensors="pt",
50
+ truncation=True,
51
+ max_length=512
52
+ ).to(self.model.device)
53
+
54
+ # Generate
55
+ with torch.inference_mode(), torch.cuda.amp.autocast():
56
+ outputs = self.model.generate(
57
+ **inputs,
58
+ max_length=512,
59
+ num_return_sequences=1,
60
+ temperature=0.7,
61
+ do_sample=True,
62
+ use_cache=True
63
+ )
64
+
65
+ # Decode exactly as in your test file
66
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
67
+
68
+ # Return as list for API compatibility
69
+ return [response]
70
+
71
+ def preprocess(self, request):
72
+ """Pre-process request for API compatibility"""
73
+ if request.content_type == "application/json":
74
+ return request.json
75
+ return request
76
+
77
+ def postprocess(self, response):
78
+ """Post-process response for API compatibility"""
79
+ return response