Spaces:
Sleeping
Sleeping
File size: 2,149 Bytes
45c901d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
from langchain.prompts.chat import ChatMessagePromptTemplate
class SpecialTokens:
def __init__(self, config):
self.user_token = config["user_token"]
self.assistant_token = config["assistant_token"]
self.system_token = config["system_token"]
self.stop_token = config["stop_token"]
def to_instruction(query, special_tokens):
return special_tokens.user_token + query + special_tokens.stop_token
def to_prompt(query, special_tokens):
return (
special_tokens.user_token
+ query
+ special_tokens.stop_token
+ special_tokens.assistant_token
)
def to_system(query, special_tokens):
return special_tokens.system_token + query + special_tokens.stop_token
def make_prompt(prompt, special_tokens):
prompt_type = prompt["type"]
if prompt_type == "system":
return to_system("\n".join(prompt["prompt"]), special_tokens)
elif prompt_type == "instruction":
return to_instruction("\n".join(prompt["prompt"]), special_tokens)
elif prompt_type == "prompt":
return to_prompt("\n".join(prompt["prompt"]), special_tokens)
else:
return "Invalid prompt type, please check your config"
def to_chat_instruction(query, special_tokens):
return ChatMessagePromptTemplate.from_template(
query, role=special_tokens.user_token
)
def to_chat_system(query, special_tokens):
return ChatMessagePromptTemplate.from_template(
query, role=special_tokens.system_token
)
def to_chat_prompt(query, special_tokens):
return ChatMessagePromptTemplate.from_template(
query, role=special_tokens.user_token
)
def make_chat_prompt(prompt, special_tokens):
prompt_type = prompt["type"]
if prompt_type == "system":
return to_chat_system("\n".join(prompt["prompt"]), special_tokens)
elif prompt_type == "instruction":
return to_chat_instruction("\n".join(prompt["prompt"]), special_tokens)
elif prompt_type == "prompt":
return to_chat_prompt("\n".join(prompt["prompt"]), special_tokens)
else:
return "Invalid prompt type, please check your config"
|