duzx16
commited on
Commit
·
1676f07
1
Parent(s):
591fa87
Implement new interface
Browse files- modeling_chatglm.py +25 -17
- tokenization_chatglm.py +15 -10
modeling_chatglm.py
CHANGED
|
@@ -996,18 +996,23 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 996 |
for layer_past in past
|
| 997 |
)
|
| 998 |
|
| 999 |
-
def process_response(self,
|
| 1000 |
-
|
| 1001 |
-
response
|
| 1002 |
-
|
| 1003 |
-
|
| 1004 |
-
|
| 1005 |
-
|
| 1006 |
-
|
| 1007 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1008 |
|
| 1009 |
@torch.inference_mode()
|
| 1010 |
-
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None,
|
| 1011 |
max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
|
| 1012 |
**kwargs):
|
| 1013 |
if history is None:
|
|
@@ -1017,17 +1022,19 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 1017 |
logits_processor.append(InvalidScoreLogitsProcessor())
|
| 1018 |
gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
|
| 1019 |
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
|
| 1020 |
-
inputs =
|
| 1021 |
-
|
|
|
|
|
|
|
| 1022 |
outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
|
| 1023 |
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
|
| 1024 |
response = tokenizer.decode(outputs)
|
| 1025 |
-
|
| 1026 |
-
history =
|
| 1027 |
return response, history
|
| 1028 |
|
| 1029 |
@torch.inference_mode()
|
| 1030 |
-
def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None,
|
| 1031 |
past_key_values=None,max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8,
|
| 1032 |
logits_processor=None, return_past_key_values=False, **kwargs):
|
| 1033 |
if history is None:
|
|
@@ -1040,9 +1047,10 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 1040 |
gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
|
| 1041 |
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
|
| 1042 |
if past_key_values is None:
|
| 1043 |
-
inputs =
|
| 1044 |
else:
|
| 1045 |
-
inputs =
|
|
|
|
| 1046 |
if past_key_values is not None:
|
| 1047 |
past_length = past_key_values[0][0].shape[0]
|
| 1048 |
if self.transformer.pre_seq_len is not None:
|
|
|
|
| 996 |
for layer_past in past
|
| 997 |
)
|
| 998 |
|
| 999 |
+
def process_response(self, output, history):
|
| 1000 |
+
content = ""
|
| 1001 |
+
for response in output.split("<|assistant|>"):
|
| 1002 |
+
metadata, content = response.split("\n", maxsplit=1)
|
| 1003 |
+
history.append({"role": "assistant", "metadata": metadata, "content": content})
|
| 1004 |
+
if not metadata.strip():
|
| 1005 |
+
content = content.strip()
|
| 1006 |
+
content = content.replace("[[训练时间]]", "2023年")
|
| 1007 |
+
else:
|
| 1008 |
+
content = "\n".join(content.split("\n")[1:-1])
|
| 1009 |
+
def tool_call(**kwargs):
|
| 1010 |
+
return kwargs
|
| 1011 |
+
content = eval(content)
|
| 1012 |
+
return content, history
|
| 1013 |
|
| 1014 |
@torch.inference_mode()
|
| 1015 |
+
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, role: str = None,
|
| 1016 |
max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
|
| 1017 |
**kwargs):
|
| 1018 |
if history is None:
|
|
|
|
| 1022 |
logits_processor.append(InvalidScoreLogitsProcessor())
|
| 1023 |
gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
|
| 1024 |
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
|
| 1025 |
+
inputs = tokenizer.build_chat_input(query, history=history, role=role)
|
| 1026 |
+
inputs = inputs.to(self.device)
|
| 1027 |
+
eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
|
| 1028 |
+
tokenizer.get_command("<|observation|>")]
|
| 1029 |
outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
|
| 1030 |
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
|
| 1031 |
response = tokenizer.decode(outputs)
|
| 1032 |
+
history.append({"role": role, "content": query})
|
| 1033 |
+
response, history = self.process_response(response, history)
|
| 1034 |
return response, history
|
| 1035 |
|
| 1036 |
@torch.inference_mode()
|
| 1037 |
+
def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, role: str = None,
|
| 1038 |
past_key_values=None,max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8,
|
| 1039 |
logits_processor=None, return_past_key_values=False, **kwargs):
|
| 1040 |
if history is None:
|
|
|
|
| 1047 |
gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
|
| 1048 |
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
|
| 1049 |
if past_key_values is None:
|
| 1050 |
+
inputs = tokenizer.build_chat_input(query, history=history, role=role)
|
| 1051 |
else:
|
| 1052 |
+
inputs = tokenizer.build_chat_input(query, role=role)
|
| 1053 |
+
input = inputs.to(self.device)
|
| 1054 |
if past_key_values is not None:
|
| 1055 |
past_length = past_key_values[0][0].shape[0]
|
| 1056 |
if self.transformer.pre_seq_len is not None:
|
tokenization_chatglm.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import os
|
| 2 |
import torch
|
| 3 |
from typing import List, Optional, Union, Dict
|
|
@@ -173,19 +174,23 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|
| 173 |
prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
|
| 174 |
return prefix_tokens
|
| 175 |
|
| 176 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
if history is None:
|
| 178 |
history = []
|
| 179 |
input_ids = []
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
input_ids.extend(
|
| 185 |
-
|
| 186 |
-
input_ids.extend(
|
| 187 |
-
[self.get_command("<|assistant|>")] + self.tokenizer.encode("\n") + self.tokenizer.encode(old_response))
|
| 188 |
-
input_ids.extend([self.get_command("<|user|>")] + self.tokenizer.encode("\n") + self.tokenizer.encode(query))
|
| 189 |
input_ids.extend([self.get_command("<|assistant|>")])
|
| 190 |
return self.batch_encode_plus([input_ids], return_tensors="pt", is_split_into_words=True)
|
| 191 |
|
|
|
|
| 1 |
+
import json
|
| 2 |
import os
|
| 3 |
import torch
|
| 4 |
from typing import List, Optional, Union, Dict
|
|
|
|
| 174 |
prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
|
| 175 |
return prefix_tokens
|
| 176 |
|
| 177 |
+
def build_single_message(self, role, metadata, message):
|
| 178 |
+
assert role in ["system", "user", "assistant", "observation"], role
|
| 179 |
+
role_tokens = [self.get_command(f"<|{role}|>")] + self.tokenizer.encode(f"{metadata}\n")
|
| 180 |
+
message_tokens = self.tokenizer.encode(message)
|
| 181 |
+
tokens = role_tokens + message_tokens
|
| 182 |
+
return tokens
|
| 183 |
+
|
| 184 |
+
def build_chat_input(self, query, history=None, role="user"):
|
| 185 |
if history is None:
|
| 186 |
history = []
|
| 187 |
input_ids = []
|
| 188 |
+
for item in history:
|
| 189 |
+
content = item["content"]
|
| 190 |
+
if item["role"] == "system" and "tools" in item:
|
| 191 |
+
content = content + "\n" + json.dumps(item["tools"], indent=4, ensure_ascii=False)
|
| 192 |
+
input_ids.extend(self.build_single_message(item["role"], item.get("metadata", ""), content))
|
| 193 |
+
input_ids.extend(self.build_single_message(role, "", query))
|
|
|
|
|
|
|
|
|
|
| 194 |
input_ids.extend([self.get_command("<|assistant|>")])
|
| 195 |
return self.batch_encode_plus([input_ids], return_tensors="pt", is_split_into_words=True)
|
| 196 |
|