Reemalsahli commited on
Commit
29df1c0
·
verified ·
1 Parent(s): de3552c

update handler.py to support chat completions

Browse files
Files changed (1) hide show
  1. handler.py +26 -13
handler.py CHANGED
@@ -10,7 +10,10 @@ class EndpointHandler:
10
  self.peft_config = PeftConfig.from_pretrained(path)
11
 
12
  # Load tokenizer from base model
13
- self.tokenizer = AutoTokenizer.from_pretrained(self.peft_config.base_model_name_or_path, use_fast=False)
 
 
 
14
 
15
  # Load base model
16
  base_model = AutoModelForCausalLM.from_pretrained(
@@ -24,20 +27,29 @@ class EndpointHandler:
24
  self.model.eval()
25
 
26
  def __call__(self, data: Dict[str, str]) -> Dict[str, str]:
27
- user_input = data.get("inputs", "")
28
- if not user_input:
29
- return {"error": "No input provided."}
30
-
31
- # Format input as chat message using chat template
32
- messages = [
33
- {"role": "user", "content": user_input}
34
- ]
35
- prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
 
 
 
 
 
 
 
 
36
 
37
- # Tokenize
38
  inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
39
 
40
- # Generate
41
  with torch.no_grad():
42
  output_ids = self.model.generate(
43
  **inputs,
@@ -48,6 +60,7 @@ class EndpointHandler:
48
  pad_token_id=self.tokenizer.eos_token_id
49
  )
50
 
 
51
  output_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
52
-
53
  return {"generated_text": output_text}
 
 
10
  self.peft_config = PeftConfig.from_pretrained(path)
11
 
12
  # Load tokenizer from base model
13
+ self.tokenizer = AutoTokenizer.from_pretrained(
14
+ self.peft_config.base_model_name_or_path,
15
+ use_fast=False
16
+ )
17
 
18
  # Load base model
19
  base_model = AutoModelForCausalLM.from_pretrained(
 
27
  self.model.eval()
28
 
29
  def __call__(self, data: Dict[str, str]) -> Dict[str, str]:
30
+ # Handle both chat-style and plain input
31
+ if "messages" in data:
32
+ messages = data["messages"]
33
+ prompt = self.tokenizer.apply_chat_template(
34
+ messages,
35
+ tokenize=False,
36
+ add_generation_prompt=True
37
+ )
38
+ else:
39
+ user_input = data.get("inputs", "")
40
+ if not user_input:
41
+ return {"error": "No input provided."}
42
+ messages = [{"role": "user", "content": user_input}]
43
+ prompt = self.tokenizer.apply_chat_template(
44
+ messages,
45
+ tokenize=False,
46
+ add_generation_prompt=True
47
+ )
48
 
49
+ # Tokenize input
50
  inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
51
 
52
+ # Generate output
53
  with torch.no_grad():
54
  output_ids = self.model.generate(
55
  **inputs,
 
60
  pad_token_id=self.tokenizer.eos_token_id
61
  )
62
 
63
+ # Decode and return result
64
  output_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
 
65
  return {"generated_text": output_text}
66
+