LiKenun commited on
Commit
1509884
·
1 Parent(s): 4c71b8b

AI-generated chat sample revision 1: support both seq2seq and causal LM models

Browse files
Files changed (1) hide show
  1. chatbot.py +89 -28
chatbot.py CHANGED
@@ -1,26 +1,51 @@
1
  from os import getenv
2
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
  from utils import get_pytorch_device, spaces_gpu
4
 
5
  # Global chatbot instance (initialized once)
6
  _chatbot = None
7
  _tokenizer = None
 
8
 
9
  def get_chatbot():
10
- global _chatbot, _tokenizer
 
11
  if _chatbot is None:
12
  model_id = getenv("CHAT_MODEL")
13
  device = get_pytorch_device()
14
  _tokenizer = AutoTokenizer.from_pretrained(model_id)
15
- _chatbot = AutoModelForSeq2SeqLM.from_pretrained(
16
- model_id,
17
- use_safetensors=True # Use safetensors to avoid torch.load restriction
18
- ).to(device)
19
- return _chatbot, _tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  @spaces_gpu
22
  def chat(message: str, conversation_history: list[dict] | None) -> tuple[str, list[dict]]:
23
- model, tokenizer = get_chatbot()
24
 
25
  # Initialize conversation history if this is the first message
26
  if conversation_history is None:
