Spaces:
Runtime error
Runtime error
from response_parser import * | |
import copy | |
import json | |
from tqdm import tqdm | |
import logging | |
import argparse | |
import os | |
def initialization(state_dict: Dict) -> None: | |
if not os.path.exists('cache'): | |
os.mkdir('cache') | |
if state_dict["bot_backend"] is None: | |
state_dict["bot_backend"] = BotBackend() | |
if 'OPENAI_API_KEY' in os.environ: | |
del os.environ['OPENAI_API_KEY'] | |
def get_bot_backend(state_dict: Dict) -> BotBackend: | |
return state_dict["bot_backend"] | |
def switch_to_gpt4(state_dict: Dict, whether_switch: bool) -> None: | |
bot_backend = get_bot_backend(state_dict) | |
if whether_switch: | |
bot_backend.update_gpt_model_choice("GPT-4") | |
else: | |
bot_backend.update_gpt_model_choice("GPT-3.5") | |
def add_text(state_dict, history, text): | |
bot_backend = get_bot_backend(state_dict) | |
bot_backend.add_text_message(user_text=text) | |
history = history + [[text, None]] | |
return history, state_dict | |
def bot(state_dict, history): | |
bot_backend = get_bot_backend(state_dict) | |
while bot_backend.finish_reason in ('new_input', 'function_call'): | |
if history[-1][1]: | |
history.append([None, ""]) | |
else: | |
history[-1][1] = "" | |
logging.info("Start chat completion") | |
response = chat_completion(bot_backend=bot_backend) | |
logging.info(f"End chat completion, response: {response}") | |
logging.info("Start parse response") | |
history, _ = parse_response( | |
chunk=response, | |
history=history, | |
bot_backend=bot_backend | |
) | |
logging.info("End parse response") | |
return history | |
def main(state, history, user_input): | |
history, state = add_text(state, history, user_input) | |
last_history = copy.deepcopy(history) | |
first_turn_flag = False | |
while True: | |
if first_turn_flag: | |
switch_to_gpt4(state, False) | |
first_turn_flag = False | |
else: | |
switch_to_gpt4(state, True) | |
logging.info("Start bot") | |
history = bot(state, history) | |
logging.info("End bot") | |
print(state["bot_backend"].conversation) | |
if last_history == copy.deepcopy(history): | |
logging.info("No new response, end conversation") | |
conversation = [item for item in state["bot_backend"].conversation if item["content"]] | |
return conversation | |
else: | |
logging.info("New response, continue conversation") | |
last_history = copy.deepcopy(history) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--input_path', type=str) | |
parser.add_argument('--output_path', type=str) | |
args = parser.parse_args() | |
logging.basicConfig(level=logging.INFO) | |
logging.info("Initialization") | |
state = {"bot_backend": None} | |
history = [] | |
initialization(state) | |
switch_to_gpt4(state_dict=state, whether_switch=True) | |
logging.info("Start") | |
with open(args.input_path, "r") as f: | |
instructions = [json.loads(line)["query"] for line in f.readlines()] | |
all_history = [] | |
logging.info(f"{len(instructions)} remaining instructions for {args.input_path}") | |
for user_input_index, user_input in enumerate(tqdm(instructions)): | |
logging.info(f"Start conversation {user_input_index}") | |
conversation = main(state, history, user_input) | |
all_history.append( | |
{ | |
"instruction": user_input, | |
"conversation": conversation | |
} | |
) | |
with open(f"{args.output_path}", "w") as f: | |
json.dump(all_history, f, indent=4, ensure_ascii=False) | |
state["bot_backend"].restart() | |