Spaces:
Runtime error
Runtime error
import gradio as gr | |
import json | |
import requests | |
import os | |
from model_inference import Inference | |
import time | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
question_selector_map = {} | |
model_list = ["llama2", "llama2-chat", "vicuna", "falcon", "falcon-instruct", "orca", "wizardlm"] | |
with open("src/inference_endpoint.json", "r") as f: | |
inference_endpoint = json.load(f) | |
for i in range(len(model_list)): | |
inference_endpoint[model_list[i]]["headers"]["Authorization"] += HF_TOKEN | |
def build_question_selector_map(questions): | |
question_selector_map = {} | |
# Build question selector map | |
for q in questions: | |
preview = f"{q['question_id']+1}: " + q["question"][:128] + "..." | |
question_selector_map[preview] = q | |
return question_selector_map | |
def math_display_question_answer(question, cot, request: gr.Request): | |
if cot: | |
q = math_cot_question_selector_map[question] | |
else: | |
q = math_question_selector_map[question] | |
return q["agent_response"]["llama"][0], q["agent_response"]["wizardlm"][0], q["agent_response"]["orca"][0], q["summarization"][0], q["agent_response"]["llama"][1], q["agent_response"]["wizardlm"][1], q["agent_response"]["orca"][1], q["summarization"][1], q["agent_response"]["llama"][2], q["agent_response"]["wizardlm"][2], q["agent_response"]["orca"][2] | |
def gsm_display_question_answer(question, cot, request: gr.Request): | |
if cot: | |
q = gsm_cot_question_selector_map[question] | |
else: | |
q = gsm_question_selector_map[question] | |
return q["agent_response"]["llama"][0], q["agent_response"]["wizardlm"][0], q["agent_response"]["orca"][0], q["summarization"][0], q["agent_response"]["llama"][1], q["agent_response"]["wizardlm"][1], q["agent_response"]["orca"][1], q["summarization"][1], q["agent_response"]["llama"][2], q["agent_response"]["wizardlm"][2], q["agent_response"]["orca"][2] | |
def mmlu_display_question_answer(question, cot, request: gr.Request): | |
if cot: | |
q = mmlu_cot_question_selector_map[question] | |
else: | |
q = mmlu_question_selector_map[question] | |
return q["agent_response"]["llama"][0], q["agent_response"]["wizardlm"][0], q["agent_response"]["orca"][0], q["summarization"][0], q["agent_response"]["llama"][1], q["agent_response"]["wizardlm"][1], q["agent_response"]["orca"][1], q["summarization"][1], q["agent_response"]["llama"][2], q["agent_response"]["wizardlm"][2], q["agent_response"]["orca"][2] | |
warmup_test = ["llama2", "wizardlm", "orca"] | |
def warmup(model_list=warmup_test, model_inference_endpoints=inference_endpoint): | |
for model in model_list: | |
API_URL = model_inference_endpoints[model]["API_URL"] | |
headers = model_inference_endpoints[model]["headers"] | |
headers["Authorization"] += HF_TOKEN | |
def query(payload): | |
return requests.post(API_URL, headers=headers, json=payload) | |
output = query({ | |
"inputs": "Hello. " | |
}) | |
time.sleep(1) | |
return { | |
options: gr.update(visible=True), | |
inputbox: gr.update(visible=True), | |
submit: gr.update(visible=True), | |
warmup_button: gr.update(visible=False), | |
welcome_message: gr.update(visible=True) | |
} | |
def inference(model_list, question, API_KEY, cot, hf_token=HF_TOKEN): | |
if len(model_list) != 3: | |
raise gr.Error("Please choose just '3' models! Neither more nor less!") | |
for i in range(len(model_list)): | |
model_list[i] = model_list[i].lower() | |
model_response = Inference(model_list, question, API_KEY, cot, hf_token) | |
return { | |
output_msg: gr.update(visible=True), | |
output_col: gr.update(visible=True), | |
model1_output1: model_response["agent_response"][model_list[0]][0], | |
model2_output1: model_response["agent_response"][model_list[1]][0], | |
model3_output1: model_response["agent_response"][model_list[2]][0], | |
summarization_text1: model_response["summarization"][0], | |
model1_output2: model_response["agent_response"][model_list[0]][1], | |
model2_output2: model_response["agent_response"][model_list[1]][1], | |
model3_output2: model_response["agent_response"][model_list[2]][1], | |
summarization_text2: model_response["summarization"][1], | |
model1_output3: model_response["agent_response"][model_list[0]][2], | |
model2_output3: model_response["agent_response"][model_list[1]][2], | |
model3_output3: model_response["agent_response"][model_list[2]][2] | |
} | |
def load_responses(): | |
with open("result/Math/math_result.json", "r") as math_file: | |
math_responses = json.load(math_file) | |
with open("result/Math/math_result_cot.json", "r") as math_cot_file: | |
math_cot_responses = json.load(math_cot_file) | |
with open("result/GSM8K/gsm_result.json", "r") as gsm_file: | |
gsm_responses = json.load(gsm_file) | |
with open("result/GSM8K/gsm_result_cot.json", "r") as gsm_cot_file: | |
gsm_cot_responses = json.load(gsm_cot_file) | |
with open("result/MMLU/mmlu_result.json", "r") as mmlu_file: | |
mmlu_responses = json.load(mmlu_file) | |
with open("result/MMLU/mmlu_result_cot.json", "r") as mmlu_cot_file: | |
mmlu_cot_responses = json.load(mmlu_cot_file) | |
return math_responses, math_cot_responses, gsm_responses, gsm_cot_responses, mmlu_responses, mmlu_cot_responses | |
def load_questions(math, gsm, mmlu): | |
math_questions = [] | |
gsm_questions = [] | |
mmlu_questions = [] | |
for i in range(100): | |
math_questions.append(f"{i+1}: " + math[i]["question"][:128] + "...") | |
gsm_questions.append(f"{i+1}: " + gsm[i]["question"][:128] + "...") | |
mmlu_questions.append(f"{i+1}: " + mmlu[i]["question"][:128] + "...") | |
return math_questions, gsm_questions, mmlu_questions | |
math_result, math_cot_result, gsm_result, gsm_cot_result, mmlu_result, mmlu_cot_result = load_responses() | |
math_questions, gsm_questions, mmlu_questions = load_questions(math_result, gsm_result, mmlu_result) | |
math_question_selector_map = build_question_selector_map(math_result) | |
math_cot_question_selector_map = build_question_selector_map(math_cot_result) | |
gsm_question_selector_map = build_question_selector_map(gsm_result) | |
gsm_cot_question_selector_map = build_question_selector_map(gsm_cot_result) | |
mmlu_question_selector_map = build_question_selector_map(mmlu_result) | |
mmlu_cot_question_selector_map = build_question_selector_map(mmlu_cot_result) | |
TITLE = """<h1 align="center">LLM Agora 🗣️🏦</h1>""" | |
INTRODUCTION_TEXT = """ | |
The **LLM Agora** 🗣️🏦 aims to improve the quality of open-source LMs' responses through debate & revision introduced in [Improving Factuality and Reasoning in Language Models through Multiagent Debate](https://arxiv.org/abs/2305.14325). | |
Thank you to the authors of this paper for suggesting a great idea! | |
Do you know that? 🤔 **LLMs can also improve their responses by debating with other LLMs**! 😮 We applied this concept to several open-source LMs to verify that the open-source model, not the proprietary one, can sufficiently improve the response through discussion. 🤗 | |
For more details, please refer to the [GitHub Repository](https://github.com/gauss5930/LLM-Agora). | |
You can also check the results in this Space! | |
You can use LLM Agora with your own questions if the response of open-source LM is not satisfactory and you want to improve the quality! | |
The Math, GSM8K, and MMLU Tabs show the results of the experiment(Llama2, WizardLM2, Orca2), and for inference, please use the 'Inference' tab. | |
Here's how to use LLM Agora! | |
1. Before starting, click the 'Warm-up LLM Agora 🔥' button and wait until 'LLM Agora Ready!!' appears. (Suggest to go grab a coffee☕ since it takes 5 minutes!) | |
2. Choose just 3 models! Neither more nor less! | |
3. Check the CoT box if you want to utilize the Chain-of-Thought while inferencing. | |
4. Please fill in your OpenAI API KEY, it will be used to use ChatGPT to summarize the responses. | |
5. Type your question in the Question box and click the 'Submit' button! If you do so, LLM Agora will show you improved answers! 🤗 (It will take roughly a minute! Please wait for an answer!) | |
For more detailed information, please check '※ Specific information about LLM Agora' at the bottom of the page. | |
""" | |
WELCOME_TEXT = """<h1 align="center">🤗🔥 Welcome to LLM Agora 🔥🤗</h1>""" | |
RESPONSE_TEXT = """<h1 align="center">🤗 Here are the responses to each model!! 🤗</h1>""" | |
SPECIFIC_INFORMATION = """ | |
This is the specific information about LLM Agora! | |
**Tasks** | |
- Math: The problem of arithmetic operations on six randomly selected numbers. The format is '{}+{}*{}+{}-{}*{}=?' | |
- GSM8K: GSM8K is a dataset of 8.5K high quality linguistically diverse grade school math word problems created by human problem writers. | |
- MMLU: MMLU (Massive Multitask Language Understanding) is a new benchmark designed to measure knowledge acquired during pretraining by evaluating models exclusively in zero-shot and few-shot settings. | |
**Model size** | |
Besides Falcon, all other models are based on Llama2. | |
|Model name|Model size| | |
|---|---| | |
|Llama2|13B| | |
|Llama2-Chat|13B| | |
|Vicuna|13B| | |
|Falcon|7B| | |
|Falcon-Instruct|7B| | |
|WizardLM|13B| | |
|Orca|13B| | |
**Agent numbers & Debate rounds** | |
- We limit the number of agents and debate rounds because of the limitation of resources. As a result, we decided to use 3 agents and 2 rounds of debate! | |
**GitHub Repository** | |
- If you want to see more specific information, please check the [GitHub Repository](https://github.com/gauss5930/LLM-Agora) of LLM Agora! | |
**Citation** | |
``` | |
@article{du2023improving, | |
title={Improving Factuality and Reasoning in Language Models through Multiagent Debate}, | |
author={Du, Yilun and Li, Shuang and Torralba, Antonio and Tenenbaum, Joshua B and Mordatch, Igor}, | |
journal={arXiv preprint arXiv:2305.14325}, | |
year={2023} | |
} | |
``` | |
""" | |
with gr.Blocks() as demo: | |
gr.HTML(TITLE) | |
gr.Markdown(INTRODUCTION_TEXT) | |
with gr.Column(): | |
with gr.Tab("Inference"): | |
warmup_button = gr.Button("Warm-up LLM Agora 🔥", visible=True) | |
welcome_message = gr.HTML(WELCOME_TEXT, visible=False) | |
with gr.Row(visible=False) as options: | |
with gr.Column(): | |
model_list = gr.CheckboxGroup(["Llama2", "Llama2-Chat", "Vicuna", "Falcon", "Falcon-Instruct", "WizardLM", "Orca"], label="Model Selection", info="Choose 3 LMs to participate in LLM Agora.", type="value") | |
cot = gr.Checkbox(label="CoT", info="Do you want to use CoT for inference?") | |
with gr.Column() as inputbox: | |
API_KEY = gr.Textbox(label="OpenAI API Key", value="", info="Please fill in your OpenAI API token.", placeholder="sk..", type="password") | |
with gr.Column(visible=False) as inputbox: | |
question = gr.Textbox(label="Question", value="", info="Please type your question!", placeholder="") | |
submit = gr.Button("Submit", visible=False) | |
with gr.Row(visible=False) as output_msg: | |
gr.HTML(RESPONSE_TEXT) | |
with gr.Column(visible=False) as output_col: | |
with gr.Row(elem_id="model1_response"): | |
model1_output1 = gr.Textbox(label="1️⃣ model's initial response") | |
model2_output1 = gr.Textbox(label="2️⃣ model's initial response") | |
model3_output1 = gr.Textbox(label="3️⃣ model's initial response") | |
summarization_text1 = gr.Textbox(lebel="Summarization 1") | |
with gr.Row(elem_id="model2_response"): | |
model1_output2 = gr.Textbox(label="1️⃣ model's revised response") | |
model2_output2 = gr.Textbox(label="2️⃣ model's revised response") | |
model3_output2 = gr.Textbox(label="3️⃣ model's revised response") | |
summarization_text2 = gr.Textbox(label="Summarization 2") | |
with gr.Row(elem_id="model3_response"): | |
model1_output3 = gr.Textbox(label="1️⃣ model's final response") | |
model2_output3 = gr.Textbox(label="2️⃣ model's final response") | |
model3_output3 = gr.Textbox(label="3️⃣ model's final response") | |
with gr.Tab("Math"): | |
math_cot = gr.Checkbox(label="CoT", info="If you want to see CoT result, please check the box.") | |
math_question_list = gr.Dropdown(math_questions, value=callable(math_questions), label="Math Question", every=0.1) | |
with gr.Column(): | |
with gr.Row(elem_id="model1_response"): | |
math_model1_output1 = gr.Textbox(label="Llama2🦙's 1️⃣st response") | |
math_model2_output1 = gr.Textbox(label="WizardLM🧙♂️'s 1️⃣st response") | |
math_model3_output1 = gr.Textbox(label="Orca🐬's 1️⃣st response") | |
math_summarization_text1 = gr.Textbox(label="Summarization 1️⃣") | |
with gr.Row(elem_id="model2_response"): | |
math_model1_output2 = gr.Textbox(label="Llama2🦙's 2️⃣nd response") | |
math_model2_output2 = gr.Textbox(label="WizardLM🧙♂️'s 2️⃣nd response") | |
math_model3_output2 = gr.Textbox(label="Orca🐬's 2️⃣nd response") | |
math_summarization_text2 = gr.Textbox(label="Summarization 2️⃣") | |
with gr.Row(elem_id="model3_response"): | |
math_model1_output3 = gr.Textbox(label="Llama2🦙's 3️⃣rd response") | |
math_model2_output3 = gr.Textbox(label="WizardLM🧙♂️'s 3️⃣rd response") | |
math_model3_output3 = gr.Textbox(label="Orca🐬's 3️⃣rd response") | |
gr.HTML("""<h1 align="center"> The result of Math </h1>""") | |
gr.HTML("""<p align="center"><img src='https://github.com/gauss5930/LLM-Agora/assets/80087878/4fc22896-1306-4a93-bd54-a7a2ff184c98'></p>""") | |
math_cot.select( | |
math_display_question_answer, | |
[math_question_list, math_cot], | |
[math_model1_output1, math_model2_output1, math_model3_output1, math_summarization_text1, math_model1_output2, math_model2_output2, math_model3_output2, math_summarization_text2, math_model1_output3, math_model2_output3, math_model3_output3] | |
) | |
math_question_list.change( | |
math_display_question_answer, | |
[math_question_list, math_cot], | |
[math_model1_output1, math_model2_output1, math_model3_output1, math_summarization_text1, math_model1_output2, math_model2_output2, math_model3_output2, math_summarization_text2, math_model1_output3, math_model2_output3, math_model3_output3] | |
) | |
with gr.Tab("GSM8K"): | |
gsm_cot = gr.Checkbox(label="CoT", info="If you want to see CoT result, please check the box.") | |
gsm_question_list = gr.Dropdown(gsm_questions, value=callable(gsm_questions), label="GSM8K Question", every=0.1) | |
with gr.Column(): | |
with gr.Row(elem_id="model1_response"): | |
gsm_model1_output1 = gr.Textbox(label="Llama2🦙's 1️⃣st response") | |
gsm_model2_output1 = gr.Textbox(label="WizardLM🧙♂️'s 1️⃣st response") | |
gsm_model3_output1 = gr.Textbox(label="Orca🐬's 1️⃣st response") | |
gsm_summarization_text1 = gr.Textbox(label="Summarization 1️⃣") | |
with gr.Row(elem_id="model2_response"): | |
gsm_model1_output2 = gr.Textbox(label="Llama2🦙's 2️⃣nd response") | |
gsm_model2_output2 = gr.Textbox(label="WizardLM🧙♂️'s 2️⃣nd response") | |
gsm_model3_output2 = gr.Textbox(label="Orca🐬's 2️⃣nd response") | |
gsm_summarization_text2 = gr.Textbox(label="Summarization 2️⃣") | |
with gr.Row(elem_id="model3_response"): | |
gsm_model1_output3 = gr.Textbox(label="Llama2🦙's 3️⃣rd response") | |
gsm_model2_output3 = gr.Textbox(label="WizardLM🧙♂️'s 3️⃣rd response") | |
gsm_model3_output3 = gr.Textbox(label="Orca🐬's 3️⃣rd response") | |
gr.HTML("""<h1 align="center"> The result of GSM8K </h1>""") | |
gr.HTML("""<p align="center"><img src="https://github.com/gauss5930/LLM-Agora/assets/80087878/64f05ea4-5bec-41e4-83d7-d8855e753290"></p>""") | |
gsm_cot.select( | |
gsm_display_question_answer, | |
[gsm_question_list, gsm_cot], | |
[gsm_model1_output1, gsm_model2_output1, gsm_model3_output1, gsm_summarization_text1, gsm_model1_output2, gsm_model2_output2, gsm_model3_output2, gsm_summarization_text2, gsm_model1_output3, gsm_model2_output3, gsm_model3_output3] | |
) | |
gsm_question_list.change( | |
gsm_display_question_answer, | |
[gsm_question_list, gsm_cot], | |
[gsm_model1_output1, gsm_model2_output1, gsm_model3_output1, gsm_summarization_text1, gsm_model1_output2, gsm_model2_output2, gsm_model3_output2, gsm_summarization_text2, gsm_model1_output3, gsm_model2_output3, gsm_model3_output3] | |
) | |
with gr.Tab("MMLU"): | |
mmlu_cot = gr.Checkbox(label="CoT", info="If you want to see CoT result, please check the box.") | |
mmlu_question_list = gr.Dropdown(mmlu_questions, value=callable(mmlu_questions), label="MMLU Question", every=0.1) | |
with gr.Column(): | |
with gr.Row(elem_id="model1_response"): | |
mmlu_model1_output1 = gr.Textbox(label="Llama2🦙's 1️⃣st response") | |
mmlu_model2_output1 = gr.Textbox(label="WizardLM🧙♂️'s 1️⃣st response") | |
mmlu_model3_output1 = gr.Textbox(label="Orca🐬's 1️⃣st response") | |
mmlu_summarization_text1 = gr.Textbox(label="Summarization 1️⃣") | |
with gr.Row(elem_id="model2_response"): | |
mmlu_model1_output2 = gr.Textbox(label="Llama2🦙's 2️⃣nd response") | |
mmlu_model2_output2 = gr.Textbox(label="WizardLM🧙♂️'s 2️⃣nd response") | |
mmlu_model3_output2 = gr.Textbox(label="Orca🐬's 2️⃣nd response") | |
mmlu_summarization_text2 = gr.Textbox(label="Summarization 2️⃣") | |
with gr.Row(elem_id="model3_response"): | |
mmlu_model1_output3 = gr.Textbox(label="Llama2🦙's 3️⃣rd response") | |
mmlu_model2_output3 = gr.Textbox(label="WizardLM🧙♂️'s 3️⃣rd response") | |
mmlu_model3_output3 = gr.Textbox(label="Orca🐬's 3️⃣rd response") | |
gr.HTML("""<h1 align="center"> The result of MMLU </h1>""") | |
gr.HTML("""<p align="center"><img src="https://github.com/composable-models/llm_multiagent_debate/assets/80087878/963571aa-228b-4d73-9082-5f528552383e"></p>""") | |
mmlu_cot.select( | |
mmlu_display_question_answer, | |
[mmlu_question_list, mmlu_cot], | |
[mmlu_model1_output1, mmlu_model2_output1, mmlu_model3_output1, mmlu_summarization_text1, mmlu_model1_output2, mmlu_model2_output2, mmlu_model3_output2, mmlu_summarization_text2, mmlu_model1_output3, mmlu_model2_output3, mmlu_model3_output3] | |
) | |
mmlu_question_list.change( | |
mmlu_display_question_answer, | |
[mmlu_question_list, mmlu_cot], | |
[mmlu_model1_output1, mmlu_model2_output1, mmlu_model3_output1, mmlu_summarization_text1, mmlu_model1_output2, mmlu_model2_output2, mmlu_model3_output2, mmlu_summarization_text2, mmlu_model1_output3, mmlu_model2_output3, mmlu_model3_output3] | |
) | |
with gr.Accordion("※ Specific information about LLM Agora", open=False): | |
gr.Markdown(SPECIFIC_INFORMATION) | |
warmup_button.click(warmup, [], [options, inputbox, submit, warmup_button, welcome_message]) | |
submit.click(inference, [model_list, question, API_KEY, cot], [output_msg, output_col, model1_output1, model2_output1, model3_output1, summarization_text1, model1_output2, model2_output2, model3_output2, summarization_text2, model1_output3, model2_output3, model3_output3]) | |
demo.launch() |