Spaces:
Sleeping
Sleeping
import gradio as gr | |
import time | |
import os | |
from spinoza_project.source.backend.llm_utils import ( | |
get_llm_api, | |
get_vectorstore_api, | |
) | |
from spinoza_project.source.frontend.utils import ( | |
init_env, | |
parse_output_llm_with_sources, | |
) | |
from spinoza_project.source.frontend.gradio_utils import ( | |
get_sources, | |
set_prompts, | |
get_config, | |
get_prompts, | |
get_assets, | |
get_theme, | |
get_init_prompt, | |
get_synthesis_prompt, | |
get_qdrants, | |
get_qdrants_public, | |
start_agents, | |
end_agents, | |
next_call, | |
zip_longest_fill, | |
reformulate, | |
answer, | |
) | |
from assets.utils_javascript import ( | |
accordion_trigger, | |
accordion_trigger_end, | |
accordion_trigger_spinoza, | |
accordion_trigger_spinoza_end, | |
update_footer, | |
) | |
init_env() | |
config = get_config() | |
## Loading Prompts | |
print("Loading Prompts") | |
prompts = get_prompts(config) | |
chat_qa_prompts, chat_reformulation_prompts = set_prompts(prompts, config) | |
synthesis_prompt_template = get_synthesis_prompt(config) | |
## Building LLM | |
print("Building LLM") | |
groq_model_name = ( | |
config["groq_model_name"] if not os.getenv("EKI_OPENAI_LLM_DEPLOYMENT_NAME") else "" | |
) | |
llm = get_llm_api(groq_model_name) | |
## Loading BDDs | |
print("Loading Databases") | |
qdrants = get_qdrants(config) | |
if os.getenv("EKI_OPENAI_LLM_DEPLOYMENT_NAME"): | |
bdd_presse = get_vectorstore_api("presse") | |
bdd_afp = get_vectorstore_api("afp") | |
else: | |
qdrants_public = get_qdrants_public(config) | |
qdrants = {**qdrants, **qdrants_public} | |
bdd_presse = None | |
bdd_afp = None | |
## Loading Assets | |
css, source_information = get_assets() | |
theme = get_theme() | |
init_prompt = get_init_prompt() | |
def reformulate_questions( | |
question, | |
llm=llm, | |
chat_reformulation_prompts=chat_reformulation_prompts, | |
config=config, | |
): | |
for elt in zip_longest_fill( | |
*[ | |
reformulate(llm, chat_reformulation_prompts, question, tab, config=config) | |
for tab in config["tabs"] | |
] | |
): | |
time.sleep(0.02) | |
yield elt | |
def retrieve_sources( | |
*questions, | |
qdrants=qdrants, | |
bdd_presse=bdd_presse, | |
bdd_afp=bdd_afp, | |
config=config, | |
): | |
formated_sources, text_sources = get_sources( | |
questions, qdrants, bdd_presse, bdd_afp, config | |
) | |
return (formated_sources, *text_sources) | |
def answer_questions( | |
*questions_sources, llm=llm, chat_qa_prompts=chat_qa_prompts, config=config | |
): | |
questions = [elt for elt in questions_sources[: len(questions_sources) // 2]] | |
sources = [elt for elt in questions_sources[len(questions_sources) // 2 :]] | |
for elt in zip_longest_fill( | |
*[ | |
answer(llm, chat_qa_prompts, question, source, tab, config) | |
for question, source, tab in zip(questions, sources, config["tabs"]) | |
] | |
): | |
time.sleep(0.02) | |
yield [ | |
[(question, parse_output_llm_with_sources(ans))] | |
for question, ans in zip(questions, elt) | |
] | |
def get_synthesis( | |
question, | |
*answers, | |
llm=llm, | |
synthesis_prompt_template=synthesis_prompt_template, | |
config=config, | |
): | |
answer = [] | |
for i, tab in enumerate(config["tabs"]): | |
if len(str(answers[i])) >= 100: | |
answer.append( | |
f"{tab}\n{answers[i]}".replace("<p>", "").replace("</p>\n", "") | |
) | |
if len(answer) == 0: | |
return "Aucune source n'a pu être identifiée pour répondre, veuillez modifier votre question" | |
else: | |
for elt in llm.stream( | |
synthesis_prompt_template, | |
{ | |
"question": question.replace("<p>", "").replace("</p>\n", ""), | |
"answers": "\n\n".join(answer), | |
}, | |
): | |
time.sleep(0.01) | |
yield [(question, parse_output_llm_with_sources(elt))] | |
with gr.Blocks( | |
title=f"🔍 Spinoza", | |
css=css, | |
js=update_footer(), | |
theme=theme, | |
) as demo: | |
chatbots = {} | |
question = gr.State("") | |
docs_textbox = gr.State([""]) | |
agent_questions = {elt: gr.State("") for elt in config["tabs"]} | |
component_sources = {elt: gr.State("") for elt in config["tabs"]} | |
text_sources = {elt: gr.State("") for elt in config["tabs"]} | |
tab_states = {elt: gr.State(elt) for elt in config["tabs"]} | |
with gr.Tab("Q&A", elem_id="main-component"): | |
with gr.Row(elem_id="chatbot-row"): | |
with gr.Column(scale=2, elem_id="center-panel"): | |
with gr.Group(elem_id="chatbot-group"): | |
for tab in list(config["tabs"].keys()) + ["Spinoza"]: | |
if tab == "Spinoza": | |
agent_name = f"Spinoza" | |
elem_id = f"accordion-{tab}" | |
elem_classes = "accordion accordion-agent spinoza-agent" | |
else: | |
agent_name = f"Agent {config['source_mapping'][tab]}" | |
elem_id = f"accordion-{config['source_mapping'][tab]}" | |
elem_classes = "accordion accordion-agent" | |
with gr.Accordion( | |
agent_name, | |
open=True if agent_name == "Spinoza" else False, | |
elem_id=elem_id, | |
elem_classes=elem_classes, | |
): | |
# chatbot_key = agent_name.lower().replace(" ", "_") | |
chatbots[tab] = gr.Chatbot( | |
value=( | |
[(None, init_prompt)] | |
if agent_name == "Spinoza" | |
else None | |
), | |
show_copy_button=True, | |
show_share_button=False, | |
show_label=False, | |
elem_id=f"chatbot-{agent_name.lower().replace(' ', '-')}", | |
layout="panel", | |
avatar_images=( | |
"./assets/logos/help.png", | |
( | |
"./assets/logos/spinoza.png" | |
if agent_name == "Spinoza" | |
else None | |
), | |
), | |
) | |
with gr.Row(elem_id="input-message"): | |
ask = gr.Textbox( | |
placeholder="Ask me anything here!", | |
show_label=False, | |
scale=7, | |
lines=1, | |
interactive=True, | |
elem_id="input-textbox", | |
) | |
with gr.Column(scale=1, variant="panel", elem_id="right-panel"): | |
with gr.TabItem("Sources", elem_id="tab-sources", id=0): | |
sources_textbox = gr.HTML( | |
show_label=False, elem_id="sources-textbox" | |
) | |
with gr.Tab("Source information", elem_id="source-component"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown(source_information) | |
with gr.Tab("Contact", elem_id="contact-component"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("For any issue contact **[email protected]**.") | |
ask.submit( | |
start_agents, inputs=[], outputs=[chatbots["Spinoza"]], js=accordion_trigger() | |
).then( | |
fn=reformulate_questions, | |
inputs=[ask], | |
outputs=[agent_questions[tab] for tab in config["tabs"]], | |
).then( | |
fn=retrieve_sources, | |
inputs=[agent_questions[tab] for tab in config["tabs"]], | |
outputs=[sources_textbox] + [text_sources[tab] for tab in config["tabs"]], | |
).then( | |
fn=answer_questions, | |
inputs=[agent_questions[tab] for tab in config["tabs"]] | |
+ [text_sources[tab] for tab in config["tabs"]], | |
outputs=[chatbots[tab] for tab in config["tabs"]], | |
).then( | |
fn=next_call, inputs=[], outputs=[], js=accordion_trigger_end() | |
).then( | |
fn=next_call, inputs=[], outputs=[], js=accordion_trigger_spinoza() | |
).then( | |
fn=get_synthesis, | |
inputs=[agent_questions[list(config["tabs"].keys())[1]]] | |
+ [chatbots[tab] for tab in config["tabs"]], | |
outputs=[chatbots["Spinoza"]], | |
).then( | |
fn=next_call, inputs=[], outputs=[], js=accordion_trigger_spinoza_end() | |
).then( | |
fn=end_agents, inputs=[], outputs=[] | |
) | |
if __name__ == "__main__": | |
demo.queue().launch(debug=True) | |