V1.7 - Dynamic conversation & France local QA (#24)
Browse files- Small clean POC Local (0327516552f1b9af184361e8594b7377a4e0bcb2)
- Clean configs (ecab0c839ce7fc8100121a0174409451dbb5dd50)
- fix : fix gradio component (1e74d1402fb3c834dc9c0163aeb5028dbd6764bd)
- take the last question as history to understand the question (676d17bd25bb4cbd8cf9396ee353aac7a7028af0)
- Add follow up questions (aaa4bbedbbdee9951a33c8d7bbfd7583fa441bc5)
- Fix : Dynamic follow up examples (a984d58b3e8cf8458eec39ed29d0267590f8e75a)
- Merged in feature/dynamic_conversation (pull request #1) (f684aff68309e2f14164cef7fd02a6da804dcef9)
- Update style.css (e8258aea3313f558bb0afa650464d740fa4b2da4)
- Merge branch 'dev' of https://bitbucket.org/ekimetrics/climate_qa into dev (03a8baf608fd8a639c96a661725c181660122aa1)
- feat: implemented talk to drias v1 (4df74e4b75096daf90548ae7eefe6c32df002bcd)
- feat: added 2 new talk to data plots (170018666950c1b7433f6baba17ad36042282a2e)
- feat: added drias model choice and changed TTD UI (1eae86b617a8336ee3e4b7b11e39dc0769731812)
- fix: fixed bugs and errors (e6e652c91a221b56eb508cc98b13d27b5785e803)
- Merge branch 'dev' into feat/talk_to_data_graph (561caf15f2fd2954b43a49064ac11a085d343151)
- update css for follow up examples (54e2358ad0444473be0ac8d6c436ef73cfcceb3b)
- ensure correct output language (723be3263bfd1a048f671b7ef11219de2c10b736)
- feat: model filtering and UI upgrade for TTD (6155a631718e808fb6622083848a5fc3bf231937)
- feat: added list of tables and reduced execution time of TTD (0bdf2f6f90901ed5cd65fc99d4de1da8eab1788a)
- add documentation (161aa8c18cb4be2d0c5df0dc64f3c202e55c582d)
- Move hardcoded configuration in a config file (45a93206d1c3507a622ba4f40df17a339d5a8bad)
- make ask drias asynchronous (5fe15430ebbb47868615f715c00d6274fad0e7a7)
- Add drias indicators (989d387afc62cb38ae239e5cc97664f702e3fa54)
- Add examples (d4fa76b2cb0a8ed1ae0f482254acfbd0242fdc73)
- UI improvment (1d710d6428b17b25b42e14fa494c2823b7e74e81)
- split front element and event listening (ca2c42949b3de84aeaaabf5879f107cd829523a0)
- Merged in feat/talk_to_data_graph (pull request #3) (f2baf8741a4c6b9712f2cd9cec2f86ddb4ca4274)
- Merge branch 'dev' of https://bitbucket.org/ekimetrics/climate_qa into dev (5c04812e26e5faa6ee2a85867f0155c1e56b96a0)
- add logs of drias interactions (bc43b45669dfd4740dc90d6daa1128278857d5d3)
- Merged in dev (pull request #4) (6af9e984b3a852db3414a96307040d2bbc031871)
- rename tabs for prod (b09184847c18b284fd0d08532a1572722f9159d5)
Co-authored-by: Armand Demasson <[email protected]>
- app.py +416 -219
- climateqa/chat.py +32 -3
- climateqa/engine/chains/answer_rag.py +4 -2
- climateqa/engine/chains/follow_up.py +33 -0
- climateqa/engine/chains/intent_categorization.py +11 -4
- climateqa/engine/chains/retrieve_documents.py +2 -8
- climateqa/engine/chains/standalone_question.py +42 -0
- climateqa/engine/graph.py +36 -19
- climateqa/engine/talk_to_data/config.py +99 -0
- climateqa/engine/talk_to_data/main.py +100 -32
- climateqa/engine/talk_to_data/plot.py +402 -0
- climateqa/engine/talk_to_data/sql_query.py +113 -0
- climateqa/engine/talk_to_data/utils.py +232 -43
- climateqa/engine/talk_to_data/workflow.py +287 -0
- front/tabs/__init__.py +4 -1
- front/tabs/chat_interface.py +15 -12
- front/tabs/main_tab.py +59 -27
- front/tabs/tab_config.py +19 -29
- front/tabs/tab_drias.py +362 -0
- style.css +102 -10
@@ -9,13 +9,13 @@ from climateqa.engine.embeddings import get_embeddings_function
|
|
9 |
from climateqa.engine.llm import get_llm
|
10 |
from climateqa.engine.vectorstore import get_pinecone_vectorstore
|
11 |
from climateqa.engine.reranker import get_reranker
|
12 |
-
from climateqa.engine.graph import make_graph_agent,make_graph_agent_poc
|
13 |
from climateqa.engine.chains.retrieve_papers import find_papers
|
14 |
from climateqa.chat import start_chat, chat_stream, finish_chat
|
15 |
-
from climateqa.engine.talk_to_data.main import ask_vanna
|
16 |
-
from climateqa.engine.talk_to_data.myVanna import MyVanna
|
17 |
|
18 |
-
from front.tabs import
|
|
|
|
|
19 |
from front.utils import process_figures
|
20 |
from gradio_modal import Modal
|
21 |
|
@@ -24,14 +24,14 @@ from utils import create_user_id
|
|
24 |
import logging
|
25 |
|
26 |
logging.basicConfig(level=logging.WARNING)
|
27 |
-
os.environ[
|
28 |
logging.getLogger().setLevel(logging.WARNING)
|
29 |
|
30 |
|
31 |
-
|
32 |
# Load environment variables in local mode
|
33 |
try:
|
34 |
from dotenv import load_dotenv
|
|
|
35 |
load_dotenv()
|
36 |
except Exception as e:
|
37 |
pass
|
@@ -62,39 +62,94 @@ share_client = service.get_share_client(file_share_name)
|
|
62 |
user_id = create_user_id()
|
63 |
|
64 |
|
65 |
-
|
66 |
# Create vectorstore and retriever
|
67 |
embeddings_function = get_embeddings_function()
|
68 |
-
vectorstore = get_pinecone_vectorstore(
|
69 |
-
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
-
llm = get_llm(provider="openai",max_tokens
|
73 |
if os.environ["GRADIO_ENV"] == "local":
|
74 |
reranker = get_reranker("nano")
|
75 |
-
else
|
76 |
reranker = get_reranker("large")
|
77 |
|
78 |
-
agent = make_graph_agent(
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
print("chat cqa - message received")
|
92 |
-
async for event in chat_stream(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
yield event
|
94 |
-
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
print("chat poc - message received")
|
97 |
-
async for event in chat_stream(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
yield event
|
99 |
|
100 |
|
@@ -102,14 +157,17 @@ async def chat_poc(query, history, audience, sources, reports, relevant_content_
|
|
102 |
# Gradio
|
103 |
# --------------------------------------------------------------------
|
104 |
|
|
|
105 |
# Function to update modal visibility
|
106 |
def update_config_modal_visibility(config_open):
|
107 |
print(config_open)
|
108 |
new_config_visibility_status = not config_open
|
109 |
return Modal(visible=new_config_visibility_status), new_config_visibility_status
|
110 |
-
|
111 |
|
112 |
-
|
|
|
|
|
|
|
113 |
sources_number = sources_textbox.count("<h2>")
|
114 |
figures_number = figures_cards.count("<h2>")
|
115 |
graphs_number = current_graphs.count("<iframe")
|
@@ -118,229 +176,368 @@ def update_sources_number_display(sources_textbox, figures_cards, current_graphs
|
|
118 |
figures_notif_label = f"Figures ({figures_number})"
|
119 |
graphs_notif_label = f"Graphs ({graphs_number})"
|
120 |
papers_notif_label = f"Papers ({papers_number})"
|
121 |
-
recommended_content_notif_label =
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
with gr.Tab(tab_name):
|
145 |
-
with gr.Row(elem_id="chatbot-row"):
|
146 |
-
# Left column - Chat interface
|
147 |
-
with gr.Column(scale=2):
|
148 |
-
chatbot, textbox, config_button = create_chat_interface(tab_name)
|
149 |
-
|
150 |
-
# Right column - Content panels
|
151 |
-
with gr.Column(scale=2, variant="panel", elem_id="right-panel"):
|
152 |
-
with gr.Tabs(elem_id="right_panel_tab") as tabs:
|
153 |
-
# Examples tab
|
154 |
-
with gr.TabItem("Examples", elem_id="tab-examples", id=0):
|
155 |
-
examples_hidden = create_examples_tab(tab_name)
|
156 |
-
|
157 |
-
# Sources tab
|
158 |
-
with gr.Tab("Sources", elem_id="tab-sources", id=1) as tab_sources:
|
159 |
-
sources_textbox = gr.HTML(show_label=False, elem_id="sources-textbox")
|
160 |
-
|
161 |
-
|
162 |
-
# Recommended content tab
|
163 |
-
with gr.Tab("Recommended content", elem_id="tab-recommended_content", id=2) as tab_recommended_content:
|
164 |
-
with gr.Tabs(elem_id="group-subtabs") as tabs_recommended_content:
|
165 |
-
# Figures subtab
|
166 |
-
with gr.Tab("Figures", elem_id="tab-figures", id=3) as tab_figures:
|
167 |
-
sources_raw, new_figures, used_figures, gallery_component, figures_cards, figure_modal = create_figures_tab()
|
168 |
-
|
169 |
-
# Papers subtab
|
170 |
-
with gr.Tab("Papers", elem_id="tab-citations", id=4) as tab_papers:
|
171 |
-
papers_direct_search, papers_summary, papers_html, citations_network, papers_modal = create_papers_tab()
|
172 |
-
|
173 |
-
# Graphs subtab
|
174 |
-
with gr.Tab("Graphs", elem_id="tab-graphs", id=5) as tab_graphs:
|
175 |
-
graphs_container = gr.HTML(
|
176 |
-
"<h2>There are no graphs to be displayed at the moment. Try asking another question.</h2>",
|
177 |
-
elem_id="graphs-container"
|
178 |
-
)
|
179 |
-
|
180 |
-
|
181 |
-
return {
|
182 |
-
"chatbot": chatbot,
|
183 |
-
"textbox": textbox,
|
184 |
-
"tabs": tabs,
|
185 |
-
"sources_raw": sources_raw,
|
186 |
-
"new_figures": new_figures,
|
187 |
-
"current_graphs": current_graphs,
|
188 |
-
"examples_hidden": examples_hidden,
|
189 |
-
"sources_textbox": sources_textbox,
|
190 |
-
"figures_cards": figures_cards,
|
191 |
-
"gallery_component": gallery_component,
|
192 |
-
"config_button": config_button,
|
193 |
-
"papers_direct_search" : papers_direct_search,
|
194 |
-
"papers_html": papers_html,
|
195 |
-
"citations_network": citations_network,
|
196 |
-
"papers_summary": papers_summary,
|
197 |
-
"tab_recommended_content": tab_recommended_content,
|
198 |
-
"tab_sources": tab_sources,
|
199 |
-
"tab_figures": tab_figures,
|
200 |
-
"tab_graphs": tab_graphs,
|
201 |
-
"tab_papers": tab_papers,
|
202 |
-
"graph_container": graphs_container,
|
203 |
-
# "vanna_sql_query": vanna_sql_query,
|
204 |
-
# "vanna_table" : vanna_table,
|
205 |
-
# "vanna_display": vanna_display
|
206 |
-
}
|
207 |
-
|
208 |
-
def config_event_handling(main_tabs_components : list[dict], config_componenets : dict):
|
209 |
-
config_open = config_componenets["config_open"]
|
210 |
-
config_modal = config_componenets["config_modal"]
|
211 |
-
close_config_modal = config_componenets["close_config_modal_button"]
|
212 |
-
|
213 |
-
for button in [close_config_modal] + [main_tab_component["config_button"] for main_tab_component in main_tabs_components]:
|
214 |
button.click(
|
215 |
fn=update_config_modal_visibility,
|
216 |
inputs=[config_open],
|
217 |
-
outputs=[config_modal, config_open]
|
218 |
-
)
|
219 |
-
|
|
|
220 |
def event_handling(
|
221 |
-
main_tab_components,
|
222 |
-
config_components,
|
223 |
-
tab_name="ClimateQ&A"
|
224 |
):
|
225 |
-
chatbot = main_tab_components
|
226 |
-
textbox = main_tab_components
|
227 |
-
tabs = main_tab_components
|
228 |
-
sources_raw = main_tab_components
|
229 |
-
new_figures = main_tab_components
|
230 |
-
current_graphs = main_tab_components
|
231 |
-
examples_hidden = main_tab_components
|
232 |
-
sources_textbox = main_tab_components
|
233 |
-
figures_cards = main_tab_components
|
234 |
-
gallery_component = main_tab_components
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
dropdown_audience = config_components["dropdown_audience"]
|
258 |
-
after = config_components["after"]
|
259 |
-
output_query = config_components["output_query"]
|
260 |
-
output_language = config_components["output_language"]
|
261 |
-
# close_config_modal = config_components["close_config_modal_button"]
|
262 |
-
|
263 |
new_sources_hmtl = gr.State([])
|
264 |
ttd_data = gr.State([])
|
265 |
|
266 |
-
|
267 |
-
# for button in [config_button, close_config_modal]:
|
268 |
-
# button.click(
|
269 |
-
# fn=update_config_modal_visibility,
|
270 |
-
# inputs=[config_open],
|
271 |
-
# outputs=[config_modal, config_open]
|
272 |
-
# )
|
273 |
-
|
274 |
if tab_name == "ClimateQ&A":
|
275 |
print("chat cqa - message sent")
|
276 |
|
277 |
# Event for textbox
|
278 |
-
(
|
279 |
-
.submit(
|
280 |
-
|
281 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
282 |
)
|
283 |
# Event for examples_hidden
|
284 |
-
(
|
285 |
-
.change(
|
286 |
-
|
287 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
288 |
)
|
289 |
-
|
290 |
-
elif tab_name == "
|
291 |
print("chat poc - message sent")
|
292 |
# Event for textbox
|
293 |
-
(
|
294 |
-
.submit(
|
295 |
-
|
296 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
297 |
)
|
298 |
# Event for examples_hidden
|
299 |
-
(
|
300 |
-
.change(
|
301 |
-
|
302 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
303 |
)
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
309 |
|
310 |
# Update sources numbers
|
311 |
for component in [sources_textbox, figures_cards, current_graphs, papers_html]:
|
312 |
-
component.change(
|
313 |
-
|
|
|
|
|
|
|
|
|
314 |
# Search for papers
|
315 |
for component in [textbox, examples_hidden, papers_direct_search]:
|
316 |
-
component.submit(
|
317 |
-
|
|
|
|
|
|
|
318 |
|
319 |
-
# if tab_name == "
|
320 |
# # Drias search
|
321 |
# textbox.submit(ask_vanna, [textbox], [vanna_sql_query ,vanna_table, vanna_display])
|
322 |
|
|
|
323 |
def main_ui():
|
324 |
# config_open = gr.State(True)
|
325 |
-
with gr.Blocks(
|
326 |
-
|
327 |
-
|
|
|
|
|
|
|
|
|
|
|
328 |
with gr.Tabs():
|
329 |
-
cqa_components = cqa_tab(tab_name
|
330 |
-
local_cqa_components = cqa_tab(tab_name
|
331 |
-
create_drias_tab()
|
332 |
-
|
333 |
create_about_tab()
|
334 |
-
|
335 |
-
event_handling(cqa_components, config_components, tab_name
|
336 |
-
event_handling(
|
337 |
-
|
338 |
-
|
339 |
-
|
|
|
|
|
340 |
demo.queue()
|
341 |
-
|
342 |
return demo
|
343 |
|
344 |
-
|
345 |
demo = main_ui()
|
346 |
demo.launch(ssr_mode=False)
|
|
|
9 |
from climateqa.engine.llm import get_llm
|
10 |
from climateqa.engine.vectorstore import get_pinecone_vectorstore
|
11 |
from climateqa.engine.reranker import get_reranker
|
12 |
+
from climateqa.engine.graph import make_graph_agent, make_graph_agent_poc
|
13 |
from climateqa.engine.chains.retrieve_papers import find_papers
|
14 |
from climateqa.chat import start_chat, chat_stream, finish_chat
|
|
|
|
|
15 |
|
16 |
+
from front.tabs import create_config_modal, cqa_tab, create_about_tab
|
17 |
+
from front.tabs import MainTabPanel, ConfigPanel
|
18 |
+
from front.tabs.tab_drias import create_drias_tab
|
19 |
from front.utils import process_figures
|
20 |
from gradio_modal import Modal
|
21 |
|
|
|
24 |
import logging
|
25 |
|
26 |
logging.basicConfig(level=logging.WARNING)
|
27 |
+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Suppresses INFO and WARNING logs
|
28 |
logging.getLogger().setLevel(logging.WARNING)
|
29 |
|
30 |
|
|
|
31 |
# Load environment variables in local mode
|
32 |
try:
|
33 |
from dotenv import load_dotenv
|
34 |
+
|
35 |
load_dotenv()
|
36 |
except Exception as e:
|
37 |
pass
|
|
|
62 |
user_id = create_user_id()
|
63 |
|
64 |
|
|
|
65 |
# Create vectorstore and retriever
|
66 |
embeddings_function = get_embeddings_function()
|
67 |
+
vectorstore = get_pinecone_vectorstore(
|
68 |
+
embeddings_function, index_name=os.getenv("PINECONE_API_INDEX")
|
69 |
+
)
|
70 |
+
vectorstore_graphs = get_pinecone_vectorstore(
|
71 |
+
embeddings_function,
|
72 |
+
index_name=os.getenv("PINECONE_API_INDEX_OWID"),
|
73 |
+
text_key="description",
|
74 |
+
)
|
75 |
+
vectorstore_region = get_pinecone_vectorstore(
|
76 |
+
embeddings_function, index_name=os.getenv("PINECONE_API_INDEX_LOCAL_V2")
|
77 |
+
)
|
78 |
|
79 |
+
llm = get_llm(provider="openai", max_tokens=1024, temperature=0.0)
|
80 |
if os.environ["GRADIO_ENV"] == "local":
|
81 |
reranker = get_reranker("nano")
|
82 |
+
else:
|
83 |
reranker = get_reranker("large")
|
84 |
|
85 |
+
agent = make_graph_agent(
|
86 |
+
llm=llm,
|
87 |
+
vectorstore_ipcc=vectorstore,
|
88 |
+
vectorstore_graphs=vectorstore_graphs,
|
89 |
+
vectorstore_region=vectorstore_region,
|
90 |
+
reranker=reranker,
|
91 |
+
threshold_docs=0.2,
|
92 |
+
)
|
93 |
+
agent_poc = make_graph_agent_poc(
|
94 |
+
llm=llm,
|
95 |
+
vectorstore_ipcc=vectorstore,
|
96 |
+
vectorstore_graphs=vectorstore_graphs,
|
97 |
+
vectorstore_region=vectorstore_region,
|
98 |
+
reranker=reranker,
|
99 |
+
threshold_docs=0,
|
100 |
+
version="v4",
|
101 |
+
) # TODO put back default 0.2
|
102 |
+
|
103 |
+
|
104 |
+
|
105 |
+
|
106 |
+
async def chat(
|
107 |
+
query,
|
108 |
+
history,
|
109 |
+
audience,
|
110 |
+
sources,
|
111 |
+
reports,
|
112 |
+
relevant_content_sources_selection,
|
113 |
+
search_only,
|
114 |
+
):
|
115 |
print("chat cqa - message received")
|
116 |
+
async for event in chat_stream(
|
117 |
+
agent,
|
118 |
+
query,
|
119 |
+
history,
|
120 |
+
audience,
|
121 |
+
sources,
|
122 |
+
reports,
|
123 |
+
relevant_content_sources_selection,
|
124 |
+
search_only,
|
125 |
+
share_client,
|
126 |
+
user_id,
|
127 |
+
):
|
128 |
yield event
|
129 |
+
|
130 |
+
|
131 |
+
async def chat_poc(
|
132 |
+
query,
|
133 |
+
history,
|
134 |
+
audience,
|
135 |
+
sources,
|
136 |
+
reports,
|
137 |
+
relevant_content_sources_selection,
|
138 |
+
search_only,
|
139 |
+
):
|
140 |
print("chat poc - message received")
|
141 |
+
async for event in chat_stream(
|
142 |
+
agent_poc,
|
143 |
+
query,
|
144 |
+
history,
|
145 |
+
audience,
|
146 |
+
sources,
|
147 |
+
reports,
|
148 |
+
relevant_content_sources_selection,
|
149 |
+
search_only,
|
150 |
+
share_client,
|
151 |
+
user_id,
|
152 |
+
):
|
153 |
yield event
|
154 |
|
155 |
|
|
|
157 |
# Gradio
|
158 |
# --------------------------------------------------------------------
|
159 |
|
160 |
+
|
161 |
# Function to update modal visibility
|
162 |
def update_config_modal_visibility(config_open):
|
163 |
print(config_open)
|
164 |
new_config_visibility_status = not config_open
|
165 |
return Modal(visible=new_config_visibility_status), new_config_visibility_status
|
|
|
166 |
|
167 |
+
|
168 |
+
def update_sources_number_display(
|
169 |
+
sources_textbox, figures_cards, current_graphs, papers_html
|
170 |
+
):
|
171 |
sources_number = sources_textbox.count("<h2>")
|
172 |
figures_number = figures_cards.count("<h2>")
|
173 |
graphs_number = current_graphs.count("<iframe")
|
|
|
176 |
figures_notif_label = f"Figures ({figures_number})"
|
177 |
graphs_notif_label = f"Graphs ({graphs_number})"
|
178 |
papers_notif_label = f"Papers ({papers_number})"
|
179 |
+
recommended_content_notif_label = (
|
180 |
+
f"Recommended content ({figures_number + graphs_number + papers_number})"
|
181 |
+
)
|
182 |
+
|
183 |
+
return (
|
184 |
+
gr.update(label=recommended_content_notif_label),
|
185 |
+
gr.update(label=sources_notif_label),
|
186 |
+
gr.update(label=figures_notif_label),
|
187 |
+
gr.update(label=graphs_notif_label),
|
188 |
+
gr.update(label=papers_notif_label),
|
189 |
+
)
|
190 |
+
|
191 |
+
|
192 |
+
def config_event_handling(
|
193 |
+
main_tabs_components: list[MainTabPanel], config_componenets: ConfigPanel
|
194 |
+
):
|
195 |
+
config_open = config_componenets.config_open
|
196 |
+
config_modal = config_componenets.config_modal
|
197 |
+
close_config_modal = config_componenets.close_config_modal_button
|
198 |
+
|
199 |
+
for button in [close_config_modal] + [
|
200 |
+
main_tab_component.config_button for main_tab_component in main_tabs_components
|
201 |
+
]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
button.click(
|
203 |
fn=update_config_modal_visibility,
|
204 |
inputs=[config_open],
|
205 |
+
outputs=[config_modal, config_open],
|
206 |
+
)
|
207 |
+
|
208 |
+
|
209 |
def event_handling(
|
210 |
+
main_tab_components: MainTabPanel,
|
211 |
+
config_components: ConfigPanel,
|
212 |
+
tab_name="ClimateQ&A",
|
213 |
):
|
214 |
+
chatbot = main_tab_components.chatbot
|
215 |
+
textbox = main_tab_components.textbox
|
216 |
+
tabs = main_tab_components.tabs
|
217 |
+
sources_raw = main_tab_components.sources_raw
|
218 |
+
new_figures = main_tab_components.new_figures
|
219 |
+
current_graphs = main_tab_components.current_graphs
|
220 |
+
examples_hidden = main_tab_components.examples_hidden
|
221 |
+
sources_textbox = main_tab_components.sources_textbox
|
222 |
+
figures_cards = main_tab_components.figures_cards
|
223 |
+
gallery_component = main_tab_components.gallery_component
|
224 |
+
papers_direct_search = main_tab_components.papers_direct_search
|
225 |
+
papers_html = main_tab_components.papers_html
|
226 |
+
citations_network = main_tab_components.citations_network
|
227 |
+
papers_summary = main_tab_components.papers_summary
|
228 |
+
tab_recommended_content = main_tab_components.tab_recommended_content
|
229 |
+
tab_sources = main_tab_components.tab_sources
|
230 |
+
tab_figures = main_tab_components.tab_figures
|
231 |
+
tab_graphs = main_tab_components.tab_graphs
|
232 |
+
tab_papers = main_tab_components.tab_papers
|
233 |
+
graphs_container = main_tab_components.graph_container
|
234 |
+
follow_up_examples = main_tab_components.follow_up_examples
|
235 |
+
follow_up_examples_hidden = main_tab_components.follow_up_examples_hidden
|
236 |
+
|
237 |
+
dropdown_sources = config_components.dropdown_sources
|
238 |
+
dropdown_reports = config_components.dropdown_reports
|
239 |
+
dropdown_external_sources = config_components.dropdown_external_sources
|
240 |
+
search_only = config_components.search_only
|
241 |
+
dropdown_audience = config_components.dropdown_audience
|
242 |
+
after = config_components.after
|
243 |
+
output_query = config_components.output_query
|
244 |
+
output_language = config_components.output_language
|
245 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
246 |
new_sources_hmtl = gr.State([])
|
247 |
ttd_data = gr.State([])
|
248 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
249 |
if tab_name == "ClimateQ&A":
|
250 |
print("chat cqa - message sent")
|
251 |
|
252 |
# Event for textbox
|
253 |
+
(
|
254 |
+
textbox.submit(
|
255 |
+
start_chat,
|
256 |
+
[textbox, chatbot, search_only],
|
257 |
+
[textbox, tabs, chatbot, sources_raw],
|
258 |
+
queue=False,
|
259 |
+
api_name=f"start_chat_{textbox.elem_id}",
|
260 |
+
)
|
261 |
+
.then(
|
262 |
+
chat,
|
263 |
+
[
|
264 |
+
textbox,
|
265 |
+
chatbot,
|
266 |
+
dropdown_audience,
|
267 |
+
dropdown_sources,
|
268 |
+
dropdown_reports,
|
269 |
+
dropdown_external_sources,
|
270 |
+
search_only,
|
271 |
+
],
|
272 |
+
[
|
273 |
+
chatbot,
|
274 |
+
new_sources_hmtl,
|
275 |
+
output_query,
|
276 |
+
output_language,
|
277 |
+
new_figures,
|
278 |
+
current_graphs,
|
279 |
+
follow_up_examples.dataset,
|
280 |
+
],
|
281 |
+
concurrency_limit=8,
|
282 |
+
api_name=f"chat_{textbox.elem_id}",
|
283 |
+
)
|
284 |
+
.then(
|
285 |
+
finish_chat, None, [textbox], api_name=f"finish_chat_{textbox.elem_id}"
|
286 |
+
)
|
287 |
)
|
288 |
# Event for examples_hidden
|
289 |
+
(
|
290 |
+
examples_hidden.change(
|
291 |
+
start_chat,
|
292 |
+
[examples_hidden, chatbot, search_only],
|
293 |
+
[examples_hidden, tabs, chatbot, sources_raw],
|
294 |
+
queue=False,
|
295 |
+
api_name=f"start_chat_{examples_hidden.elem_id}",
|
296 |
+
)
|
297 |
+
.then(
|
298 |
+
chat,
|
299 |
+
[
|
300 |
+
examples_hidden,
|
301 |
+
chatbot,
|
302 |
+
dropdown_audience,
|
303 |
+
dropdown_sources,
|
304 |
+
dropdown_reports,
|
305 |
+
dropdown_external_sources,
|
306 |
+
search_only,
|
307 |
+
],
|
308 |
+
[
|
309 |
+
chatbot,
|
310 |
+
new_sources_hmtl,
|
311 |
+
output_query,
|
312 |
+
output_language,
|
313 |
+
new_figures,
|
314 |
+
current_graphs,
|
315 |
+
follow_up_examples.dataset,
|
316 |
+
],
|
317 |
+
concurrency_limit=8,
|
318 |
+
api_name=f"chat_{examples_hidden.elem_id}",
|
319 |
+
)
|
320 |
+
.then(
|
321 |
+
finish_chat,
|
322 |
+
None,
|
323 |
+
[textbox],
|
324 |
+
api_name=f"finish_chat_{examples_hidden.elem_id}",
|
325 |
+
)
|
326 |
+
)
|
327 |
+
(
|
328 |
+
follow_up_examples_hidden.change(
|
329 |
+
start_chat,
|
330 |
+
[follow_up_examples_hidden, chatbot, search_only],
|
331 |
+
[follow_up_examples_hidden, tabs, chatbot, sources_raw],
|
332 |
+
queue=False,
|
333 |
+
api_name=f"start_chat_{examples_hidden.elem_id}",
|
334 |
+
)
|
335 |
+
.then(
|
336 |
+
chat,
|
337 |
+
[
|
338 |
+
follow_up_examples_hidden,
|
339 |
+
chatbot,
|
340 |
+
dropdown_audience,
|
341 |
+
dropdown_sources,
|
342 |
+
dropdown_reports,
|
343 |
+
dropdown_external_sources,
|
344 |
+
search_only,
|
345 |
+
],
|
346 |
+
[
|
347 |
+
chatbot,
|
348 |
+
new_sources_hmtl,
|
349 |
+
output_query,
|
350 |
+
output_language,
|
351 |
+
new_figures,
|
352 |
+
current_graphs,
|
353 |
+
follow_up_examples.dataset,
|
354 |
+
],
|
355 |
+
concurrency_limit=8,
|
356 |
+
api_name=f"chat_{examples_hidden.elem_id}",
|
357 |
+
)
|
358 |
+
.then(
|
359 |
+
finish_chat,
|
360 |
+
None,
|
361 |
+
[textbox],
|
362 |
+
api_name=f"finish_chat_{follow_up_examples_hidden.elem_id}",
|
363 |
+
)
|
364 |
)
|
365 |
+
|
366 |
+
elif tab_name == "France - Local Q&A":
|
367 |
print("chat poc - message sent")
|
368 |
# Event for textbox
|
369 |
+
(
|
370 |
+
textbox.submit(
|
371 |
+
start_chat,
|
372 |
+
[textbox, chatbot, search_only],
|
373 |
+
[textbox, tabs, chatbot, sources_raw],
|
374 |
+
queue=False,
|
375 |
+
api_name=f"start_chat_{textbox.elem_id}",
|
376 |
+
)
|
377 |
+
.then(
|
378 |
+
chat_poc,
|
379 |
+
[
|
380 |
+
textbox,
|
381 |
+
chatbot,
|
382 |
+
dropdown_audience,
|
383 |
+
dropdown_sources,
|
384 |
+
dropdown_reports,
|
385 |
+
dropdown_external_sources,
|
386 |
+
search_only,
|
387 |
+
],
|
388 |
+
[
|
389 |
+
chatbot,
|
390 |
+
new_sources_hmtl,
|
391 |
+
output_query,
|
392 |
+
output_language,
|
393 |
+
new_figures,
|
394 |
+
current_graphs,
|
395 |
+
],
|
396 |
+
concurrency_limit=8,
|
397 |
+
api_name=f"chat_{textbox.elem_id}",
|
398 |
+
)
|
399 |
+
.then(
|
400 |
+
finish_chat, None, [textbox], api_name=f"finish_chat_{textbox.elem_id}"
|
401 |
+
)
|
402 |
)
|
403 |
# Event for examples_hidden
|
404 |
+
(
|
405 |
+
examples_hidden.change(
|
406 |
+
start_chat,
|
407 |
+
[examples_hidden, chatbot, search_only],
|
408 |
+
[examples_hidden, tabs, chatbot, sources_raw],
|
409 |
+
queue=False,
|
410 |
+
api_name=f"start_chat_{examples_hidden.elem_id}",
|
411 |
+
)
|
412 |
+
.then(
|
413 |
+
chat_poc,
|
414 |
+
[
|
415 |
+
examples_hidden,
|
416 |
+
chatbot,
|
417 |
+
dropdown_audience,
|
418 |
+
dropdown_sources,
|
419 |
+
dropdown_reports,
|
420 |
+
dropdown_external_sources,
|
421 |
+
search_only,
|
422 |
+
],
|
423 |
+
[
|
424 |
+
chatbot,
|
425 |
+
new_sources_hmtl,
|
426 |
+
output_query,
|
427 |
+
output_language,
|
428 |
+
new_figures,
|
429 |
+
current_graphs,
|
430 |
+
],
|
431 |
+
concurrency_limit=8,
|
432 |
+
api_name=f"chat_{examples_hidden.elem_id}",
|
433 |
+
)
|
434 |
+
.then(
|
435 |
+
finish_chat,
|
436 |
+
None,
|
437 |
+
[textbox],
|
438 |
+
api_name=f"finish_chat_{examples_hidden.elem_id}",
|
439 |
+
)
|
440 |
+
)
|
441 |
+
(
|
442 |
+
follow_up_examples_hidden.change(
|
443 |
+
start_chat,
|
444 |
+
[follow_up_examples_hidden, chatbot, search_only],
|
445 |
+
[follow_up_examples_hidden, tabs, chatbot, sources_raw],
|
446 |
+
queue=False,
|
447 |
+
api_name=f"start_chat_{examples_hidden.elem_id}",
|
448 |
+
)
|
449 |
+
.then(
|
450 |
+
chat,
|
451 |
+
[
|
452 |
+
follow_up_examples_hidden,
|
453 |
+
chatbot,
|
454 |
+
dropdown_audience,
|
455 |
+
dropdown_sources,
|
456 |
+
dropdown_reports,
|
457 |
+
dropdown_external_sources,
|
458 |
+
search_only,
|
459 |
+
],
|
460 |
+
[
|
461 |
+
chatbot,
|
462 |
+
new_sources_hmtl,
|
463 |
+
output_query,
|
464 |
+
output_language,
|
465 |
+
new_figures,
|
466 |
+
current_graphs,
|
467 |
+
follow_up_examples.dataset,
|
468 |
+
],
|
469 |
+
concurrency_limit=8,
|
470 |
+
api_name=f"chat_{examples_hidden.elem_id}",
|
471 |
+
)
|
472 |
+
.then(
|
473 |
+
finish_chat,
|
474 |
+
None,
|
475 |
+
[textbox],
|
476 |
+
api_name=f"finish_chat_{follow_up_examples_hidden.elem_id}",
|
477 |
+
)
|
478 |
)
|
479 |
+
|
480 |
+
new_sources_hmtl.change(
|
481 |
+
lambda x: x, inputs=[new_sources_hmtl], outputs=[sources_textbox]
|
482 |
+
)
|
483 |
+
current_graphs.change(
|
484 |
+
lambda x: x, inputs=[current_graphs], outputs=[graphs_container]
|
485 |
+
)
|
486 |
+
new_figures.change(
|
487 |
+
process_figures,
|
488 |
+
inputs=[sources_raw, new_figures],
|
489 |
+
outputs=[sources_raw, figures_cards, gallery_component],
|
490 |
+
)
|
491 |
|
492 |
# Update sources numbers
|
493 |
for component in [sources_textbox, figures_cards, current_graphs, papers_html]:
|
494 |
+
component.change(
|
495 |
+
update_sources_number_display,
|
496 |
+
[sources_textbox, figures_cards, current_graphs, papers_html],
|
497 |
+
[tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers],
|
498 |
+
)
|
499 |
+
|
500 |
# Search for papers
|
501 |
for component in [textbox, examples_hidden, papers_direct_search]:
|
502 |
+
component.submit(
|
503 |
+
find_papers,
|
504 |
+
[component, after, dropdown_external_sources],
|
505 |
+
[papers_html, citations_network, papers_summary],
|
506 |
+
)
|
507 |
|
508 |
+
# if tab_name == "France - Local Q&A": # Not untill results are good enough
|
509 |
# # Drias search
|
510 |
# textbox.submit(ask_vanna, [textbox], [vanna_sql_query ,vanna_table, vanna_display])
|
511 |
|
512 |
+
|
513 |
def main_ui():
|
514 |
# config_open = gr.State(True)
|
515 |
+
with gr.Blocks(
|
516 |
+
title="Climate Q&A",
|
517 |
+
css_paths=os.getcwd() + "/style.css",
|
518 |
+
theme=theme,
|
519 |
+
elem_id="main-component",
|
520 |
+
) as demo:
|
521 |
+
config_components = create_config_modal()
|
522 |
+
|
523 |
with gr.Tabs():
|
524 |
+
cqa_components = cqa_tab(tab_name="ClimateQ&A")
|
525 |
+
local_cqa_components = cqa_tab(tab_name="France - Local Q&A")
|
526 |
+
create_drias_tab(share_client=share_client, user_id=user_id)
|
527 |
+
|
528 |
create_about_tab()
|
529 |
+
|
530 |
+
event_handling(cqa_components, config_components, tab_name="ClimateQ&A")
|
531 |
+
event_handling(
|
532 |
+
local_cqa_components, config_components, tab_name="France - Local Q&A"
|
533 |
+
)
|
534 |
+
|
535 |
+
config_event_handling([cqa_components, local_cqa_components], config_components)
|
536 |
+
|
537 |
demo.queue()
|
538 |
+
|
539 |
return demo
|
540 |
|
541 |
+
|
542 |
demo = main_ui()
|
543 |
demo.launch(ssr_mode=False)
|
@@ -61,6 +61,27 @@ def handle_numerical_data(event):
|
|
61 |
return numerical_data, sql_query
|
62 |
return None, None
|
63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
# Main chat function
|
65 |
async def chat_stream(
|
66 |
agent : CompiledStateGraph,
|
@@ -101,6 +122,7 @@ async def chat_stream(
|
|
101 |
audience_prompt = init_audience(audience)
|
102 |
sources = sources or ["IPCC", "IPBES"]
|
103 |
reports = reports or []
|
|
|
104 |
|
105 |
# Prepare inputs for agent
|
106 |
inputs = {
|
@@ -109,7 +131,8 @@ async def chat_stream(
|
|
109 |
"sources_input": sources,
|
110 |
"relevant_content_sources_selection": relevant_content_sources_selection,
|
111 |
"search_only": search_only,
|
112 |
-
"reports": reports
|
|
|
113 |
}
|
114 |
|
115 |
# Get streaming events from agent
|
@@ -129,6 +152,7 @@ async def chat_stream(
|
|
129 |
retrieved_contents = []
|
130 |
answer_message_content = ""
|
131 |
vanna_data = {}
|
|
|
132 |
|
133 |
# Define processing steps
|
134 |
steps_display = {
|
@@ -200,7 +224,12 @@ async def chat_stream(
|
|
200 |
sub_questions = [q["question"] + "-> relevant sources : " + str(q["sources"]) for q in event["data"]["output"]["questions_list"]]
|
201 |
history[-1].content += "Decompose question into sub-questions:\n\n - " + "\n - ".join(sub_questions)
|
202 |
|
203 |
-
|
|
|
|
|
|
|
|
|
|
|
204 |
|
205 |
except Exception as e:
|
206 |
print(f"Event {event} has failed")
|
@@ -211,4 +240,4 @@ async def chat_stream(
|
|
211 |
# Call the function to log interaction
|
212 |
log_interaction_to_azure(history, output_query, sources, docs, share_client, user_id)
|
213 |
|
214 |
-
yield history, docs_html, output_query, output_language, related_contents, graphs_html#, vanna_data
|
|
|
61 |
return numerical_data, sql_query
|
62 |
return None, None
|
63 |
|
64 |
+
def log_drias_interaction_to_azure(query, sql_query, data, share_client, user_id):
|
65 |
+
try:
|
66 |
+
# Log interaction to Azure if not in local environment
|
67 |
+
if os.getenv("GRADIO_ENV") != "local":
|
68 |
+
timestamp = str(datetime.now().timestamp())
|
69 |
+
logs = {
|
70 |
+
"user_id": str(user_id),
|
71 |
+
"query": query,
|
72 |
+
"sql_query": sql_query,
|
73 |
+
# "data": data.to_dict() if data is not None else None,
|
74 |
+
"time": timestamp,
|
75 |
+
}
|
76 |
+
log_on_azure(f"drias_{timestamp}.json", logs, share_client)
|
77 |
+
print(f"Logged Drias interaction to Azure Blob Storage: {logs}")
|
78 |
+
else:
|
79 |
+
print("share_client or user_id is None, or GRADIO_ENV is local")
|
80 |
+
except Exception as e:
|
81 |
+
print(f"Error logging Drias interaction on Azure Blob Storage: {e}")
|
82 |
+
error_msg = f"Drias Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)"
|
83 |
+
raise gr.Error(error_msg)
|
84 |
+
|
85 |
# Main chat function
|
86 |
async def chat_stream(
|
87 |
agent : CompiledStateGraph,
|
|
|
122 |
audience_prompt = init_audience(audience)
|
123 |
sources = sources or ["IPCC", "IPBES"]
|
124 |
reports = reports or []
|
125 |
+
relevant_history_discussion = history[-2:] if len(history) > 1 else []
|
126 |
|
127 |
# Prepare inputs for agent
|
128 |
inputs = {
|
|
|
131 |
"sources_input": sources,
|
132 |
"relevant_content_sources_selection": relevant_content_sources_selection,
|
133 |
"search_only": search_only,
|
134 |
+
"reports": reports,
|
135 |
+
"chat_history": relevant_history_discussion,
|
136 |
}
|
137 |
|
138 |
# Get streaming events from agent
|
|
|
152 |
retrieved_contents = []
|
153 |
answer_message_content = ""
|
154 |
vanna_data = {}
|
155 |
+
follow_up_examples = gr.Dataset(samples=[])
|
156 |
|
157 |
# Define processing steps
|
158 |
steps_display = {
|
|
|
224 |
sub_questions = [q["question"] + "-> relevant sources : " + str(q["sources"]) for q in event["data"]["output"]["questions_list"]]
|
225 |
history[-1].content += "Decompose question into sub-questions:\n\n - " + "\n - ".join(sub_questions)
|
226 |
|
227 |
+
# Handle follow up questions
|
228 |
+
if event["name"] == "generate_follow_up" and event["event"] == "on_chain_end":
|
229 |
+
follow_up_examples = event["data"]["output"].get("follow_up_questions", [])
|
230 |
+
follow_up_examples = gr.Dataset(samples= [ [question] for question in follow_up_examples ])
|
231 |
+
|
232 |
+
yield history, docs_html, output_query, output_language, related_contents, graphs_html, follow_up_examples#, vanna_data
|
233 |
|
234 |
except Exception as e:
|
235 |
print(f"Event {event} has failed")
|
|
|
240 |
# Call the function to log interaction
|
241 |
log_interaction_to_azure(history, output_query, sources, docs, share_client, user_id)
|
242 |
|
243 |
+
yield history, docs_html, output_query, output_language, related_contents, graphs_html, follow_up_examples#, vanna_data
|
@@ -65,6 +65,7 @@ def make_rag_node(llm,with_docs = True):
|
|
65 |
async def answer_rag(state,config):
|
66 |
print("---- Answer RAG ----")
|
67 |
start_time = time.time()
|
|
|
68 |
print("Sources used : " + "\n".join([x.metadata["short_name"] + " - page " + str(x.metadata["page_number"]) for x in state["documents"]]))
|
69 |
|
70 |
answer = await rag_chain.ainvoke(state,config)
|
@@ -73,9 +74,10 @@ def make_rag_node(llm,with_docs = True):
|
|
73 |
elapsed_time = end_time - start_time
|
74 |
print("RAG elapsed time: ", elapsed_time)
|
75 |
print("Answer size : ", len(answer))
|
76 |
-
# print(f"\n\nAnswer:\n{answer}")
|
77 |
|
78 |
-
|
|
|
|
|
79 |
|
80 |
return answer_rag
|
81 |
|
|
|
65 |
async def answer_rag(state,config):
|
66 |
print("---- Answer RAG ----")
|
67 |
start_time = time.time()
|
68 |
+
chat_history = state.get("chat_history",[])
|
69 |
print("Sources used : " + "\n".join([x.metadata["short_name"] + " - page " + str(x.metadata["page_number"]) for x in state["documents"]]))
|
70 |
|
71 |
answer = await rag_chain.ainvoke(state,config)
|
|
|
74 |
elapsed_time = end_time - start_time
|
75 |
print("RAG elapsed time: ", elapsed_time)
|
76 |
print("Answer size : ", len(answer))
|
|
|
77 |
|
78 |
+
chat_history.append({"question":state["query"],"answer":answer})
|
79 |
+
|
80 |
+
return {"answer":answer,"chat_history": chat_history}
|
81 |
|
82 |
return answer_rag
|
83 |
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
from langchain.prompts import ChatPromptTemplate
|
3 |
+
|
4 |
+
|
5 |
+
FOLLOW_UP_TEMPLATE = """Based on the previous question and answer, generate 2-3 relevant follow-up questions that would help explore the topic further.
|
6 |
+
|
7 |
+
Previous Question: {user_input}
|
8 |
+
Previous Answer: {answer}
|
9 |
+
|
10 |
+
Generate short, concise, focused follow-up questions
|
11 |
+
You don't need a full question as it will be reformulated later as a standalone question with the context. Eg. "Details the first point"
|
12 |
+
"""
|
13 |
+
|
14 |
+
def make_follow_up_node(llm):
|
15 |
+
prompt = ChatPromptTemplate.from_template(FOLLOW_UP_TEMPLATE)
|
16 |
+
|
17 |
+
def generate_follow_up(state):
|
18 |
+
print("---- Generate_follow_up ----")
|
19 |
+
if not state.get("answer"):
|
20 |
+
return state
|
21 |
+
|
22 |
+
response = llm.invoke(prompt.format(
|
23 |
+
user_input=state["user_input"],
|
24 |
+
answer=state["answer"]
|
25 |
+
))
|
26 |
+
|
27 |
+
# Extract questions from response
|
28 |
+
follow_ups = [q.strip() for q in response.content.split("\n") if q.strip()]
|
29 |
+
state["follow_up_questions"] = follow_ups
|
30 |
+
|
31 |
+
return state
|
32 |
+
|
33 |
+
return generate_follow_up
|
@@ -1,4 +1,3 @@
|
|
1 |
-
|
2 |
from langchain_core.pydantic_v1 import BaseModel, Field
|
3 |
from typing import List
|
4 |
from typing import Literal
|
@@ -44,7 +43,7 @@ def make_intent_categorization_chain(llm):
|
|
44 |
llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"IntentCategorizer"})
|
45 |
|
46 |
prompt = ChatPromptTemplate.from_messages([
|
47 |
-
("system", "You are a helpful assistant, you will analyze,
|
48 |
("user", "input: {input}")
|
49 |
])
|
50 |
|
@@ -58,11 +57,19 @@ def make_intent_categorization_node(llm):
|
|
58 |
|
59 |
def categorize_message(state):
|
60 |
print("---- Categorize_message ----")
|
|
|
61 |
|
62 |
output = categorization_chain.invoke({"input": state["user_input"]})
|
63 |
-
print(f"\n\
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
output["query"] = state["user_input"]
|
|
|
66 |
return output
|
67 |
|
68 |
return categorize_message
|
|
|
|
|
1 |
from langchain_core.pydantic_v1 import BaseModel, Field
|
2 |
from typing import List
|
3 |
from typing import Literal
|
|
|
43 |
llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"IntentCategorizer"})
|
44 |
|
45 |
prompt = ChatPromptTemplate.from_messages([
|
46 |
+
("system", "You are a helpful assistant, you will analyze, detect the language, and categorize the user input message using the function provided. You MUST detect and return the language of the input message. Categorize the user input as ai ONLY if it is related to Artificial Intelligence, search if it is related to the environment, climate change, energy, biodiversity, nature, etc. and chitchat if it is just general conversation."),
|
47 |
("user", "input: {input}")
|
48 |
])
|
49 |
|
|
|
57 |
|
58 |
def categorize_message(state):
|
59 |
print("---- Categorize_message ----")
|
60 |
+
print(f"Input state: {state}")
|
61 |
|
62 |
output = categorization_chain.invoke({"input": state["user_input"]})
|
63 |
+
print(f"\n\nRaw output from categorization: {output}\n")
|
64 |
+
|
65 |
+
if "language" not in output:
|
66 |
+
print("WARNING: Language field missing from output, setting default to English")
|
67 |
+
output["language"] = "English"
|
68 |
+
else:
|
69 |
+
print(f"Language detected: {output['language']}")
|
70 |
+
|
71 |
output["query"] = state["user_input"]
|
72 |
+
print(f"Final output: {output}")
|
73 |
return output
|
74 |
|
75 |
return categorize_message
|
@@ -621,10 +621,7 @@ def make_IPx_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_
|
|
621 |
|
622 |
def make_POC_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
|
623 |
|
624 |
-
async def retrieve_POC_docs_node(state, config):
|
625 |
-
if "POC region" not in state["relevant_content_sources_selection"] :
|
626 |
-
return {}
|
627 |
-
|
628 |
source_type = "POC"
|
629 |
POC_questions_index = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "POC"]
|
630 |
|
@@ -665,10 +662,7 @@ def make_POC_by_ToC_retriever_node(
|
|
665 |
k_summary=5,
|
666 |
):
|
667 |
|
668 |
-
async def retrieve_POC_docs_node(state, config):
|
669 |
-
if "POC region" not in state["relevant_content_sources_selection"] :
|
670 |
-
return {}
|
671 |
-
|
672 |
search_figures = "Figures (IPCC/IPBES)" in state["relevant_content_sources_selection"]
|
673 |
search_only = state["search_only"]
|
674 |
search_only = state["search_only"]
|
|
|
621 |
|
622 |
def make_POC_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
|
623 |
|
624 |
+
async def retrieve_POC_docs_node(state, config):
|
|
|
|
|
|
|
625 |
source_type = "POC"
|
626 |
POC_questions_index = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "POC"]
|
627 |
|
|
|
662 |
k_summary=5,
|
663 |
):
|
664 |
|
665 |
+
async def retrieve_POC_docs_node(state, config):
|
|
|
|
|
|
|
666 |
search_figures = "Figures (IPCC/IPBES)" in state["relevant_content_sources_selection"]
|
667 |
search_only = state["search_only"]
|
668 |
search_only = state["search_only"]
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.prompts import ChatPromptTemplate
|
2 |
+
|
3 |
+
def make_standalone_question_chain(llm):
|
4 |
+
prompt = ChatPromptTemplate.from_messages([
|
5 |
+
("system", """You are a helpful assistant that transforms user questions into standalone questions
|
6 |
+
by incorporating context from the chat history if needed. The output should be a self-contained
|
7 |
+
question that can be understood without any additional context.
|
8 |
+
|
9 |
+
Examples:
|
10 |
+
Chat History: "Let's talk about renewable energy"
|
11 |
+
User Input: "What about solar?"
|
12 |
+
Output: "What are the key aspects of solar energy as a renewable energy source?"
|
13 |
+
|
14 |
+
Chat History: "What causes global warming?"
|
15 |
+
User Input: "And what are its effects?"
|
16 |
+
Output: "What are the effects of global warming on the environment and society?"
|
17 |
+
"""),
|
18 |
+
("user", """Chat History: {chat_history}
|
19 |
+
User Question: {question}
|
20 |
+
|
21 |
+
Transform this into a standalone question:
|
22 |
+
Make sure to keep the original language of the question.""")
|
23 |
+
])
|
24 |
+
|
25 |
+
chain = prompt | llm
|
26 |
+
return chain
|
27 |
+
|
28 |
+
def make_standalone_question_node(llm):
|
29 |
+
standalone_chain = make_standalone_question_chain(llm)
|
30 |
+
|
31 |
+
def transform_to_standalone(state):
|
32 |
+
chat_history = state.get("chat_history", "")
|
33 |
+
if chat_history == "":
|
34 |
+
return {}
|
35 |
+
output = standalone_chain.invoke({
|
36 |
+
"chat_history": chat_history,
|
37 |
+
"question": state["user_input"]
|
38 |
+
})
|
39 |
+
state["user_input"] = output.content
|
40 |
+
return state
|
41 |
+
|
42 |
+
return transform_to_standalone
|
@@ -23,13 +23,15 @@ from .chains.retrieve_documents import make_IPx_retriever_node, make_POC_retriev
|
|
23 |
from .chains.answer_rag import make_rag_node
|
24 |
from .chains.graph_retriever import make_graph_retriever_node
|
25 |
from .chains.chitchat_categorization import make_chitchat_intent_categorization_node
|
26 |
-
|
|
|
27 |
|
28 |
class GraphState(TypedDict):
|
29 |
"""
|
30 |
Represents the state of our graph.
|
31 |
"""
|
32 |
user_input : str
|
|
|
33 |
language : str
|
34 |
intent : str
|
35 |
search_graphs_chitchat : bool
|
@@ -49,6 +51,7 @@ class GraphState(TypedDict):
|
|
49 |
recommended_content : List[Document] # OWID Graphs # TODO merge with related_contents
|
50 |
search_only : bool = False
|
51 |
reports : List[str] = []
|
|
|
52 |
|
53 |
def dummy(state):
|
54 |
return
|
@@ -100,15 +103,6 @@ def route_continue_retrieve_documents(state):
|
|
100 |
else:
|
101 |
return "retrieve_documents"
|
102 |
|
103 |
-
def route_continue_retrieve_local_documents(state):
|
104 |
-
index_question_poc = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "POC"]
|
105 |
-
questions_poc_finished = all(elem in state["handled_questions_index"] for elem in index_question_poc)
|
106 |
-
# if questions_poc_finished and state["search_only"]:
|
107 |
-
# return END
|
108 |
-
if questions_poc_finished or ("POC region" not in state["relevant_content_sources_selection"]):
|
109 |
-
return "end_retrieve_local_documents"
|
110 |
-
else:
|
111 |
-
return "retrieve_local_data"
|
112 |
|
113 |
def route_retrieve_documents(state):
|
114 |
sources_to_retrieve = []
|
@@ -120,6 +114,11 @@ def route_retrieve_documents(state):
|
|
120 |
return END
|
121 |
return sources_to_retrieve
|
122 |
|
|
|
|
|
|
|
|
|
|
|
123 |
def make_id_dict(values):
|
124 |
return {k:k for k in values}
|
125 |
|
@@ -128,6 +127,7 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_regi
|
|
128 |
workflow = StateGraph(GraphState)
|
129 |
|
130 |
# Define the node functions
|
|
|
131 |
categorize_intent = make_intent_categorization_node(llm)
|
132 |
transform_query = make_query_transform_node(llm)
|
133 |
translate_query = make_translation_node(llm)
|
@@ -139,9 +139,11 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_regi
|
|
139 |
answer_rag = make_rag_node(llm, with_docs=True)
|
140 |
answer_rag_no_docs = make_rag_node(llm, with_docs=False)
|
141 |
chitchat_categorize_intent = make_chitchat_intent_categorization_node(llm)
|
|
|
142 |
|
143 |
# Define the nodes
|
144 |
# workflow.add_node("set_defaults", set_defaults)
|
|
|
145 |
workflow.add_node("categorize_intent", categorize_intent)
|
146 |
workflow.add_node("answer_climate", dummy)
|
147 |
workflow.add_node("answer_search", answer_search)
|
@@ -155,9 +157,11 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_regi
|
|
155 |
workflow.add_node("retrieve_documents", retrieve_documents)
|
156 |
workflow.add_node("answer_rag", answer_rag)
|
157 |
workflow.add_node("answer_rag_no_docs", answer_rag_no_docs)
|
|
|
|
|
158 |
|
159 |
# Entry point
|
160 |
-
workflow.set_entry_point("
|
161 |
|
162 |
# CONDITIONAL EDGES
|
163 |
workflow.add_conditional_edges(
|
@@ -189,20 +193,29 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_regi
|
|
189 |
make_id_dict(["retrieve_graphs", END])
|
190 |
)
|
191 |
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
# Define the edges
|
|
|
193 |
workflow.add_edge("translate_query", "transform_query")
|
194 |
workflow.add_edge("transform_query", "retrieve_documents") #TODO put back
|
195 |
# workflow.add_edge("transform_query", "retrieve_local_data")
|
196 |
# workflow.add_edge("transform_query", END) # TODO remove
|
197 |
|
198 |
workflow.add_edge("retrieve_graphs", END)
|
199 |
-
workflow.add_edge("answer_rag",
|
200 |
-
workflow.add_edge("answer_rag_no_docs",
|
201 |
workflow.add_edge("answer_chitchat", "chitchat_categorize_intent")
|
202 |
workflow.add_edge("retrieve_graphs_chitchat", END)
|
203 |
|
204 |
# workflow.add_edge("retrieve_local_data", "answer_search")
|
205 |
workflow.add_edge("retrieve_documents", "answer_search")
|
|
|
|
|
206 |
|
207 |
# Compile
|
208 |
app = workflow.compile()
|
@@ -228,6 +241,8 @@ def make_graph_agent_poc(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_
|
|
228 |
workflow = StateGraph(GraphState)
|
229 |
|
230 |
# Define the node functions
|
|
|
|
|
231 |
categorize_intent = make_intent_categorization_node(llm)
|
232 |
transform_query = make_query_transform_node(llm)
|
233 |
translate_query = make_translation_node(llm)
|
@@ -240,9 +255,11 @@ def make_graph_agent_poc(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_
|
|
240 |
answer_rag = make_rag_node(llm, with_docs=True)
|
241 |
answer_rag_no_docs = make_rag_node(llm, with_docs=False)
|
242 |
chitchat_categorize_intent = make_chitchat_intent_categorization_node(llm)
|
|
|
243 |
|
244 |
# Define the nodes
|
245 |
# workflow.add_node("set_defaults", set_defaults)
|
|
|
246 |
workflow.add_node("categorize_intent", categorize_intent)
|
247 |
workflow.add_node("answer_climate", dummy)
|
248 |
workflow.add_node("answer_search", answer_search)
|
@@ -258,9 +275,10 @@ def make_graph_agent_poc(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_
|
|
258 |
workflow.add_node("retrieve_documents", retrieve_documents)
|
259 |
workflow.add_node("answer_rag", answer_rag)
|
260 |
workflow.add_node("answer_rag_no_docs", answer_rag_no_docs)
|
|
|
261 |
|
262 |
# Entry point
|
263 |
-
workflow.set_entry_point("
|
264 |
|
265 |
# CONDITIONAL EDGES
|
266 |
workflow.add_conditional_edges(
|
@@ -293,22 +311,21 @@ def make_graph_agent_poc(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_
|
|
293 |
)
|
294 |
|
295 |
# Define the edges
|
|
|
296 |
workflow.add_edge("translate_query", "transform_query")
|
297 |
workflow.add_edge("transform_query", "retrieve_documents") #TODO put back
|
298 |
workflow.add_edge("transform_query", "retrieve_local_data")
|
299 |
# workflow.add_edge("transform_query", END) # TODO remove
|
300 |
|
301 |
workflow.add_edge("retrieve_graphs", END)
|
302 |
-
workflow.add_edge("answer_rag",
|
303 |
-
workflow.add_edge("answer_rag_no_docs",
|
304 |
workflow.add_edge("answer_chitchat", "chitchat_categorize_intent")
|
305 |
workflow.add_edge("retrieve_graphs_chitchat", END)
|
306 |
|
307 |
workflow.add_edge("retrieve_local_data", "answer_search")
|
308 |
workflow.add_edge("retrieve_documents", "answer_search")
|
309 |
-
|
310 |
-
# workflow.add_edge("transform_query", "retrieve_drias_data")
|
311 |
-
# workflow.add_edge("retrieve_drias_data", END)
|
312 |
|
313 |
|
314 |
# Compile
|
|
|
23 |
from .chains.answer_rag import make_rag_node
|
24 |
from .chains.graph_retriever import make_graph_retriever_node
|
25 |
from .chains.chitchat_categorization import make_chitchat_intent_categorization_node
|
26 |
+
from .chains.standalone_question import make_standalone_question_node
|
27 |
+
from .chains.follow_up import make_follow_up_node # Add this import
|
28 |
|
29 |
class GraphState(TypedDict):
|
30 |
"""
|
31 |
Represents the state of our graph.
|
32 |
"""
|
33 |
user_input : str
|
34 |
+
chat_history : str
|
35 |
language : str
|
36 |
intent : str
|
37 |
search_graphs_chitchat : bool
|
|
|
51 |
recommended_content : List[Document] # OWID Graphs # TODO merge with related_contents
|
52 |
search_only : bool = False
|
53 |
reports : List[str] = []
|
54 |
+
follow_up_questions: List[str] = []
|
55 |
|
56 |
def dummy(state):
|
57 |
return
|
|
|
103 |
else:
|
104 |
return "retrieve_documents"
|
105 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
def route_retrieve_documents(state):
|
108 |
sources_to_retrieve = []
|
|
|
114 |
return END
|
115 |
return sources_to_retrieve
|
116 |
|
117 |
+
def route_follow_up(state):
|
118 |
+
if state["follow_up_questions"]:
|
119 |
+
return "process_follow_up"
|
120 |
+
return END
|
121 |
+
|
122 |
def make_id_dict(values):
|
123 |
return {k:k for k in values}
|
124 |
|
|
|
127 |
workflow = StateGraph(GraphState)
|
128 |
|
129 |
# Define the node functions
|
130 |
+
standalone_question_node = make_standalone_question_node(llm)
|
131 |
categorize_intent = make_intent_categorization_node(llm)
|
132 |
transform_query = make_query_transform_node(llm)
|
133 |
translate_query = make_translation_node(llm)
|
|
|
139 |
answer_rag = make_rag_node(llm, with_docs=True)
|
140 |
answer_rag_no_docs = make_rag_node(llm, with_docs=False)
|
141 |
chitchat_categorize_intent = make_chitchat_intent_categorization_node(llm)
|
142 |
+
generate_follow_up = make_follow_up_node(llm)
|
143 |
|
144 |
# Define the nodes
|
145 |
# workflow.add_node("set_defaults", set_defaults)
|
146 |
+
workflow.add_node("standalone_question", standalone_question_node)
|
147 |
workflow.add_node("categorize_intent", categorize_intent)
|
148 |
workflow.add_node("answer_climate", dummy)
|
149 |
workflow.add_node("answer_search", answer_search)
|
|
|
157 |
workflow.add_node("retrieve_documents", retrieve_documents)
|
158 |
workflow.add_node("answer_rag", answer_rag)
|
159 |
workflow.add_node("answer_rag_no_docs", answer_rag_no_docs)
|
160 |
+
workflow.add_node("generate_follow_up", generate_follow_up)
|
161 |
+
# workflow.add_node("process_follow_up", standalone_question_node)
|
162 |
|
163 |
# Entry point
|
164 |
+
workflow.set_entry_point("standalone_question")
|
165 |
|
166 |
# CONDITIONAL EDGES
|
167 |
workflow.add_conditional_edges(
|
|
|
193 |
make_id_dict(["retrieve_graphs", END])
|
194 |
)
|
195 |
|
196 |
+
# workflow.add_conditional_edges(
|
197 |
+
# "generate_follow_up",
|
198 |
+
# route_follow_up,
|
199 |
+
# make_id_dict(["process_follow_up", END])
|
200 |
+
# )
|
201 |
+
|
202 |
# Define the edges
|
203 |
+
workflow.add_edge("standalone_question", "categorize_intent")
|
204 |
workflow.add_edge("translate_query", "transform_query")
|
205 |
workflow.add_edge("transform_query", "retrieve_documents") #TODO put back
|
206 |
# workflow.add_edge("transform_query", "retrieve_local_data")
|
207 |
# workflow.add_edge("transform_query", END) # TODO remove
|
208 |
|
209 |
workflow.add_edge("retrieve_graphs", END)
|
210 |
+
workflow.add_edge("answer_rag", "generate_follow_up")
|
211 |
+
workflow.add_edge("answer_rag_no_docs", "generate_follow_up")
|
212 |
workflow.add_edge("answer_chitchat", "chitchat_categorize_intent")
|
213 |
workflow.add_edge("retrieve_graphs_chitchat", END)
|
214 |
|
215 |
# workflow.add_edge("retrieve_local_data", "answer_search")
|
216 |
workflow.add_edge("retrieve_documents", "answer_search")
|
217 |
+
workflow.add_edge("generate_follow_up",END)
|
218 |
+
# workflow.add_edge("process_follow_up", "categorize_intent")
|
219 |
|
220 |
# Compile
|
221 |
app = workflow.compile()
|
|
|
241 |
workflow = StateGraph(GraphState)
|
242 |
|
243 |
# Define the node functions
|
244 |
+
standalone_question_node = make_standalone_question_node(llm)
|
245 |
+
|
246 |
categorize_intent = make_intent_categorization_node(llm)
|
247 |
transform_query = make_query_transform_node(llm)
|
248 |
translate_query = make_translation_node(llm)
|
|
|
255 |
answer_rag = make_rag_node(llm, with_docs=True)
|
256 |
answer_rag_no_docs = make_rag_node(llm, with_docs=False)
|
257 |
chitchat_categorize_intent = make_chitchat_intent_categorization_node(llm)
|
258 |
+
generate_follow_up = make_follow_up_node(llm)
|
259 |
|
260 |
# Define the nodes
|
261 |
# workflow.add_node("set_defaults", set_defaults)
|
262 |
+
workflow.add_node("standalone_question", standalone_question_node)
|
263 |
workflow.add_node("categorize_intent", categorize_intent)
|
264 |
workflow.add_node("answer_climate", dummy)
|
265 |
workflow.add_node("answer_search", answer_search)
|
|
|
275 |
workflow.add_node("retrieve_documents", retrieve_documents)
|
276 |
workflow.add_node("answer_rag", answer_rag)
|
277 |
workflow.add_node("answer_rag_no_docs", answer_rag_no_docs)
|
278 |
+
workflow.add_node("generate_follow_up", generate_follow_up)
|
279 |
|
280 |
# Entry point
|
281 |
+
workflow.set_entry_point("standalone_question")
|
282 |
|
283 |
# CONDITIONAL EDGES
|
284 |
workflow.add_conditional_edges(
|
|
|
311 |
)
|
312 |
|
313 |
# Define the edges
|
314 |
+
workflow.add_edge("standalone_question", "categorize_intent")
|
315 |
workflow.add_edge("translate_query", "transform_query")
|
316 |
workflow.add_edge("transform_query", "retrieve_documents") #TODO put back
|
317 |
workflow.add_edge("transform_query", "retrieve_local_data")
|
318 |
# workflow.add_edge("transform_query", END) # TODO remove
|
319 |
|
320 |
workflow.add_edge("retrieve_graphs", END)
|
321 |
+
workflow.add_edge("answer_rag", "generate_follow_up")
|
322 |
+
workflow.add_edge("answer_rag_no_docs", "generate_follow_up")
|
323 |
workflow.add_edge("answer_chitchat", "chitchat_categorize_intent")
|
324 |
workflow.add_edge("retrieve_graphs_chitchat", END)
|
325 |
|
326 |
workflow.add_edge("retrieve_local_data", "answer_search")
|
327 |
workflow.add_edge("retrieve_documents", "answer_search")
|
328 |
+
workflow.add_edge("generate_follow_up",END)
|
|
|
|
|
329 |
|
330 |
|
331 |
# Compile
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
DRIAS_TABLES = [
|
2 |
+
"total_winter_precipitation",
|
3 |
+
"total_summer_precipiation",
|
4 |
+
"total_annual_precipitation",
|
5 |
+
"total_remarkable_daily_precipitation",
|
6 |
+
"frequency_of_remarkable_daily_precipitation",
|
7 |
+
"extreme_precipitation_intensity",
|
8 |
+
"mean_winter_temperature",
|
9 |
+
"mean_summer_temperature",
|
10 |
+
"mean_annual_temperature",
|
11 |
+
"number_of_tropical_nights",
|
12 |
+
"maximum_summer_temperature",
|
13 |
+
"number_of_days_with_tx_above_30",
|
14 |
+
"number_of_days_with_tx_above_35",
|
15 |
+
"number_of_days_with_a_dry_ground",
|
16 |
+
]
|
17 |
+
|
18 |
+
INDICATOR_COLUMNS_PER_TABLE = {
|
19 |
+
"total_winter_precipitation": "total_winter_precipitation",
|
20 |
+
"total_summer_precipiation": "total_summer_precipitation",
|
21 |
+
"total_annual_precipitation": "total_annual_precipitation",
|
22 |
+
"total_remarkable_daily_precipitation": "total_remarkable_daily_precipitation",
|
23 |
+
"frequency_of_remarkable_daily_precipitation": "frequency_of_remarkable_daily_precipitation",
|
24 |
+
"extreme_precipitation_intensity": "extreme_precipitation_intensity",
|
25 |
+
"mean_winter_temperature": "mean_winter_temperature",
|
26 |
+
"mean_summer_temperature": "mean_summer_temperature",
|
27 |
+
"mean_annual_temperature": "mean_annual_temperature",
|
28 |
+
"number_of_tropical_nights": "number_tropical_nights",
|
29 |
+
"maximum_summer_temperature": "maximum_summer_temperature",
|
30 |
+
"number_of_days_with_tx_above_30": "number_of_days_with_tx_above_30",
|
31 |
+
"number_of_days_with_tx_above_35": "number_of_days_with_tx_above_35",
|
32 |
+
"number_of_days_with_a_dry_ground": "number_of_days_with_dry_ground"
|
33 |
+
}
|
34 |
+
|
35 |
+
DRIAS_MODELS = [
|
36 |
+
'ALL',
|
37 |
+
'RegCM4-6_MPI-ESM-LR',
|
38 |
+
'RACMO22E_EC-EARTH',
|
39 |
+
'RegCM4-6_HadGEM2-ES',
|
40 |
+
'HadREM3-GA7_EC-EARTH',
|
41 |
+
'HadREM3-GA7_CNRM-CM5',
|
42 |
+
'REMO2015_NorESM1-M',
|
43 |
+
'SMHI-RCA4_EC-EARTH',
|
44 |
+
'WRF381P_NorESM1-M',
|
45 |
+
'ALADIN63_CNRM-CM5',
|
46 |
+
'CCLM4-8-17_MPI-ESM-LR',
|
47 |
+
'HIRHAM5_IPSL-CM5A-MR',
|
48 |
+
'HadREM3-GA7_HadGEM2-ES',
|
49 |
+
'SMHI-RCA4_IPSL-CM5A-MR',
|
50 |
+
'HIRHAM5_NorESM1-M',
|
51 |
+
'REMO2009_MPI-ESM-LR',
|
52 |
+
'CCLM4-8-17_HadGEM2-ES'
|
53 |
+
]
|
54 |
+
# Mapping between indicator columns and their units
|
55 |
+
INDICATOR_TO_UNIT = {
|
56 |
+
"total_winter_precipitation": "mm",
|
57 |
+
"total_summer_precipitation": "mm",
|
58 |
+
"total_annual_precipitation": "mm",
|
59 |
+
"total_remarkable_daily_precipitation": "mm",
|
60 |
+
"frequency_of_remarkable_daily_precipitation": "days",
|
61 |
+
"extreme_precipitation_intensity": "mm",
|
62 |
+
"mean_winter_temperature": "°C",
|
63 |
+
"mean_summer_temperature": "°C",
|
64 |
+
"mean_annual_temperature": "°C",
|
65 |
+
"number_tropical_nights": "days",
|
66 |
+
"maximum_summer_temperature": "°C",
|
67 |
+
"number_of_days_with_tx_above_30": "days",
|
68 |
+
"number_of_days_with_tx_above_35": "days",
|
69 |
+
"number_of_days_with_dry_ground": "days"
|
70 |
+
}
|
71 |
+
|
72 |
+
DRIAS_UI_TEXT = """
|
73 |
+
Hi, I'm **Talk to Drias**, designed to answer your questions using [**DRIAS - TRACC 2023**](https://www.drias-climat.fr/accompagnement/sections/401) data.
|
74 |
+
I'll answer by displaying a list of SQL queries, graphs and data most relevant to your question.
|
75 |
+
|
76 |
+
❓ **How to use?**
|
77 |
+
You can ask me anything about these climate indicators: **temperature**, **precipitation** or **drought**.
|
78 |
+
You can specify **location** and/or **year**.
|
79 |
+
You can choose from a list of climate models. By default, we take the **average of each model**.
|
80 |
+
|
81 |
+
For example, you can ask:
|
82 |
+
- What will the temperature be like in Paris?
|
83 |
+
- What will be the total rainfall in France in 2030?
|
84 |
+
- How frequent will extreme events be in Lyon?
|
85 |
+
|
86 |
+
**Example of indicators in the data**:
|
87 |
+
- Mean temperature (annual, winter, summer)
|
88 |
+
- Total precipitation (annual, winter, summer)
|
89 |
+
- Number of days with remarkable precipitations, with dry ground, with temperature above 30°C
|
90 |
+
|
91 |
+
⚠️ **Limitations**:
|
92 |
+
- You can't ask anything that isn't related to **DRIAS - TRACC 2023** data.
|
93 |
+
- You can only ask about **locations in France**.
|
94 |
+
- If you specify a year, there may be **no data for that year for some models**.
|
95 |
+
- You **cannot compare two models**.
|
96 |
+
|
97 |
+
🛈 **Information**
|
98 |
+
Please note that we **log your questions for meta-analysis purposes**, so avoid sharing any sensitive or personal information.
|
99 |
+
"""
|
@@ -1,47 +1,115 @@
|
|
1 |
-
from climateqa.engine.talk_to_data.
|
2 |
-
from climateqa.engine.talk_to_data.utils import loc2coords, detect_location_with_openai, detectTable, nearestNeighbourSQL, detect_relevant_tables, replace_coordonates
|
3 |
-
import sqlite3
|
4 |
-
import os
|
5 |
-
import pandas as pd
|
6 |
from climateqa.engine.llm import get_llm
|
7 |
import ast
|
8 |
|
9 |
-
|
10 |
-
|
11 |
llm = get_llm(provider="openai")
|
12 |
|
13 |
-
def ask_llm_to_add_table_names(sql_query, llm):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
sql_with_table_names = llm.invoke(f"Make the following sql query display the source table in the rows {sql_query}. Just answer the query. The answer should not include ```sql\n").content
|
15 |
return sql_with_table_names
|
16 |
|
17 |
-
def ask_llm_column_names(sql_query, llm):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
columns = llm.invoke(f"From the given sql query, list the columns that are being selected. The answer should only be a python list. Just answer the list. The SQL query : {sql_query}").content
|
19 |
columns_list = ast.literal_eval(columns.strip("```python\n").strip())
|
20 |
return columns_list
|
21 |
|
22 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
-
|
25 |
-
location = detect_location_with_openai(query)
|
26 |
-
if location:
|
27 |
|
28 |
-
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
sql_query, result_dataframe, figure = vn.ask(user_input_with_coords, print_results=False, allow_llm_to_see_data=True, auto_train=False)
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
return "", empty_df, empty_fig
|
|
|
1 |
+
from climateqa.engine.talk_to_data.workflow import drias_workflow
|
|
|
|
|
|
|
|
|
2 |
from climateqa.engine.llm import get_llm
|
3 |
import ast
|
4 |
|
|
|
|
|
5 |
llm = get_llm(provider="openai")
|
6 |
|
7 |
+
def ask_llm_to_add_table_names(sql_query: str, llm) -> str:
|
8 |
+
"""Adds table names to the SQL query result rows using LLM.
|
9 |
+
|
10 |
+
This function modifies the SQL query to include the source table name in each row
|
11 |
+
of the result set, making it easier to track which data comes from which table.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
sql_query (str): The original SQL query to modify
|
15 |
+
llm: The language model instance to use for generating the modified query
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
str: The modified SQL query with table names included in the result rows
|
19 |
+
"""
|
20 |
sql_with_table_names = llm.invoke(f"Make the following sql query display the source table in the rows {sql_query}. Just answer the query. The answer should not include ```sql\n").content
|
21 |
return sql_with_table_names
|
22 |
|
23 |
+
def ask_llm_column_names(sql_query: str, llm) -> list[str]:
|
24 |
+
"""Extracts column names from a SQL query using LLM.
|
25 |
+
|
26 |
+
This function analyzes a SQL query to identify which columns are being selected
|
27 |
+
in the result set.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
sql_query (str): The SQL query to analyze
|
31 |
+
llm: The language model instance to use for column extraction
|
32 |
+
|
33 |
+
Returns:
|
34 |
+
list[str]: A list of column names being selected in the query
|
35 |
+
"""
|
36 |
columns = llm.invoke(f"From the given sql query, list the columns that are being selected. The answer should only be a python list. Just answer the list. The SQL query : {sql_query}").content
|
37 |
columns_list = ast.literal_eval(columns.strip("```python\n").strip())
|
38 |
return columns_list
|
39 |
|
40 |
+
async def ask_drias(query: str, index_state: int = 0) -> tuple:
|
41 |
+
"""Main function to process a DRIAS query and return results.
|
42 |
+
|
43 |
+
This function orchestrates the DRIAS workflow, processing a user query to generate
|
44 |
+
SQL queries, dataframes, and visualizations. It handles multiple results and allows
|
45 |
+
pagination through them.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
query (str): The user's question about climate data
|
49 |
+
index_state (int, optional): The index of the result to return. Defaults to 0.
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
tuple: A tuple containing:
|
53 |
+
- sql_query (str): The SQL query used
|
54 |
+
- dataframe (pd.DataFrame): The resulting data
|
55 |
+
- figure (Callable): Function to generate the visualization
|
56 |
+
- sql_queries (list): All generated SQL queries
|
57 |
+
- result_dataframes (list): All resulting dataframes
|
58 |
+
- figures (list): All figure generation functions
|
59 |
+
- index_state (int): Current result index
|
60 |
+
- table_list (list): List of table names used
|
61 |
+
- error (str): Error message if any
|
62 |
+
"""
|
63 |
+
final_state = await drias_workflow(query)
|
64 |
+
sql_queries = []
|
65 |
+
result_dataframes = []
|
66 |
+
figures = []
|
67 |
+
table_list = []
|
68 |
+
|
69 |
+
for plot_state in final_state['plot_states'].values():
|
70 |
+
for table_state in plot_state['table_states'].values():
|
71 |
+
if table_state['status'] == 'OK':
|
72 |
+
if 'table_name' in table_state:
|
73 |
+
table_list.append(' '.join(table_state['table_name'].capitalize().split('_')))
|
74 |
+
if 'sql_query' in table_state and table_state['sql_query'] is not None:
|
75 |
+
sql_queries.append(table_state['sql_query'])
|
76 |
+
|
77 |
+
if 'dataframe' in table_state and table_state['dataframe'] is not None:
|
78 |
+
result_dataframes.append(table_state['dataframe'])
|
79 |
+
if 'figure' in table_state and table_state['figure'] is not None:
|
80 |
+
figures.append(table_state['figure'])
|
81 |
+
|
82 |
+
if "error" in final_state and final_state["error"] != "":
|
83 |
+
return None, None, None, [], [], [], 0, final_state["error"]
|
84 |
+
|
85 |
+
sql_query = sql_queries[index_state]
|
86 |
+
dataframe = result_dataframes[index_state]
|
87 |
+
figure = figures[index_state](dataframe)
|
88 |
|
89 |
+
return sql_query, dataframe, figure, sql_queries, result_dataframes, figures, index_state, table_list, ""
|
|
|
|
|
90 |
|
91 |
+
# def ask_vanna(vn,db_vanna_path, query):
|
92 |
+
|
93 |
+
# try :
|
94 |
+
# location = detect_location_with_openai(query)
|
95 |
+
# if location:
|
96 |
+
|
97 |
+
# coords = loc2coords(location)
|
98 |
+
# user_input = query.lower().replace(location.lower(), f"lat, long : {coords}")
|
99 |
|
100 |
+
# relevant_tables = detect_relevant_tables(db_vanna_path, user_input, llm)
|
101 |
+
# coords_tables = [nearestNeighbourSQL(db_vanna_path, coords, relevant_tables[i]) for i in range(len(relevant_tables))]
|
102 |
+
# user_input_with_coords = replace_coordonates(coords, user_input, coords_tables)
|
103 |
+
|
104 |
+
# sql_query, result_dataframe, figure = vn.ask(user_input_with_coords, print_results=False, allow_llm_to_see_data=True, auto_train=False)
|
105 |
+
|
106 |
+
# return sql_query, result_dataframe, figure
|
107 |
+
# else :
|
108 |
+
# empty_df = pd.DataFrame()
|
109 |
+
# empty_fig = None
|
110 |
+
# return "", empty_df, empty_fig
|
111 |
+
# except Exception as e:
|
112 |
+
# print(f"Error: {e}")
|
113 |
+
# empty_df = pd.DataFrame()
|
114 |
+
# empty_fig = None
|
115 |
+
# return "", empty_df, empty_fig
|
|
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, TypedDict
|
2 |
+
from matplotlib.figure import figaspect
|
3 |
+
import pandas as pd
|
4 |
+
from plotly.graph_objects import Figure
|
5 |
+
import plotly.graph_objects as go
|
6 |
+
import plotly.express as px
|
7 |
+
|
8 |
+
from climateqa.engine.talk_to_data.sql_query import (
|
9 |
+
indicator_for_given_year_query,
|
10 |
+
indicator_per_year_at_location_query,
|
11 |
+
)
|
12 |
+
from climateqa.engine.talk_to_data.config import INDICATOR_TO_UNIT
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
class Plot(TypedDict):
|
18 |
+
"""Represents a plot configuration in the DRIAS system.
|
19 |
+
|
20 |
+
This class defines the structure for configuring different types of plots
|
21 |
+
that can be generated from climate data.
|
22 |
+
|
23 |
+
Attributes:
|
24 |
+
name (str): The name of the plot type
|
25 |
+
description (str): A description of what the plot shows
|
26 |
+
params (list[str]): List of required parameters for the plot
|
27 |
+
plot_function (Callable[..., Callable[..., Figure]]): Function to generate the plot
|
28 |
+
sql_query (Callable[..., str]): Function to generate the SQL query for the plot
|
29 |
+
"""
|
30 |
+
name: str
|
31 |
+
description: str
|
32 |
+
params: list[str]
|
33 |
+
plot_function: Callable[..., Callable[..., Figure]]
|
34 |
+
sql_query: Callable[..., str]
|
35 |
+
|
36 |
+
|
37 |
+
def plot_indicator_evolution_at_location(params: dict) -> Callable[..., Figure]:
|
38 |
+
"""Generates a function to plot indicator evolution over time at a location.
|
39 |
+
|
40 |
+
This function creates a line plot showing how a climate indicator changes
|
41 |
+
over time at a specific location. It handles temperature, precipitation,
|
42 |
+
and other climate indicators.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
params (dict): Dictionary containing:
|
46 |
+
- indicator_column (str): The column name for the indicator
|
47 |
+
- location (str): The location to plot
|
48 |
+
- model (str): The climate model to use
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
|
52 |
+
|
53 |
+
Example:
|
54 |
+
>>> plot_func = plot_indicator_evolution_at_location({
|
55 |
+
... 'indicator_column': 'mean_temperature',
|
56 |
+
... 'location': 'Paris',
|
57 |
+
... 'model': 'ALL'
|
58 |
+
... })
|
59 |
+
>>> fig = plot_func(df)
|
60 |
+
"""
|
61 |
+
indicator = params["indicator_column"]
|
62 |
+
location = params["location"]
|
63 |
+
indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
|
64 |
+
unit = INDICATOR_TO_UNIT.get(indicator, "")
|
65 |
+
|
66 |
+
def plot_data(df: pd.DataFrame) -> Figure:
|
67 |
+
"""Generates the actual plot from the data.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
df (pd.DataFrame): DataFrame containing the data to plot
|
71 |
+
|
72 |
+
Returns:
|
73 |
+
Figure: A plotly Figure object showing the indicator evolution
|
74 |
+
"""
|
75 |
+
fig = go.Figure()
|
76 |
+
if df['model'].nunique() != 1:
|
77 |
+
df_avg = df.groupby("year", as_index=False)[indicator].mean()
|
78 |
+
|
79 |
+
# Transform to list to avoid pandas encoding
|
80 |
+
indicators = df_avg[indicator].astype(float).tolist()
|
81 |
+
years = df_avg["year"].astype(int).tolist()
|
82 |
+
|
83 |
+
# Compute the 10-year rolling average
|
84 |
+
sliding_averages = (
|
85 |
+
df_avg[indicator]
|
86 |
+
.rolling(window=10, min_periods=1)
|
87 |
+
.mean()
|
88 |
+
.astype(float)
|
89 |
+
.tolist()
|
90 |
+
)
|
91 |
+
model_label = "Model Average"
|
92 |
+
|
93 |
+
else:
|
94 |
+
df_model = df
|
95 |
+
|
96 |
+
# Transform to list to avoid pandas encoding
|
97 |
+
indicators = df_model[indicator].astype(float).tolist()
|
98 |
+
years = df_model["year"].astype(int).tolist()
|
99 |
+
|
100 |
+
# Compute the 10-year rolling average
|
101 |
+
sliding_averages = (
|
102 |
+
df_model[indicator]
|
103 |
+
.rolling(window=10, min_periods=1)
|
104 |
+
.mean()
|
105 |
+
.astype(float)
|
106 |
+
.tolist()
|
107 |
+
)
|
108 |
+
model_label = f"Model : {df['model'].unique()[0]}"
|
109 |
+
|
110 |
+
|
111 |
+
# Indicator per year plot
|
112 |
+
fig.add_scatter(
|
113 |
+
x=years,
|
114 |
+
y=indicators,
|
115 |
+
name=f"Yearly {indicator_label}",
|
116 |
+
mode="lines",
|
117 |
+
marker=dict(color="#1f77b4"),
|
118 |
+
hovertemplate=f"{indicator_label}: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
|
119 |
+
)
|
120 |
+
|
121 |
+
# Sliding average dashed line
|
122 |
+
fig.add_scatter(
|
123 |
+
x=years,
|
124 |
+
y=sliding_averages,
|
125 |
+
mode="lines",
|
126 |
+
name="10 years rolling average",
|
127 |
+
line=dict(dash="dash"),
|
128 |
+
marker=dict(color="#d62728"),
|
129 |
+
hovertemplate=f"10-year average: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
|
130 |
+
)
|
131 |
+
fig.update_layout(
|
132 |
+
title=f"Plot of {indicator_label} in {location} ({model_label})",
|
133 |
+
xaxis_title="Year",
|
134 |
+
yaxis_title=f"{indicator_label} ({unit})",
|
135 |
+
template="plotly_white",
|
136 |
+
)
|
137 |
+
return fig
|
138 |
+
|
139 |
+
return plot_data
|
140 |
+
|
141 |
+
|
142 |
+
indicator_evolution_at_location: Plot = {
|
143 |
+
"name": "Indicator evolution at location",
|
144 |
+
"description": "Plot an evolution of the indicator at a certain location",
|
145 |
+
"params": ["indicator_column", "location", "model"],
|
146 |
+
"plot_function": plot_indicator_evolution_at_location,
|
147 |
+
"sql_query": indicator_per_year_at_location_query,
|
148 |
+
}
|
149 |
+
|
150 |
+
|
151 |
+
def plot_indicator_number_of_days_per_year_at_location(
|
152 |
+
params: dict,
|
153 |
+
) -> Callable[..., Figure]:
|
154 |
+
"""Generates a function to plot the number of days per year for an indicator.
|
155 |
+
|
156 |
+
This function creates a bar chart showing the frequency of certain climate
|
157 |
+
events (like days above a temperature threshold) per year at a specific location.
|
158 |
+
|
159 |
+
Args:
|
160 |
+
params (dict): Dictionary containing:
|
161 |
+
- indicator_column (str): The column name for the indicator
|
162 |
+
- location (str): The location to plot
|
163 |
+
- model (str): The climate model to use
|
164 |
+
|
165 |
+
Returns:
|
166 |
+
Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
|
167 |
+
"""
|
168 |
+
indicator = params["indicator_column"]
|
169 |
+
location = params["location"]
|
170 |
+
indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
|
171 |
+
unit = INDICATOR_TO_UNIT.get(indicator, "")
|
172 |
+
|
173 |
+
def plot_data(df: pd.DataFrame) -> Figure:
|
174 |
+
"""Generate the figure thanks to the dataframe
|
175 |
+
|
176 |
+
Args:
|
177 |
+
df (pd.DataFrame): pandas dataframe with the required data
|
178 |
+
|
179 |
+
Returns:
|
180 |
+
Figure: Plotly figure
|
181 |
+
"""
|
182 |
+
fig = go.Figure()
|
183 |
+
if df['model'].nunique() != 1:
|
184 |
+
df_avg = df.groupby("year", as_index=False)[indicator].mean()
|
185 |
+
|
186 |
+
# Transform to list to avoid pandas encoding
|
187 |
+
indicators = df_avg[indicator].astype(float).tolist()
|
188 |
+
years = df_avg["year"].astype(int).tolist()
|
189 |
+
model_label = "Model Average"
|
190 |
+
|
191 |
+
else:
|
192 |
+
df_model = df
|
193 |
+
# Transform to list to avoid pandas encoding
|
194 |
+
indicators = df_model[indicator].astype(float).tolist()
|
195 |
+
years = df_model["year"].astype(int).tolist()
|
196 |
+
model_label = f"Model : {df['model'].unique()[0]}"
|
197 |
+
|
198 |
+
|
199 |
+
# Bar plot
|
200 |
+
fig.add_trace(
|
201 |
+
go.Bar(
|
202 |
+
x=years,
|
203 |
+
y=indicators,
|
204 |
+
width=0.5,
|
205 |
+
marker=dict(color="#1f77b4"),
|
206 |
+
hovertemplate=f"{indicator_label}: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
|
207 |
+
)
|
208 |
+
)
|
209 |
+
|
210 |
+
fig.update_layout(
|
211 |
+
title=f"{indicator_label} in {location} ({model_label})",
|
212 |
+
xaxis_title="Year",
|
213 |
+
yaxis_title=f"{indicator_label} ({unit})",
|
214 |
+
yaxis=dict(range=[0, max(indicators)]),
|
215 |
+
bargap=0.5,
|
216 |
+
template="plotly_white",
|
217 |
+
)
|
218 |
+
|
219 |
+
return fig
|
220 |
+
|
221 |
+
return plot_data
|
222 |
+
|
223 |
+
|
224 |
+
indicator_number_of_days_per_year_at_location: Plot = {
|
225 |
+
"name": "Indicator number of days per year at location",
|
226 |
+
"description": "Plot a barchart of the number of days per year of a certain indicator at a certain location. It is appropriate for frequency indicator.",
|
227 |
+
"params": ["indicator_column", "location", "model"],
|
228 |
+
"plot_function": plot_indicator_number_of_days_per_year_at_location,
|
229 |
+
"sql_query": indicator_per_year_at_location_query,
|
230 |
+
}
|
231 |
+
|
232 |
+
|
233 |
+
def plot_distribution_of_indicator_for_given_year(
|
234 |
+
params: dict,
|
235 |
+
) -> Callable[..., Figure]:
|
236 |
+
"""Generates a function to plot the distribution of an indicator for a year.
|
237 |
+
|
238 |
+
This function creates a histogram showing the distribution of a climate
|
239 |
+
indicator across different locations for a specific year.
|
240 |
+
|
241 |
+
Args:
|
242 |
+
params (dict): Dictionary containing:
|
243 |
+
- indicator_column (str): The column name for the indicator
|
244 |
+
- year (str): The year to plot
|
245 |
+
- model (str): The climate model to use
|
246 |
+
|
247 |
+
Returns:
|
248 |
+
Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
|
249 |
+
"""
|
250 |
+
indicator = params["indicator_column"]
|
251 |
+
year = params["year"]
|
252 |
+
indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
|
253 |
+
unit = INDICATOR_TO_UNIT.get(indicator, "")
|
254 |
+
|
255 |
+
def plot_data(df: pd.DataFrame) -> Figure:
|
256 |
+
"""Generate the figure thanks to the dataframe
|
257 |
+
|
258 |
+
Args:
|
259 |
+
df (pd.DataFrame): pandas dataframe with the required data
|
260 |
+
|
261 |
+
Returns:
|
262 |
+
Figure: Plotly figure
|
263 |
+
"""
|
264 |
+
fig = go.Figure()
|
265 |
+
if df['model'].nunique() != 1:
|
266 |
+
df_avg = df.groupby(["latitude", "longitude"], as_index=False)[
|
267 |
+
indicator
|
268 |
+
].mean()
|
269 |
+
|
270 |
+
# Transform to list to avoid pandas encoding
|
271 |
+
indicators = df_avg[indicator].astype(float).tolist()
|
272 |
+
model_label = "Model Average"
|
273 |
+
|
274 |
+
else:
|
275 |
+
df_model = df
|
276 |
+
|
277 |
+
# Transform to list to avoid pandas encoding
|
278 |
+
indicators = df_model[indicator].astype(float).tolist()
|
279 |
+
model_label = f"Model : {df['model'].unique()[0]}"
|
280 |
+
|
281 |
+
|
282 |
+
fig.add_trace(
|
283 |
+
go.Histogram(
|
284 |
+
x=indicators,
|
285 |
+
opacity=0.8,
|
286 |
+
histnorm="percent",
|
287 |
+
marker=dict(color="#1f77b4"),
|
288 |
+
hovertemplate=f"{indicator_label}: %{{x:.2f}} {unit}<br>Frequency: %{{y:.2f}}%<extra></extra>"
|
289 |
+
)
|
290 |
+
)
|
291 |
+
|
292 |
+
fig.update_layout(
|
293 |
+
title=f"Distribution of {indicator_label} in {year} ({model_label})",
|
294 |
+
xaxis_title=f"{indicator_label} ({unit})",
|
295 |
+
yaxis_title="Frequency (%)",
|
296 |
+
plot_bgcolor="rgba(0, 0, 0, 0)",
|
297 |
+
showlegend=False,
|
298 |
+
)
|
299 |
+
|
300 |
+
return fig
|
301 |
+
|
302 |
+
return plot_data
|
303 |
+
|
304 |
+
|
305 |
+
distribution_of_indicator_for_given_year: Plot = {
|
306 |
+
"name": "Distribution of an indicator for a given year",
|
307 |
+
"description": "Plot an histogram of the distribution for a given year of the values of an indicator",
|
308 |
+
"params": ["indicator_column", "model", "year"],
|
309 |
+
"plot_function": plot_distribution_of_indicator_for_given_year,
|
310 |
+
"sql_query": indicator_for_given_year_query,
|
311 |
+
}
|
312 |
+
|
313 |
+
|
314 |
+
def plot_map_of_france_of_indicator_for_given_year(
|
315 |
+
params: dict,
|
316 |
+
) -> Callable[..., Figure]:
|
317 |
+
"""Generates a function to plot a map of France for an indicator.
|
318 |
+
|
319 |
+
This function creates a choropleth map of France showing the spatial
|
320 |
+
distribution of a climate indicator for a specific year.
|
321 |
+
|
322 |
+
Args:
|
323 |
+
params (dict): Dictionary containing:
|
324 |
+
- indicator_column (str): The column name for the indicator
|
325 |
+
- year (str): The year to plot
|
326 |
+
- model (str): The climate model to use
|
327 |
+
|
328 |
+
Returns:
|
329 |
+
Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
|
330 |
+
"""
|
331 |
+
indicator = params["indicator_column"]
|
332 |
+
year = params["year"]
|
333 |
+
indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
|
334 |
+
unit = INDICATOR_TO_UNIT.get(indicator, "")
|
335 |
+
|
336 |
+
def plot_data(df: pd.DataFrame) -> Figure:
|
337 |
+
fig = go.Figure()
|
338 |
+
if df['model'].nunique() != 1:
|
339 |
+
df_avg = df.groupby(["latitude", "longitude"], as_index=False)[
|
340 |
+
indicator
|
341 |
+
].mean()
|
342 |
+
|
343 |
+
indicators = df_avg[indicator].astype(float).tolist()
|
344 |
+
latitudes = df_avg["latitude"].astype(float).tolist()
|
345 |
+
longitudes = df_avg["longitude"].astype(float).tolist()
|
346 |
+
model_label = "Model Average"
|
347 |
+
|
348 |
+
else:
|
349 |
+
df_model = df
|
350 |
+
|
351 |
+
# Transform to list to avoid pandas encoding
|
352 |
+
indicators = df_model[indicator].astype(float).tolist()
|
353 |
+
latitudes = df_model["latitude"].astype(float).tolist()
|
354 |
+
longitudes = df_model["longitude"].astype(float).tolist()
|
355 |
+
model_label = f"Model : {df['model'].unique()[0]}"
|
356 |
+
|
357 |
+
|
358 |
+
fig.add_trace(
|
359 |
+
go.Scattermapbox(
|
360 |
+
lat=latitudes,
|
361 |
+
lon=longitudes,
|
362 |
+
mode="markers",
|
363 |
+
marker=dict(
|
364 |
+
size=10,
|
365 |
+
color=indicators, # Color mapped to values
|
366 |
+
colorscale="Turbo", # Color scale (can be 'Plasma', 'Jet', etc.)
|
367 |
+
cmin=min(indicators), # Minimum color range
|
368 |
+
cmax=max(indicators), # Maximum color range
|
369 |
+
showscale=True, # Show colorbar
|
370 |
+
),
|
371 |
+
text=[f"{indicator_label}: {value:.2f} {unit}" for value in indicators], # Add hover text showing the indicator value
|
372 |
+
hoverinfo="text" # Only show the custom text on hover
|
373 |
+
)
|
374 |
+
)
|
375 |
+
|
376 |
+
fig.update_layout(
|
377 |
+
mapbox_style="open-street-map", # Use OpenStreetMap
|
378 |
+
mapbox_zoom=3,
|
379 |
+
mapbox_center={"lat": 46.6, "lon": 2.0},
|
380 |
+
coloraxis_colorbar=dict(title=f"{indicator_label} ({unit})"), # Add legend
|
381 |
+
title=f"{indicator_label} in {year} in France ({model_label}) " # Title
|
382 |
+
)
|
383 |
+
return fig
|
384 |
+
|
385 |
+
return plot_data
|
386 |
+
|
387 |
+
|
388 |
+
map_of_france_of_indicator_for_given_year: Plot = {
|
389 |
+
"name": "Map of France of an indicator for a given year",
|
390 |
+
"description": "Heatmap on the map of France of the values of an in indicator for a given year",
|
391 |
+
"params": ["indicator_column", "year", "model"],
|
392 |
+
"plot_function": plot_map_of_france_of_indicator_for_given_year,
|
393 |
+
"sql_query": indicator_for_given_year_query,
|
394 |
+
}
|
395 |
+
|
396 |
+
|
397 |
+
PLOTS = [
|
398 |
+
indicator_evolution_at_location,
|
399 |
+
indicator_number_of_days_per_year_at_location,
|
400 |
+
distribution_of_indicator_for_given_year,
|
401 |
+
map_of_france_of_indicator_for_given_year,
|
402 |
+
]
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
from concurrent.futures import ThreadPoolExecutor
|
3 |
+
from typing import TypedDict
|
4 |
+
import duckdb
|
5 |
+
import pandas as pd
|
6 |
+
|
7 |
+
async def execute_sql_query(sql_query: str) -> pd.DataFrame:
|
8 |
+
"""Executes a SQL query on the DRIAS database and returns the results.
|
9 |
+
|
10 |
+
This function connects to the DuckDB database containing DRIAS climate data
|
11 |
+
and executes the provided SQL query. It handles the database connection and
|
12 |
+
returns the results as a pandas DataFrame.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
sql_query (str): The SQL query to execute
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
pd.DataFrame: A DataFrame containing the query results
|
19 |
+
|
20 |
+
Raises:
|
21 |
+
duckdb.Error: If there is an error executing the SQL query
|
22 |
+
"""
|
23 |
+
def _execute_query():
|
24 |
+
# Execute the query
|
25 |
+
results = duckdb.sql(sql_query)
|
26 |
+
# return fetched data
|
27 |
+
return results.fetchdf()
|
28 |
+
|
29 |
+
# Run the query in a thread pool to avoid blocking
|
30 |
+
loop = asyncio.get_event_loop()
|
31 |
+
with ThreadPoolExecutor() as executor:
|
32 |
+
return await loop.run_in_executor(executor, _execute_query)
|
33 |
+
|
34 |
+
|
35 |
+
class IndicatorPerYearAtLocationQueryParams(TypedDict, total=False):
|
36 |
+
"""Parameters for querying an indicator's values over time at a location.
|
37 |
+
|
38 |
+
This class defines the parameters needed to query climate indicator data
|
39 |
+
for a specific location over multiple years.
|
40 |
+
|
41 |
+
Attributes:
|
42 |
+
indicator_column (str): The column name for the climate indicator
|
43 |
+
latitude (str): The latitude coordinate of the location
|
44 |
+
longitude (str): The longitude coordinate of the location
|
45 |
+
model (str): The climate model to use (optional)
|
46 |
+
"""
|
47 |
+
indicator_column: str
|
48 |
+
latitude: str
|
49 |
+
longitude: str
|
50 |
+
model: str
|
51 |
+
|
52 |
+
|
53 |
+
def indicator_per_year_at_location_query(
|
54 |
+
table: str, params: IndicatorPerYearAtLocationQueryParams
|
55 |
+
) -> str:
|
56 |
+
"""SQL Query to get the evolution of an indicator per year at a certain location
|
57 |
+
|
58 |
+
Args:
|
59 |
+
table (str): sql table of the indicator
|
60 |
+
params (IndicatorPerYearAtLocationQueryParams) : dictionary with the required params for the query
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
str: the sql query
|
64 |
+
"""
|
65 |
+
indicator_column = params.get("indicator_column")
|
66 |
+
latitude = params.get("latitude")
|
67 |
+
longitude = params.get("longitude")
|
68 |
+
|
69 |
+
if indicator_column is None or latitude is None or longitude is None: # If one parameter is missing, returns an empty query
|
70 |
+
return ""
|
71 |
+
|
72 |
+
table = f"'hf://datasets/timeki/drias_db/{table.lower()}.parquet'"
|
73 |
+
|
74 |
+
sql_query = f"SELECT year, {indicator_column}, model\nFROM {table}\nWHERE latitude = {latitude} \nAnd longitude = {longitude} \nOrder by Year"
|
75 |
+
|
76 |
+
return sql_query
|
77 |
+
|
78 |
+
class IndicatorForGivenYearQueryParams(TypedDict, total=False):
|
79 |
+
"""Parameters for querying an indicator's values across locations for a year.
|
80 |
+
|
81 |
+
This class defines the parameters needed to query climate indicator data
|
82 |
+
across different locations for a specific year.
|
83 |
+
|
84 |
+
Attributes:
|
85 |
+
indicator_column (str): The column name for the climate indicator
|
86 |
+
year (str): The year to query
|
87 |
+
model (str): The climate model to use (optional)
|
88 |
+
"""
|
89 |
+
indicator_column: str
|
90 |
+
year: str
|
91 |
+
model: str
|
92 |
+
|
93 |
+
def indicator_for_given_year_query(
|
94 |
+
table:str, params: IndicatorForGivenYearQueryParams
|
95 |
+
) -> str:
|
96 |
+
"""SQL Query to get the values of an indicator with their latitudes, longitudes and models for a given year
|
97 |
+
|
98 |
+
Args:
|
99 |
+
table (str): sql table of the indicator
|
100 |
+
params (IndicatorForGivenYearQueryParams): dictionarry with the required params for the query
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
str: the sql query
|
104 |
+
"""
|
105 |
+
indicator_column = params.get("indicator_column")
|
106 |
+
year = params.get('year')
|
107 |
+
if year is None or indicator_column is None: # If one parameter is missing, returns an empty query
|
108 |
+
return ""
|
109 |
+
|
110 |
+
table = f"'hf://datasets/timeki/drias_db/{table.lower()}.parquet'"
|
111 |
+
|
112 |
+
sql_query = f"Select {indicator_column}, latitude, longitude, model\nFrom {table}\nWhere year = {year}"
|
113 |
+
return sql_query
|
@@ -1,12 +1,15 @@
|
|
1 |
import re
|
2 |
-
import
|
3 |
-
import
|
4 |
from geopy.geocoders import Nominatim
|
5 |
-
import sqlite3
|
6 |
import ast
|
7 |
from climateqa.engine.llm import get_llm
|
|
|
|
|
|
|
8 |
|
9 |
-
|
|
|
10 |
"""
|
11 |
Detects locations in a sentence using OpenAI's API via LangChain.
|
12 |
"""
|
@@ -19,74 +22,260 @@ def detect_location_with_openai(sentence):
|
|
19 |
Sentence: "{sentence}"
|
20 |
"""
|
21 |
|
22 |
-
response = llm.
|
23 |
location_list = ast.literal_eval(response.content.strip("```python\n").strip())
|
24 |
if location_list:
|
25 |
return location_list[0]
|
26 |
else:
|
27 |
return ""
|
28 |
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
pattern = r'(?i)\bFROM\s+((?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+)(?:\.(?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+))*)'
|
31 |
matches = re.findall(pattern, sql_query)
|
32 |
return matches
|
33 |
|
34 |
|
35 |
-
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
geolocator = Nominatim(user_agent="city_to_latlong")
|
38 |
-
|
39 |
-
return (
|
40 |
|
41 |
|
42 |
-
def coords2loc(coords
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
geolocator = Nominatim(user_agent="coords_to_city")
|
44 |
try:
|
45 |
location = geolocator.reverse(coords)
|
46 |
return location.address
|
47 |
except Exception as e:
|
48 |
print(f"Error: {e}")
|
49 |
-
return "Unknown Location"
|
50 |
|
51 |
|
52 |
-
def nearestNeighbourSQL(
|
53 |
-
conn = sqlite3.connect(db)
|
54 |
long = round(location[1], 3)
|
55 |
lat = round(location[0], 3)
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
"
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
prompt = (
|
79 |
-
f"You are helping to build a
|
80 |
-
f"
|
81 |
-
f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
)
|
83 |
-
table_names = ast.literal_eval(llm.invoke(prompt).content.strip("```python\n").strip())
|
84 |
return table_names
|
85 |
|
|
|
86 |
def replace_coordonates(coords, query, coords_tables):
|
87 |
n = query.count(str(coords[0]))
|
88 |
|
89 |
for i in range(n):
|
90 |
-
query = query.replace(str(coords[0]), str(coords_tables[i][0]),1)
|
91 |
-
query = query.replace(str(coords[1]), str(coords_tables[i][1]),1)
|
92 |
-
return query
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import re
|
2 |
+
from typing import Annotated, TypedDict
|
3 |
+
import duckdb
|
4 |
from geopy.geocoders import Nominatim
|
|
|
5 |
import ast
|
6 |
from climateqa.engine.llm import get_llm
|
7 |
+
from climateqa.engine.talk_to_data.config import DRIAS_TABLES
|
8 |
+
from climateqa.engine.talk_to_data.plot import PLOTS, Plot
|
9 |
+
from langchain_core.prompts import ChatPromptTemplate
|
10 |
|
11 |
+
|
12 |
+
async def detect_location_with_openai(sentence):
|
13 |
"""
|
14 |
Detects locations in a sentence using OpenAI's API via LangChain.
|
15 |
"""
|
|
|
22 |
Sentence: "{sentence}"
|
23 |
"""
|
24 |
|
25 |
+
response = await llm.ainvoke(prompt)
|
26 |
location_list = ast.literal_eval(response.content.strip("```python\n").strip())
|
27 |
if location_list:
|
28 |
return location_list[0]
|
29 |
else:
|
30 |
return ""
|
31 |
|
32 |
+
class ArrayOutput(TypedDict):
|
33 |
+
"""Represents the output of a function that returns an array.
|
34 |
+
|
35 |
+
This class is used to type-hint functions that return arrays,
|
36 |
+
ensuring consistent return types across the codebase.
|
37 |
+
|
38 |
+
Attributes:
|
39 |
+
array (str): A syntactically valid Python array string
|
40 |
+
"""
|
41 |
+
array: Annotated[str, "Syntactically valid python array."]
|
42 |
+
|
43 |
+
async def detect_year_with_openai(sentence: str) -> str:
|
44 |
+
"""
|
45 |
+
Detects years in a sentence using OpenAI's API via LangChain.
|
46 |
+
"""
|
47 |
+
llm = get_llm()
|
48 |
+
|
49 |
+
prompt = """
|
50 |
+
Extract all years mentioned in the following sentence.
|
51 |
+
Return the result as a Python list. If no year are mentioned, return an empty list.
|
52 |
+
|
53 |
+
Sentence: "{sentence}"
|
54 |
+
"""
|
55 |
+
|
56 |
+
prompt = ChatPromptTemplate.from_template(prompt)
|
57 |
+
structured_llm = llm.with_structured_output(ArrayOutput)
|
58 |
+
chain = prompt | structured_llm
|
59 |
+
response: ArrayOutput = await chain.ainvoke({"sentence": sentence})
|
60 |
+
years_list = eval(response['array'])
|
61 |
+
if len(years_list) > 0:
|
62 |
+
return years_list[0]
|
63 |
+
else:
|
64 |
+
return ""
|
65 |
+
|
66 |
+
|
67 |
+
def detectTable(sql_query: str) -> list[str]:
|
68 |
+
"""Extracts table names from a SQL query.
|
69 |
+
|
70 |
+
This function uses regular expressions to find all table names
|
71 |
+
referenced in a SQL query's FROM clause.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
sql_query (str): The SQL query to analyze
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
list[str]: A list of table names found in the query
|
78 |
+
|
79 |
+
Example:
|
80 |
+
>>> detectTable("SELECT * FROM temperature_data WHERE year > 2000")
|
81 |
+
['temperature_data']
|
82 |
+
"""
|
83 |
pattern = r'(?i)\bFROM\s+((?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+)(?:\.(?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+))*)'
|
84 |
matches = re.findall(pattern, sql_query)
|
85 |
return matches
|
86 |
|
87 |
|
88 |
+
def loc2coords(location: str) -> tuple[float, float]:
|
89 |
+
"""Converts a location name to geographic coordinates.
|
90 |
+
|
91 |
+
This function uses the Nominatim geocoding service to convert
|
92 |
+
a location name (e.g., city name) to its latitude and longitude.
|
93 |
+
|
94 |
+
Args:
|
95 |
+
location (str): The name of the location to geocode
|
96 |
+
|
97 |
+
Returns:
|
98 |
+
tuple[float, float]: A tuple containing (latitude, longitude)
|
99 |
+
|
100 |
+
Raises:
|
101 |
+
AttributeError: If the location cannot be found
|
102 |
+
"""
|
103 |
geolocator = Nominatim(user_agent="city_to_latlong")
|
104 |
+
coords = geolocator.geocode(location)
|
105 |
+
return (coords.latitude, coords.longitude)
|
106 |
|
107 |
|
108 |
+
def coords2loc(coords: tuple[float, float]) -> str:
|
109 |
+
"""Converts geographic coordinates to a location name.
|
110 |
+
|
111 |
+
This function uses the Nominatim reverse geocoding service to convert
|
112 |
+
latitude and longitude coordinates to a human-readable location name.
|
113 |
+
|
114 |
+
Args:
|
115 |
+
coords (tuple[float, float]): A tuple containing (latitude, longitude)
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
str: The address of the location, or "Unknown Location" if not found
|
119 |
+
|
120 |
+
Example:
|
121 |
+
>>> coords2loc((48.8566, 2.3522))
|
122 |
+
'Paris, France'
|
123 |
+
"""
|
124 |
geolocator = Nominatim(user_agent="coords_to_city")
|
125 |
try:
|
126 |
location = geolocator.reverse(coords)
|
127 |
return location.address
|
128 |
except Exception as e:
|
129 |
print(f"Error: {e}")
|
130 |
+
return "Unknown Location"
|
131 |
|
132 |
|
133 |
+
def nearestNeighbourSQL(location: tuple, table: str) -> tuple[str, str]:
|
|
|
134 |
long = round(location[1], 3)
|
135 |
lat = round(location[0], 3)
|
136 |
+
|
137 |
+
table = f"'hf://datasets/timeki/drias_db/{table.lower()}.parquet'"
|
138 |
+
|
139 |
+
results = duckdb.sql(
|
140 |
+
f"SELECT latitude, longitude FROM {table} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}"
|
141 |
+
).fetchdf()
|
142 |
+
|
143 |
+
if len(results) == 0:
|
144 |
+
return "", ""
|
145 |
+
# cursor.execute(f"SELECT latitude, longitude FROM {table} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}")
|
146 |
+
return results['latitude'].iloc[0], results['longitude'].iloc[0]
|
147 |
+
|
148 |
+
|
149 |
+
async def detect_relevant_tables(user_question: str, plot: Plot, llm) -> list[str]:
|
150 |
+
"""Identifies relevant tables for a plot based on user input.
|
151 |
+
|
152 |
+
This function uses an LLM to analyze the user's question and the plot
|
153 |
+
description to determine which tables in the DRIAS database would be
|
154 |
+
most relevant for generating the requested visualization.
|
155 |
+
|
156 |
+
Args:
|
157 |
+
user_question (str): The user's question about climate data
|
158 |
+
plot (Plot): The plot configuration object
|
159 |
+
llm: The language model instance to use for analysis
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
list[str]: A list of table names that are relevant for the plot
|
163 |
+
|
164 |
+
Example:
|
165 |
+
>>> detect_relevant_tables(
|
166 |
+
... "What will the temperature be like in Paris?",
|
167 |
+
... indicator_evolution_at_location,
|
168 |
+
... llm
|
169 |
+
... )
|
170 |
+
['mean_annual_temperature', 'mean_summer_temperature']
|
171 |
+
"""
|
172 |
+
# Get all table names
|
173 |
+
table_names_list = DRIAS_TABLES
|
174 |
+
|
175 |
prompt = (
|
176 |
+
f"You are helping to build a plot following this description : {plot['description']}."
|
177 |
+
f"You are given a list of tables and a user question."
|
178 |
+
f"Based on the description of the plot, which table are appropriate for that kind of plot."
|
179 |
+
f"Write the 3 most relevant tables to use. Answer only a python list of table name."
|
180 |
+
f"### List of tables : {table_names_list}"
|
181 |
+
f"### User question : {user_question}"
|
182 |
+
f"### List of table name : "
|
183 |
+
)
|
184 |
+
|
185 |
+
table_names = ast.literal_eval(
|
186 |
+
(await llm.ainvoke(prompt)).content.strip("```python\n").strip()
|
187 |
)
|
|
|
188 |
return table_names
|
189 |
|
190 |
+
|
191 |
def replace_coordonates(coords, query, coords_tables):
|
192 |
n = query.count(str(coords[0]))
|
193 |
|
194 |
for i in range(n):
|
195 |
+
query = query.replace(str(coords[0]), str(coords_tables[i][0]), 1)
|
196 |
+
query = query.replace(str(coords[1]), str(coords_tables[i][1]), 1)
|
197 |
+
return query
|
198 |
+
|
199 |
+
|
200 |
+
async def detect_relevant_plots(user_question: str, llm):
|
201 |
+
plots_description = ""
|
202 |
+
for plot in PLOTS:
|
203 |
+
plots_description += "Name: " + plot["name"]
|
204 |
+
plots_description += " - Description: " + plot["description"] + "\n"
|
205 |
+
|
206 |
+
prompt = (
|
207 |
+
f"You are helping to answer a quesiton with insightful visualizations."
|
208 |
+
f"You are given an user question and a list of plots with their name and description."
|
209 |
+
f"Based on the descriptions of the plots, which plot is appropriate to answer to this question."
|
210 |
+
f"Write the most relevant tables to use. Answer only a python list of plot name."
|
211 |
+
f"### Descriptions of the plots : {plots_description}"
|
212 |
+
f"### User question : {user_question}"
|
213 |
+
f"### Name of the plot : "
|
214 |
+
)
|
215 |
+
# prompt = (
|
216 |
+
# f"You are helping to answer a question with insightful visualizations. "
|
217 |
+
# f"Given a list of plots with their name and description: "
|
218 |
+
# f"{plots_description} "
|
219 |
+
# f"The user question is: {user_question}. "
|
220 |
+
# f"Choose the most relevant plots to answer the question. "
|
221 |
+
# f"The answer must be a Python list with the names of the relevant plots, and nothing else. "
|
222 |
+
# f"Ensure the response is in the exact format: ['PlotName1', 'PlotName2']."
|
223 |
+
# )
|
224 |
+
|
225 |
+
plot_names = ast.literal_eval(
|
226 |
+
(await llm.ainvoke(prompt)).content.strip("```python\n").strip()
|
227 |
+
)
|
228 |
+
return plot_names
|
229 |
+
|
230 |
+
|
231 |
+
# Next Version
|
232 |
+
# class QueryOutput(TypedDict):
|
233 |
+
# """Generated SQL query."""
|
234 |
+
|
235 |
+
# query: Annotated[str, ..., "Syntactically valid SQL query."]
|
236 |
+
|
237 |
+
|
238 |
+
# class PlotlyCodeOutput(TypedDict):
|
239 |
+
# """Generated Plotly code"""
|
240 |
+
|
241 |
+
# code: Annotated[str, ..., "Synatically valid Plotly python code."]
|
242 |
+
# def write_sql_query(user_input: str, db: SQLDatabase, relevant_tables: list[str], llm):
|
243 |
+
# """Generate SQL query to fetch information."""
|
244 |
+
# prompt_params = {
|
245 |
+
# "dialect": db.dialect,
|
246 |
+
# "table_info": db.get_table_info(),
|
247 |
+
# "input": user_input,
|
248 |
+
# "relevant_tables": relevant_tables,
|
249 |
+
# "model": "ALADIN63_CNRM-CM5",
|
250 |
+
# }
|
251 |
+
|
252 |
+
# prompt = ChatPromptTemplate.from_template(query_prompt_template)
|
253 |
+
# structured_llm = llm.with_structured_output(QueryOutput)
|
254 |
+
# chain = prompt | structured_llm
|
255 |
+
# result = chain.invoke(prompt_params)
|
256 |
+
|
257 |
+
# return result["query"]
|
258 |
+
|
259 |
+
|
260 |
+
# def fetch_data_from_sql_query(db: str, sql_query: str):
|
261 |
+
# conn = sqlite3.connect(db)
|
262 |
+
# cursor = conn.cursor()
|
263 |
+
# cursor.execute(sql_query)
|
264 |
+
# column_names = [desc[0] for desc in cursor.description]
|
265 |
+
# values = cursor.fetchall()
|
266 |
+
# return {"column_names": column_names, "data": values}
|
267 |
+
|
268 |
+
|
269 |
+
# def generate_chart_code(user_input: str, sql_query: list[str], llm):
|
270 |
+
# """ "Generate plotly python code for the chart based on the sql query and the user question"""
|
271 |
+
|
272 |
+
# class PlotlyCodeOutput(TypedDict):
|
273 |
+
# """Generated Plotly code"""
|
274 |
+
|
275 |
+
# code: Annotated[str, ..., "Synatically valid Plotly python code."]
|
276 |
+
|
277 |
+
# prompt = ChatPromptTemplate.from_template(plot_prompt_template)
|
278 |
+
# structured_llm = llm.with_structured_output(PlotlyCodeOutput)
|
279 |
+
# chain = prompt | structured_llm
|
280 |
+
# result = chain.invoke({"input": user_input, "sql_query": sql_query})
|
281 |
+
# return result["code"]
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from typing import Any, Callable, NotRequired, TypedDict
|
4 |
+
import pandas as pd
|
5 |
+
|
6 |
+
from plotly.graph_objects import Figure
|
7 |
+
from climateqa.engine.llm import get_llm
|
8 |
+
from climateqa.engine.talk_to_data.config import INDICATOR_COLUMNS_PER_TABLE
|
9 |
+
from climateqa.engine.talk_to_data.plot import PLOTS, Plot
|
10 |
+
from climateqa.engine.talk_to_data.sql_query import execute_sql_query
|
11 |
+
from climateqa.engine.talk_to_data.utils import (
|
12 |
+
detect_relevant_plots,
|
13 |
+
detect_year_with_openai,
|
14 |
+
loc2coords,
|
15 |
+
detect_location_with_openai,
|
16 |
+
nearestNeighbourSQL,
|
17 |
+
detect_relevant_tables,
|
18 |
+
)
|
19 |
+
|
20 |
+
ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))
|
21 |
+
|
22 |
+
class TableState(TypedDict):
|
23 |
+
"""Represents the state of a table in the DRIAS workflow.
|
24 |
+
|
25 |
+
This class defines the structure for tracking the state of a table during the
|
26 |
+
data processing workflow, including its name, parameters, SQL query, and results.
|
27 |
+
|
28 |
+
Attributes:
|
29 |
+
table_name (str): The name of the table in the database
|
30 |
+
params (dict[str, Any]): Parameters used for querying the table
|
31 |
+
sql_query (str, optional): The SQL query used to fetch data
|
32 |
+
dataframe (pd.DataFrame | None, optional): The resulting data
|
33 |
+
figure (Callable[..., Figure], optional): Function to generate visualization
|
34 |
+
status (str): The current status of the table processing ('OK' or 'ERROR')
|
35 |
+
"""
|
36 |
+
table_name: str
|
37 |
+
params: dict[str, Any]
|
38 |
+
sql_query: NotRequired[str]
|
39 |
+
dataframe: NotRequired[pd.DataFrame | None]
|
40 |
+
figure: NotRequired[Callable[..., Figure]]
|
41 |
+
status: str
|
42 |
+
|
43 |
+
class PlotState(TypedDict):
|
44 |
+
"""Represents the state of a plot in the DRIAS workflow.
|
45 |
+
|
46 |
+
This class defines the structure for tracking the state of a plot during the
|
47 |
+
data processing workflow, including its name and associated tables.
|
48 |
+
|
49 |
+
Attributes:
|
50 |
+
plot_name (str): The name of the plot
|
51 |
+
tables (list[str]): List of tables used in the plot
|
52 |
+
table_states (dict[str, TableState]): States of the tables used in the plot
|
53 |
+
"""
|
54 |
+
plot_name: str
|
55 |
+
tables: list[str]
|
56 |
+
table_states: dict[str, TableState]
|
57 |
+
|
58 |
+
class State(TypedDict):
|
59 |
+
user_input: str
|
60 |
+
plots: list[str]
|
61 |
+
plot_states: dict[str, PlotState]
|
62 |
+
error: NotRequired[str]
|
63 |
+
|
64 |
+
async def drias_workflow(user_input: str) -> State:
|
65 |
+
"""Performs the complete workflow of Talk To Drias : from user input to sql queries, dataframes and figures generated
|
66 |
+
|
67 |
+
Args:
|
68 |
+
user_input (str): initial user input
|
69 |
+
|
70 |
+
Returns:
|
71 |
+
State: Final state with all the results
|
72 |
+
"""
|
73 |
+
state: State = {
|
74 |
+
'user_input': user_input,
|
75 |
+
'plots': [],
|
76 |
+
'plot_states': {}
|
77 |
+
}
|
78 |
+
|
79 |
+
llm = get_llm(provider="openai")
|
80 |
+
|
81 |
+
plots = await find_relevant_plots(state, llm)
|
82 |
+
state['plots'] = plots
|
83 |
+
|
84 |
+
if not state['plots']:
|
85 |
+
state['error'] = 'There is no plot to answer to the question'
|
86 |
+
return state
|
87 |
+
|
88 |
+
have_relevant_table = False
|
89 |
+
have_sql_query = False
|
90 |
+
have_dataframe = False
|
91 |
+
for plot_name in state['plots']:
|
92 |
+
|
93 |
+
plot = next((p for p in PLOTS if p['name'] == plot_name), None) # Find the associated plot object
|
94 |
+
if plot is None:
|
95 |
+
continue
|
96 |
+
|
97 |
+
plot_state: PlotState = {
|
98 |
+
'plot_name': plot_name,
|
99 |
+
'tables': [],
|
100 |
+
'table_states': {}
|
101 |
+
}
|
102 |
+
|
103 |
+
plot_state['plot_name'] = plot_name
|
104 |
+
|
105 |
+
relevant_tables = await find_relevant_tables_per_plot(state, plot, llm)
|
106 |
+
if len(relevant_tables) > 0 :
|
107 |
+
have_relevant_table = True
|
108 |
+
|
109 |
+
plot_state['tables'] = relevant_tables
|
110 |
+
|
111 |
+
params = {}
|
112 |
+
for param_name in plot['params']:
|
113 |
+
param = await find_param(state, param_name, relevant_tables[0])
|
114 |
+
if param:
|
115 |
+
params.update(param)
|
116 |
+
|
117 |
+
for n, table in enumerate(plot_state['tables']):
|
118 |
+
if n > 2:
|
119 |
+
break
|
120 |
+
|
121 |
+
table_state: TableState = {
|
122 |
+
'table_name': table,
|
123 |
+
'params': params,
|
124 |
+
'status': 'OK'
|
125 |
+
}
|
126 |
+
|
127 |
+
table_state["params"]['indicator_column'] = find_indicator_column(table)
|
128 |
+
|
129 |
+
sql_query = plot['sql_query'](table, table_state['params'])
|
130 |
+
|
131 |
+
if sql_query == "":
|
132 |
+
table_state['status'] = 'ERROR'
|
133 |
+
continue
|
134 |
+
else :
|
135 |
+
have_sql_query = True
|
136 |
+
|
137 |
+
table_state['sql_query'] = sql_query
|
138 |
+
df = await execute_sql_query(sql_query)
|
139 |
+
|
140 |
+
if len(df) > 0:
|
141 |
+
have_dataframe = True
|
142 |
+
|
143 |
+
figure = plot['plot_function'](table_state['params'])
|
144 |
+
table_state['dataframe'] = df
|
145 |
+
table_state['figure'] = figure
|
146 |
+
plot_state['table_states'][table] = table_state
|
147 |
+
|
148 |
+
state['plot_states'][plot_name] = plot_state
|
149 |
+
|
150 |
+
if not have_relevant_table:
|
151 |
+
state['error'] = "There is no relevant table in the our database to answer your question"
|
152 |
+
elif not have_sql_query:
|
153 |
+
state['error'] = "There is no relevant sql query on our database that can help to answer your question"
|
154 |
+
elif not have_dataframe:
|
155 |
+
state['error'] = "There is no data in our table that can answer to your question"
|
156 |
+
|
157 |
+
return state
|
158 |
+
|
159 |
+
async def find_relevant_plots(state: State, llm) -> list[str]:
|
160 |
+
print("---- Find relevant plots ----")
|
161 |
+
relevant_plots = await detect_relevant_plots(state['user_input'], llm)
|
162 |
+
return relevant_plots
|
163 |
+
|
164 |
+
async def find_relevant_tables_per_plot(state: State, plot: Plot, llm) -> list[str]:
|
165 |
+
print(f"---- Find relevant tables for {plot['name']} ----")
|
166 |
+
relevant_tables = await detect_relevant_tables(state['user_input'], plot, llm)
|
167 |
+
return relevant_tables
|
168 |
+
|
169 |
+
async def find_param(state: State, param_name:str, table: str) -> dict[str, Any] | None:
|
170 |
+
"""Perform the good method to retrieve the desired parameter
|
171 |
+
|
172 |
+
Args:
|
173 |
+
state (State): state of the workflow
|
174 |
+
param_name (str): name of the desired parameter
|
175 |
+
table (str): name of the table
|
176 |
+
|
177 |
+
Returns:
|
178 |
+
dict[str, Any] | None:
|
179 |
+
"""
|
180 |
+
if param_name == 'location':
|
181 |
+
location = await find_location(state['user_input'], table)
|
182 |
+
return location
|
183 |
+
if param_name == 'year':
|
184 |
+
year = await find_year(state['user_input'])
|
185 |
+
return {'year': year}
|
186 |
+
return None
|
187 |
+
|
188 |
+
class Location(TypedDict):
|
189 |
+
location: str
|
190 |
+
latitude: NotRequired[str]
|
191 |
+
longitude: NotRequired[str]
|
192 |
+
|
193 |
+
async def find_location(user_input: str, table: str) -> Location:
|
194 |
+
print(f"---- Find location in table {table} ----")
|
195 |
+
location = await detect_location_with_openai(user_input)
|
196 |
+
output: Location = {'location' : location}
|
197 |
+
if location:
|
198 |
+
coords = loc2coords(location)
|
199 |
+
neighbour = nearestNeighbourSQL(coords, table)
|
200 |
+
output.update({
|
201 |
+
"latitude": neighbour[0],
|
202 |
+
"longitude": neighbour[1],
|
203 |
+
})
|
204 |
+
return output
|
205 |
+
|
206 |
+
async def find_year(user_input: str) -> str:
|
207 |
+
"""Extracts year information from user input using LLM.
|
208 |
+
|
209 |
+
This function uses an LLM to identify and extract year information from the
|
210 |
+
user's query, which is used to filter data in subsequent queries.
|
211 |
+
|
212 |
+
Args:
|
213 |
+
user_input (str): The user's query text
|
214 |
+
|
215 |
+
Returns:
|
216 |
+
str: The extracted year, or empty string if no year found
|
217 |
+
"""
|
218 |
+
print(f"---- Find year ---")
|
219 |
+
year = await detect_year_with_openai(user_input)
|
220 |
+
return year
|
221 |
+
|
222 |
+
def find_indicator_column(table: str) -> str:
|
223 |
+
"""Retrieves the name of the indicator column within a table.
|
224 |
+
|
225 |
+
This function maps table names to their corresponding indicator columns
|
226 |
+
using the predefined mapping in INDICATOR_COLUMNS_PER_TABLE.
|
227 |
+
|
228 |
+
Args:
|
229 |
+
table (str): Name of the table in the database
|
230 |
+
|
231 |
+
Returns:
|
232 |
+
str: Name of the indicator column for the specified table
|
233 |
+
|
234 |
+
Raises:
|
235 |
+
KeyError: If the table name is not found in the mapping
|
236 |
+
"""
|
237 |
+
print(f"---- Find indicator column in table {table} ----")
|
238 |
+
return INDICATOR_COLUMNS_PER_TABLE[table]
|
239 |
+
|
240 |
+
|
241 |
+
# def make_write_query_node():
|
242 |
+
|
243 |
+
# def write_query(state):
|
244 |
+
# print("---- Write query ----")
|
245 |
+
# for table in state["tables"]:
|
246 |
+
# sql_query = QUERIES[state[table]['query_type']](
|
247 |
+
# table=table,
|
248 |
+
# indicator_column=state[table]["columns"],
|
249 |
+
# longitude=state[table]["longitude"],
|
250 |
+
# latitude=state[table]["latitude"],
|
251 |
+
# )
|
252 |
+
# state[table].update({"sql_query": sql_query})
|
253 |
+
|
254 |
+
# return state
|
255 |
+
|
256 |
+
# return write_query
|
257 |
+
|
258 |
+
# def make_fetch_data_node(db_path):
|
259 |
+
|
260 |
+
# def fetch_data(state):
|
261 |
+
# print("---- Fetch data ----")
|
262 |
+
# for table in state["tables"]:
|
263 |
+
# results = execute_sql_query(db_path, state[table]['sql_query'])
|
264 |
+
# state[table].update(results)
|
265 |
+
|
266 |
+
# return state
|
267 |
+
|
268 |
+
# return fetch_data
|
269 |
+
|
270 |
+
|
271 |
+
|
272 |
+
## V2
|
273 |
+
|
274 |
+
|
275 |
+
# def make_fetch_data_node(db_path: str, llm):
|
276 |
+
# def fetch_data(state):
|
277 |
+
# print("---- Fetch data ----")
|
278 |
+
# db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
|
279 |
+
# output = {}
|
280 |
+
# sql_query = write_sql_query(state["query"], db, state["tables"], llm)
|
281 |
+
# # TO DO : Add query checker
|
282 |
+
# print(f"SQL query : {sql_query}")
|
283 |
+
# output["sql_query"] = sql_query
|
284 |
+
# output.update(fetch_data_from_sql_query(db_path, sql_query))
|
285 |
+
# return output
|
286 |
+
|
287 |
+
# return fetch_data
|
@@ -3,4 +3,7 @@ from .tab_examples import create_examples_tab
|
|
3 |
from .tab_papers import create_papers_tab
|
4 |
from .tab_figures import create_figures_tab
|
5 |
from .chat_interface import create_chat_interface
|
6 |
-
from .tab_about import create_about_tab
|
|
|
|
|
|
|
|
3 |
from .tab_papers import create_papers_tab
|
4 |
from .tab_figures import create_figures_tab
|
5 |
from .chat_interface import create_chat_interface
|
6 |
+
from .tab_about import create_about_tab
|
7 |
+
from .main_tab import MainTabPanel
|
8 |
+
from .tab_config import ConfigPanel
|
9 |
+
from .main_tab import cqa_tab
|
@@ -21,21 +21,21 @@ What do you want to learn ?
|
|
21 |
"""
|
22 |
|
23 |
init_prompt_poc = """
|
24 |
-
|
25 |
|
26 |
-
❓
|
27 |
-
- **Language
|
28 |
-
- **Audience
|
29 |
-
- **Sources
|
30 |
-
- **Relevant content sources
|
31 |
|
32 |
⚠️ Limitations
|
33 |
-
*
|
34 |
|
35 |
-
🛈
|
36 |
-
|
37 |
|
38 |
-
|
39 |
"""
|
40 |
|
41 |
|
@@ -54,7 +54,10 @@ def create_chat_interface(tab):
|
|
54 |
max_height="80vh",
|
55 |
height="100vh"
|
56 |
)
|
57 |
-
|
|
|
|
|
|
|
58 |
with gr.Row(elem_id="input-message"):
|
59 |
|
60 |
textbox = gr.Textbox(
|
@@ -68,7 +71,7 @@ def create_chat_interface(tab):
|
|
68 |
|
69 |
config_button = gr.Button("", elem_id="config-button")
|
70 |
|
71 |
-
return chatbot, textbox, config_button
|
72 |
|
73 |
|
74 |
|
|
|
21 |
"""
|
22 |
|
23 |
init_prompt_poc = """
|
24 |
+
Bonjour, je suis ClimateQ&A, un assistant conversationnel conçu pour vous aider à comprendre le changement climatique et la perte de biodiversité. Je réponds à vos questions en **parcourant les rapports scientifiques du GIEC et de l'IPBES, le PCAET de Paris, le Plan Biodiversité 2018-2024, et les rapports Acclimaterra de la Région Nouvelle-Aquitaine**.
|
25 |
|
26 |
+
❓ Mode d'emploi
|
27 |
+
- **Language** : Vous pouvez me poser vos questions dans n'importe quelle langue.
|
28 |
+
- **Audience** : Vous pouvez préciser votre public (enfants, grand public, experts) pour obtenir une réponse plus adaptée.
|
29 |
+
- **Sources** : Vous pouvez choisir de chercher dans les rapports du GIEC ou de l'IPBES, et dans les sources POC pour les documents locaux (PCAET, Plan Biodiversité, Acclimaterra).
|
30 |
+
- **Relevant content sources** : Vous pouvez choisir de rechercher des images, des papiers scientifiques ou des graphiques qui peuvent être pertinents pour votre question.
|
31 |
|
32 |
⚠️ Limitations
|
33 |
+
*Veuillez noter que l'IA n'est pas parfaite et peut parfois donner des réponses non pertinentes. Si vous n'êtes pas satisfait de la réponse, veuillez poser une question plus précise ou nous faire part de vos commentaires pour nous aider à améliorer le système.*
|
34 |
|
35 |
+
🛈 Informations
|
36 |
+
Veuillez noter que nous enregistrons vos questions à des fins de méta-analyse, évitez donc de partager toute information sensible ou personnelle.
|
37 |
|
38 |
+
Que voulez-vous apprendre ?
|
39 |
"""
|
40 |
|
41 |
|
|
|
54 |
max_height="80vh",
|
55 |
height="100vh"
|
56 |
)
|
57 |
+
with gr.Accordion("Click here for follow up questions examples", elem_id="follow-up-examples",open = False):
|
58 |
+
follow_up_examples_hidden = gr.Textbox(visible=False, elem_id="follow-up-hidden")
|
59 |
+
follow_up_examples = gr.Examples(examples=["What evidence do we have of climate change ?"], label="", inputs= [follow_up_examples_hidden], elem_id="follow-up-button", run_on_click=False)
|
60 |
+
|
61 |
with gr.Row(elem_id="input-message"):
|
62 |
|
63 |
textbox = gr.Textbox(
|
|
|
71 |
|
72 |
config_button = gr.Button("", elem_id="config-button")
|
73 |
|
74 |
+
return chatbot, textbox, config_button, follow_up_examples, follow_up_examples_hidden
|
75 |
|
76 |
|
77 |
|
@@ -1,8 +1,37 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
2 |
from .chat_interface import create_chat_interface
|
3 |
from .tab_examples import create_examples_tab
|
4 |
from .tab_papers import create_papers_tab
|
5 |
from .tab_figures import create_figures_tab
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
def cqa_tab(tab_name):
|
8 |
# State variables
|
@@ -11,14 +40,14 @@ def cqa_tab(tab_name):
|
|
11 |
with gr.Row(elem_id="chatbot-row"):
|
12 |
# Left column - Chat interface
|
13 |
with gr.Column(scale=2):
|
14 |
-
chatbot, textbox, config_button = create_chat_interface(tab_name)
|
15 |
|
16 |
# Right column - Content panels
|
17 |
with gr.Column(scale=2, variant="panel", elem_id="right-panel"):
|
18 |
with gr.Tabs(elem_id="right_panel_tab") as tabs:
|
19 |
# Examples tab
|
20 |
with gr.TabItem("Examples", elem_id="tab-examples", id=0):
|
21 |
-
examples_hidden
|
22 |
|
23 |
# Sources tab
|
24 |
with gr.Tab("Sources", elem_id="tab-sources", id=1) as tab_sources:
|
@@ -34,7 +63,7 @@ def cqa_tab(tab_name):
|
|
34 |
|
35 |
# Papers subtab
|
36 |
with gr.Tab("Papers", elem_id="tab-citations", id=4) as tab_papers:
|
37 |
-
papers_summary, papers_html, citations_network, papers_modal = create_papers_tab()
|
38 |
|
39 |
# Graphs subtab
|
40 |
with gr.Tab("Graphs", elem_id="tab-graphs", id=5) as tab_graphs:
|
@@ -42,27 +71,30 @@ def cqa_tab(tab_name):
|
|
42 |
"<h2>There are no graphs to be displayed at the moment. Try asking another question.</h2>",
|
43 |
elem_id="graphs-container"
|
44 |
)
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
from gradio.helpers import Examples
|
3 |
+
from typing import TypedDict
|
4 |
from .chat_interface import create_chat_interface
|
5 |
from .tab_examples import create_examples_tab
|
6 |
from .tab_papers import create_papers_tab
|
7 |
from .tab_figures import create_figures_tab
|
8 |
+
from dataclasses import dataclass
|
9 |
+
|
10 |
+
@dataclass
|
11 |
+
class MainTabPanel:
|
12 |
+
chatbot: gr.Chatbot
|
13 |
+
textbox: gr.Textbox
|
14 |
+
tabs: gr.Tabs
|
15 |
+
sources_raw: gr.State
|
16 |
+
new_figures: gr.State
|
17 |
+
current_graphs: gr.State
|
18 |
+
examples_hidden: gr.State
|
19 |
+
sources_textbox: gr.HTML
|
20 |
+
figures_cards: gr.HTML
|
21 |
+
gallery_component: gr.Gallery
|
22 |
+
config_button: gr.Button
|
23 |
+
papers_direct_search: gr.TextArea
|
24 |
+
papers_html: gr.HTML
|
25 |
+
citations_network: gr.Plot
|
26 |
+
papers_summary: gr.Textbox
|
27 |
+
tab_recommended_content: gr.Tab
|
28 |
+
tab_sources: gr.Tab
|
29 |
+
tab_figures: gr.Tab
|
30 |
+
tab_graphs: gr.Tab
|
31 |
+
tab_papers: gr.Tab
|
32 |
+
graph_container: gr.HTML
|
33 |
+
follow_up_examples : Examples
|
34 |
+
follow_up_examples_hidden : gr.Textbox
|
35 |
|
36 |
def cqa_tab(tab_name):
|
37 |
# State variables
|
|
|
40 |
with gr.Row(elem_id="chatbot-row"):
|
41 |
# Left column - Chat interface
|
42 |
with gr.Column(scale=2):
|
43 |
+
chatbot, textbox, config_button, follow_up_examples, follow_up_examples_hidden = create_chat_interface(tab_name)
|
44 |
|
45 |
# Right column - Content panels
|
46 |
with gr.Column(scale=2, variant="panel", elem_id="right-panel"):
|
47 |
with gr.Tabs(elem_id="right_panel_tab") as tabs:
|
48 |
# Examples tab
|
49 |
with gr.TabItem("Examples", elem_id="tab-examples", id=0):
|
50 |
+
examples_hidden = create_examples_tab(tab_name)
|
51 |
|
52 |
# Sources tab
|
53 |
with gr.Tab("Sources", elem_id="tab-sources", id=1) as tab_sources:
|
|
|
63 |
|
64 |
# Papers subtab
|
65 |
with gr.Tab("Papers", elem_id="tab-citations", id=4) as tab_papers:
|
66 |
+
papers_direct_search, papers_summary, papers_html, citations_network, papers_modal = create_papers_tab()
|
67 |
|
68 |
# Graphs subtab
|
69 |
with gr.Tab("Graphs", elem_id="tab-graphs", id=5) as tab_graphs:
|
|
|
71 |
"<h2>There are no graphs to be displayed at the moment. Try asking another question.</h2>",
|
72 |
elem_id="graphs-container"
|
73 |
)
|
74 |
+
|
75 |
+
|
76 |
+
return MainTabPanel(
|
77 |
+
chatbot=chatbot,
|
78 |
+
textbox=textbox,
|
79 |
+
tabs=tabs,
|
80 |
+
sources_raw=sources_raw,
|
81 |
+
new_figures=new_figures,
|
82 |
+
current_graphs=current_graphs,
|
83 |
+
examples_hidden=examples_hidden,
|
84 |
+
sources_textbox=sources_textbox,
|
85 |
+
figures_cards=figures_cards,
|
86 |
+
gallery_component=gallery_component,
|
87 |
+
config_button=config_button,
|
88 |
+
papers_direct_search=papers_direct_search,
|
89 |
+
papers_html=papers_html,
|
90 |
+
citations_network=citations_network,
|
91 |
+
papers_summary=papers_summary,
|
92 |
+
tab_recommended_content=tab_recommended_content,
|
93 |
+
tab_sources=tab_sources,
|
94 |
+
tab_figures=tab_figures,
|
95 |
+
tab_graphs=tab_graphs,
|
96 |
+
tab_papers=tab_papers,
|
97 |
+
graph_container=graphs_container,
|
98 |
+
follow_up_examples= follow_up_examples,
|
99 |
+
follow_up_examples_hidden = follow_up_examples_hidden
|
100 |
+
)
|
@@ -2,8 +2,10 @@ import gradio as gr
|
|
2 |
from gradio_modal import Modal
|
3 |
from climateqa.constants import POSSIBLE_REPORTS
|
4 |
from typing import TypedDict
|
|
|
5 |
|
6 |
-
|
|
|
7 |
config_open: gr.State
|
8 |
config_modal: Modal
|
9 |
dropdown_sources: gr.CheckboxGroup
|
@@ -14,6 +16,7 @@ class ConfigPanel(TypedDict):
|
|
14 |
after: gr.Slider
|
15 |
output_query: gr.Textbox
|
16 |
output_language: gr.Textbox
|
|
|
17 |
|
18 |
|
19 |
def create_config_modal():
|
@@ -37,9 +40,9 @@ def create_config_modal():
|
|
37 |
)
|
38 |
|
39 |
dropdown_external_sources = gr.CheckboxGroup(
|
40 |
-
choices=["Figures (IPCC/IPBES)", "Papers (OpenAlex)", "Graphs (OurWorldInData)"
|
41 |
label="Select database to search for relevant content",
|
42 |
-
value=["Figures (IPCC/IPBES)"
|
43 |
interactive=True
|
44 |
)
|
45 |
|
@@ -95,29 +98,16 @@ def create_config_modal():
|
|
95 |
|
96 |
close_config_modal_button = gr.Button("Validate and Close", elem_id="close-config-modal")
|
97 |
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
return {
|
112 |
-
"config_open" : config_open,
|
113 |
-
"config_modal": config_modal,
|
114 |
-
"dropdown_sources": dropdown_sources,
|
115 |
-
"dropdown_reports": dropdown_reports,
|
116 |
-
"dropdown_external_sources": dropdown_external_sources,
|
117 |
-
"search_only": search_only,
|
118 |
-
"dropdown_audience": dropdown_audience,
|
119 |
-
"after": after,
|
120 |
-
"output_query": output_query,
|
121 |
-
"output_language": output_language,
|
122 |
-
"close_config_modal_button": close_config_modal_button
|
123 |
-
}
|
|
|
2 |
from gradio_modal import Modal
|
3 |
from climateqa.constants import POSSIBLE_REPORTS
|
4 |
from typing import TypedDict
|
5 |
+
from dataclasses import dataclass
|
6 |
|
7 |
+
@dataclass
|
8 |
+
class ConfigPanel:
|
9 |
config_open: gr.State
|
10 |
config_modal: Modal
|
11 |
dropdown_sources: gr.CheckboxGroup
|
|
|
16 |
after: gr.Slider
|
17 |
output_query: gr.Textbox
|
18 |
output_language: gr.Textbox
|
19 |
+
close_config_modal_button: gr.Button
|
20 |
|
21 |
|
22 |
def create_config_modal():
|
|
|
40 |
)
|
41 |
|
42 |
dropdown_external_sources = gr.CheckboxGroup(
|
43 |
+
choices=["Figures (IPCC/IPBES)", "Papers (OpenAlex)", "Graphs (OurWorldInData)"],
|
44 |
label="Select database to search for relevant content",
|
45 |
+
value=["Figures (IPCC/IPBES)"],
|
46 |
interactive=True
|
47 |
)
|
48 |
|
|
|
98 |
|
99 |
close_config_modal_button = gr.Button("Validate and Close", elem_id="close-config-modal")
|
100 |
|
101 |
+
return ConfigPanel(
|
102 |
+
config_open=config_open,
|
103 |
+
config_modal=config_modal,
|
104 |
+
dropdown_sources=dropdown_sources,
|
105 |
+
dropdown_reports=dropdown_reports,
|
106 |
+
dropdown_external_sources=dropdown_external_sources,
|
107 |
+
search_only=search_only,
|
108 |
+
dropdown_audience=dropdown_audience,
|
109 |
+
after=after,
|
110 |
+
output_query=output_query,
|
111 |
+
output_language=output_language,
|
112 |
+
close_config_modal_button=close_config_modal_button
|
113 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from typing import TypedDict, List, Optional
|
3 |
+
import os
|
4 |
+
import pandas as pd
|
5 |
+
|
6 |
+
from climateqa.engine.talk_to_data.main import ask_drias
|
7 |
+
from climateqa.engine.talk_to_data.config import DRIAS_MODELS, DRIAS_UI_TEXT
|
8 |
+
from climateqa.chat import log_drias_interaction_to_azure
|
9 |
+
|
10 |
+
|
11 |
+
class DriasUIElements(TypedDict):
|
12 |
+
tab: gr.Tab
|
13 |
+
details_accordion: gr.Accordion
|
14 |
+
examples_hidden: gr.Textbox
|
15 |
+
examples: gr.Examples
|
16 |
+
drias_direct_question: gr.Textbox
|
17 |
+
result_text: gr.Textbox
|
18 |
+
table_names_display: gr.DataFrame
|
19 |
+
query_accordion: gr.Accordion
|
20 |
+
drias_sql_query: gr.Textbox
|
21 |
+
chart_accordion: gr.Accordion
|
22 |
+
model_selection: gr.Dropdown
|
23 |
+
drias_display: gr.Plot
|
24 |
+
table_accordion: gr.Accordion
|
25 |
+
drias_table: gr.DataFrame
|
26 |
+
pagination_display: gr.Markdown
|
27 |
+
prev_button: gr.Button
|
28 |
+
next_button: gr.Button
|
29 |
+
|
30 |
+
|
31 |
+
async def ask_drias_query(query: str, index_state: int):
|
32 |
+
result = await ask_drias(query, index_state)
|
33 |
+
return result
|
34 |
+
|
35 |
+
|
36 |
+
def show_results(sql_queries_state, dataframes_state, plots_state):
|
37 |
+
if not sql_queries_state or not dataframes_state or not plots_state:
|
38 |
+
# If all results are empty, show "No result"
|
39 |
+
return (
|
40 |
+
gr.update(visible=True),
|
41 |
+
gr.update(visible=False),
|
42 |
+
gr.update(visible=False),
|
43 |
+
gr.update(visible=False),
|
44 |
+
gr.update(visible=False),
|
45 |
+
gr.update(visible=False),
|
46 |
+
gr.update(visible=False),
|
47 |
+
gr.update(visible=False),
|
48 |
+
)
|
49 |
+
else:
|
50 |
+
# Show the appropriate components with their data
|
51 |
+
return (
|
52 |
+
gr.update(visible=False),
|
53 |
+
gr.update(visible=True),
|
54 |
+
gr.update(visible=True),
|
55 |
+
gr.update(visible=True),
|
56 |
+
gr.update(visible=True),
|
57 |
+
gr.update(visible=True),
|
58 |
+
gr.update(visible=True),
|
59 |
+
gr.update(visible=True),
|
60 |
+
)
|
61 |
+
|
62 |
+
|
63 |
+
def filter_by_model(dataframes, figures, index_state, model_selection):
|
64 |
+
df = dataframes[index_state]
|
65 |
+
if df.empty:
|
66 |
+
return df, None
|
67 |
+
if "model" not in df.columns:
|
68 |
+
return df, figures[index_state](df)
|
69 |
+
if model_selection != "ALL":
|
70 |
+
df = df[df["model"] == model_selection]
|
71 |
+
if df.empty:
|
72 |
+
return df, None
|
73 |
+
figure = figures[index_state](df)
|
74 |
+
return df, figure
|
75 |
+
|
76 |
+
|
77 |
+
def update_pagination(index, sql_queries):
|
78 |
+
pagination = f"{index + 1}/{len(sql_queries)}" if sql_queries else ""
|
79 |
+
return pagination
|
80 |
+
|
81 |
+
|
82 |
+
def show_previous(index, sql_queries, dataframes, plots):
|
83 |
+
if index > 0:
|
84 |
+
index -= 1
|
85 |
+
return (
|
86 |
+
sql_queries[index],
|
87 |
+
dataframes[index],
|
88 |
+
plots[index](dataframes[index]),
|
89 |
+
index,
|
90 |
+
)
|
91 |
+
|
92 |
+
|
93 |
+
def show_next(index, sql_queries, dataframes, plots):
|
94 |
+
if index < len(sql_queries) - 1:
|
95 |
+
index += 1
|
96 |
+
return (
|
97 |
+
sql_queries[index],
|
98 |
+
dataframes[index],
|
99 |
+
plots[index](dataframes[index]),
|
100 |
+
index,
|
101 |
+
)
|
102 |
+
|
103 |
+
|
104 |
+
def display_table_names(table_names):
|
105 |
+
return [table_names]
|
106 |
+
|
107 |
+
|
108 |
+
def on_table_click(evt: gr.SelectData, table_names, sql_queries, dataframes, plots):
|
109 |
+
index = evt.index[1]
|
110 |
+
figure = plots[index](dataframes[index])
|
111 |
+
return (
|
112 |
+
sql_queries[index],
|
113 |
+
dataframes[index],
|
114 |
+
figure,
|
115 |
+
index,
|
116 |
+
)
|
117 |
+
|
118 |
+
|
119 |
+
def create_drias_ui() -> DriasUIElements:
|
120 |
+
"""Create and return all UI elements for the DRIAS tab."""
|
121 |
+
with gr.Tab("France - Talk to DRIAS", elem_id="tab-vanna", id=6) as tab:
|
122 |
+
with gr.Accordion(label="Details") as details_accordion:
|
123 |
+
gr.Markdown(DRIAS_UI_TEXT)
|
124 |
+
|
125 |
+
# Add examples for common questions
|
126 |
+
examples_hidden = gr.Textbox(visible=False, elem_id="drias-examples-hidden")
|
127 |
+
examples = gr.Examples(
|
128 |
+
examples=[
|
129 |
+
["What will the temperature be like in Paris?"],
|
130 |
+
["What will be the total rainfall in France in 2030?"],
|
131 |
+
["How frequent will extreme events be in Lyon?"],
|
132 |
+
["Comment va évoluer la température en France entre 2030 et 2050 ?"]
|
133 |
+
],
|
134 |
+
label="Example Questions",
|
135 |
+
inputs=[examples_hidden],
|
136 |
+
outputs=[examples_hidden],
|
137 |
+
)
|
138 |
+
|
139 |
+
with gr.Row():
|
140 |
+
drias_direct_question = gr.Textbox(
|
141 |
+
label="Direct Question",
|
142 |
+
placeholder="You can write direct question here",
|
143 |
+
elem_id="direct-question",
|
144 |
+
interactive=True,
|
145 |
+
)
|
146 |
+
|
147 |
+
result_text = gr.Textbox(
|
148 |
+
label="", elem_id="no-result-label", interactive=False, visible=True
|
149 |
+
)
|
150 |
+
|
151 |
+
table_names_display = gr.DataFrame(
|
152 |
+
[], label="List of relevant indicators", headers=None, interactive=False, elem_id="table-names", visible=False
|
153 |
+
)
|
154 |
+
|
155 |
+
with gr.Accordion(label="SQL Query Used", visible=False) as query_accordion:
|
156 |
+
drias_sql_query = gr.Textbox(
|
157 |
+
label="", elem_id="sql-query", interactive=False
|
158 |
+
)
|
159 |
+
|
160 |
+
with gr.Accordion(label="Chart", visible=False) as chart_accordion:
|
161 |
+
model_selection = gr.Dropdown(
|
162 |
+
label="Model", choices=DRIAS_MODELS, value="ALL", interactive=True
|
163 |
+
)
|
164 |
+
drias_display = gr.Plot(elem_id="vanna-plot")
|
165 |
+
|
166 |
+
with gr.Accordion(
|
167 |
+
label="Data used", open=False, visible=False
|
168 |
+
) as table_accordion:
|
169 |
+
drias_table = gr.DataFrame([], elem_id="vanna-table")
|
170 |
+
|
171 |
+
pagination_display = gr.Markdown(
|
172 |
+
value="", visible=False, elem_id="pagination-display"
|
173 |
+
)
|
174 |
+
|
175 |
+
with gr.Row():
|
176 |
+
prev_button = gr.Button("Previous", visible=False)
|
177 |
+
next_button = gr.Button("Next", visible=False)
|
178 |
+
|
179 |
+
return DriasUIElements(
|
180 |
+
tab=tab,
|
181 |
+
details_accordion=details_accordion,
|
182 |
+
examples_hidden=examples_hidden,
|
183 |
+
examples=examples,
|
184 |
+
drias_direct_question=drias_direct_question,
|
185 |
+
result_text=result_text,
|
186 |
+
table_names_display=table_names_display,
|
187 |
+
query_accordion=query_accordion,
|
188 |
+
drias_sql_query=drias_sql_query,
|
189 |
+
chart_accordion=chart_accordion,
|
190 |
+
model_selection=model_selection,
|
191 |
+
drias_display=drias_display,
|
192 |
+
table_accordion=table_accordion,
|
193 |
+
drias_table=drias_table,
|
194 |
+
pagination_display=pagination_display,
|
195 |
+
prev_button=prev_button,
|
196 |
+
next_button=next_button
|
197 |
+
)
|
198 |
+
|
199 |
+
def log_drias_to_azure(query: str, sql_query: str, data, share_client, user_id):
|
200 |
+
"""Log Drias interaction to Azure storage."""
|
201 |
+
print("log_drias_to_azure")
|
202 |
+
if share_client is not None and user_id is not None:
|
203 |
+
log_drias_interaction_to_azure(
|
204 |
+
query=query,
|
205 |
+
sql_query=sql_query,
|
206 |
+
data=data,
|
207 |
+
share_client=share_client,
|
208 |
+
user_id=user_id
|
209 |
+
)
|
210 |
+
else:
|
211 |
+
print("share_client or user_id is None")
|
212 |
+
|
213 |
+
def setup_drias_events(ui_elements: DriasUIElements, share_client=None, user_id=None) -> None:
|
214 |
+
"""Set up all event handlers for the DRIAS tab."""
|
215 |
+
# Create state variables
|
216 |
+
sql_queries_state = gr.State([])
|
217 |
+
dataframes_state = gr.State([])
|
218 |
+
plots_state = gr.State([])
|
219 |
+
index_state = gr.State(0)
|
220 |
+
table_names_list = gr.State([])
|
221 |
+
|
222 |
+
def log_drias_interaction(query: str, sql_query: str, data: pd.DataFrame):
|
223 |
+
log_drias_to_azure(query, sql_query, data, share_client, user_id)
|
224 |
+
|
225 |
+
|
226 |
+
# Handle example selection
|
227 |
+
ui_elements["examples_hidden"].change(
|
228 |
+
lambda x: (gr.Accordion(open=False), gr.Textbox(value=x)),
|
229 |
+
inputs=[ui_elements["examples_hidden"]],
|
230 |
+
outputs=[ui_elements["details_accordion"], ui_elements["drias_direct_question"]]
|
231 |
+
).then(
|
232 |
+
ask_drias_query,
|
233 |
+
inputs=[ui_elements["examples_hidden"], index_state],
|
234 |
+
outputs=[
|
235 |
+
ui_elements["drias_sql_query"],
|
236 |
+
ui_elements["drias_table"],
|
237 |
+
ui_elements["drias_display"],
|
238 |
+
sql_queries_state,
|
239 |
+
dataframes_state,
|
240 |
+
plots_state,
|
241 |
+
index_state,
|
242 |
+
table_names_list,
|
243 |
+
ui_elements["result_text"],
|
244 |
+
],
|
245 |
+
).then(
|
246 |
+
log_drias_interaction,
|
247 |
+
inputs=[ui_elements["examples_hidden"], ui_elements["drias_sql_query"], ui_elements["drias_table"]],
|
248 |
+
outputs=[],
|
249 |
+
).then(
|
250 |
+
show_results,
|
251 |
+
inputs=[sql_queries_state, dataframes_state, plots_state],
|
252 |
+
outputs=[
|
253 |
+
ui_elements["result_text"],
|
254 |
+
ui_elements["query_accordion"],
|
255 |
+
ui_elements["table_accordion"],
|
256 |
+
ui_elements["chart_accordion"],
|
257 |
+
ui_elements["prev_button"],
|
258 |
+
ui_elements["next_button"],
|
259 |
+
ui_elements["pagination_display"],
|
260 |
+
ui_elements["table_names_display"],
|
261 |
+
],
|
262 |
+
).then(
|
263 |
+
update_pagination,
|
264 |
+
inputs=[index_state, sql_queries_state],
|
265 |
+
outputs=[ui_elements["pagination_display"]],
|
266 |
+
).then(
|
267 |
+
display_table_names,
|
268 |
+
inputs=[table_names_list],
|
269 |
+
outputs=[ui_elements["table_names_display"]],
|
270 |
+
)
|
271 |
+
|
272 |
+
# Handle direct question submission
|
273 |
+
ui_elements["drias_direct_question"].submit(
|
274 |
+
lambda: gr.Accordion(open=False),
|
275 |
+
inputs=None,
|
276 |
+
outputs=[ui_elements["details_accordion"]]
|
277 |
+
).then(
|
278 |
+
ask_drias_query,
|
279 |
+
inputs=[ui_elements["drias_direct_question"], index_state],
|
280 |
+
outputs=[
|
281 |
+
ui_elements["drias_sql_query"],
|
282 |
+
ui_elements["drias_table"],
|
283 |
+
ui_elements["drias_display"],
|
284 |
+
sql_queries_state,
|
285 |
+
dataframes_state,
|
286 |
+
plots_state,
|
287 |
+
index_state,
|
288 |
+
table_names_list,
|
289 |
+
ui_elements["result_text"],
|
290 |
+
],
|
291 |
+
).then(
|
292 |
+
log_drias_interaction,
|
293 |
+
inputs=[ui_elements["drias_direct_question"], ui_elements["drias_sql_query"], ui_elements["drias_table"]],
|
294 |
+
outputs=[],
|
295 |
+
).then(
|
296 |
+
show_results,
|
297 |
+
inputs=[sql_queries_state, dataframes_state, plots_state],
|
298 |
+
outputs=[
|
299 |
+
ui_elements["result_text"],
|
300 |
+
ui_elements["query_accordion"],
|
301 |
+
ui_elements["table_accordion"],
|
302 |
+
ui_elements["chart_accordion"],
|
303 |
+
ui_elements["prev_button"],
|
304 |
+
ui_elements["next_button"],
|
305 |
+
ui_elements["pagination_display"],
|
306 |
+
ui_elements["table_names_display"],
|
307 |
+
],
|
308 |
+
).then(
|
309 |
+
update_pagination,
|
310 |
+
inputs=[index_state, sql_queries_state],
|
311 |
+
outputs=[ui_elements["pagination_display"]],
|
312 |
+
).then(
|
313 |
+
display_table_names,
|
314 |
+
inputs=[table_names_list],
|
315 |
+
outputs=[ui_elements["table_names_display"]],
|
316 |
+
)
|
317 |
+
|
318 |
+
# Handle model selection change
|
319 |
+
ui_elements["model_selection"].change(
|
320 |
+
filter_by_model,
|
321 |
+
inputs=[dataframes_state, plots_state, index_state, ui_elements["model_selection"]],
|
322 |
+
outputs=[ui_elements["drias_table"], ui_elements["drias_display"]],
|
323 |
+
)
|
324 |
+
|
325 |
+
# Handle pagination buttons
|
326 |
+
ui_elements["prev_button"].click(
|
327 |
+
show_previous,
|
328 |
+
inputs=[index_state, sql_queries_state, dataframes_state, plots_state],
|
329 |
+
outputs=[ui_elements["drias_sql_query"], ui_elements["drias_table"], ui_elements["drias_display"], index_state],
|
330 |
+
).then(
|
331 |
+
update_pagination,
|
332 |
+
inputs=[index_state, sql_queries_state],
|
333 |
+
outputs=[ui_elements["pagination_display"]],
|
334 |
+
)
|
335 |
+
|
336 |
+
ui_elements["next_button"].click(
|
337 |
+
show_next,
|
338 |
+
inputs=[index_state, sql_queries_state, dataframes_state, plots_state],
|
339 |
+
outputs=[ui_elements["drias_sql_query"], ui_elements["drias_table"], ui_elements["drias_display"], index_state],
|
340 |
+
).then(
|
341 |
+
update_pagination,
|
342 |
+
inputs=[index_state, sql_queries_state],
|
343 |
+
outputs=[ui_elements["pagination_display"]],
|
344 |
+
)
|
345 |
+
|
346 |
+
# Handle table selection
|
347 |
+
ui_elements["table_names_display"].select(
|
348 |
+
fn=on_table_click,
|
349 |
+
inputs=[table_names_list, sql_queries_state, dataframes_state, plots_state],
|
350 |
+
outputs=[ui_elements["drias_sql_query"], ui_elements["drias_table"], ui_elements["drias_display"], index_state],
|
351 |
+
).then(
|
352 |
+
update_pagination,
|
353 |
+
inputs=[index_state, sql_queries_state],
|
354 |
+
outputs=[ui_elements["pagination_display"]],
|
355 |
+
)
|
356 |
+
|
357 |
+
def create_drias_tab(share_client=None, user_id=None):
|
358 |
+
"""Create the DRIAS tab with all its components and event handlers."""
|
359 |
+
ui_elements = create_drias_ui()
|
360 |
+
setup_drias_events(ui_elements, share_client=share_client, user_id=user_id)
|
361 |
+
|
362 |
+
|
@@ -29,8 +29,6 @@ main.flex.flex-1.flex-col {
|
|
29 |
}
|
30 |
|
31 |
|
32 |
-
}
|
33 |
-
|
34 |
.tab-nav {
|
35 |
border: none !important;
|
36 |
}
|
@@ -111,10 +109,18 @@ main.flex.flex-1.flex-col {
|
|
111 |
border: none;
|
112 |
}
|
113 |
|
114 |
-
#input-textbox > label > textarea {
|
115 |
border-radius: 40px;
|
116 |
padding-left: 30px;
|
117 |
resize: none;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
}
|
119 |
|
120 |
#input-message > div {
|
@@ -474,6 +480,33 @@ a {
|
|
474 |
text-decoration: none !important;
|
475 |
}
|
476 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
477 |
/* Media Queries */
|
478 |
/* Desktop Media Query */
|
479 |
@media screen and (min-width: 1024px) {
|
@@ -487,7 +520,6 @@ a {
|
|
487 |
height: calc(100vh - 190px) !important;
|
488 |
overflow-y: scroll !important;
|
489 |
}
|
490 |
-
div#tab-vanna,
|
491 |
div#sources-figures,
|
492 |
div#graphs-container,
|
493 |
div#tab-citations {
|
@@ -496,6 +528,15 @@ a {
|
|
496 |
overflow-y: scroll !important;
|
497 |
}
|
498 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
499 |
div#chatbot-row {
|
500 |
max-height: calc(100vh - 90px) !important;
|
501 |
}
|
@@ -514,7 +555,11 @@ a {
|
|
514 |
/* Mobile Media Query */
|
515 |
@media screen and (max-width: 767px) {
|
516 |
div#chatbot {
|
517 |
-
height:
|
|
|
|
|
|
|
|
|
518 |
}
|
519 |
|
520 |
#submit-button {
|
@@ -607,14 +652,61 @@ a {
|
|
607 |
}
|
608 |
|
609 |
#vanna-display {
|
610 |
-
max-height:
|
611 |
/* overflow-y: scroll; */
|
612 |
}
|
613 |
#sql-query{
|
614 |
-
max-height:
|
615 |
overflow-y:scroll;
|
616 |
}
|
617 |
-
|
618 |
-
|
619 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
620 |
}
|
|
|
29 |
}
|
30 |
|
31 |
|
|
|
|
|
32 |
.tab-nav {
|
33 |
border: none !important;
|
34 |
}
|
|
|
109 |
border: none;
|
110 |
}
|
111 |
|
112 |
+
#input-textbox > label > div > textarea {
|
113 |
border-radius: 40px;
|
114 |
padding-left: 30px;
|
115 |
resize: none;
|
116 |
+
background-color: #d7e2ed; /* Light blue background */
|
117 |
+
border: 2px solid #4b8ec3; /* Blue border */
|
118 |
+
font-size: 16px; /* Increase font size */
|
119 |
+
color: #333; /* Text color */
|
120 |
+
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Add shadow */
|
121 |
+
::placeholder {
|
122 |
+
color: #4b4747; /* Darker placeholder color */
|
123 |
+
}
|
124 |
}
|
125 |
|
126 |
#input-message > div {
|
|
|
480 |
text-decoration: none !important;
|
481 |
}
|
482 |
|
483 |
+
/* Follow-up Examples Styles */
|
484 |
+
#follow-up-examples {
|
485 |
+
max-height: 20vh;
|
486 |
+
overflow-y: auto;
|
487 |
+
gap: 8px;
|
488 |
+
display: flex;
|
489 |
+
flex-direction: column;
|
490 |
+
overflow-y: hidden;
|
491 |
+
background: rgb(229, 235, 237);
|
492 |
+
}
|
493 |
+
|
494 |
+
#follow-up-button {
|
495 |
+
overflow-y: visible;
|
496 |
+
display: block;
|
497 |
+
padding: 8px 12px;
|
498 |
+
margin: 4px 0;
|
499 |
+
border-radius: 8px;
|
500 |
+
background-color: #f0f8ff;
|
501 |
+
transition: background-color 0.2s;
|
502 |
+
background: rgb(240, 240, 236);
|
503 |
+
|
504 |
+
}
|
505 |
+
|
506 |
+
#follow-up-button:hover {
|
507 |
+
background-color: #e0f0ff;
|
508 |
+
}
|
509 |
+
|
510 |
/* Media Queries */
|
511 |
/* Desktop Media Query */
|
512 |
@media screen and (min-width: 1024px) {
|
|
|
520 |
height: calc(100vh - 190px) !important;
|
521 |
overflow-y: scroll !important;
|
522 |
}
|
|
|
523 |
div#sources-figures,
|
524 |
div#graphs-container,
|
525 |
div#tab-citations {
|
|
|
528 |
overflow-y: scroll !important;
|
529 |
}
|
530 |
|
531 |
+
div#chatbot-row {
|
532 |
+
max-height: calc(100vh - 200px) !important;
|
533 |
+
}
|
534 |
+
|
535 |
+
div#chatbot {
|
536 |
+
height: 70vh !important;
|
537 |
+
max-height: 70vh !important;
|
538 |
+
}
|
539 |
+
|
540 |
div#chatbot-row {
|
541 |
max-height: calc(100vh - 90px) !important;
|
542 |
}
|
|
|
555 |
/* Mobile Media Query */
|
556 |
@media screen and (max-width: 767px) {
|
557 |
div#chatbot {
|
558 |
+
height: 400px !important; /* Reduced from 500px */
|
559 |
+
}
|
560 |
+
|
561 |
+
#follow-up-examples {
|
562 |
+
max-height: 150px;
|
563 |
}
|
564 |
|
565 |
#submit-button {
|
|
|
652 |
}
|
653 |
|
654 |
#vanna-display {
|
655 |
+
max-height: 200px;
|
656 |
/* overflow-y: scroll; */
|
657 |
}
|
658 |
#sql-query{
|
659 |
+
max-height: 300px;
|
660 |
overflow-y:scroll;
|
661 |
}
|
662 |
+
|
663 |
+
#sql-query textarea{
|
664 |
+
min-height: 100px !important;
|
665 |
+
}
|
666 |
+
|
667 |
+
#sql-query span{
|
668 |
+
display: none;
|
669 |
+
}
|
670 |
+
div#tab-vanna{
|
671 |
+
max-height: 100¨vh;
|
672 |
+
overflow-y: hidden;
|
673 |
+
}
|
674 |
+
#vanna-plot{
|
675 |
+
max-height:500px
|
676 |
+
}
|
677 |
+
|
678 |
+
#pagination-display{
|
679 |
+
text-align: center;
|
680 |
+
font-weight: bold;
|
681 |
+
font-size: 16px;
|
682 |
+
}
|
683 |
+
|
684 |
+
#table-names table{
|
685 |
+
overflow: hidden;
|
686 |
+
}
|
687 |
+
#table-names thead{
|
688 |
+
display: none;
|
689 |
+
}
|
690 |
+
|
691 |
+
/* DRIAS Data Table Styles */
|
692 |
+
#vanna-table {
|
693 |
+
height: 400px !important;
|
694 |
+
overflow-y: auto !important;
|
695 |
+
}
|
696 |
+
|
697 |
+
#vanna-table > div[class*="table"] {
|
698 |
+
height: 400px !important;
|
699 |
+
overflow-y: None !important;
|
700 |
+
}
|
701 |
+
|
702 |
+
#vanna-table .table-wrap {
|
703 |
+
height: 400px !important;
|
704 |
+
overflow-y: None !important;
|
705 |
+
}
|
706 |
+
|
707 |
+
#vanna-table thead {
|
708 |
+
position: sticky;
|
709 |
+
top: 0;
|
710 |
+
background: white;
|
711 |
+
z-index: 1;
|
712 |
}
|