henriceriocain commited on
Commit
e054408
·
verified ·
1 Parent(s): 8ffdae3

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +68 -61
handler.py CHANGED
@@ -1,72 +1,79 @@
1
  from typing import Dict, List
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
4
  from peft import PeftModel
 
 
 
 
 
5
 
6
  class EndpointHandler:
7
  def __init__(self, path: str):
8
- print("Loading base model...")
9
- # Configure 4-bit quantization
10
- self.bnb_config = BitsAndBytesConfig(
11
- load_in_4bit=True,
12
- bnb_4bit_quant_type="nf4",
13
- bnb_4bit_compute_dtype=torch.float16,
14
- bnb_4bit_use_double_quant=True,
15
- )
16
-
17
- # Load base model with 4-bit quantization
18
- base_model = AutoModelForCausalLM.from_pretrained(
19
- "EleutherAI/gpt-j-6B",
20
- quantization_config=self.bnb_config,
21
- device_map="auto",
22
- torch_dtype=torch.float16
23
- )
24
-
25
- print("Loading adapter weights...")
26
- # Load the adapter weights
27
- self.model = PeftModel.from_pretrained(
28
- base_model,
29
- path
30
- )
31
-
32
- # Set up tokenizer
33
- self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
34
- self.tokenizer.pad_token = self.tokenizer.eos_token
35
 
36
  def __call__(self, data: Dict) -> List[str]:
37
- """Matches your generate_response function exactly"""
38
- # Get the question from the input
39
- question = data.pop("inputs", data)
40
- if isinstance(question, list):
41
- question = question[0]
42
-
43
- # Format prompt
44
- prompt = f"Question: {question}\nAnswer:"
45
-
46
- # Tokenize
47
- inputs = self.tokenizer(
48
- prompt,
49
- return_tensors="pt",
50
- truncation=True,
51
- max_length=512
52
- ).to(self.model.device)
53
-
54
- # Generate
55
- with torch.inference_mode(), torch.cuda.amp.autocast():
56
- outputs = self.model.generate(
57
- **inputs,
58
- max_length=512,
59
- num_return_sequences=1,
60
- temperature=0.7,
61
- do_sample=True,
62
- use_cache=True
63
- )
64
-
65
- # Decode exactly as in your test file
66
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
67
-
68
- # Return as list for API compatibility
69
- return [response]
 
 
 
70
 
71
  def preprocess(self, request):
72
  """Pre-process request for API compatibility"""
 
1
  from typing import Dict, List
2
  import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from peft import PeftModel
5
+ import os
6
+ import logging
7
+
8
+ logging.basicConfig(level=logging.INFO)
9
+ logger = logging.getLogger(__name__)
10
 
11
  class EndpointHandler:
12
  def __init__(self, path: str):
13
+ try:
14
+ logger.info("Loading base model...")
15
+ # Load base model with 8-bit quantization
16
+ base_model = AutoModelForCausalLM.from_pretrained(
17
+ "EleutherAI/gpt-j-6B",
18
+ load_in_8bit=True,
19
+ device_map="auto",
20
+ torch_dtype=torch.float16
21
+ )
22
+
23
+ logger.info("Loading adapter weights...")
24
+ # Load the adapter weights
25
+ self.model = PeftModel.from_pretrained(
26
+ base_model,
27
+ path
28
+ )
29
+
30
+ # Set up tokenizer
31
+ self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
32
+ self.tokenizer.pad_token = self.tokenizer.eos_token
33
+
34
+ logger.info("Model loaded successfully!")
35
+
36
+ except Exception as e:
37
+ logger.error(f"Error initializing model: {str(e)}")
38
+ raise
 
39
 
40
  def __call__(self, data: Dict) -> List[str]:
41
+ try:
42
+ # Get the question from the input
43
+ question = data.pop("inputs", data)
44
+ if isinstance(question, list):
45
+ question = question[0]
46
+
47
+ # Format prompt exactly as in your test file
48
+ prompt = f"Question: {question}\nAnswer:"
49
+
50
+ # Tokenize exactly as in your test file
51
+ inputs = self.tokenizer(
52
+ prompt,
53
+ return_tensors="pt",
54
+ truncation=True,
55
+ max_length=512
56
+ ).to(self.model.device)
57
+
58
+ # Generate with exact same parameters as your test file
59
+ with torch.inference_mode(), torch.cuda.amp.autocast():
60
+ outputs = self.model.generate(
61
+ **inputs,
62
+ max_length=512,
63
+ num_return_sequences=1,
64
+ temperature=0.7,
65
+ do_sample=True,
66
+ use_cache=True
67
+ )
68
+
69
+ # Decode exactly as in your test file
70
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
71
+
72
+ return [response]
73
+
74
+ except Exception as e:
75
+ logger.error(f"Error generating response: {str(e)}")
76
+ return [f"Error generating response: {str(e)}"]
77
 
78
  def preprocess(self, request):
79
  """Pre-process request for API compatibility"""