update handler.py to support chat completions
Browse files- handler.py +26 -13
handler.py
CHANGED
@@ -10,7 +10,10 @@ class EndpointHandler:
|
|
10 |
self.peft_config = PeftConfig.from_pretrained(path)
|
11 |
|
12 |
# Load tokenizer from base model
|
13 |
-
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
|
|
|
|
|
14 |
|
15 |
# Load base model
|
16 |
base_model = AutoModelForCausalLM.from_pretrained(
|
@@ -24,20 +27,29 @@ class EndpointHandler:
|
|
24 |
self.model.eval()
|
25 |
|
26 |
def __call__(self, data: Dict[str, str]) -> Dict[str, str]:
|
27 |
-
|
28 |
-
if
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
-
# Tokenize
|
38 |
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
|
39 |
|
40 |
-
# Generate
|
41 |
with torch.no_grad():
|
42 |
output_ids = self.model.generate(
|
43 |
**inputs,
|
@@ -48,6 +60,7 @@ class EndpointHandler:
|
|
48 |
pad_token_id=self.tokenizer.eos_token_id
|
49 |
)
|
50 |
|
|
|
51 |
output_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
52 |
-
|
53 |
return {"generated_text": output_text}
|
|
|
|
10 |
self.peft_config = PeftConfig.from_pretrained(path)
|
11 |
|
12 |
# Load tokenizer from base model
|
13 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
14 |
+
self.peft_config.base_model_name_or_path,
|
15 |
+
use_fast=False
|
16 |
+
)
|
17 |
|
18 |
# Load base model
|
19 |
base_model = AutoModelForCausalLM.from_pretrained(
|
|
|
27 |
self.model.eval()
|
28 |
|
29 |
def __call__(self, data: Dict[str, str]) -> Dict[str, str]:
|
30 |
+
# Handle both chat-style and plain input
|
31 |
+
if "messages" in data:
|
32 |
+
messages = data["messages"]
|
33 |
+
prompt = self.tokenizer.apply_chat_template(
|
34 |
+
messages,
|
35 |
+
tokenize=False,
|
36 |
+
add_generation_prompt=True
|
37 |
+
)
|
38 |
+
else:
|
39 |
+
user_input = data.get("inputs", "")
|
40 |
+
if not user_input:
|
41 |
+
return {"error": "No input provided."}
|
42 |
+
messages = [{"role": "user", "content": user_input}]
|
43 |
+
prompt = self.tokenizer.apply_chat_template(
|
44 |
+
messages,
|
45 |
+
tokenize=False,
|
46 |
+
add_generation_prompt=True
|
47 |
+
)
|
48 |
|
49 |
+
# Tokenize input
|
50 |
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
|
51 |
|
52 |
+
# Generate output
|
53 |
with torch.no_grad():
|
54 |
output_ids = self.model.generate(
|
55 |
**inputs,
|
|
|
60 |
pad_token_id=self.tokenizer.eos_token_id
|
61 |
)
|
62 |
|
63 |
+
# Decode and return result
|
64 |
output_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
|
|
65 |
return {"generated_text": output_text}
|
66 |
+
|