momenaca's picture
update structure of the project for clarity
3650955
raw
history blame
2.15 kB
from langchain.prompts.chat import ChatMessagePromptTemplate
class SpecialTokens:
def __init__(self, config):
self.user_token = config["user_token"]
self.assistant_token = config["assistant_token"]
self.system_token = config["system_token"]
self.stop_token = config["stop_token"]
def to_instruction(query, special_tokens):
return special_tokens.user_token + query + special_tokens.stop_token
def to_prompt(query, special_tokens):
return (
special_tokens.user_token
+ query
+ special_tokens.stop_token
+ special_tokens.assistant_token
)
def to_system(query, special_tokens):
return special_tokens.system_token + query + special_tokens.stop_token
def make_prompt(prompt, special_tokens):
prompt_type = prompt["type"]
if prompt_type == "system":
return to_system("\n".join(prompt["prompt"]), special_tokens)
elif prompt_type == "instruction":
return to_instruction("\n".join(prompt["prompt"]), special_tokens)
elif prompt_type == "prompt":
return to_prompt("\n".join(prompt["prompt"]), special_tokens)
else:
return "Invalid prompt type, please check your config"
def to_chat_instruction(query, special_tokens):
return ChatMessagePromptTemplate.from_template(
query, role=special_tokens.user_token
)
def to_chat_system(query, special_tokens):
return ChatMessagePromptTemplate.from_template(
query, role=special_tokens.system_token
)
def to_chat_prompt(query, special_tokens):
return ChatMessagePromptTemplate.from_template(
query, role=special_tokens.user_token
)
def make_chat_prompt(prompt, special_tokens):
prompt_type = prompt["type"]
if prompt_type == "system":
return to_chat_system("\n".join(prompt["prompt"]), special_tokens)
elif prompt_type == "instruction":
return to_chat_instruction("\n".join(prompt["prompt"]), special_tokens)
elif prompt_type == "prompt":
return to_chat_prompt("\n".join(prompt["prompt"]), special_tokens)
else:
return "Invalid prompt type, please check your config"