@@ -29,36 +54,72 @@ def chat(message: str, conversation_history: list[dict] | None) -> tuple[str, li
29
  # Add the user's message
30
  conversation_history.append({"role": "user", "content": message})
31
 
32
- # For BlenderBot models, format conversation as dialogue history
33
- # Build the full conversation context as a string
34
- dialogue_text = ""
35
- for msg in conversation_history:
36
- if msg["role"] == "user":
37
- dialogue_text += f"User: {msg['content']}\n"
38
- elif msg["role"] == "assistant":
39
- dialogue_text += f"Assistant: {msg['content']}\n"
40
-
41
- # Tokenize the input
42
- inputs = tokenizer([dialogue_text], return_tensors="pt", truncation=True, max_length=512)
43
  device = get_pytorch_device()
44
- inputs = {k: v.to(device) for k, v in inputs.items()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  # Generate response
47
  outputs = model.generate(
48
  **inputs,
49
- max_new_tokens=128,
50
  do_sample=True,
51
  temperature=0.7,
52
  pad_token_id=tokenizer.eos_token_id
53
  )
54
 
55
- # Decode the generated tokens - for seq2seq models, this should be just the assistant's response
56
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
57
-
58
- # Clean up the response - remove any "Assistant:" prefix if present
59
- response = response.strip()
60
- if response.startswith("Assistant:"):
61
- response = response[len("Assistant:"):].strip()
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  # Add the assistant's response to history
64
  conversation_history.append({"role": "assistant", "content": response})
 
1
  from os import getenv
2
+ from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
3
  from utils import get_pytorch_device, spaces_gpu
4
 
5
  # Global chatbot instance (initialized once)
6
  _chatbot = None
7
  _tokenizer = None
8
+ _is_seq2seq = None
9
 
10
  def get_chatbot():
11
+ """Get or create the chatbot model instance. Supports both causal LM and seq2seq models."""
12
+ global _chatbot, _tokenizer, _is_seq2seq
13
  if _chatbot is None:
14
  model_id = getenv("CHAT_MODEL")
15
  device = get_pytorch_device()
16
  _tokenizer = AutoTokenizer.from_pretrained(model_id)
17
+
18
+ # Try to determine model type and load accordingly
19
+ # Check tokenizer config or model config to see if it's seq2seq
20
+ try:
21
+ from transformers import AutoConfig
22
+ config = AutoConfig.from_pretrained(model_id)
23
+ # Seq2seq models have encoder/decoder, causal LMs don't
24
+ _is_seq2seq = hasattr(config, 'is_encoder_decoder') and config.is_encoder_decoder
25
+ except Exception:
26
+ # Default to causal LM (most modern chat models)
27
+ _is_seq2seq = False
28
+
29
+ if _is_seq2seq:
30
+ _chatbot = AutoModelForSeq2SeqLM.from_pretrained(
31
+ model_id,
32
+ use_safetensors=True
33
+ ).to(device)
34
+ else:
35
+ _chatbot = AutoModelForCausalLM.from_pretrained(
36
+ model_id,
37
+ use_safetensors=True
38
+ ).to(device)
39
+
40
+ # Set pad token if not set
41
+ if _tokenizer.pad_token is None:
42
+ _tokenizer.pad_token = _tokenizer.eos_token
43
+
44
+ return _chatbot, _tokenizer, _is_seq2seq
45
 
46
  @spaces_gpu
47
  def chat(message: str, conversation_history: list[dict] | None) -> tuple[str, list[dict]]:
48
+ model, tokenizer, is_seq2seq = get_chatbot()
49
 
50
  # Initialize conversation history if this is the first message
51
  if conversation_history is None:
 
54
  # Add the user's message
55
  conversation_history.append({"role": "user", "content": message})
56
 
 
 
 
 
 
 
 
 
 
 
 
57
  device = get_pytorch_device()
58
+
59
+ # Check if tokenizer has a chat template (modern chat models)
60
+ use_chat_template = hasattr(tokenizer, 'chat_template') and tokenizer.chat_template is not None
61
+
62
+ if use_chat_template:
63
+ # Use chat template for modern chat models (Qwen, Mistral, etc.)
64
+ try:
65
+ formatted_input = tokenizer.apply_chat_template(
66
+ conversation_history,
67
+ tokenize=False,
68
+ add_generation_prompt=True
69
+ )
70
+ inputs = tokenizer(formatted_input, return_tensors="pt", truncation=True).to(device)
71
+ except Exception:
72
+ use_chat_template = False
73
+
74
+ if not use_chat_template:
75
+ # For models without chat templates (BlenderBot, older models)
76
+ if is_seq2seq:
77
+ # Seq2seq format: "User: ...\nAssistant: ..."
78
+ dialogue_text = ""
79
+ for msg in conversation_history:
80
+ if msg["role"] == "user":
81
+ dialogue_text += f"User: {msg['content']}\n"
82
+ elif msg["role"] == "assistant":
83
+ dialogue_text += f"Assistant: {msg['content']}\n"
84
+ inputs = tokenizer([dialogue_text], return_tensors="pt", truncation=True, max_length=512).to(device)
85
+ else:
86
+ # Causal LM format: just concatenate messages
87
+ dialogue_text = ""
88
+ for msg in conversation_history:
89
+ if msg["role"] == "user":
90
+ dialogue_text += f"User: {msg['content']}\n\n"
91
+ elif msg["role"] == "assistant":
92
+ dialogue_text += f"Assistant: {msg['content']}\n\n"
93
+ dialogue_text += "Assistant:"
94
+ inputs = tokenizer(dialogue_text, return_tensors="pt", truncation=True, max_length=1024).to(device)
95
 
96
  # Generate response
97
  outputs = model.generate(
98
  **inputs,
99
+ max_new_tokens=256,
100
  do_sample=True,
101
  temperature=0.7,
102
  pad_token_id=tokenizer.eos_token_id
103
  )
104
 
105
+ # Decode the response
106
+ if is_seq2seq:
107
+ # For seq2seq, output is just the generated response
108
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
109
+ # Clean up any "Assistant:" prefix
110
+ if response.startswith("Assistant:"):
111
+ response = response[len("Assistant:"):].strip()
112
+ else:
113
+ # For causal LMs, extract only the newly generated part
114
+ if use_chat_template:
115
+ # Extract only new tokens (generated part)
116
+ input_length = inputs.input_ids.shape[1]
117
+ generated_tokens = outputs[0][input_length:]
118
+ response = tokenizer.decode(generated_tokens, skip_special_tokens=True)
119
+ else:
120
+ # Extract text after the prompt
121
+ full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
122
+ response = full_text.split("Assistant:")[-1].strip()
123
 
124
  # Add the assistant's response to history
125
  conversation_history.append({"role": "assistant", "content": response})