timeki commited on
Commit
6af9e98
·
2 Parent(s): 3e75ed8 bc43b45

Merged in dev (pull request #4)

Browse files
app.py CHANGED
@@ -9,13 +9,13 @@ from climateqa.engine.embeddings import get_embeddings_function
9
  from climateqa.engine.llm import get_llm
10
  from climateqa.engine.vectorstore import get_pinecone_vectorstore
11
  from climateqa.engine.reranker import get_reranker
12
- from climateqa.engine.graph import make_graph_agent,make_graph_agent_poc
13
  from climateqa.engine.chains.retrieve_papers import find_papers
14
  from climateqa.chat import start_chat, chat_stream, finish_chat
15
- from climateqa.engine.talk_to_data.main import ask_vanna
16
- from climateqa.engine.talk_to_data.myVanna import MyVanna
17
 
18
- from front.tabs import (create_config_modal, create_examples_tab, create_papers_tab, create_figures_tab, create_chat_interface, create_about_tab)
 
 
19
  from front.utils import process_figures
20
  from gradio_modal import Modal
21
 
@@ -24,14 +24,14 @@ from utils import create_user_id
24
  import logging
25
 
26
  logging.basicConfig(level=logging.WARNING)
27
- os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Suppresses INFO and WARNING logs
28
  logging.getLogger().setLevel(logging.WARNING)
29
 
30
 
31
-
32
  # Load environment variables in local mode
33
  try:
34
  from dotenv import load_dotenv
 
35
  load_dotenv()
36
  except Exception as e:
37
  pass
@@ -62,39 +62,94 @@ share_client = service.get_share_client(file_share_name)
62
  user_id = create_user_id()
63
 
64
 
65
-
66
  # Create vectorstore and retriever
67
  embeddings_function = get_embeddings_function()
68
- vectorstore = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX"))
69
- vectorstore_graphs = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX_OWID"), text_key="description")
70
- vectorstore_region = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX_LOCAL_V2"))
 
 
 
 
 
 
 
 
71
 
72
- llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
73
  if os.environ["GRADIO_ENV"] == "local":
74
  reranker = get_reranker("nano")
75
- else :
76
  reranker = get_reranker("large")
77
 
78
- agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0.2)
79
- agent_poc = make_graph_agent_poc(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0, version="v4")#TODO put back default 0.2
80
-
81
- #Vanna object
82
-
83
- vn = MyVanna(config = {"temperature": 0, "api_key": os.getenv('THEO_API_KEY'), 'model': os.getenv('VANNA_MODEL'), 'pc_api_key': os.getenv('VANNA_PINECONE_API_KEY'), 'index_name': os.getenv('VANNA_INDEX_NAME'), "top_k" : 4})
84
- db_vanna_path = os.path.join(os.getcwd(), "data/drias/drias.db")
85
- vn.connect_to_sqlite(db_vanna_path)
86
-
87
- def ask_vanna_query(query):
88
- return ask_vanna(vn, db_vanna_path, query)
89
-
90
- async def chat(query, history, audience, sources, reports, relevant_content_sources_selection, search_only):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  print("chat cqa - message received")
92
- async for event in chat_stream(agent, query, history, audience, sources, reports, relevant_content_sources_selection, search_only, share_client, user_id):
 
 
 
 
 
 
 
 
 
 
 
93
  yield event
94
-
95
- async def chat_poc(query, history, audience, sources, reports, relevant_content_sources_selection, search_only):
 
 
 
 
 
 
 
 
 
96
  print("chat poc - message received")
97
- async for event in chat_stream(agent_poc, query, history, audience, sources, reports, relevant_content_sources_selection, search_only, share_client, user_id):
 
 
 
 
 
 
 
 
 
 
 
98
  yield event
99
 
100
 
@@ -102,14 +157,17 @@ async def chat_poc(query, history, audience, sources, reports, relevant_content_
102
  # Gradio
103
  # --------------------------------------------------------------------
104
 
 
105
  # Function to update modal visibility
106
  def update_config_modal_visibility(config_open):
107
  print(config_open)
108
  new_config_visibility_status = not config_open
109
  return Modal(visible=new_config_visibility_status), new_config_visibility_status
110
-
111
 
112
- def update_sources_number_display(sources_textbox, figures_cards, current_graphs, papers_html):
 
 
 
113
  sources_number = sources_textbox.count("<h2>")
114
  figures_number = figures_cards.count("<h2>")
115
  graphs_number = current_graphs.count("<iframe")
@@ -118,229 +176,368 @@ def update_sources_number_display(sources_textbox, figures_cards, current_graphs
118
  figures_notif_label = f"Figures ({figures_number})"
119
  graphs_notif_label = f"Graphs ({graphs_number})"
120
  papers_notif_label = f"Papers ({papers_number})"
121
- recommended_content_notif_label = f"Recommended content ({figures_number + graphs_number + papers_number})"
122
-
123
- return gr.update(label=recommended_content_notif_label), gr.update(label=sources_notif_label), gr.update(label=figures_notif_label), gr.update(label=graphs_notif_label), gr.update(label=papers_notif_label)
124
-
125
- def create_drias_tab():
126
- with gr.Tab("Beta - Talk to DRIAS", elem_id="tab-vanna", id=6) as tab_vanna:
127
- vanna_direct_question = gr.Textbox(label="Direct Question", placeholder="You can write direct question here",elem_id="direct-question", interactive=True)
128
- with gr.Accordion("Details",elem_id = 'vanna-details', open=False) as vanna_details :
129
- vanna_sql_query = gr.Textbox(label="SQL Query Used", elem_id="sql-query", interactive=False)
130
- show_vanna_table = gr.Button("Show Table", elem_id="show-table")
131
- with Modal(visible=False) as vanna_table_modal:
132
- vanna_table = gr.DataFrame([], elem_id="vanna-table")
133
- close_vanna_modal = gr.Button("Close", elem_id="close-vanna-modal")
134
- close_vanna_modal.click(lambda: Modal(visible=False),None, [vanna_table_modal])
135
- show_vanna_table.click(lambda: Modal(visible=True),None ,[vanna_table_modal])
136
-
137
- vanna_display = gr.Plot()
138
- vanna_direct_question.submit(ask_vanna_query, [vanna_direct_question], [vanna_sql_query ,vanna_table, vanna_display])
139
-
140
- # # UI Layout Components
141
- def cqa_tab(tab_name):
142
- # State variables
143
- current_graphs = gr.State([])
144
- with gr.Tab(tab_name):
145
- with gr.Row(elem_id="chatbot-row"):
146
- # Left column - Chat interface
147
- with gr.Column(scale=2):
148
- chatbot, textbox, config_button = create_chat_interface(tab_name)
149
-
150
- # Right column - Content panels
151
- with gr.Column(scale=2, variant="panel", elem_id="right-panel"):
152
- with gr.Tabs(elem_id="right_panel_tab") as tabs:
153
- # Examples tab
154
- with gr.TabItem("Examples", elem_id="tab-examples", id=0):
155
- examples_hidden = create_examples_tab(tab_name)
156
-
157
- # Sources tab
158
- with gr.Tab("Sources", elem_id="tab-sources", id=1) as tab_sources:
159
- sources_textbox = gr.HTML(show_label=False, elem_id="sources-textbox")
160
-
161
-
162
- # Recommended content tab
163
- with gr.Tab("Recommended content", elem_id="tab-recommended_content", id=2) as tab_recommended_content:
164
- with gr.Tabs(elem_id="group-subtabs") as tabs_recommended_content:
165
- # Figures subtab
166
- with gr.Tab("Figures", elem_id="tab-figures", id=3) as tab_figures:
167
- sources_raw, new_figures, used_figures, gallery_component, figures_cards, figure_modal = create_figures_tab()
168
-
169
- # Papers subtab
170
- with gr.Tab("Papers", elem_id="tab-citations", id=4) as tab_papers:
171
- papers_direct_search, papers_summary, papers_html, citations_network, papers_modal = create_papers_tab()
172
-
173
- # Graphs subtab
174
- with gr.Tab("Graphs", elem_id="tab-graphs", id=5) as tab_graphs:
175
- graphs_container = gr.HTML(
176
- "<h2>There are no graphs to be displayed at the moment. Try asking another question.</h2>",
177
- elem_id="graphs-container"
178
- )
179
-
180
-
181
- return {
182
- "chatbot": chatbot,
183
- "textbox": textbox,
184
- "tabs": tabs,
185
- "sources_raw": sources_raw,
186
- "new_figures": new_figures,
187
- "current_graphs": current_graphs,
188
- "examples_hidden": examples_hidden,
189
- "sources_textbox": sources_textbox,
190
- "figures_cards": figures_cards,
191
- "gallery_component": gallery_component,
192
- "config_button": config_button,
193
- "papers_direct_search" : papers_direct_search,
194
- "papers_html": papers_html,
195
- "citations_network": citations_network,
196
- "papers_summary": papers_summary,
197
- "tab_recommended_content": tab_recommended_content,
198
- "tab_sources": tab_sources,
199
- "tab_figures": tab_figures,
200
- "tab_graphs": tab_graphs,
201
- "tab_papers": tab_papers,
202
- "graph_container": graphs_container,
203
- # "vanna_sql_query": vanna_sql_query,
204
- # "vanna_table" : vanna_table,
205
- # "vanna_display": vanna_display
206
- }
207
-
208
- def config_event_handling(main_tabs_components : list[dict], config_componenets : dict):
209
- config_open = config_componenets["config_open"]
210
- config_modal = config_componenets["config_modal"]
211
- close_config_modal = config_componenets["close_config_modal_button"]
212
-
213
- for button in [close_config_modal] + [main_tab_component["config_button"] for main_tab_component in main_tabs_components]:
214
  button.click(
215
  fn=update_config_modal_visibility,
216
  inputs=[config_open],
217
- outputs=[config_modal, config_open]
218
- )
219
-
 
220
  def event_handling(
221
- main_tab_components,
222
- config_components,
223
- tab_name="ClimateQ&A"
224
  ):
225
- chatbot = main_tab_components["chatbot"]
226
- textbox = main_tab_components["textbox"]
227
- tabs = main_tab_components["tabs"]
228
- sources_raw = main_tab_components["sources_raw"]
229
- new_figures = main_tab_components["new_figures"]
230
- current_graphs = main_tab_components["current_graphs"]
231
- examples_hidden = main_tab_components["examples_hidden"]
232
- sources_textbox = main_tab_components["sources_textbox"]
233
- figures_cards = main_tab_components["figures_cards"]
234
- gallery_component = main_tab_components["gallery_component"]
235
- # config_button = main_tab_components["config_button"]
236
- papers_direct_search = main_tab_components["papers_direct_search"]
237
- papers_html = main_tab_components["papers_html"]
238
- citations_network = main_tab_components["citations_network"]
239
- papers_summary = main_tab_components["papers_summary"]
240
- tab_recommended_content = main_tab_components["tab_recommended_content"]
241
- tab_sources = main_tab_components["tab_sources"]
242
- tab_figures = main_tab_components["tab_figures"]
243
- tab_graphs = main_tab_components["tab_graphs"]
244
- tab_papers = main_tab_components["tab_papers"]
245
- graphs_container = main_tab_components["graph_container"]
246
- # vanna_sql_query = main_tab_components["vanna_sql_query"]
247
- # vanna_table = main_tab_components["vanna_table"]
248
- # vanna_display = main_tab_components["vanna_display"]
249
-
250
-
251
- # config_open = config_components["config_open"]
252
- # config_modal = config_components["config_modal"]
253
- dropdown_sources = config_components["dropdown_sources"]
254
- dropdown_reports = config_components["dropdown_reports"]
255
- dropdown_external_sources = config_components["dropdown_external_sources"]
256
- search_only = config_components["search_only"]
257
- dropdown_audience = config_components["dropdown_audience"]
258
- after = config_components["after"]
259
- output_query = config_components["output_query"]
260
- output_language = config_components["output_language"]
261
- # close_config_modal = config_components["close_config_modal_button"]
262
-
263
  new_sources_hmtl = gr.State([])
264
  ttd_data = gr.State([])
265
 
266
-
267
- # for button in [config_button, close_config_modal]:
268
- # button.click(
269
- # fn=update_config_modal_visibility,
270
- # inputs=[config_open],
271
- # outputs=[config_modal, config_open]
272
- # )
273
-
274
  if tab_name == "ClimateQ&A":
275
  print("chat cqa - message sent")
276
 
277
  # Event for textbox
278
- (textbox
279
- .submit(start_chat, [textbox, chatbot, search_only], [textbox, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{textbox.elem_id}")
280
- .then(chat, [textbox, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs], concurrency_limit=8, api_name=f"chat_{textbox.elem_id}")
281
- .then(finish_chat, None, [textbox], api_name=f"finish_chat_{textbox.elem_id}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
  )
283
  # Event for examples_hidden
284
- (examples_hidden
285
- .change(start_chat, [examples_hidden, chatbot, search_only], [examples_hidden, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{examples_hidden.elem_id}")
286
- .then(chat, [examples_hidden, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs], concurrency_limit=8, api_name=f"chat_{examples_hidden.elem_id}")
287
- .then(finish_chat, None, [textbox], api_name=f"finish_chat_{examples_hidden.elem_id}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
  )
289
-
290
  elif tab_name == "Beta - POC Adapt'Action":
291
  print("chat poc - message sent")
292
  # Event for textbox
293
- (textbox
294
- .submit(start_chat, [textbox, chatbot, search_only], [textbox, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{textbox.elem_id}")
295
- .then(chat_poc, [textbox, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs], concurrency_limit=8, api_name=f"chat_{textbox.elem_id}")
296
- .then(finish_chat, None, [textbox], api_name=f"finish_chat_{textbox.elem_id}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
  )
298
  # Event for examples_hidden
299
- (examples_hidden
300
- .change(start_chat, [examples_hidden, chatbot, search_only], [examples_hidden, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{examples_hidden.elem_id}")
301
- .then(chat_poc, [examples_hidden, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs], concurrency_limit=8, api_name=f"chat_{examples_hidden.elem_id}")
302
- .then(finish_chat, None, [textbox], api_name=f"finish_chat_{examples_hidden.elem_id}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
  )
304
-
305
-
306
- new_sources_hmtl.change(lambda x : x, inputs = [new_sources_hmtl], outputs = [sources_textbox])
307
- current_graphs.change(lambda x: x, inputs=[current_graphs], outputs=[graphs_container])
308
- new_figures.change(process_figures, inputs=[sources_raw, new_figures], outputs=[sources_raw, figures_cards, gallery_component])
 
 
 
 
 
 
 
309
 
310
  # Update sources numbers
311
  for component in [sources_textbox, figures_cards, current_graphs, papers_html]:
312
- component.change(update_sources_number_display, [sources_textbox, figures_cards, current_graphs, papers_html], [tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers])
313
-
 
 
 
 
314
  # Search for papers
315
  for component in [textbox, examples_hidden, papers_direct_search]:
316
- component.submit(find_papers, [component, after, dropdown_external_sources], [papers_html, citations_network, papers_summary])
317
-
 
 
 
318
 
319
  # if tab_name == "Beta - POC Adapt'Action": # Not untill results are good enough
320
  # # Drias search
321
  # textbox.submit(ask_vanna, [textbox], [vanna_sql_query ,vanna_table, vanna_display])
322
 
 
323
  def main_ui():
324
  # config_open = gr.State(True)
325
- with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=theme, elem_id="main-component") as demo:
326
- config_components = create_config_modal()
327
-
 
 
 
 
 
328
  with gr.Tabs():
329
- cqa_components = cqa_tab(tab_name = "ClimateQ&A")
330
- local_cqa_components = cqa_tab(tab_name = "Beta - POC Adapt'Action")
331
- create_drias_tab()
332
-
333
  create_about_tab()
334
-
335
- event_handling(cqa_components, config_components, tab_name = 'ClimateQ&A')
336
- event_handling(local_cqa_components, config_components, tab_name = "Beta - POC Adapt'Action")
337
-
338
- config_event_handling([cqa_components,local_cqa_components] ,config_components)
339
-
 
 
340
  demo.queue()
341
-
342
  return demo
343
 
344
-
345
  demo = main_ui()
346
  demo.launch(ssr_mode=False)
 
9
  from climateqa.engine.llm import get_llm
10
  from climateqa.engine.vectorstore import get_pinecone_vectorstore
11
  from climateqa.engine.reranker import get_reranker
12
+ from climateqa.engine.graph import make_graph_agent, make_graph_agent_poc
13
  from climateqa.engine.chains.retrieve_papers import find_papers
14
  from climateqa.chat import start_chat, chat_stream, finish_chat
 
 
15
 
16
+ from front.tabs import create_config_modal, cqa_tab, create_about_tab
17
+ from front.tabs import MainTabPanel, ConfigPanel
18
+ from front.tabs.tab_drias import create_drias_tab
19
  from front.utils import process_figures
20
  from gradio_modal import Modal
21
 
 
24
  import logging
25
 
26
  logging.basicConfig(level=logging.WARNING)
27
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Suppresses INFO and WARNING logs
28
  logging.getLogger().setLevel(logging.WARNING)
29
 
30
 
 
31
  # Load environment variables in local mode
32
  try:
33
  from dotenv import load_dotenv
34
+
35
  load_dotenv()
36
  except Exception as e:
37
  pass
 
62
  user_id = create_user_id()
63
 
64
 
 
65
  # Create vectorstore and retriever
66
  embeddings_function = get_embeddings_function()
67
+ vectorstore = get_pinecone_vectorstore(
68
+ embeddings_function, index_name=os.getenv("PINECONE_API_INDEX")
69
+ )
70
+ vectorstore_graphs = get_pinecone_vectorstore(
71
+ embeddings_function,
72
+ index_name=os.getenv("PINECONE_API_INDEX_OWID"),
73
+ text_key="description",
74
+ )
75
+ vectorstore_region = get_pinecone_vectorstore(
76
+ embeddings_function, index_name=os.getenv("PINECONE_API_INDEX_LOCAL_V2")
77
+ )
78
 
79
+ llm = get_llm(provider="openai", max_tokens=1024, temperature=0.0)
80
  if os.environ["GRADIO_ENV"] == "local":
81
  reranker = get_reranker("nano")
82
+ else:
83
  reranker = get_reranker("large")
84
 
85
+ agent = make_graph_agent(
86
+ llm=llm,
87
+ vectorstore_ipcc=vectorstore,
88
+ vectorstore_graphs=vectorstore_graphs,
89
+ vectorstore_region=vectorstore_region,
90
+ reranker=reranker,
91
+ threshold_docs=0.2,
92
+ )
93
+ agent_poc = make_graph_agent_poc(
94
+ llm=llm,
95
+ vectorstore_ipcc=vectorstore,
96
+ vectorstore_graphs=vectorstore_graphs,
97
+ vectorstore_region=vectorstore_region,
98
+ reranker=reranker,
99
+ threshold_docs=0,
100
+ version="v4",
101
+ ) # TODO put back default 0.2
102
+
103
+
104
+
105
+
106
+ async def chat(
107
+ query,
108
+ history,
109
+ audience,
110
+ sources,
111
+ reports,
112
+ relevant_content_sources_selection,
113
+ search_only,
114
+ ):
115
  print("chat cqa - message received")
116
+ async for event in chat_stream(
117
+ agent,
118
+ query,
119
+ history,
120
+ audience,
121
+ sources,
122
+ reports,
123
+ relevant_content_sources_selection,
124
+ search_only,
125
+ share_client,
126
+ user_id,
127
+ ):
128
  yield event
129
+
130
+
131
+ async def chat_poc(
132
+ query,
133
+ history,
134
+ audience,
135
+ sources,
136
+ reports,
137
+ relevant_content_sources_selection,
138
+ search_only,
139
+ ):
140
  print("chat poc - message received")
141
+ async for event in chat_stream(
142
+ agent_poc,
143
+ query,
144
+ history,
145
+ audience,
146
+ sources,
147
+ reports,
148
+ relevant_content_sources_selection,
149
+ search_only,
150
+ share_client,
151
+ user_id,
152
+ ):
153
  yield event
154
 
155
 
 
157
  # Gradio
158
  # --------------------------------------------------------------------
159
 
160
+
161
  # Function to update modal visibility
162
  def update_config_modal_visibility(config_open):
163
  print(config_open)
164
  new_config_visibility_status = not config_open
165
  return Modal(visible=new_config_visibility_status), new_config_visibility_status
 
166
 
167
+
168
+ def update_sources_number_display(
169
+ sources_textbox, figures_cards, current_graphs, papers_html
170
+ ):
171
  sources_number = sources_textbox.count("<h2>")
172
  figures_number = figures_cards.count("<h2>")
173
  graphs_number = current_graphs.count("<iframe")
 
176
  figures_notif_label = f"Figures ({figures_number})"
177
  graphs_notif_label = f"Graphs ({graphs_number})"
178
  papers_notif_label = f"Papers ({papers_number})"
179
+ recommended_content_notif_label = (
180
+ f"Recommended content ({figures_number + graphs_number + papers_number})"
181
+ )
182
+
183
+ return (
184
+ gr.update(label=recommended_content_notif_label),
185
+ gr.update(label=sources_notif_label),
186
+ gr.update(label=figures_notif_label),
187
+ gr.update(label=graphs_notif_label),
188
+ gr.update(label=papers_notif_label),
189
+ )
190
+
191
+
192
+ def config_event_handling(
193
+ main_tabs_components: list[MainTabPanel], config_componenets: ConfigPanel
194
+ ):
195
+ config_open = config_componenets.config_open
196
+ config_modal = config_componenets.config_modal
197
+ close_config_modal = config_componenets.close_config_modal_button
198
+
199
+ for button in [close_config_modal] + [
200
+ main_tab_component.config_button for main_tab_component in main_tabs_components
201
+ ]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  button.click(
203
  fn=update_config_modal_visibility,
204
  inputs=[config_open],
205
+ outputs=[config_modal, config_open],
206
+ )
207
+
208
+
209
  def event_handling(
210
+ main_tab_components: MainTabPanel,
211
+ config_components: ConfigPanel,
212
+ tab_name="ClimateQ&A",
213
  ):
214
+ chatbot = main_tab_components.chatbot
215
+ textbox = main_tab_components.textbox
216
+ tabs = main_tab_components.tabs
217
+ sources_raw = main_tab_components.sources_raw
218
+ new_figures = main_tab_components.new_figures
219
+ current_graphs = main_tab_components.current_graphs
220
+ examples_hidden = main_tab_components.examples_hidden
221
+ sources_textbox = main_tab_components.sources_textbox
222
+ figures_cards = main_tab_components.figures_cards
223
+ gallery_component = main_tab_components.gallery_component
224
+ papers_direct_search = main_tab_components.papers_direct_search
225
+ papers_html = main_tab_components.papers_html
226
+ citations_network = main_tab_components.citations_network
227
+ papers_summary = main_tab_components.papers_summary
228
+ tab_recommended_content = main_tab_components.tab_recommended_content
229
+ tab_sources = main_tab_components.tab_sources
230
+ tab_figures = main_tab_components.tab_figures
231
+ tab_graphs = main_tab_components.tab_graphs
232
+ tab_papers = main_tab_components.tab_papers
233
+ graphs_container = main_tab_components.graph_container
234
+ follow_up_examples = main_tab_components.follow_up_examples
235
+ follow_up_examples_hidden = main_tab_components.follow_up_examples_hidden
236
+
237
+ dropdown_sources = config_components.dropdown_sources
238
+ dropdown_reports = config_components.dropdown_reports
239
+ dropdown_external_sources = config_components.dropdown_external_sources
240
+ search_only = config_components.search_only
241
+ dropdown_audience = config_components.dropdown_audience
242
+ after = config_components.after
243
+ output_query = config_components.output_query
244
+ output_language = config_components.output_language
245
+
 
 
 
 
 
 
246
  new_sources_hmtl = gr.State([])
247
  ttd_data = gr.State([])
248
 
 
 
 
 
 
 
 
 
249
  if tab_name == "ClimateQ&A":
250
  print("chat cqa - message sent")
251
 
252
  # Event for textbox
253
+ (
254
+ textbox.submit(
255
+ start_chat,
256
+ [textbox, chatbot, search_only],
257
+ [textbox, tabs, chatbot, sources_raw],
258
+ queue=False,
259
+ api_name=f"start_chat_{textbox.elem_id}",
260
+ )
261
+ .then(
262
+ chat,
263
+ [
264
+ textbox,
265
+ chatbot,
266
+ dropdown_audience,
267
+ dropdown_sources,
268
+ dropdown_reports,
269
+ dropdown_external_sources,
270
+ search_only,
271
+ ],
272
+ [
273
+ chatbot,
274
+ new_sources_hmtl,
275
+ output_query,
276
+ output_language,
277
+ new_figures,
278
+ current_graphs,
279
+ follow_up_examples.dataset,
280
+ ],
281
+ concurrency_limit=8,
282
+ api_name=f"chat_{textbox.elem_id}",
283
+ )
284
+ .then(
285
+ finish_chat, None, [textbox], api_name=f"finish_chat_{textbox.elem_id}"
286
+ )
287
  )
288
  # Event for examples_hidden
289
+ (
290
+ examples_hidden.change(
291
+ start_chat,
292
+ [examples_hidden, chatbot, search_only],
293
+ [examples_hidden, tabs, chatbot, sources_raw],
294
+ queue=False,
295
+ api_name=f"start_chat_{examples_hidden.elem_id}",
296
+ )
297
+ .then(
298
+ chat,
299
+ [
300
+ examples_hidden,
301
+ chatbot,
302
+ dropdown_audience,
303
+ dropdown_sources,
304
+ dropdown_reports,
305
+ dropdown_external_sources,
306
+ search_only,
307
+ ],
308
+ [
309
+ chatbot,
310
+ new_sources_hmtl,
311
+ output_query,
312
+ output_language,
313
+ new_figures,
314
+ current_graphs,
315
+ follow_up_examples.dataset,
316
+ ],
317
+ concurrency_limit=8,
318
+ api_name=f"chat_{examples_hidden.elem_id}",
319
+ )
320
+ .then(
321
+ finish_chat,
322
+ None,
323
+ [textbox],
324
+ api_name=f"finish_chat_{examples_hidden.elem_id}",
325
+ )
326
+ )
327
+ (
328
+ follow_up_examples_hidden.change(
329
+ start_chat,
330
+ [follow_up_examples_hidden, chatbot, search_only],
331
+ [follow_up_examples_hidden, tabs, chatbot, sources_raw],
332
+ queue=False,
333
+ api_name=f"start_chat_{examples_hidden.elem_id}",
334
+ )
335
+ .then(
336
+ chat,
337
+ [
338
+ follow_up_examples_hidden,
339
+ chatbot,
340
+ dropdown_audience,
341
+ dropdown_sources,
342
+ dropdown_reports,
343
+ dropdown_external_sources,
344
+ search_only,
345
+ ],
346
+ [
347
+ chatbot,
348
+ new_sources_hmtl,
349
+ output_query,
350
+ output_language,
351
+ new_figures,
352
+ current_graphs,
353
+ follow_up_examples.dataset,
354
+ ],
355
+ concurrency_limit=8,
356
+ api_name=f"chat_{examples_hidden.elem_id}",
357
+ )
358
+ .then(
359
+ finish_chat,
360
+ None,
361
+ [textbox],
362
+ api_name=f"finish_chat_{follow_up_examples_hidden.elem_id}",
363
+ )
364
  )
365
+
366
  elif tab_name == "Beta - POC Adapt'Action":
367
  print("chat poc - message sent")
368
  # Event for textbox
369
+ (
370
+ textbox.submit(
371
+ start_chat,
372
+ [textbox, chatbot, search_only],
373
+ [textbox, tabs, chatbot, sources_raw],
374
+ queue=False,
375
+ api_name=f"start_chat_{textbox.elem_id}",
376
+ )
377
+ .then(
378
+ chat_poc,
379
+ [
380
+ textbox,
381
+ chatbot,
382
+ dropdown_audience,
383
+ dropdown_sources,
384
+ dropdown_reports,
385
+ dropdown_external_sources,
386
+ search_only,
387
+ ],
388
+ [
389
+ chatbot,
390
+ new_sources_hmtl,
391
+ output_query,
392
+ output_language,
393
+ new_figures,
394
+ current_graphs,
395
+ ],
396
+ concurrency_limit=8,
397
+ api_name=f"chat_{textbox.elem_id}",
398
+ )
399
+ .then(
400
+ finish_chat, None, [textbox], api_name=f"finish_chat_{textbox.elem_id}"
401
+ )
402
  )
403
  # Event for examples_hidden
404
+ (
405
+ examples_hidden.change(
406
+ start_chat,
407
+ [examples_hidden, chatbot, search_only],
408
+ [examples_hidden, tabs, chatbot, sources_raw],
409
+ queue=False,
410
+ api_name=f"start_chat_{examples_hidden.elem_id}",
411
+ )
412
+ .then(
413
+ chat_poc,
414
+ [
415
+ examples_hidden,
416
+ chatbot,
417
+ dropdown_audience,
418
+ dropdown_sources,
419
+ dropdown_reports,
420
+ dropdown_external_sources,
421
+ search_only,
422
+ ],
423
+ [
424
+ chatbot,
425
+ new_sources_hmtl,
426
+ output_query,
427
+ output_language,
428
+ new_figures,
429
+ current_graphs,
430
+ ],
431
+ concurrency_limit=8,
432
+ api_name=f"chat_{examples_hidden.elem_id}",
433
+ )
434
+ .then(
435
+ finish_chat,
436
+ None,
437
+ [textbox],
438
+ api_name=f"finish_chat_{examples_hidden.elem_id}",
439
+ )
440
+ )
441
+ (
442
+ follow_up_examples_hidden.change(
443
+ start_chat,
444
+ [follow_up_examples_hidden, chatbot, search_only],
445
+ [follow_up_examples_hidden, tabs, chatbot, sources_raw],
446
+ queue=False,
447
+ api_name=f"start_chat_{examples_hidden.elem_id}",
448
+ )
449
+ .then(
450
+ chat,
451
+ [
452
+ follow_up_examples_hidden,
453
+ chatbot,
454
+ dropdown_audience,
455
+ dropdown_sources,
456
+ dropdown_reports,
457
+ dropdown_external_sources,
458
+ search_only,
459
+ ],
460
+ [
461
+ chatbot,
462
+ new_sources_hmtl,
463
+ output_query,
464
+ output_language,
465
+ new_figures,
466
+ current_graphs,
467
+ follow_up_examples.dataset,
468
+ ],
469
+ concurrency_limit=8,
470
+ api_name=f"chat_{examples_hidden.elem_id}",
471
+ )
472
+ .then(
473
+ finish_chat,
474
+ None,
475
+ [textbox],
476
+ api_name=f"finish_chat_{follow_up_examples_hidden.elem_id}",
477
+ )
478
  )
479
+
480
+ new_sources_hmtl.change(
481
+ lambda x: x, inputs=[new_sources_hmtl], outputs=[sources_textbox]
482
+ )
483
+ current_graphs.change(
484
+ lambda x: x, inputs=[current_graphs], outputs=[graphs_container]
485
+ )
486
+ new_figures.change(
487
+ process_figures,
488
+ inputs=[sources_raw, new_figures],
489
+ outputs=[sources_raw, figures_cards, gallery_component],
490
+ )
491
 
492
  # Update sources numbers
493
  for component in [sources_textbox, figures_cards, current_graphs, papers_html]:
494
+ component.change(
495
+ update_sources_number_display,
496
+ [sources_textbox, figures_cards, current_graphs, papers_html],
497
+ [tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers],
498
+ )
499
+
500
  # Search for papers
501
  for component in [textbox, examples_hidden, papers_direct_search]:
502
+ component.submit(
503
+ find_papers,
504
+ [component, after, dropdown_external_sources],
505
+ [papers_html, citations_network, papers_summary],
506
+ )
507
 
508
  # if tab_name == "Beta - POC Adapt'Action": # Not untill results are good enough
509
  # # Drias search
510
  # textbox.submit(ask_vanna, [textbox], [vanna_sql_query ,vanna_table, vanna_display])
511
 
512
+
513
  def main_ui():
514
  # config_open = gr.State(True)
515
+ with gr.Blocks(
516
+ title="Climate Q&A",
517
+ css_paths=os.getcwd() + "/style.css",
518
+ theme=theme,
519
+ elem_id="main-component",
520
+ ) as demo:
521
+ config_components = create_config_modal()
522
+
523
  with gr.Tabs():
524
+ cqa_components = cqa_tab(tab_name="ClimateQ&A")
525
+ local_cqa_components = cqa_tab(tab_name="Beta - POC Adapt'Action")
526
+ create_drias_tab(share_client=share_client, user_id=user_id)
527
+
528
  create_about_tab()
529
+
530
+ event_handling(cqa_components, config_components, tab_name="ClimateQ&A")
531
+ event_handling(
532
+ local_cqa_components, config_components, tab_name="Beta - POC Adapt'Action"
533
+ )
534
+
535
+ config_event_handling([cqa_components, local_cqa_components], config_components)
536
+
537
  demo.queue()
538
+
539
  return demo
540
 
541
+
542
  demo = main_ui()
543
  demo.launch(ssr_mode=False)
climateqa/chat.py CHANGED
@@ -61,6 +61,27 @@ def handle_numerical_data(event):
61
  return numerical_data, sql_query
62
  return None, None
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  # Main chat function
65
  async def chat_stream(
66
  agent : CompiledStateGraph,
@@ -101,6 +122,7 @@ async def chat_stream(
101
  audience_prompt = init_audience(audience)
102
  sources = sources or ["IPCC", "IPBES"]
103
  reports = reports or []
 
104
 
105
  # Prepare inputs for agent
106
  inputs = {
@@ -109,7 +131,8 @@ async def chat_stream(
109
  "sources_input": sources,
110
  "relevant_content_sources_selection": relevant_content_sources_selection,
111
  "search_only": search_only,
112
- "reports": reports
 
113
  }
114
 
115
  # Get streaming events from agent
@@ -129,6 +152,7 @@ async def chat_stream(
129
  retrieved_contents = []
130
  answer_message_content = ""
131
  vanna_data = {}
 
132
 
133
  # Define processing steps
134
  steps_display = {
@@ -200,7 +224,12 @@ async def chat_stream(
200
  sub_questions = [q["question"] + "-> relevant sources : " + str(q["sources"]) for q in event["data"]["output"]["questions_list"]]
201
  history[-1].content += "Decompose question into sub-questions:\n\n - " + "\n - ".join(sub_questions)
202
 
203
- yield history, docs_html, output_query, output_language, related_contents, graphs_html#, vanna_data
 
 
 
 
 
204
 
205
  except Exception as e:
206
  print(f"Event {event} has failed")
@@ -211,4 +240,4 @@ async def chat_stream(
211
  # Call the function to log interaction
212
  log_interaction_to_azure(history, output_query, sources, docs, share_client, user_id)
213
 
214
- yield history, docs_html, output_query, output_language, related_contents, graphs_html#, vanna_data
 
61
  return numerical_data, sql_query
62
  return None, None
63
 
64
+ def log_drias_interaction_to_azure(query, sql_query, data, share_client, user_id):
65
+ try:
66
+ # Log interaction to Azure if not in local environment
67
+ if os.getenv("GRADIO_ENV") != "local":
68
+ timestamp = str(datetime.now().timestamp())
69
+ logs = {
70
+ "user_id": str(user_id),
71
+ "query": query,
72
+ "sql_query": sql_query,
73
+ # "data": data.to_dict() if data is not None else None,
74
+ "time": timestamp,
75
+ }
76
+ log_on_azure(f"drias_{timestamp}.json", logs, share_client)
77
+ print(f"Logged Drias interaction to Azure Blob Storage: {logs}")
78
+ else:
79
+ print("share_client or user_id is None, or GRADIO_ENV is local")
80
+ except Exception as e:
81
+ print(f"Error logging Drias interaction on Azure Blob Storage: {e}")
82
+ error_msg = f"Drias Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)"
83
+ raise gr.Error(error_msg)
84
+
85
  # Main chat function
86
  async def chat_stream(
87
  agent : CompiledStateGraph,
 
122
  audience_prompt = init_audience(audience)
123
  sources = sources or ["IPCC", "IPBES"]
124
  reports = reports or []
125
+ relevant_history_discussion = history[-2:] if len(history) > 1 else []
126
 
127
  # Prepare inputs for agent
128
  inputs = {
 
131
  "sources_input": sources,
132
  "relevant_content_sources_selection": relevant_content_sources_selection,
133
  "search_only": search_only,
134
+ "reports": reports,
135
+ "chat_history": relevant_history_discussion,
136
  }
137
 
138
  # Get streaming events from agent
 
152
  retrieved_contents = []
153
  answer_message_content = ""
154
  vanna_data = {}
155
+ follow_up_examples = gr.Dataset(samples=[])
156
 
157
  # Define processing steps
158
  steps_display = {
 
224
  sub_questions = [q["question"] + "-> relevant sources : " + str(q["sources"]) for q in event["data"]["output"]["questions_list"]]
225
  history[-1].content += "Decompose question into sub-questions:\n\n - " + "\n - ".join(sub_questions)
226
 
227
+ # Handle follow up questions
228
+ if event["name"] == "generate_follow_up" and event["event"] == "on_chain_end":
229
+ follow_up_examples = event["data"]["output"].get("follow_up_questions", [])
230
+ follow_up_examples = gr.Dataset(samples= [ [question] for question in follow_up_examples ])
231
+
232
+ yield history, docs_html, output_query, output_language, related_contents, graphs_html, follow_up_examples#, vanna_data
233
 
234
  except Exception as e:
235
  print(f"Event {event} has failed")
 
240
  # Call the function to log interaction
241
  log_interaction_to_azure(history, output_query, sources, docs, share_client, user_id)
242
 
243
+ yield history, docs_html, output_query, output_language, related_contents, graphs_html, follow_up_examples#, vanna_data
climateqa/engine/chains/answer_rag.py CHANGED
@@ -65,6 +65,7 @@ def make_rag_node(llm,with_docs = True):
65
  async def answer_rag(state,config):
66
  print("---- Answer RAG ----")
67
  start_time = time.time()
 
68
  print("Sources used : " + "\n".join([x.metadata["short_name"] + " - page " + str(x.metadata["page_number"]) for x in state["documents"]]))
69
 
70
  answer = await rag_chain.ainvoke(state,config)
@@ -73,9 +74,10 @@ def make_rag_node(llm,with_docs = True):
73
  elapsed_time = end_time - start_time
74
  print("RAG elapsed time: ", elapsed_time)
75
  print("Answer size : ", len(answer))
76
- # print(f"\n\nAnswer:\n{answer}")
77
 
78
- return {"answer":answer}
 
 
79
 
80
  return answer_rag
81
 
 
65
  async def answer_rag(state,config):
66
  print("---- Answer RAG ----")
67
  start_time = time.time()
68
+ chat_history = state.get("chat_history",[])
69
  print("Sources used : " + "\n".join([x.metadata["short_name"] + " - page " + str(x.metadata["page_number"]) for x in state["documents"]]))
70
 
71
  answer = await rag_chain.ainvoke(state,config)
 
74
  elapsed_time = end_time - start_time
75
  print("RAG elapsed time: ", elapsed_time)
76
  print("Answer size : ", len(answer))
 
77
 
78
+ chat_history.append({"question":state["query"],"answer":answer})
79
+
80
+ return {"answer":answer,"chat_history": chat_history}
81
 
82
  return answer_rag
83
 
climateqa/engine/chains/follow_up.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from langchain.prompts import ChatPromptTemplate
3
+
4
+
5
+ FOLLOW_UP_TEMPLATE = """Based on the previous question and answer, generate 2-3 relevant follow-up questions that would help explore the topic further.
6
+
7
+ Previous Question: {user_input}
8
+ Previous Answer: {answer}
9
+
10
+ Generate short, concise, focused follow-up questions
11
+ You don't need a full question as it will be reformulated later as a standalone question with the context. Eg. "Details the first point"
12
+ """
13
+
14
+ def make_follow_up_node(llm):
15
+ prompt = ChatPromptTemplate.from_template(FOLLOW_UP_TEMPLATE)
16
+
17
+ def generate_follow_up(state):
18
+ print("---- Generate_follow_up ----")
19
+ if not state.get("answer"):
20
+ return state
21
+
22
+ response = llm.invoke(prompt.format(
23
+ user_input=state["user_input"],
24
+ answer=state["answer"]
25
+ ))
26
+
27
+ # Extract questions from response
28
+ follow_ups = [q.strip() for q in response.content.split("\n") if q.strip()]
29
+ state["follow_up_questions"] = follow_ups
30
+
31
+ return state
32
+
33
+ return generate_follow_up
climateqa/engine/chains/intent_categorization.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  from langchain_core.pydantic_v1 import BaseModel, Field
3
  from typing import List
4
  from typing import Literal
@@ -44,7 +43,7 @@ def make_intent_categorization_chain(llm):
44
  llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"IntentCategorizer"})
45
 
46
  prompt = ChatPromptTemplate.from_messages([
47
- ("system", "You are a helpful assistant, you will analyze, translate and categorize the user input message using the function provided. Categorize the user input as ai ONLY if it is related to Artificial Intelligence, search if it is related to the environment, climate change, energy, biodiversity, nature, etc. and chitchat if it is just general conversation."),
48
  ("user", "input: {input}")
49
  ])
50
 
@@ -58,11 +57,19 @@ def make_intent_categorization_node(llm):
58
 
59
  def categorize_message(state):
60
  print("---- Categorize_message ----")
 
61
 
62
  output = categorization_chain.invoke({"input": state["user_input"]})
63
- print(f"\n\nOutput intent categorization: {output}\n")
64
- if "language" not in output: output["language"] = "English"
 
 
 
 
 
 
65
  output["query"] = state["user_input"]
 
66
  return output
67
 
68
  return categorize_message
 
 
1
  from langchain_core.pydantic_v1 import BaseModel, Field
2
  from typing import List
3
  from typing import Literal
 
43
  llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"IntentCategorizer"})
44
 
45
  prompt = ChatPromptTemplate.from_messages([
46
+ ("system", "You are a helpful assistant, you will analyze, detect the language, and categorize the user input message using the function provided. You MUST detect and return the language of the input message. Categorize the user input as ai ONLY if it is related to Artificial Intelligence, search if it is related to the environment, climate change, energy, biodiversity, nature, etc. and chitchat if it is just general conversation."),
47
  ("user", "input: {input}")
48
  ])
49
 
 
57
 
58
  def categorize_message(state):
59
  print("---- Categorize_message ----")
60
+ print(f"Input state: {state}")
61
 
62
  output = categorization_chain.invoke({"input": state["user_input"]})
63
+ print(f"\n\nRaw output from categorization: {output}\n")
64
+
65
+ if "language" not in output:
66
+ print("WARNING: Language field missing from output, setting default to English")
67
+ output["language"] = "English"
68
+ else:
69
+ print(f"Language detected: {output['language']}")
70
+
71
  output["query"] = state["user_input"]
72
+ print(f"Final output: {output}")
73
  return output
74
 
75
  return categorize_message
climateqa/engine/chains/retrieve_documents.py CHANGED
@@ -621,10 +621,7 @@ def make_IPx_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_
621
 
622
  def make_POC_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
623
 
624
- async def retrieve_POC_docs_node(state, config):
625
- if "POC region" not in state["relevant_content_sources_selection"] :
626
- return {}
627
-
628
  source_type = "POC"
629
  POC_questions_index = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "POC"]
630
 
@@ -665,10 +662,7 @@ def make_POC_by_ToC_retriever_node(
665
  k_summary=5,
666
  ):
667
 
668
- async def retrieve_POC_docs_node(state, config):
669
- if "POC region" not in state["relevant_content_sources_selection"] :
670
- return {}
671
-
672
  search_figures = "Figures (IPCC/IPBES)" in state["relevant_content_sources_selection"]
673
  search_only = state["search_only"]
674
  search_only = state["search_only"]
 
621
 
622
  def make_POC_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
623
 
624
+ async def retrieve_POC_docs_node(state, config):
 
 
 
625
  source_type = "POC"
626
  POC_questions_index = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "POC"]
627
 
 
662
  k_summary=5,
663
  ):
664
 
665
+ async def retrieve_POC_docs_node(state, config):
 
 
 
666
  search_figures = "Figures (IPCC/IPBES)" in state["relevant_content_sources_selection"]
667
  search_only = state["search_only"]
668
  search_only = state["search_only"]
climateqa/engine/chains/standalone_question.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.prompts import ChatPromptTemplate
2
+
3
+ def make_standalone_question_chain(llm):
4
+ prompt = ChatPromptTemplate.from_messages([
5
+ ("system", """You are a helpful assistant that transforms user questions into standalone questions
6
+ by incorporating context from the chat history if needed. The output should be a self-contained
7
+ question that can be understood without any additional context.
8
+
9
+ Examples:
10
+ Chat History: "Let's talk about renewable energy"
11
+ User Input: "What about solar?"
12
+ Output: "What are the key aspects of solar energy as a renewable energy source?"
13
+
14
+ Chat History: "What causes global warming?"
15
+ User Input: "And what are its effects?"
16
+ Output: "What are the effects of global warming on the environment and society?"
17
+ """),
18
+ ("user", """Chat History: {chat_history}
19
+ User Question: {question}
20
+
21
+ Transform this into a standalone question:
22
+ Make sure to keep the original language of the question.""")
23
+ ])
24
+
25
+ chain = prompt | llm
26
+ return chain
27
+
28
+ def make_standalone_question_node(llm):
29
+ standalone_chain = make_standalone_question_chain(llm)
30
+
31
+ def transform_to_standalone(state):
32
+ chat_history = state.get("chat_history", "")
33
+ if chat_history == "":
34
+ return {}
35
+ output = standalone_chain.invoke({
36
+ "chat_history": chat_history,
37
+ "question": state["user_input"]
38
+ })
39
+ state["user_input"] = output.content
40
+ return state
41
+
42
+ return transform_to_standalone
climateqa/engine/graph.py CHANGED
@@ -23,13 +23,15 @@ from .chains.retrieve_documents import make_IPx_retriever_node, make_POC_retriev
23
  from .chains.answer_rag import make_rag_node
24
  from .chains.graph_retriever import make_graph_retriever_node
25
  from .chains.chitchat_categorization import make_chitchat_intent_categorization_node
26
- # from .chains.set_defaults import set_defaults
 
27
 
28
  class GraphState(TypedDict):
29
  """
30
  Represents the state of our graph.
31
  """
32
  user_input : str
 
33
  language : str
34
  intent : str
35
  search_graphs_chitchat : bool
@@ -49,6 +51,7 @@ class GraphState(TypedDict):
49
  recommended_content : List[Document] # OWID Graphs # TODO merge with related_contents
50
  search_only : bool = False
51
  reports : List[str] = []
 
52
 
53
  def dummy(state):
54
  return
@@ -100,15 +103,6 @@ def route_continue_retrieve_documents(state):
100
  else:
101
  return "retrieve_documents"
102
 
103
- def route_continue_retrieve_local_documents(state):
104
- index_question_poc = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "POC"]
105
- questions_poc_finished = all(elem in state["handled_questions_index"] for elem in index_question_poc)
106
- # if questions_poc_finished and state["search_only"]:
107
- # return END
108
- if questions_poc_finished or ("POC region" not in state["relevant_content_sources_selection"]):
109
- return "end_retrieve_local_documents"
110
- else:
111
- return "retrieve_local_data"
112
 
113
  def route_retrieve_documents(state):
114
  sources_to_retrieve = []
@@ -120,6 +114,11 @@ def route_retrieve_documents(state):
120
  return END
121
  return sources_to_retrieve
122
 
 
 
 
 
 
123
  def make_id_dict(values):
124
  return {k:k for k in values}
125
 
@@ -128,6 +127,7 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_regi
128
  workflow = StateGraph(GraphState)
129
 
130
  # Define the node functions
 
131
  categorize_intent = make_intent_categorization_node(llm)
132
  transform_query = make_query_transform_node(llm)
133
  translate_query = make_translation_node(llm)
@@ -139,9 +139,11 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_regi
139
  answer_rag = make_rag_node(llm, with_docs=True)
140
  answer_rag_no_docs = make_rag_node(llm, with_docs=False)
141
  chitchat_categorize_intent = make_chitchat_intent_categorization_node(llm)
 
142
 
143
  # Define the nodes
144
  # workflow.add_node("set_defaults", set_defaults)
 
145
  workflow.add_node("categorize_intent", categorize_intent)
146
  workflow.add_node("answer_climate", dummy)
147
  workflow.add_node("answer_search", answer_search)
@@ -155,9 +157,11 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_regi
155
  workflow.add_node("retrieve_documents", retrieve_documents)
156
  workflow.add_node("answer_rag", answer_rag)
157
  workflow.add_node("answer_rag_no_docs", answer_rag_no_docs)
 
 
158
 
159
  # Entry point
160
- workflow.set_entry_point("categorize_intent")
161
 
162
  # CONDITIONAL EDGES
163
  workflow.add_conditional_edges(
@@ -189,20 +193,29 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_regi
189
  make_id_dict(["retrieve_graphs", END])
190
  )
191
 
 
 
 
 
 
 
192
  # Define the edges
 
193
  workflow.add_edge("translate_query", "transform_query")
194
  workflow.add_edge("transform_query", "retrieve_documents") #TODO put back
195
  # workflow.add_edge("transform_query", "retrieve_local_data")
196
  # workflow.add_edge("transform_query", END) # TODO remove
197
 
198
  workflow.add_edge("retrieve_graphs", END)
199
- workflow.add_edge("answer_rag", END)
200
- workflow.add_edge("answer_rag_no_docs", END)
201
  workflow.add_edge("answer_chitchat", "chitchat_categorize_intent")
202
  workflow.add_edge("retrieve_graphs_chitchat", END)
203
 
204
  # workflow.add_edge("retrieve_local_data", "answer_search")
205
  workflow.add_edge("retrieve_documents", "answer_search")
 
 
206
 
207
  # Compile
208
  app = workflow.compile()
@@ -228,6 +241,8 @@ def make_graph_agent_poc(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_
228
  workflow = StateGraph(GraphState)
229
 
230
  # Define the node functions
 
 
231
  categorize_intent = make_intent_categorization_node(llm)
232
  transform_query = make_query_transform_node(llm)
233
  translate_query = make_translation_node(llm)
@@ -240,9 +255,11 @@ def make_graph_agent_poc(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_
240
  answer_rag = make_rag_node(llm, with_docs=True)
241
  answer_rag_no_docs = make_rag_node(llm, with_docs=False)
242
  chitchat_categorize_intent = make_chitchat_intent_categorization_node(llm)
 
243
 
244
  # Define the nodes
245
  # workflow.add_node("set_defaults", set_defaults)
 
246
  workflow.add_node("categorize_intent", categorize_intent)
247
  workflow.add_node("answer_climate", dummy)
248
  workflow.add_node("answer_search", answer_search)
@@ -258,9 +275,10 @@ def make_graph_agent_poc(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_
258
  workflow.add_node("retrieve_documents", retrieve_documents)
259
  workflow.add_node("answer_rag", answer_rag)
260
  workflow.add_node("answer_rag_no_docs", answer_rag_no_docs)
 
261
 
262
  # Entry point
263
- workflow.set_entry_point("categorize_intent")
264
 
265
  # CONDITIONAL EDGES
266
  workflow.add_conditional_edges(
@@ -293,22 +311,21 @@ def make_graph_agent_poc(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_
293
  )
294
 
295
  # Define the edges
 
296
  workflow.add_edge("translate_query", "transform_query")
297
  workflow.add_edge("transform_query", "retrieve_documents") #TODO put back
298
  workflow.add_edge("transform_query", "retrieve_local_data")
299
  # workflow.add_edge("transform_query", END) # TODO remove
300
 
301
  workflow.add_edge("retrieve_graphs", END)
302
- workflow.add_edge("answer_rag", END)
303
- workflow.add_edge("answer_rag_no_docs", END)
304
  workflow.add_edge("answer_chitchat", "chitchat_categorize_intent")
305
  workflow.add_edge("retrieve_graphs_chitchat", END)
306
 
307
  workflow.add_edge("retrieve_local_data", "answer_search")
308
  workflow.add_edge("retrieve_documents", "answer_search")
309
-
310
- # workflow.add_edge("transform_query", "retrieve_drias_data")
311
- # workflow.add_edge("retrieve_drias_data", END)
312
 
313
 
314
  # Compile
 
23
  from .chains.answer_rag import make_rag_node
24
  from .chains.graph_retriever import make_graph_retriever_node
25
  from .chains.chitchat_categorization import make_chitchat_intent_categorization_node
26
+ from .chains.standalone_question import make_standalone_question_node
27
+ from .chains.follow_up import make_follow_up_node # Add this import
28
 
29
  class GraphState(TypedDict):
30
  """
31
  Represents the state of our graph.
32
  """
33
  user_input : str
34
+ chat_history : str
35
  language : str
36
  intent : str
37
  search_graphs_chitchat : bool
 
51
  recommended_content : List[Document] # OWID Graphs # TODO merge with related_contents
52
  search_only : bool = False
53
  reports : List[str] = []
54
+ follow_up_questions: List[str] = []
55
 
56
  def dummy(state):
57
  return
 
103
  else:
104
  return "retrieve_documents"
105
 
 
 
 
 
 
 
 
 
 
106
 
107
  def route_retrieve_documents(state):
108
  sources_to_retrieve = []
 
114
  return END
115
  return sources_to_retrieve
116
 
117
+ def route_follow_up(state):
118
+ if state["follow_up_questions"]:
119
+ return "process_follow_up"
120
+ return END
121
+
122
  def make_id_dict(values):
123
  return {k:k for k in values}
124
 
 
127
  workflow = StateGraph(GraphState)
128
 
129
  # Define the node functions
130
+ standalone_question_node = make_standalone_question_node(llm)
131
  categorize_intent = make_intent_categorization_node(llm)
132
  transform_query = make_query_transform_node(llm)
133
  translate_query = make_translation_node(llm)
 
139
  answer_rag = make_rag_node(llm, with_docs=True)
140
  answer_rag_no_docs = make_rag_node(llm, with_docs=False)
141
  chitchat_categorize_intent = make_chitchat_intent_categorization_node(llm)
142
+ generate_follow_up = make_follow_up_node(llm)
143
 
144
  # Define the nodes
145
  # workflow.add_node("set_defaults", set_defaults)
146
+ workflow.add_node("standalone_question", standalone_question_node)
147
  workflow.add_node("categorize_intent", categorize_intent)
148
  workflow.add_node("answer_climate", dummy)
149
  workflow.add_node("answer_search", answer_search)
 
157
  workflow.add_node("retrieve_documents", retrieve_documents)
158
  workflow.add_node("answer_rag", answer_rag)
159
  workflow.add_node("answer_rag_no_docs", answer_rag_no_docs)
160
+ workflow.add_node("generate_follow_up", generate_follow_up)
161
+ # workflow.add_node("process_follow_up", standalone_question_node)
162
 
163
  # Entry point
164
+ workflow.set_entry_point("standalone_question")
165
 
166
  # CONDITIONAL EDGES
167
  workflow.add_conditional_edges(
 
193
  make_id_dict(["retrieve_graphs", END])
194
  )
195
 
196
+ # workflow.add_conditional_edges(
197
+ # "generate_follow_up",
198
+ # route_follow_up,
199
+ # make_id_dict(["process_follow_up", END])
200
+ # )
201
+
202
  # Define the edges
203
+ workflow.add_edge("standalone_question", "categorize_intent")
204
  workflow.add_edge("translate_query", "transform_query")
205
  workflow.add_edge("transform_query", "retrieve_documents") #TODO put back
206
  # workflow.add_edge("transform_query", "retrieve_local_data")
207
  # workflow.add_edge("transform_query", END) # TODO remove
208
 
209
  workflow.add_edge("retrieve_graphs", END)
210
+ workflow.add_edge("answer_rag", "generate_follow_up")
211
+ workflow.add_edge("answer_rag_no_docs", "generate_follow_up")
212
  workflow.add_edge("answer_chitchat", "chitchat_categorize_intent")
213
  workflow.add_edge("retrieve_graphs_chitchat", END)
214
 
215
  # workflow.add_edge("retrieve_local_data", "answer_search")
216
  workflow.add_edge("retrieve_documents", "answer_search")
217
+ workflow.add_edge("generate_follow_up",END)
218
+ # workflow.add_edge("process_follow_up", "categorize_intent")
219
 
220
  # Compile
221
  app = workflow.compile()
 
241
  workflow = StateGraph(GraphState)
242
 
243
  # Define the node functions
244
+ standalone_question_node = make_standalone_question_node(llm)
245
+
246
  categorize_intent = make_intent_categorization_node(llm)
247
  transform_query = make_query_transform_node(llm)
248
  translate_query = make_translation_node(llm)
 
255
  answer_rag = make_rag_node(llm, with_docs=True)
256
  answer_rag_no_docs = make_rag_node(llm, with_docs=False)
257
  chitchat_categorize_intent = make_chitchat_intent_categorization_node(llm)
258
+ generate_follow_up = make_follow_up_node(llm)
259
 
260
  # Define the nodes
261
  # workflow.add_node("set_defaults", set_defaults)
262
+ workflow.add_node("standalone_question", standalone_question_node)
263
  workflow.add_node("categorize_intent", categorize_intent)
264
  workflow.add_node("answer_climate", dummy)
265
  workflow.add_node("answer_search", answer_search)
 
275
  workflow.add_node("retrieve_documents", retrieve_documents)
276
  workflow.add_node("answer_rag", answer_rag)
277
  workflow.add_node("answer_rag_no_docs", answer_rag_no_docs)
278
+ workflow.add_node("generate_follow_up", generate_follow_up)
279
 
280
  # Entry point
281
+ workflow.set_entry_point("standalone_question")
282
 
283
  # CONDITIONAL EDGES
284
  workflow.add_conditional_edges(
 
311
  )
312
 
313
  # Define the edges
314
+ workflow.add_edge("standalone_question", "categorize_intent")
315
  workflow.add_edge("translate_query", "transform_query")
316
  workflow.add_edge("transform_query", "retrieve_documents") #TODO put back
317
  workflow.add_edge("transform_query", "retrieve_local_data")
318
  # workflow.add_edge("transform_query", END) # TODO remove
319
 
320
  workflow.add_edge("retrieve_graphs", END)
321
+ workflow.add_edge("answer_rag", "generate_follow_up")
322
+ workflow.add_edge("answer_rag_no_docs", "generate_follow_up")
323
  workflow.add_edge("answer_chitchat", "chitchat_categorize_intent")
324
  workflow.add_edge("retrieve_graphs_chitchat", END)
325
 
326
  workflow.add_edge("retrieve_local_data", "answer_search")
327
  workflow.add_edge("retrieve_documents", "answer_search")
328
+ workflow.add_edge("generate_follow_up",END)
 
 
329
 
330
 
331
  # Compile
climateqa/engine/talk_to_data/config.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DRIAS_TABLES = [
2
+ "total_winter_precipitation",
3
+ "total_summer_precipiation",
4
+ "total_annual_precipitation",
5
+ "total_remarkable_daily_precipitation",
6
+ "frequency_of_remarkable_daily_precipitation",
7
+ "extreme_precipitation_intensity",
8
+ "mean_winter_temperature",
9
+ "mean_summer_temperature",
10
+ "mean_annual_temperature",
11
+ "number_of_tropical_nights",
12
+ "maximum_summer_temperature",
13
+ "number_of_days_with_tx_above_30",
14
+ "number_of_days_with_tx_above_35",
15
+ "number_of_days_with_a_dry_ground",
16
+ ]
17
+
18
+ INDICATOR_COLUMNS_PER_TABLE = {
19
+ "total_winter_precipitation": "total_winter_precipitation",
20
+ "total_summer_precipiation": "total_summer_precipitation",
21
+ "total_annual_precipitation": "total_annual_precipitation",
22
+ "total_remarkable_daily_precipitation": "total_remarkable_daily_precipitation",
23
+ "frequency_of_remarkable_daily_precipitation": "frequency_of_remarkable_daily_precipitation",
24
+ "extreme_precipitation_intensity": "extreme_precipitation_intensity",
25
+ "mean_winter_temperature": "mean_winter_temperature",
26
+ "mean_summer_temperature": "mean_summer_temperature",
27
+ "mean_annual_temperature": "mean_annual_temperature",
28
+ "number_of_tropical_nights": "number_tropical_nights",
29
+ "maximum_summer_temperature": "maximum_summer_temperature",
30
+ "number_of_days_with_tx_above_30": "number_of_days_with_tx_above_30",
31
+ "number_of_days_with_tx_above_35": "number_of_days_with_tx_above_35",
32
+ "number_of_days_with_a_dry_ground": "number_of_days_with_dry_ground"
33
+ }
34
+
35
+ DRIAS_MODELS = [
36
+ 'ALL',
37
+ 'RegCM4-6_MPI-ESM-LR',
38
+ 'RACMO22E_EC-EARTH',
39
+ 'RegCM4-6_HadGEM2-ES',
40
+ 'HadREM3-GA7_EC-EARTH',
41
+ 'HadREM3-GA7_CNRM-CM5',
42
+ 'REMO2015_NorESM1-M',
43
+ 'SMHI-RCA4_EC-EARTH',
44
+ 'WRF381P_NorESM1-M',
45
+ 'ALADIN63_CNRM-CM5',
46
+ 'CCLM4-8-17_MPI-ESM-LR',
47
+ 'HIRHAM5_IPSL-CM5A-MR',
48
+ 'HadREM3-GA7_HadGEM2-ES',
49
+ 'SMHI-RCA4_IPSL-CM5A-MR',
50
+ 'HIRHAM5_NorESM1-M',
51
+ 'REMO2009_MPI-ESM-LR',
52
+ 'CCLM4-8-17_HadGEM2-ES'
53
+ ]
54
+ # Mapping between indicator columns and their units
55
+ INDICATOR_TO_UNIT = {
56
+ "total_winter_precipitation": "mm",
57
+ "total_summer_precipitation": "mm",
58
+ "total_annual_precipitation": "mm",
59
+ "total_remarkable_daily_precipitation": "mm",
60
+ "frequency_of_remarkable_daily_precipitation": "days",
61
+ "extreme_precipitation_intensity": "mm",
62
+ "mean_winter_temperature": "°C",
63
+ "mean_summer_temperature": "°C",
64
+ "mean_annual_temperature": "°C",
65
+ "number_tropical_nights": "days",
66
+ "maximum_summer_temperature": "°C",
67
+ "number_of_days_with_tx_above_30": "days",
68
+ "number_of_days_with_tx_above_35": "days",
69
+ "number_of_days_with_dry_ground": "days"
70
+ }
71
+
72
+ DRIAS_UI_TEXT = """
73
+ Hi, I'm **Talk to Drias**, designed to answer your questions using [**DRIAS - TRACC 2023**](https://www.drias-climat.fr/accompagnement/sections/401) data.
74
+ I'll answer by displaying a list of SQL queries, graphs and data most relevant to your question.
75
+
76
+ ❓ **How to use?**
77
+ You can ask me anything about these climate indicators: **temperature**, **precipitation** or **drought**.
78
+ You can specify **location** and/or **year**.
79
+ You can choose from a list of climate models. By default, we take the **average of each model**.
80
+
81
+ For example, you can ask:
82
+ - What will the temperature be like in Paris?
83
+ - What will be the total rainfall in France in 2030?
84
+ - How frequent will extreme events be in Lyon?
85
+
86
+ **Example of indicators in the data**:
87
+ - Mean temperature (annual, winter, summer)
88
+ - Total precipitation (annual, winter, summer)
89
+ - Number of days with remarkable precipitations, with dry ground, with temperature above 30°C
90
+
91
+ ⚠️ **Limitations**:
92
+ - You can't ask anything that isn't related to **DRIAS - TRACC 2023** data.
93
+ - You can only ask about **locations in France**.
94
+ - If you specify a year, there may be **no data for that year for some models**.
95
+ - You **cannot compare two models**.
96
+
97
+ 🛈 **Information**
98
+ Please note that we **log your questions for meta-analysis purposes**, so avoid sharing any sensitive or personal information.
99
+ """
climateqa/engine/talk_to_data/main.py CHANGED
@@ -1,47 +1,115 @@
1
- from climateqa.engine.talk_to_data.myVanna import MyVanna
2
- from climateqa.engine.talk_to_data.utils import loc2coords, detect_location_with_openai, detectTable, nearestNeighbourSQL, detect_relevant_tables, replace_coordonates
3
- import sqlite3
4
- import os
5
- import pandas as pd
6
  from climateqa.engine.llm import get_llm
7
  import ast
8
 
9
-
10
-
11
  llm = get_llm(provider="openai")
12
 
13
- def ask_llm_to_add_table_names(sql_query, llm):
 
 
 
 
 
 
 
 
 
 
 
 
14
  sql_with_table_names = llm.invoke(f"Make the following sql query display the source table in the rows {sql_query}. Just answer the query. The answer should not include ```sql\n").content
15
  return sql_with_table_names
16
 
17
- def ask_llm_column_names(sql_query, llm):
 
 
 
 
 
 
 
 
 
 
 
 
18
  columns = llm.invoke(f"From the given sql query, list the columns that are being selected. The answer should only be a python list. Just answer the list. The SQL query : {sql_query}").content
19
  columns_list = ast.literal_eval(columns.strip("```python\n").strip())
20
  return columns_list
21
 
22
- def ask_vanna(vn,db_vanna_path, query):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- try :
25
- location = detect_location_with_openai(query)
26
- if location:
27
 
28
- coords = loc2coords(location)
29
- user_input = query.lower().replace(location.lower(), f"lat, long : {coords}")
 
 
 
 
 
 
30
 
31
- relevant_tables = detect_relevant_tables(user_input, llm)
32
- coords_tables = [nearestNeighbourSQL(db_vanna_path, coords, relevant_tables[i]) for i in range(len(relevant_tables))]
33
- user_input_with_coords = replace_coordonates(coords, user_input, coords_tables)
34
-
35
- sql_query, result_dataframe, figure = vn.ask(user_input_with_coords, print_results=False, allow_llm_to_see_data=True, auto_train=False)
36
-
37
- return sql_query, result_dataframe, figure
38
-
39
- else :
40
- empty_df = pd.DataFrame()
41
- empty_fig = None
42
- return "", empty_df, empty_fig
43
- except Exception as e:
44
- print(f"Error: {e}")
45
- empty_df = pd.DataFrame()
46
- empty_fig = None
47
- return "", empty_df, empty_fig
 
1
+ from climateqa.engine.talk_to_data.workflow import drias_workflow
 
 
 
 
2
  from climateqa.engine.llm import get_llm
3
  import ast
4
 
 
 
5
  llm = get_llm(provider="openai")
6
 
7
+ def ask_llm_to_add_table_names(sql_query: str, llm) -> str:
8
+ """Adds table names to the SQL query result rows using LLM.
9
+
10
+ This function modifies the SQL query to include the source table name in each row
11
+ of the result set, making it easier to track which data comes from which table.
12
+
13
+ Args:
14
+ sql_query (str): The original SQL query to modify
15
+ llm: The language model instance to use for generating the modified query
16
+
17
+ Returns:
18
+ str: The modified SQL query with table names included in the result rows
19
+ """
20
  sql_with_table_names = llm.invoke(f"Make the following sql query display the source table in the rows {sql_query}. Just answer the query. The answer should not include ```sql\n").content
21
  return sql_with_table_names
22
 
23
+ def ask_llm_column_names(sql_query: str, llm) -> list[str]:
24
+ """Extracts column names from a SQL query using LLM.
25
+
26
+ This function analyzes a SQL query to identify which columns are being selected
27
+ in the result set.
28
+
29
+ Args:
30
+ sql_query (str): The SQL query to analyze
31
+ llm: The language model instance to use for column extraction
32
+
33
+ Returns:
34
+ list[str]: A list of column names being selected in the query
35
+ """
36
  columns = llm.invoke(f"From the given sql query, list the columns that are being selected. The answer should only be a python list. Just answer the list. The SQL query : {sql_query}").content
37
  columns_list = ast.literal_eval(columns.strip("```python\n").strip())
38
  return columns_list
39
 
40
+ async def ask_drias(query: str, index_state: int = 0) -> tuple:
41
+ """Main function to process a DRIAS query and return results.
42
+
43
+ This function orchestrates the DRIAS workflow, processing a user query to generate
44
+ SQL queries, dataframes, and visualizations. It handles multiple results and allows
45
+ pagination through them.
46
+
47
+ Args:
48
+ query (str): The user's question about climate data
49
+ index_state (int, optional): The index of the result to return. Defaults to 0.
50
+
51
+ Returns:
52
+ tuple: A tuple containing:
53
+ - sql_query (str): The SQL query used
54
+ - dataframe (pd.DataFrame): The resulting data
55
+ - figure (Callable): Function to generate the visualization
56
+ - sql_queries (list): All generated SQL queries
57
+ - result_dataframes (list): All resulting dataframes
58
+ - figures (list): All figure generation functions
59
+ - index_state (int): Current result index
60
+ - table_list (list): List of table names used
61
+ - error (str): Error message if any
62
+ """
63
+ final_state = await drias_workflow(query)
64
+ sql_queries = []
65
+ result_dataframes = []
66
+ figures = []
67
+ table_list = []
68
+
69
+ for plot_state in final_state['plot_states'].values():
70
+ for table_state in plot_state['table_states'].values():
71
+ if table_state['status'] == 'OK':
72
+ if 'table_name' in table_state:
73
+ table_list.append(' '.join(table_state['table_name'].capitalize().split('_')))
74
+ if 'sql_query' in table_state and table_state['sql_query'] is not None:
75
+ sql_queries.append(table_state['sql_query'])
76
+
77
+ if 'dataframe' in table_state and table_state['dataframe'] is not None:
78
+ result_dataframes.append(table_state['dataframe'])
79
+ if 'figure' in table_state and table_state['figure'] is not None:
80
+ figures.append(table_state['figure'])
81
+
82
+ if "error" in final_state and final_state["error"] != "":
83
+ return None, None, None, [], [], [], 0, final_state["error"]
84
+
85
+ sql_query = sql_queries[index_state]
86
+ dataframe = result_dataframes[index_state]
87
+ figure = figures[index_state](dataframe)
88
 
89
+ return sql_query, dataframe, figure, sql_queries, result_dataframes, figures, index_state, table_list, ""
 
 
90
 
91
+ # def ask_vanna(vn,db_vanna_path, query):
92
+
93
+ # try :
94
+ # location = detect_location_with_openai(query)
95
+ # if location:
96
+
97
+ # coords = loc2coords(location)
98
+ # user_input = query.lower().replace(location.lower(), f"lat, long : {coords}")
99
 
100
+ # relevant_tables = detect_relevant_tables(db_vanna_path, user_input, llm)
101
+ # coords_tables = [nearestNeighbourSQL(db_vanna_path, coords, relevant_tables[i]) for i in range(len(relevant_tables))]
102
+ # user_input_with_coords = replace_coordonates(coords, user_input, coords_tables)
103
+
104
+ # sql_query, result_dataframe, figure = vn.ask(user_input_with_coords, print_results=False, allow_llm_to_see_data=True, auto_train=False)
105
+
106
+ # return sql_query, result_dataframe, figure
107
+ # else :
108
+ # empty_df = pd.DataFrame()
109
+ # empty_fig = None
110
+ # return "", empty_df, empty_fig
111
+ # except Exception as e:
112
+ # print(f"Error: {e}")
113
+ # empty_df = pd.DataFrame()
114
+ # empty_fig = None
115
+ # return "", empty_df, empty_fig
 
climateqa/engine/talk_to_data/plot.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, TypedDict
2
+ from matplotlib.figure import figaspect
3
+ import pandas as pd
4
+ from plotly.graph_objects import Figure
5
+ import plotly.graph_objects as go
6
+ import plotly.express as px
7
+
8
+ from climateqa.engine.talk_to_data.sql_query import (
9
+ indicator_for_given_year_query,
10
+ indicator_per_year_at_location_query,
11
+ )
12
+ from climateqa.engine.talk_to_data.config import INDICATOR_TO_UNIT
13
+
14
+
15
+
16
+
17
+ class Plot(TypedDict):
18
+ """Represents a plot configuration in the DRIAS system.
19
+
20
+ This class defines the structure for configuring different types of plots
21
+ that can be generated from climate data.
22
+
23
+ Attributes:
24
+ name (str): The name of the plot type
25
+ description (str): A description of what the plot shows
26
+ params (list[str]): List of required parameters for the plot
27
+ plot_function (Callable[..., Callable[..., Figure]]): Function to generate the plot
28
+ sql_query (Callable[..., str]): Function to generate the SQL query for the plot
29
+ """
30
+ name: str
31
+ description: str
32
+ params: list[str]
33
+ plot_function: Callable[..., Callable[..., Figure]]
34
+ sql_query: Callable[..., str]
35
+
36
+
37
+ def plot_indicator_evolution_at_location(params: dict) -> Callable[..., Figure]:
38
+ """Generates a function to plot indicator evolution over time at a location.
39
+
40
+ This function creates a line plot showing how a climate indicator changes
41
+ over time at a specific location. It handles temperature, precipitation,
42
+ and other climate indicators.
43
+
44
+ Args:
45
+ params (dict): Dictionary containing:
46
+ - indicator_column (str): The column name for the indicator
47
+ - location (str): The location to plot
48
+ - model (str): The climate model to use
49
+
50
+ Returns:
51
+ Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
52
+
53
+ Example:
54
+ >>> plot_func = plot_indicator_evolution_at_location({
55
+ ... 'indicator_column': 'mean_temperature',
56
+ ... 'location': 'Paris',
57
+ ... 'model': 'ALL'
58
+ ... })
59
+ >>> fig = plot_func(df)
60
+ """
61
+ indicator = params["indicator_column"]
62
+ location = params["location"]
63
+ indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
64
+ unit = INDICATOR_TO_UNIT.get(indicator, "")
65
+
66
+ def plot_data(df: pd.DataFrame) -> Figure:
67
+ """Generates the actual plot from the data.
68
+
69
+ Args:
70
+ df (pd.DataFrame): DataFrame containing the data to plot
71
+
72
+ Returns:
73
+ Figure: A plotly Figure object showing the indicator evolution
74
+ """
75
+ fig = go.Figure()
76
+ if df['model'].nunique() != 1:
77
+ df_avg = df.groupby("year", as_index=False)[indicator].mean()
78
+
79
+ # Transform to list to avoid pandas encoding
80
+ indicators = df_avg[indicator].astype(float).tolist()
81
+ years = df_avg["year"].astype(int).tolist()
82
+
83
+ # Compute the 10-year rolling average
84
+ sliding_averages = (
85
+ df_avg[indicator]
86
+ .rolling(window=10, min_periods=1)
87
+ .mean()
88
+ .astype(float)
89
+ .tolist()
90
+ )
91
+ model_label = "Model Average"
92
+
93
+ else:
94
+ df_model = df
95
+
96
+ # Transform to list to avoid pandas encoding
97
+ indicators = df_model[indicator].astype(float).tolist()
98
+ years = df_model["year"].astype(int).tolist()
99
+
100
+ # Compute the 10-year rolling average
101
+ sliding_averages = (
102
+ df_model[indicator]
103
+ .rolling(window=10, min_periods=1)
104
+ .mean()
105
+ .astype(float)
106
+ .tolist()
107
+ )
108
+ model_label = f"Model : {df['model'].unique()[0]}"
109
+
110
+
111
+ # Indicator per year plot
112
+ fig.add_scatter(
113
+ x=years,
114
+ y=indicators,
115
+ name=f"Yearly {indicator_label}",
116
+ mode="lines",
117
+ marker=dict(color="#1f77b4"),
118
+ hovertemplate=f"{indicator_label}: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
119
+ )
120
+
121
+ # Sliding average dashed line
122
+ fig.add_scatter(
123
+ x=years,
124
+ y=sliding_averages,
125
+ mode="lines",
126
+ name="10 years rolling average",
127
+ line=dict(dash="dash"),
128
+ marker=dict(color="#d62728"),
129
+ hovertemplate=f"10-year average: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
130
+ )
131
+ fig.update_layout(
132
+ title=f"Plot of {indicator_label} in {location} ({model_label})",
133
+ xaxis_title="Year",
134
+ yaxis_title=f"{indicator_label} ({unit})",
135
+ template="plotly_white",
136
+ )
137
+ return fig
138
+
139
+ return plot_data
140
+
141
+
142
+ indicator_evolution_at_location: Plot = {
143
+ "name": "Indicator evolution at location",
144
+ "description": "Plot an evolution of the indicator at a certain location",
145
+ "params": ["indicator_column", "location", "model"],
146
+ "plot_function": plot_indicator_evolution_at_location,
147
+ "sql_query": indicator_per_year_at_location_query,
148
+ }
149
+
150
+
151
+ def plot_indicator_number_of_days_per_year_at_location(
152
+ params: dict,
153
+ ) -> Callable[..., Figure]:
154
+ """Generates a function to plot the number of days per year for an indicator.
155
+
156
+ This function creates a bar chart showing the frequency of certain climate
157
+ events (like days above a temperature threshold) per year at a specific location.
158
+
159
+ Args:
160
+ params (dict): Dictionary containing:
161
+ - indicator_column (str): The column name for the indicator
162
+ - location (str): The location to plot
163
+ - model (str): The climate model to use
164
+
165
+ Returns:
166
+ Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
167
+ """
168
+ indicator = params["indicator_column"]
169
+ location = params["location"]
170
+ indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
171
+ unit = INDICATOR_TO_UNIT.get(indicator, "")
172
+
173
+ def plot_data(df: pd.DataFrame) -> Figure:
174
+ """Generate the figure thanks to the dataframe
175
+
176
+ Args:
177
+ df (pd.DataFrame): pandas dataframe with the required data
178
+
179
+ Returns:
180
+ Figure: Plotly figure
181
+ """
182
+ fig = go.Figure()
183
+ if df['model'].nunique() != 1:
184
+ df_avg = df.groupby("year", as_index=False)[indicator].mean()
185
+
186
+ # Transform to list to avoid pandas encoding
187
+ indicators = df_avg[indicator].astype(float).tolist()
188
+ years = df_avg["year"].astype(int).tolist()
189
+ model_label = "Model Average"
190
+
191
+ else:
192
+ df_model = df
193
+ # Transform to list to avoid pandas encoding
194
+ indicators = df_model[indicator].astype(float).tolist()
195
+ years = df_model["year"].astype(int).tolist()
196
+ model_label = f"Model : {df['model'].unique()[0]}"
197
+
198
+
199
+ # Bar plot
200
+ fig.add_trace(
201
+ go.Bar(
202
+ x=years,
203
+ y=indicators,
204
+ width=0.5,
205
+ marker=dict(color="#1f77b4"),
206
+ hovertemplate=f"{indicator_label}: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
207
+ )
208
+ )
209
+
210
+ fig.update_layout(
211
+ title=f"{indicator_label} in {location} ({model_label})",
212
+ xaxis_title="Year",
213
+ yaxis_title=f"{indicator_label} ({unit})",
214
+ yaxis=dict(range=[0, max(indicators)]),
215
+ bargap=0.5,
216
+ template="plotly_white",
217
+ )
218
+
219
+ return fig
220
+
221
+ return plot_data
222
+
223
+
224
+ indicator_number_of_days_per_year_at_location: Plot = {
225
+ "name": "Indicator number of days per year at location",
226
+ "description": "Plot a barchart of the number of days per year of a certain indicator at a certain location. It is appropriate for frequency indicator.",
227
+ "params": ["indicator_column", "location", "model"],
228
+ "plot_function": plot_indicator_number_of_days_per_year_at_location,
229
+ "sql_query": indicator_per_year_at_location_query,
230
+ }
231
+
232
+
233
+ def plot_distribution_of_indicator_for_given_year(
234
+ params: dict,
235
+ ) -> Callable[..., Figure]:
236
+ """Generates a function to plot the distribution of an indicator for a year.
237
+
238
+ This function creates a histogram showing the distribution of a climate
239
+ indicator across different locations for a specific year.
240
+
241
+ Args:
242
+ params (dict): Dictionary containing:
243
+ - indicator_column (str): The column name for the indicator
244
+ - year (str): The year to plot
245
+ - model (str): The climate model to use
246
+
247
+ Returns:
248
+ Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
249
+ """
250
+ indicator = params["indicator_column"]
251
+ year = params["year"]
252
+ indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
253
+ unit = INDICATOR_TO_UNIT.get(indicator, "")
254
+
255
+ def plot_data(df: pd.DataFrame) -> Figure:
256
+ """Generate the figure thanks to the dataframe
257
+
258
+ Args:
259
+ df (pd.DataFrame): pandas dataframe with the required data
260
+
261
+ Returns:
262
+ Figure: Plotly figure
263
+ """
264
+ fig = go.Figure()
265
+ if df['model'].nunique() != 1:
266
+ df_avg = df.groupby(["latitude", "longitude"], as_index=False)[
267
+ indicator
268
+ ].mean()
269
+
270
+ # Transform to list to avoid pandas encoding
271
+ indicators = df_avg[indicator].astype(float).tolist()
272
+ model_label = "Model Average"
273
+
274
+ else:
275
+ df_model = df
276
+
277
+ # Transform to list to avoid pandas encoding
278
+ indicators = df_model[indicator].astype(float).tolist()
279
+ model_label = f"Model : {df['model'].unique()[0]}"
280
+
281
+
282
+ fig.add_trace(
283
+ go.Histogram(
284
+ x=indicators,
285
+ opacity=0.8,
286
+ histnorm="percent",
287
+ marker=dict(color="#1f77b4"),
288
+ hovertemplate=f"{indicator_label}: %{{x:.2f}} {unit}<br>Frequency: %{{y:.2f}}%<extra></extra>"
289
+ )
290
+ )
291
+
292
+ fig.update_layout(
293
+ title=f"Distribution of {indicator_label} in {year} ({model_label})",
294
+ xaxis_title=f"{indicator_label} ({unit})",
295
+ yaxis_title="Frequency (%)",
296
+ plot_bgcolor="rgba(0, 0, 0, 0)",
297
+ showlegend=False,
298
+ )
299
+
300
+ return fig
301
+
302
+ return plot_data
303
+
304
+
305
+ distribution_of_indicator_for_given_year: Plot = {
306
+ "name": "Distribution of an indicator for a given year",
307
+ "description": "Plot an histogram of the distribution for a given year of the values of an indicator",
308
+ "params": ["indicator_column", "model", "year"],
309
+ "plot_function": plot_distribution_of_indicator_for_given_year,
310
+ "sql_query": indicator_for_given_year_query,
311
+ }
312
+
313
+
314
+ def plot_map_of_france_of_indicator_for_given_year(
315
+ params: dict,
316
+ ) -> Callable[..., Figure]:
317
+ """Generates a function to plot a map of France for an indicator.
318
+
319
+ This function creates a choropleth map of France showing the spatial
320
+ distribution of a climate indicator for a specific year.
321
+
322
+ Args:
323
+ params (dict): Dictionary containing:
324
+ - indicator_column (str): The column name for the indicator
325
+ - year (str): The year to plot
326
+ - model (str): The climate model to use
327
+
328
+ Returns:
329
+ Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
330
+ """
331
+ indicator = params["indicator_column"]
332
+ year = params["year"]
333
+ indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
334
+ unit = INDICATOR_TO_UNIT.get(indicator, "")
335
+
336
+ def plot_data(df: pd.DataFrame) -> Figure:
337
+ fig = go.Figure()
338
+ if df['model'].nunique() != 1:
339
+ df_avg = df.groupby(["latitude", "longitude"], as_index=False)[
340
+ indicator
341
+ ].mean()
342
+
343
+ indicators = df_avg[indicator].astype(float).tolist()
344
+ latitudes = df_avg["latitude"].astype(float).tolist()
345
+ longitudes = df_avg["longitude"].astype(float).tolist()
346
+ model_label = "Model Average"
347
+
348
+ else:
349
+ df_model = df
350
+
351
+ # Transform to list to avoid pandas encoding
352
+ indicators = df_model[indicator].astype(float).tolist()
353
+ latitudes = df_model["latitude"].astype(float).tolist()
354
+ longitudes = df_model["longitude"].astype(float).tolist()
355
+ model_label = f"Model : {df['model'].unique()[0]}"
356
+
357
+
358
+ fig.add_trace(
359
+ go.Scattermapbox(
360
+ lat=latitudes,
361
+ lon=longitudes,
362
+ mode="markers",
363
+ marker=dict(
364
+ size=10,
365
+ color=indicators, # Color mapped to values
366
+ colorscale="Turbo", # Color scale (can be 'Plasma', 'Jet', etc.)
367
+ cmin=min(indicators), # Minimum color range
368
+ cmax=max(indicators), # Maximum color range
369
+ showscale=True, # Show colorbar
370
+ ),
371
+ text=[f"{indicator_label}: {value:.2f} {unit}" for value in indicators], # Add hover text showing the indicator value
372
+ hoverinfo="text" # Only show the custom text on hover
373
+ )
374
+ )
375
+
376
+ fig.update_layout(
377
+ mapbox_style="open-street-map", # Use OpenStreetMap
378
+ mapbox_zoom=3,
379
+ mapbox_center={"lat": 46.6, "lon": 2.0},
380
+ coloraxis_colorbar=dict(title=f"{indicator_label} ({unit})"), # Add legend
381
+ title=f"{indicator_label} in {year} in France ({model_label}) " # Title
382
+ )
383
+ return fig
384
+
385
+ return plot_data
386
+
387
+
388
+ map_of_france_of_indicator_for_given_year: Plot = {
389
+ "name": "Map of France of an indicator for a given year",
390
+ "description": "Heatmap on the map of France of the values of an in indicator for a given year",
391
+ "params": ["indicator_column", "year", "model"],
392
+ "plot_function": plot_map_of_france_of_indicator_for_given_year,
393
+ "sql_query": indicator_for_given_year_query,
394
+ }
395
+
396
+
397
+ PLOTS = [
398
+ indicator_evolution_at_location,
399
+ indicator_number_of_days_per_year_at_location,
400
+ distribution_of_indicator_for_given_year,
401
+ map_of_france_of_indicator_for_given_year,
402
+ ]
climateqa/engine/talk_to_data/sql_query.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from concurrent.futures import ThreadPoolExecutor
3
+ from typing import TypedDict
4
+ import duckdb
5
+ import pandas as pd
6
+
7
+ async def execute_sql_query(sql_query: str) -> pd.DataFrame:
8
+ """Executes a SQL query on the DRIAS database and returns the results.
9
+
10
+ This function connects to the DuckDB database containing DRIAS climate data
11
+ and executes the provided SQL query. It handles the database connection and
12
+ returns the results as a pandas DataFrame.
13
+
14
+ Args:
15
+ sql_query (str): The SQL query to execute
16
+
17
+ Returns:
18
+ pd.DataFrame: A DataFrame containing the query results
19
+
20
+ Raises:
21
+ duckdb.Error: If there is an error executing the SQL query
22
+ """
23
+ def _execute_query():
24
+ # Execute the query
25
+ results = duckdb.sql(sql_query)
26
+ # return fetched data
27
+ return results.fetchdf()
28
+
29
+ # Run the query in a thread pool to avoid blocking
30
+ loop = asyncio.get_event_loop()
31
+ with ThreadPoolExecutor() as executor:
32
+ return await loop.run_in_executor(executor, _execute_query)
33
+
34
+
35
+ class IndicatorPerYearAtLocationQueryParams(TypedDict, total=False):
36
+ """Parameters for querying an indicator's values over time at a location.
37
+
38
+ This class defines the parameters needed to query climate indicator data
39
+ for a specific location over multiple years.
40
+
41
+ Attributes:
42
+ indicator_column (str): The column name for the climate indicator
43
+ latitude (str): The latitude coordinate of the location
44
+ longitude (str): The longitude coordinate of the location
45
+ model (str): The climate model to use (optional)
46
+ """
47
+ indicator_column: str
48
+ latitude: str
49
+ longitude: str
50
+ model: str
51
+
52
+
53
+ def indicator_per_year_at_location_query(
54
+ table: str, params: IndicatorPerYearAtLocationQueryParams
55
+ ) -> str:
56
+ """SQL Query to get the evolution of an indicator per year at a certain location
57
+
58
+ Args:
59
+ table (str): sql table of the indicator
60
+ params (IndicatorPerYearAtLocationQueryParams) : dictionary with the required params for the query
61
+
62
+ Returns:
63
+ str: the sql query
64
+ """
65
+ indicator_column = params.get("indicator_column")
66
+ latitude = params.get("latitude")
67
+ longitude = params.get("longitude")
68
+
69
+ if indicator_column is None or latitude is None or longitude is None: # If one parameter is missing, returns an empty query
70
+ return ""
71
+
72
+ table = f"'hf://datasets/timeki/drias_db/{table.lower()}.parquet'"
73
+
74
+ sql_query = f"SELECT year, {indicator_column}, model\nFROM {table}\nWHERE latitude = {latitude} \nAnd longitude = {longitude} \nOrder by Year"
75
+
76
+ return sql_query
77
+
78
+ class IndicatorForGivenYearQueryParams(TypedDict, total=False):
79
+ """Parameters for querying an indicator's values across locations for a year.
80
+
81
+ This class defines the parameters needed to query climate indicator data
82
+ across different locations for a specific year.
83
+
84
+ Attributes:
85
+ indicator_column (str): The column name for the climate indicator
86
+ year (str): The year to query
87
+ model (str): The climate model to use (optional)
88
+ """
89
+ indicator_column: str
90
+ year: str
91
+ model: str
92
+
93
+ def indicator_for_given_year_query(
94
+ table:str, params: IndicatorForGivenYearQueryParams
95
+ ) -> str:
96
+ """SQL Query to get the values of an indicator with their latitudes, longitudes and models for a given year
97
+
98
+ Args:
99
+ table (str): sql table of the indicator
100
+ params (IndicatorForGivenYearQueryParams): dictionarry with the required params for the query
101
+
102
+ Returns:
103
+ str: the sql query
104
+ """
105
+ indicator_column = params.get("indicator_column")
106
+ year = params.get('year')
107
+ if year is None or indicator_column is None: # If one parameter is missing, returns an empty query
108
+ return ""
109
+
110
+ table = f"'hf://datasets/timeki/drias_db/{table.lower()}.parquet'"
111
+
112
+ sql_query = f"Select {indicator_column}, latitude, longitude, model\nFrom {table}\nWhere year = {year}"
113
+ return sql_query
climateqa/engine/talk_to_data/utils.py CHANGED
@@ -1,12 +1,15 @@
1
  import re
2
- import openai
3
- import pandas as pd
4
  from geopy.geocoders import Nominatim
5
- import sqlite3
6
  import ast
7
  from climateqa.engine.llm import get_llm
 
 
 
8
 
9
- def detect_location_with_openai(sentence):
 
10
  """
11
  Detects locations in a sentence using OpenAI's API via LangChain.
12
  """
@@ -19,74 +22,260 @@ def detect_location_with_openai(sentence):
19
  Sentence: "{sentence}"
20
  """
21
 
22
- response = llm.invoke(prompt)
23
  location_list = ast.literal_eval(response.content.strip("```python\n").strip())
24
  if location_list:
25
  return location_list[0]
26
  else:
27
  return ""
28
 
29
- def detectTable(sql_query):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  pattern = r'(?i)\bFROM\s+((?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+)(?:\.(?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+))*)'
31
  matches = re.findall(pattern, sql_query)
32
  return matches
33
 
34
 
35
-
36
- def loc2coords(location : str):
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  geolocator = Nominatim(user_agent="city_to_latlong")
38
- location = geolocator.geocode(location)
39
- return (location.latitude, location.longitude)
40
 
41
 
42
- def coords2loc(coords : tuple):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  geolocator = Nominatim(user_agent="coords_to_city")
44
  try:
45
  location = geolocator.reverse(coords)
46
  return location.address
47
  except Exception as e:
48
  print(f"Error: {e}")
49
- return "Unknown Location"
50
 
51
 
52
- def nearestNeighbourSQL(db: str, location: tuple, table : str):
53
- conn = sqlite3.connect(db)
54
  long = round(location[1], 3)
55
  lat = round(location[0], 3)
56
- cursor = conn.cursor()
57
- cursor.execute(f"SELECT lat, lon FROM {table} WHERE lat BETWEEN {lat - 0.3} AND {lat + 0.3} AND lon BETWEEN {long - 0.3} AND {long + 0.3}")
58
- results = cursor.fetchall()
59
- return results[0]
60
-
61
- def detect_relevant_tables(user_question, llm):
62
- table_names_list = [
63
- "Frequency_of_rainy_days_index",
64
- "Winter_precipitation_total",
65
- "Summer_precipitation_total",
66
- "Annual_precipitation_total",
67
- # "Remarkable_daily_precipitation_total_(Q99)",
68
- "Frequency_of_remarkable_daily_precipitation",
69
- "Extreme_precipitation_intensity",
70
- "Mean_winter_temperature",
71
- "Mean_summer_temperature",
72
- "Number_of_tropical_nights",
73
- "Maximum_summer_temperature",
74
- "Number_of_days_with_Tx_above_30C",
75
- "Number_of_days_with_Tx_above_35C",
76
- "Drought_index"
77
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  prompt = (
79
- f"You are helping to build a sql query to retrieve relevant data for a user question."
80
- f"The different tables are {table_names_list}."
81
- f"The user question is {user_question}. Write the relevant tables to use. Answer only a python list of table name."
 
 
 
 
 
 
 
 
82
  )
83
- table_names = ast.literal_eval(llm.invoke(prompt).content.strip("```python\n").strip())
84
  return table_names
85
 
 
86
  def replace_coordonates(coords, query, coords_tables):
87
  n = query.count(str(coords[0]))
88
 
89
  for i in range(n):
90
- query = query.replace(str(coords[0]), str(coords_tables[i][0]),1)
91
- query = query.replace(str(coords[1]), str(coords_tables[i][1]),1)
92
- return query
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import re
2
+ from typing import Annotated, TypedDict
3
+ import duckdb
4
  from geopy.geocoders import Nominatim
 
5
  import ast
6
  from climateqa.engine.llm import get_llm
7
+ from climateqa.engine.talk_to_data.config import DRIAS_TABLES
8
+ from climateqa.engine.talk_to_data.plot import PLOTS, Plot
9
+ from langchain_core.prompts import ChatPromptTemplate
10
 
11
+
12
+ async def detect_location_with_openai(sentence):
13
  """
14
  Detects locations in a sentence using OpenAI's API via LangChain.
15
  """
 
22
  Sentence: "{sentence}"
23
  """
24
 
25
+ response = await llm.ainvoke(prompt)
26
  location_list = ast.literal_eval(response.content.strip("```python\n").strip())
27
  if location_list:
28
  return location_list[0]
29
  else:
30
  return ""
31
 
32
+ class ArrayOutput(TypedDict):
33
+ """Represents the output of a function that returns an array.
34
+
35
+ This class is used to type-hint functions that return arrays,
36
+ ensuring consistent return types across the codebase.
37
+
38
+ Attributes:
39
+ array (str): A syntactically valid Python array string
40
+ """
41
+ array: Annotated[str, "Syntactically valid python array."]
42
+
43
+ async def detect_year_with_openai(sentence: str) -> str:
44
+ """
45
+ Detects years in a sentence using OpenAI's API via LangChain.
46
+ """
47
+ llm = get_llm()
48
+
49
+ prompt = """
50
+ Extract all years mentioned in the following sentence.
51
+ Return the result as a Python list. If no year are mentioned, return an empty list.
52
+
53
+ Sentence: "{sentence}"
54
+ """
55
+
56
+ prompt = ChatPromptTemplate.from_template(prompt)
57
+ structured_llm = llm.with_structured_output(ArrayOutput)
58
+ chain = prompt | structured_llm
59
+ response: ArrayOutput = await chain.ainvoke({"sentence": sentence})
60
+ years_list = eval(response['array'])
61
+ if len(years_list) > 0:
62
+ return years_list[0]
63
+ else:
64
+ return ""
65
+
66
+
67
+ def detectTable(sql_query: str) -> list[str]:
68
+ """Extracts table names from a SQL query.
69
+
70
+ This function uses regular expressions to find all table names
71
+ referenced in a SQL query's FROM clause.
72
+
73
+ Args:
74
+ sql_query (str): The SQL query to analyze
75
+
76
+ Returns:
77
+ list[str]: A list of table names found in the query
78
+
79
+ Example:
80
+ >>> detectTable("SELECT * FROM temperature_data WHERE year > 2000")
81
+ ['temperature_data']
82
+ """
83
  pattern = r'(?i)\bFROM\s+((?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+)(?:\.(?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+))*)'
84
  matches = re.findall(pattern, sql_query)
85
  return matches
86
 
87
 
88
+ def loc2coords(location: str) -> tuple[float, float]:
89
+ """Converts a location name to geographic coordinates.
90
+
91
+ This function uses the Nominatim geocoding service to convert
92
+ a location name (e.g., city name) to its latitude and longitude.
93
+
94
+ Args:
95
+ location (str): The name of the location to geocode
96
+
97
+ Returns:
98
+ tuple[float, float]: A tuple containing (latitude, longitude)
99
+
100
+ Raises:
101
+ AttributeError: If the location cannot be found
102
+ """
103
  geolocator = Nominatim(user_agent="city_to_latlong")
104
+ coords = geolocator.geocode(location)
105
+ return (coords.latitude, coords.longitude)
106
 
107
 
108
+ def coords2loc(coords: tuple[float, float]) -> str:
109
+ """Converts geographic coordinates to a location name.
110
+
111
+ This function uses the Nominatim reverse geocoding service to convert
112
+ latitude and longitude coordinates to a human-readable location name.
113
+
114
+ Args:
115
+ coords (tuple[float, float]): A tuple containing (latitude, longitude)
116
+
117
+ Returns:
118
+ str: The address of the location, or "Unknown Location" if not found
119
+
120
+ Example:
121
+ >>> coords2loc((48.8566, 2.3522))
122
+ 'Paris, France'
123
+ """
124
  geolocator = Nominatim(user_agent="coords_to_city")
125
  try:
126
  location = geolocator.reverse(coords)
127
  return location.address
128
  except Exception as e:
129
  print(f"Error: {e}")
130
+ return "Unknown Location"
131
 
132
 
133
+ def nearestNeighbourSQL(location: tuple, table: str) -> tuple[str, str]:
 
134
  long = round(location[1], 3)
135
  lat = round(location[0], 3)
136
+
137
+ table = f"'hf://datasets/timeki/drias_db/{table.lower()}.parquet'"
138
+
139
+ results = duckdb.sql(
140
+ f"SELECT latitude, longitude FROM {table} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}"
141
+ ).fetchdf()
142
+
143
+ if len(results) == 0:
144
+ return "", ""
145
+ # cursor.execute(f"SELECT latitude, longitude FROM {table} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}")
146
+ return results['latitude'].iloc[0], results['longitude'].iloc[0]
147
+
148
+
149
+ async def detect_relevant_tables(user_question: str, plot: Plot, llm) -> list[str]:
150
+ """Identifies relevant tables for a plot based on user input.
151
+
152
+ This function uses an LLM to analyze the user's question and the plot
153
+ description to determine which tables in the DRIAS database would be
154
+ most relevant for generating the requested visualization.
155
+
156
+ Args:
157
+ user_question (str): The user's question about climate data
158
+ plot (Plot): The plot configuration object
159
+ llm: The language model instance to use for analysis
160
+
161
+ Returns:
162
+ list[str]: A list of table names that are relevant for the plot
163
+
164
+ Example:
165
+ >>> detect_relevant_tables(
166
+ ... "What will the temperature be like in Paris?",
167
+ ... indicator_evolution_at_location,
168
+ ... llm
169
+ ... )
170
+ ['mean_annual_temperature', 'mean_summer_temperature']
171
+ """
172
+ # Get all table names
173
+ table_names_list = DRIAS_TABLES
174
+
175
  prompt = (
176
+ f"You are helping to build a plot following this description : {plot['description']}."
177
+ f"You are given a list of tables and a user question."
178
+ f"Based on the description of the plot, which table are appropriate for that kind of plot."
179
+ f"Write the 3 most relevant tables to use. Answer only a python list of table name."
180
+ f"### List of tables : {table_names_list}"
181
+ f"### User question : {user_question}"
182
+ f"### List of table name : "
183
+ )
184
+
185
+ table_names = ast.literal_eval(
186
+ (await llm.ainvoke(prompt)).content.strip("```python\n").strip()
187
  )
 
188
  return table_names
189
 
190
+
191
  def replace_coordonates(coords, query, coords_tables):
192
  n = query.count(str(coords[0]))
193
 
194
  for i in range(n):
195
+ query = query.replace(str(coords[0]), str(coords_tables[i][0]), 1)
196
+ query = query.replace(str(coords[1]), str(coords_tables[i][1]), 1)
197
+ return query
198
+
199
+
200
+ async def detect_relevant_plots(user_question: str, llm):
201
+ plots_description = ""
202
+ for plot in PLOTS:
203
+ plots_description += "Name: " + plot["name"]
204
+ plots_description += " - Description: " + plot["description"] + "\n"
205
+
206
+ prompt = (
207
+ f"You are helping to answer a quesiton with insightful visualizations."
208
+ f"You are given an user question and a list of plots with their name and description."
209
+ f"Based on the descriptions of the plots, which plot is appropriate to answer to this question."
210
+ f"Write the most relevant tables to use. Answer only a python list of plot name."
211
+ f"### Descriptions of the plots : {plots_description}"
212
+ f"### User question : {user_question}"
213
+ f"### Name of the plot : "
214
+ )
215
+ # prompt = (
216
+ # f"You are helping to answer a question with insightful visualizations. "
217
+ # f"Given a list of plots with their name and description: "
218
+ # f"{plots_description} "
219
+ # f"The user question is: {user_question}. "
220
+ # f"Choose the most relevant plots to answer the question. "
221
+ # f"The answer must be a Python list with the names of the relevant plots, and nothing else. "
222
+ # f"Ensure the response is in the exact format: ['PlotName1', 'PlotName2']."
223
+ # )
224
+
225
+ plot_names = ast.literal_eval(
226
+ (await llm.ainvoke(prompt)).content.strip("```python\n").strip()
227
+ )
228
+ return plot_names
229
+
230
+
231
+ # Next Version
232
+ # class QueryOutput(TypedDict):
233
+ # """Generated SQL query."""
234
+
235
+ # query: Annotated[str, ..., "Syntactically valid SQL query."]
236
+
237
+
238
+ # class PlotlyCodeOutput(TypedDict):
239
+ # """Generated Plotly code"""
240
+
241
+ # code: Annotated[str, ..., "Synatically valid Plotly python code."]
242
+ # def write_sql_query(user_input: str, db: SQLDatabase, relevant_tables: list[str], llm):
243
+ # """Generate SQL query to fetch information."""
244
+ # prompt_params = {
245
+ # "dialect": db.dialect,
246
+ # "table_info": db.get_table_info(),
247
+ # "input": user_input,
248
+ # "relevant_tables": relevant_tables,
249
+ # "model": "ALADIN63_CNRM-CM5",
250
+ # }
251
+
252
+ # prompt = ChatPromptTemplate.from_template(query_prompt_template)
253
+ # structured_llm = llm.with_structured_output(QueryOutput)
254
+ # chain = prompt | structured_llm
255
+ # result = chain.invoke(prompt_params)
256
+
257
+ # return result["query"]
258
+
259
+
260
+ # def fetch_data_from_sql_query(db: str, sql_query: str):
261
+ # conn = sqlite3.connect(db)
262
+ # cursor = conn.cursor()
263
+ # cursor.execute(sql_query)
264
+ # column_names = [desc[0] for desc in cursor.description]
265
+ # values = cursor.fetchall()
266
+ # return {"column_names": column_names, "data": values}
267
+
268
+
269
+ # def generate_chart_code(user_input: str, sql_query: list[str], llm):
270
+ # """ "Generate plotly python code for the chart based on the sql query and the user question"""
271
+
272
+ # class PlotlyCodeOutput(TypedDict):
273
+ # """Generated Plotly code"""
274
+
275
+ # code: Annotated[str, ..., "Synatically valid Plotly python code."]
276
+
277
+ # prompt = ChatPromptTemplate.from_template(plot_prompt_template)
278
+ # structured_llm = llm.with_structured_output(PlotlyCodeOutput)
279
+ # chain = prompt | structured_llm
280
+ # result = chain.invoke({"input": user_input, "sql_query": sql_query})
281
+ # return result["code"]
climateqa/engine/talk_to_data/workflow.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from typing import Any, Callable, NotRequired, TypedDict
4
+ import pandas as pd
5
+
6
+ from plotly.graph_objects import Figure
7
+ from climateqa.engine.llm import get_llm
8
+ from climateqa.engine.talk_to_data.config import INDICATOR_COLUMNS_PER_TABLE
9
+ from climateqa.engine.talk_to_data.plot import PLOTS, Plot
10
+ from climateqa.engine.talk_to_data.sql_query import execute_sql_query
11
+ from climateqa.engine.talk_to_data.utils import (
12
+ detect_relevant_plots,
13
+ detect_year_with_openai,
14
+ loc2coords,
15
+ detect_location_with_openai,
16
+ nearestNeighbourSQL,
17
+ detect_relevant_tables,
18
+ )
19
+
20
+ ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))
21
+
22
+ class TableState(TypedDict):
23
+ """Represents the state of a table in the DRIAS workflow.
24
+
25
+ This class defines the structure for tracking the state of a table during the
26
+ data processing workflow, including its name, parameters, SQL query, and results.
27
+
28
+ Attributes:
29
+ table_name (str): The name of the table in the database
30
+ params (dict[str, Any]): Parameters used for querying the table
31
+ sql_query (str, optional): The SQL query used to fetch data
32
+ dataframe (pd.DataFrame | None, optional): The resulting data
33
+ figure (Callable[..., Figure], optional): Function to generate visualization
34
+ status (str): The current status of the table processing ('OK' or 'ERROR')
35
+ """
36
+ table_name: str
37
+ params: dict[str, Any]
38
+ sql_query: NotRequired[str]
39
+ dataframe: NotRequired[pd.DataFrame | None]
40
+ figure: NotRequired[Callable[..., Figure]]
41
+ status: str
42
+
43
+ class PlotState(TypedDict):
44
+ """Represents the state of a plot in the DRIAS workflow.
45
+
46
+ This class defines the structure for tracking the state of a plot during the
47
+ data processing workflow, including its name and associated tables.
48
+
49
+ Attributes:
50
+ plot_name (str): The name of the plot
51
+ tables (list[str]): List of tables used in the plot
52
+ table_states (dict[str, TableState]): States of the tables used in the plot
53
+ """
54
+ plot_name: str
55
+ tables: list[str]
56
+ table_states: dict[str, TableState]
57
+
58
+ class State(TypedDict):
59
+ user_input: str
60
+ plots: list[str]
61
+ plot_states: dict[str, PlotState]
62
+ error: NotRequired[str]
63
+
64
+ async def drias_workflow(user_input: str) -> State:
65
+ """Performs the complete workflow of Talk To Drias : from user input to sql queries, dataframes and figures generated
66
+
67
+ Args:
68
+ user_input (str): initial user input
69
+
70
+ Returns:
71
+ State: Final state with all the results
72
+ """
73
+ state: State = {
74
+ 'user_input': user_input,
75
+ 'plots': [],
76
+ 'plot_states': {}
77
+ }
78
+
79
+ llm = get_llm(provider="openai")
80
+
81
+ plots = await find_relevant_plots(state, llm)
82
+ state['plots'] = plots
83
+
84
+ if not state['plots']:
85
+ state['error'] = 'There is no plot to answer to the question'
86
+ return state
87
+
88
+ have_relevant_table = False
89
+ have_sql_query = False
90
+ have_dataframe = False
91
+ for plot_name in state['plots']:
92
+
93
+ plot = next((p for p in PLOTS if p['name'] == plot_name), None) # Find the associated plot object
94
+ if plot is None:
95
+ continue
96
+
97
+ plot_state: PlotState = {
98
+ 'plot_name': plot_name,
99
+ 'tables': [],
100
+ 'table_states': {}
101
+ }
102
+
103
+ plot_state['plot_name'] = plot_name
104
+
105
+ relevant_tables = await find_relevant_tables_per_plot(state, plot, llm)
106
+ if len(relevant_tables) > 0 :
107
+ have_relevant_table = True
108
+
109
+ plot_state['tables'] = relevant_tables
110
+
111
+ params = {}
112
+ for param_name in plot['params']:
113
+ param = await find_param(state, param_name, relevant_tables[0])
114
+ if param:
115
+ params.update(param)
116
+
117
+ for n, table in enumerate(plot_state['tables']):
118
+ if n > 2:
119
+ break
120
+
121
+ table_state: TableState = {
122
+ 'table_name': table,
123
+ 'params': params,
124
+ 'status': 'OK'
125
+ }
126
+
127
+ table_state["params"]['indicator_column'] = find_indicator_column(table)
128
+
129
+ sql_query = plot['sql_query'](table, table_state['params'])
130
+
131
+ if sql_query == "":
132
+ table_state['status'] = 'ERROR'
133
+ continue
134
+ else :
135
+ have_sql_query = True
136
+
137
+ table_state['sql_query'] = sql_query
138
+ df = await execute_sql_query(sql_query)
139
+
140
+ if len(df) > 0:
141
+ have_dataframe = True
142
+
143
+ figure = plot['plot_function'](table_state['params'])
144
+ table_state['dataframe'] = df
145
+ table_state['figure'] = figure
146
+ plot_state['table_states'][table] = table_state
147
+
148
+ state['plot_states'][plot_name] = plot_state
149
+
150
+ if not have_relevant_table:
151
+ state['error'] = "There is no relevant table in the our database to answer your question"
152
+ elif not have_sql_query:
153
+ state['error'] = "There is no relevant sql query on our database that can help to answer your question"
154
+ elif not have_dataframe:
155
+ state['error'] = "There is no data in our table that can answer to your question"
156
+
157
+ return state
158
+
159
+ async def find_relevant_plots(state: State, llm) -> list[str]:
160
+ print("---- Find relevant plots ----")
161
+ relevant_plots = await detect_relevant_plots(state['user_input'], llm)
162
+ return relevant_plots
163
+
164
+ async def find_relevant_tables_per_plot(state: State, plot: Plot, llm) -> list[str]:
165
+ print(f"---- Find relevant tables for {plot['name']} ----")
166
+ relevant_tables = await detect_relevant_tables(state['user_input'], plot, llm)
167
+ return relevant_tables
168
+
169
+ async def find_param(state: State, param_name:str, table: str) -> dict[str, Any] | None:
170
+ """Perform the good method to retrieve the desired parameter
171
+
172
+ Args:
173
+ state (State): state of the workflow
174
+ param_name (str): name of the desired parameter
175
+ table (str): name of the table
176
+
177
+ Returns:
178
+ dict[str, Any] | None:
179
+ """
180
+ if param_name == 'location':
181
+ location = await find_location(state['user_input'], table)
182
+ return location
183
+ if param_name == 'year':
184
+ year = await find_year(state['user_input'])
185
+ return {'year': year}
186
+ return None
187
+
188
+ class Location(TypedDict):
189
+ location: str
190
+ latitude: NotRequired[str]
191
+ longitude: NotRequired[str]
192
+
193
+ async def find_location(user_input: str, table: str) -> Location:
194
+ print(f"---- Find location in table {table} ----")
195
+ location = await detect_location_with_openai(user_input)
196
+ output: Location = {'location' : location}
197
+ if location:
198
+ coords = loc2coords(location)
199
+ neighbour = nearestNeighbourSQL(coords, table)
200
+ output.update({
201
+ "latitude": neighbour[0],
202
+ "longitude": neighbour[1],
203
+ })
204
+ return output
205
+
206
+ async def find_year(user_input: str) -> str:
207
+ """Extracts year information from user input using LLM.
208
+
209
+ This function uses an LLM to identify and extract year information from the
210
+ user's query, which is used to filter data in subsequent queries.
211
+
212
+ Args:
213
+ user_input (str): The user's query text
214
+
215
+ Returns:
216
+ str: The extracted year, or empty string if no year found
217
+ """
218
+ print(f"---- Find year ---")
219
+ year = await detect_year_with_openai(user_input)
220
+ return year
221
+
222
+ def find_indicator_column(table: str) -> str:
223
+ """Retrieves the name of the indicator column within a table.
224
+
225
+ This function maps table names to their corresponding indicator columns
226
+ using the predefined mapping in INDICATOR_COLUMNS_PER_TABLE.
227
+
228
+ Args:
229
+ table (str): Name of the table in the database
230
+
231
+ Returns:
232
+ str: Name of the indicator column for the specified table
233
+
234
+ Raises:
235
+ KeyError: If the table name is not found in the mapping
236
+ """
237
+ print(f"---- Find indicator column in table {table} ----")
238
+ return INDICATOR_COLUMNS_PER_TABLE[table]
239
+
240
+
241
+ # def make_write_query_node():
242
+
243
+ # def write_query(state):
244
+ # print("---- Write query ----")
245
+ # for table in state["tables"]:
246
+ # sql_query = QUERIES[state[table]['query_type']](
247
+ # table=table,
248
+ # indicator_column=state[table]["columns"],
249
+ # longitude=state[table]["longitude"],
250
+ # latitude=state[table]["latitude"],
251
+ # )
252
+ # state[table].update({"sql_query": sql_query})
253
+
254
+ # return state
255
+
256
+ # return write_query
257
+
258
+ # def make_fetch_data_node(db_path):
259
+
260
+ # def fetch_data(state):
261
+ # print("---- Fetch data ----")
262
+ # for table in state["tables"]:
263
+ # results = execute_sql_query(db_path, state[table]['sql_query'])
264
+ # state[table].update(results)
265
+
266
+ # return state
267
+
268
+ # return fetch_data
269
+
270
+
271
+
272
+ ## V2
273
+
274
+
275
+ # def make_fetch_data_node(db_path: str, llm):
276
+ # def fetch_data(state):
277
+ # print("---- Fetch data ----")
278
+ # db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
279
+ # output = {}
280
+ # sql_query = write_sql_query(state["query"], db, state["tables"], llm)
281
+ # # TO DO : Add query checker
282
+ # print(f"SQL query : {sql_query}")
283
+ # output["sql_query"] = sql_query
284
+ # output.update(fetch_data_from_sql_query(db_path, sql_query))
285
+ # return output
286
+
287
+ # return fetch_data
front/tabs/__init__.py CHANGED
@@ -3,4 +3,7 @@ from .tab_examples import create_examples_tab
3
  from .tab_papers import create_papers_tab
4
  from .tab_figures import create_figures_tab
5
  from .chat_interface import create_chat_interface
6
- from .tab_about import create_about_tab
 
 
 
 
3
  from .tab_papers import create_papers_tab
4
  from .tab_figures import create_figures_tab
5
  from .chat_interface import create_chat_interface
6
+ from .tab_about import create_about_tab
7
+ from .main_tab import MainTabPanel
8
+ from .tab_config import ConfigPanel
9
+ from .main_tab import cqa_tab
front/tabs/chat_interface.py CHANGED
@@ -21,21 +21,21 @@ What do you want to learn ?
21
  """
22
 
23
  init_prompt_poc = """
24
- Hello, I am ClimateQ&A, a conversational assistant designed to help you understand climate change and biodiversity loss. I will answer your questions by **sifting through the IPCC and IPBES scientific reports, PCAET of Paris, the Plan Biodiversité 2018-2024, and Acclimaterra reports from la Région Nouvelle-Aquitaine **.
25
 
26
- How to use
27
- - **Language**: You can ask me your questions in any language.
28
- - **Audience**: You can specify your audience (children, general public, experts) to get a more adapted answer.
29
- - **Sources**: You can choose to search in the IPCC or IPBES reports, and POC sources for local documents (PCAET, Plan Biodiversité, Acclimaterra).
30
- - **Relevant content sources**: You can choose to search for figures, papers, or graphs that can be relevant for your question.
31
 
32
  ⚠️ Limitations
33
- *Please note that the AI is not perfect and may sometimes give irrelevant answers. If you are not satisfied with the answer, please ask a more specific question or report your feedback to help us improve the system.*
34
 
35
- 🛈 Information
36
- Please note that we log your questions for meta-analysis purposes, so avoid sharing any sensitive or personal information.
37
 
38
- What do you want to learn ?
39
  """
40
 
41
 
@@ -54,7 +54,10 @@ def create_chat_interface(tab):
54
  max_height="80vh",
55
  height="100vh"
56
  )
57
-
 
 
 
58
  with gr.Row(elem_id="input-message"):
59
 
60
  textbox = gr.Textbox(
@@ -68,7 +71,7 @@ def create_chat_interface(tab):
68
 
69
  config_button = gr.Button("", elem_id="config-button")
70
 
71
- return chatbot, textbox, config_button
72
 
73
 
74
 
 
21
  """
22
 
23
  init_prompt_poc = """
24
+ Bonjour, je suis ClimateQ&A, un assistant conversationnel conçu pour vous aider à comprendre le changement climatique et la perte de biodiversité. Je réponds à vos questions en **parcourant les rapports scientifiques du GIEC et de l'IPBES, le PCAET de Paris, le Plan Biodiversité 2018-2024, et les rapports Acclimaterra de la Région Nouvelle-Aquitaine**.
25
 
26
+ Mode d'emploi
27
+ - **Language** : Vous pouvez me poser vos questions dans n'importe quelle langue.
28
+ - **Audience** : Vous pouvez préciser votre public (enfants, grand public, experts) pour obtenir une réponse plus adaptée.
29
+ - **Sources** : Vous pouvez choisir de chercher dans les rapports du GIEC ou de l'IPBES, et dans les sources POC pour les documents locaux (PCAET, Plan Biodiversité, Acclimaterra).
30
+ - **Relevant content sources** : Vous pouvez choisir de rechercher des images, des papiers scientifiques ou des graphiques qui peuvent être pertinents pour votre question.
31
 
32
  ⚠️ Limitations
33
+ *Veuillez noter que l'IA n'est pas parfaite et peut parfois donner des réponses non pertinentes. Si vous n'êtes pas satisfait de la réponse, veuillez poser une question plus précise ou nous faire part de vos commentaires pour nous aider à améliorer le système.*
34
 
35
+ 🛈 Informations
36
+ Veuillez noter que nous enregistrons vos questions à des fins de méta-analyse, évitez donc de partager toute information sensible ou personnelle.
37
 
38
+ Que voulez-vous apprendre ?
39
  """
40
 
41
 
 
54
  max_height="80vh",
55
  height="100vh"
56
  )
57
+ with gr.Accordion("Click here for follow up questions examples", elem_id="follow-up-examples",open = False):
58
+ follow_up_examples_hidden = gr.Textbox(visible=False, elem_id="follow-up-hidden")
59
+ follow_up_examples = gr.Examples(examples=["What evidence do we have of climate change ?"], label="", inputs= [follow_up_examples_hidden], elem_id="follow-up-button", run_on_click=False)
60
+
61
  with gr.Row(elem_id="input-message"):
62
 
63
  textbox = gr.Textbox(
 
71
 
72
  config_button = gr.Button("", elem_id="config-button")
73
 
74
+ return chatbot, textbox, config_button, follow_up_examples, follow_up_examples_hidden
75
 
76
 
77
 
front/tabs/main_tab.py CHANGED
@@ -1,8 +1,37 @@
1
  import gradio as gr
 
 
2
  from .chat_interface import create_chat_interface
3
  from .tab_examples import create_examples_tab
4
  from .tab_papers import create_papers_tab
5
  from .tab_figures import create_figures_tab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  def cqa_tab(tab_name):
8
  # State variables
@@ -11,14 +40,14 @@ def cqa_tab(tab_name):
11
  with gr.Row(elem_id="chatbot-row"):
12
  # Left column - Chat interface
13
  with gr.Column(scale=2):
14
- chatbot, textbox, config_button = create_chat_interface(tab_name)
15
 
16
  # Right column - Content panels
17
  with gr.Column(scale=2, variant="panel", elem_id="right-panel"):
18
  with gr.Tabs(elem_id="right_panel_tab") as tabs:
19
  # Examples tab
20
  with gr.TabItem("Examples", elem_id="tab-examples", id=0):
21
- examples_hidden, dropdown_samples, samples = create_examples_tab()
22
 
23
  # Sources tab
24
  with gr.Tab("Sources", elem_id="tab-sources", id=1) as tab_sources:
@@ -34,7 +63,7 @@ def cqa_tab(tab_name):
34
 
35
  # Papers subtab
36
  with gr.Tab("Papers", elem_id="tab-citations", id=4) as tab_papers:
37
- papers_summary, papers_html, citations_network, papers_modal = create_papers_tab()
38
 
39
  # Graphs subtab
40
  with gr.Tab("Graphs", elem_id="tab-graphs", id=5) as tab_graphs:
@@ -42,27 +71,30 @@ def cqa_tab(tab_name):
42
  "<h2>There are no graphs to be displayed at the moment. Try asking another question.</h2>",
43
  elem_id="graphs-container"
44
  )
45
- return {
46
- "chatbot": chatbot,
47
- "textbox": textbox,
48
- "tabs": tabs,
49
- "sources_raw": sources_raw,
50
- "new_figures": new_figures,
51
- "current_graphs": current_graphs,
52
- "examples_hidden": examples_hidden,
53
- "dropdown_samples": dropdown_samples,
54
- "samples": samples,
55
- "sources_textbox": sources_textbox,
56
- "figures_cards": figures_cards,
57
- "gallery_component": gallery_component,
58
- "config_button": config_button,
59
- "papers_html": papers_html,
60
- "citations_network": citations_network,
61
- "papers_summary": papers_summary,
62
- "tab_recommended_content": tab_recommended_content,
63
- "tab_sources": tab_sources,
64
- "tab_figures": tab_figures,
65
- "tab_graphs": tab_graphs,
66
- "tab_papers": tab_papers,
67
- "graph_container": graphs_container
68
- }
 
 
 
 
1
  import gradio as gr
2
+ from gradio.helpers import Examples
3
+ from typing import TypedDict
4
  from .chat_interface import create_chat_interface
5
  from .tab_examples import create_examples_tab
6
  from .tab_papers import create_papers_tab
7
  from .tab_figures import create_figures_tab
8
+ from dataclasses import dataclass
9
+
10
+ @dataclass
11
+ class MainTabPanel:
12
+ chatbot: gr.Chatbot
13
+ textbox: gr.Textbox
14
+ tabs: gr.Tabs
15
+ sources_raw: gr.State
16
+ new_figures: gr.State
17
+ current_graphs: gr.State
18
+ examples_hidden: gr.State
19
+ sources_textbox: gr.HTML
20
+ figures_cards: gr.HTML
21
+ gallery_component: gr.Gallery
22
+ config_button: gr.Button
23
+ papers_direct_search: gr.TextArea
24
+ papers_html: gr.HTML
25
+ citations_network: gr.Plot
26
+ papers_summary: gr.Textbox
27
+ tab_recommended_content: gr.Tab
28
+ tab_sources: gr.Tab
29
+ tab_figures: gr.Tab
30
+ tab_graphs: gr.Tab
31
+ tab_papers: gr.Tab
32
+ graph_container: gr.HTML
33
+ follow_up_examples : Examples
34
+ follow_up_examples_hidden : gr.Textbox
35
 
36
  def cqa_tab(tab_name):
37
  # State variables
 
40
  with gr.Row(elem_id="chatbot-row"):
41
  # Left column - Chat interface
42
  with gr.Column(scale=2):
43
+ chatbot, textbox, config_button, follow_up_examples, follow_up_examples_hidden = create_chat_interface(tab_name)
44
 
45
  # Right column - Content panels
46
  with gr.Column(scale=2, variant="panel", elem_id="right-panel"):
47
  with gr.Tabs(elem_id="right_panel_tab") as tabs:
48
  # Examples tab
49
  with gr.TabItem("Examples", elem_id="tab-examples", id=0):
50
+ examples_hidden = create_examples_tab(tab_name)
51
 
52
  # Sources tab
53
  with gr.Tab("Sources", elem_id="tab-sources", id=1) as tab_sources:
 
63
 
64
  # Papers subtab
65
  with gr.Tab("Papers", elem_id="tab-citations", id=4) as tab_papers:
66
+ papers_direct_search, papers_summary, papers_html, citations_network, papers_modal = create_papers_tab()
67
 
68
  # Graphs subtab
69
  with gr.Tab("Graphs", elem_id="tab-graphs", id=5) as tab_graphs:
 
71
  "<h2>There are no graphs to be displayed at the moment. Try asking another question.</h2>",
72
  elem_id="graphs-container"
73
  )
74
+
75
+
76
+ return MainTabPanel(
77
+ chatbot=chatbot,
78
+ textbox=textbox,
79
+ tabs=tabs,
80
+ sources_raw=sources_raw,
81
+ new_figures=new_figures,
82
+ current_graphs=current_graphs,
83
+ examples_hidden=examples_hidden,
84
+ sources_textbox=sources_textbox,
85
+ figures_cards=figures_cards,
86
+ gallery_component=gallery_component,
87
+ config_button=config_button,
88
+ papers_direct_search=papers_direct_search,
89
+ papers_html=papers_html,
90
+ citations_network=citations_network,
91
+ papers_summary=papers_summary,
92
+ tab_recommended_content=tab_recommended_content,
93
+ tab_sources=tab_sources,
94
+ tab_figures=tab_figures,
95
+ tab_graphs=tab_graphs,
96
+ tab_papers=tab_papers,
97
+ graph_container=graphs_container,
98
+ follow_up_examples= follow_up_examples,
99
+ follow_up_examples_hidden = follow_up_examples_hidden
100
+ )
front/tabs/tab_config.py CHANGED
@@ -2,8 +2,10 @@ import gradio as gr
2
  from gradio_modal import Modal
3
  from climateqa.constants import POSSIBLE_REPORTS
4
  from typing import TypedDict
 
5
 
6
- class ConfigPanel(TypedDict):
 
7
  config_open: gr.State
8
  config_modal: Modal
9
  dropdown_sources: gr.CheckboxGroup
@@ -14,6 +16,7 @@ class ConfigPanel(TypedDict):
14
  after: gr.Slider
15
  output_query: gr.Textbox
16
  output_language: gr.Textbox
 
17
 
18
 
19
  def create_config_modal():
@@ -37,9 +40,9 @@ def create_config_modal():
37
  )
38
 
39
  dropdown_external_sources = gr.CheckboxGroup(
40
- choices=["Figures (IPCC/IPBES)", "Papers (OpenAlex)", "Graphs (OurWorldInData)","POC region"],
41
  label="Select database to search for relevant content",
42
- value=["Figures (IPCC/IPBES)","POC region"],
43
  interactive=True
44
  )
45
 
@@ -95,29 +98,16 @@ def create_config_modal():
95
 
96
  close_config_modal_button = gr.Button("Validate and Close", elem_id="close-config-modal")
97
 
98
-
99
- # return ConfigPanel(
100
- # config_open=config_open,
101
- # config_modal=config_modal,
102
- # dropdown_sources=dropdown_sources,
103
- # dropdown_reports=dropdown_reports,
104
- # dropdown_external_sources=dropdown_external_sources,
105
- # search_only=search_only,
106
- # dropdown_audience=dropdown_audience,
107
- # after=after,
108
- # output_query=output_query,
109
- # output_language=output_language
110
- # )
111
- return {
112
- "config_open" : config_open,
113
- "config_modal": config_modal,
114
- "dropdown_sources": dropdown_sources,
115
- "dropdown_reports": dropdown_reports,
116
- "dropdown_external_sources": dropdown_external_sources,
117
- "search_only": search_only,
118
- "dropdown_audience": dropdown_audience,
119
- "after": after,
120
- "output_query": output_query,
121
- "output_language": output_language,
122
- "close_config_modal_button": close_config_modal_button
123
- }
 
2
  from gradio_modal import Modal
3
  from climateqa.constants import POSSIBLE_REPORTS
4
  from typing import TypedDict
5
+ from dataclasses import dataclass
6
 
7
+ @dataclass
8
+ class ConfigPanel:
9
  config_open: gr.State
10
  config_modal: Modal
11
  dropdown_sources: gr.CheckboxGroup
 
16
  after: gr.Slider
17
  output_query: gr.Textbox
18
  output_language: gr.Textbox
19
+ close_config_modal_button: gr.Button
20
 
21
 
22
  def create_config_modal():
 
40
  )
41
 
42
  dropdown_external_sources = gr.CheckboxGroup(
43
+ choices=["Figures (IPCC/IPBES)", "Papers (OpenAlex)", "Graphs (OurWorldInData)"],
44
  label="Select database to search for relevant content",
45
+ value=["Figures (IPCC/IPBES)"],
46
  interactive=True
47
  )
48
 
 
98
 
99
  close_config_modal_button = gr.Button("Validate and Close", elem_id="close-config-modal")
100
 
101
+ return ConfigPanel(
102
+ config_open=config_open,
103
+ config_modal=config_modal,
104
+ dropdown_sources=dropdown_sources,
105
+ dropdown_reports=dropdown_reports,
106
+ dropdown_external_sources=dropdown_external_sources,
107
+ search_only=search_only,
108
+ dropdown_audience=dropdown_audience,
109
+ after=after,
110
+ output_query=output_query,
111
+ output_language=output_language,
112
+ close_config_modal_button=close_config_modal_button
113
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
front/tabs/tab_drias.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from typing import TypedDict, List, Optional
3
+ import os
4
+ import pandas as pd
5
+
6
+ from climateqa.engine.talk_to_data.main import ask_drias
7
+ from climateqa.engine.talk_to_data.config import DRIAS_MODELS, DRIAS_UI_TEXT
8
+ from climateqa.chat import log_drias_interaction_to_azure
9
+
10
+
11
+ class DriasUIElements(TypedDict):
12
+ tab: gr.Tab
13
+ details_accordion: gr.Accordion
14
+ examples_hidden: gr.Textbox
15
+ examples: gr.Examples
16
+ drias_direct_question: gr.Textbox
17
+ result_text: gr.Textbox
18
+ table_names_display: gr.DataFrame
19
+ query_accordion: gr.Accordion
20
+ drias_sql_query: gr.Textbox
21
+ chart_accordion: gr.Accordion
22
+ model_selection: gr.Dropdown
23
+ drias_display: gr.Plot
24
+ table_accordion: gr.Accordion
25
+ drias_table: gr.DataFrame
26
+ pagination_display: gr.Markdown
27
+ prev_button: gr.Button
28
+ next_button: gr.Button
29
+
30
+
31
+ async def ask_drias_query(query: str, index_state: int):
32
+ result = await ask_drias(query, index_state)
33
+ return result
34
+
35
+
36
+ def show_results(sql_queries_state, dataframes_state, plots_state):
37
+ if not sql_queries_state or not dataframes_state or not plots_state:
38
+ # If all results are empty, show "No result"
39
+ return (
40
+ gr.update(visible=True),
41
+ gr.update(visible=False),
42
+ gr.update(visible=False),
43
+ gr.update(visible=False),
44
+ gr.update(visible=False),
45
+ gr.update(visible=False),
46
+ gr.update(visible=False),
47
+ gr.update(visible=False),
48
+ )
49
+ else:
50
+ # Show the appropriate components with their data
51
+ return (
52
+ gr.update(visible=False),
53
+ gr.update(visible=True),
54
+ gr.update(visible=True),
55
+ gr.update(visible=True),
56
+ gr.update(visible=True),
57
+ gr.update(visible=True),
58
+ gr.update(visible=True),
59
+ gr.update(visible=True),
60
+ )
61
+
62
+
63
+ def filter_by_model(dataframes, figures, index_state, model_selection):
64
+ df = dataframes[index_state]
65
+ if df.empty:
66
+ return df, None
67
+ if "model" not in df.columns:
68
+ return df, figures[index_state](df)
69
+ if model_selection != "ALL":
70
+ df = df[df["model"] == model_selection]
71
+ if df.empty:
72
+ return df, None
73
+ figure = figures[index_state](df)
74
+ return df, figure
75
+
76
+
77
+ def update_pagination(index, sql_queries):
78
+ pagination = f"{index + 1}/{len(sql_queries)}" if sql_queries else ""
79
+ return pagination
80
+
81
+
82
+ def show_previous(index, sql_queries, dataframes, plots):
83
+ if index > 0:
84
+ index -= 1
85
+ return (
86
+ sql_queries[index],
87
+ dataframes[index],
88
+ plots[index](dataframes[index]),
89
+ index,
90
+ )
91
+
92
+
93
+ def show_next(index, sql_queries, dataframes, plots):
94
+ if index < len(sql_queries) - 1:
95
+ index += 1
96
+ return (
97
+ sql_queries[index],
98
+ dataframes[index],
99
+ plots[index](dataframes[index]),
100
+ index,
101
+ )
102
+
103
+
104
+ def display_table_names(table_names):
105
+ return [table_names]
106
+
107
+
108
+ def on_table_click(evt: gr.SelectData, table_names, sql_queries, dataframes, plots):
109
+ index = evt.index[1]
110
+ figure = plots[index](dataframes[index])
111
+ return (
112
+ sql_queries[index],
113
+ dataframes[index],
114
+ figure,
115
+ index,
116
+ )
117
+
118
+
119
+ def create_drias_ui() -> DriasUIElements:
120
+ """Create and return all UI elements for the DRIAS tab."""
121
+ with gr.Tab("Beta - Talk to DRIAS", elem_id="tab-vanna", id=6) as tab:
122
+ with gr.Accordion(label="Details") as details_accordion:
123
+ gr.Markdown(DRIAS_UI_TEXT)
124
+
125
+ # Add examples for common questions
126
+ examples_hidden = gr.Textbox(visible=False, elem_id="drias-examples-hidden")
127
+ examples = gr.Examples(
128
+ examples=[
129
+ ["What will the temperature be like in Paris?"],
130
+ ["What will be the total rainfall in France in 2030?"],
131
+ ["How frequent will extreme events be in Lyon?"],
132
+ ["Comment va évoluer la température en France entre 2030 et 2050 ?"]
133
+ ],
134
+ label="Example Questions",
135
+ inputs=[examples_hidden],
136
+ outputs=[examples_hidden],
137
+ )
138
+
139
+ with gr.Row():
140
+ drias_direct_question = gr.Textbox(
141
+ label="Direct Question",
142
+ placeholder="You can write direct question here",
143
+ elem_id="direct-question",
144
+ interactive=True,
145
+ )
146
+
147
+ result_text = gr.Textbox(
148
+ label="", elem_id="no-result-label", interactive=False, visible=True
149
+ )
150
+
151
+ table_names_display = gr.DataFrame(
152
+ [], label="List of relevant indicators", headers=None, interactive=False, elem_id="table-names", visible=False
153
+ )
154
+
155
+ with gr.Accordion(label="SQL Query Used", visible=False) as query_accordion:
156
+ drias_sql_query = gr.Textbox(
157
+ label="", elem_id="sql-query", interactive=False
158
+ )
159
+
160
+ with gr.Accordion(label="Chart", visible=False) as chart_accordion:
161
+ model_selection = gr.Dropdown(
162
+ label="Model", choices=DRIAS_MODELS, value="ALL", interactive=True
163
+ )
164
+ drias_display = gr.Plot(elem_id="vanna-plot")
165
+
166
+ with gr.Accordion(
167
+ label="Data used", open=False, visible=False
168
+ ) as table_accordion:
169
+ drias_table = gr.DataFrame([], elem_id="vanna-table")
170
+
171
+ pagination_display = gr.Markdown(
172
+ value="", visible=False, elem_id="pagination-display"
173
+ )
174
+
175
+ with gr.Row():
176
+ prev_button = gr.Button("Previous", visible=False)
177
+ next_button = gr.Button("Next", visible=False)
178
+
179
+ return DriasUIElements(
180
+ tab=tab,
181
+ details_accordion=details_accordion,
182
+ examples_hidden=examples_hidden,
183
+ examples=examples,
184
+ drias_direct_question=drias_direct_question,
185
+ result_text=result_text,
186
+ table_names_display=table_names_display,
187
+ query_accordion=query_accordion,
188
+ drias_sql_query=drias_sql_query,
189
+ chart_accordion=chart_accordion,
190
+ model_selection=model_selection,
191
+ drias_display=drias_display,
192
+ table_accordion=table_accordion,
193
+ drias_table=drias_table,
194
+ pagination_display=pagination_display,
195
+ prev_button=prev_button,
196
+ next_button=next_button
197
+ )
198
+
199
+ def log_drias_to_azure(query: str, sql_query: str, data, share_client, user_id):
200
+ """Log Drias interaction to Azure storage."""
201
+ print("log_drias_to_azure")
202
+ if share_client is not None and user_id is not None:
203
+ log_drias_interaction_to_azure(
204
+ query=query,
205
+ sql_query=sql_query,
206
+ data=data,
207
+ share_client=share_client,
208
+ user_id=user_id
209
+ )
210
+ else:
211
+ print("share_client or user_id is None")
212
+
213
+ def setup_drias_events(ui_elements: DriasUIElements, share_client=None, user_id=None) -> None:
214
+ """Set up all event handlers for the DRIAS tab."""
215
+ # Create state variables
216
+ sql_queries_state = gr.State([])
217
+ dataframes_state = gr.State([])
218
+ plots_state = gr.State([])
219
+ index_state = gr.State(0)
220
+ table_names_list = gr.State([])
221
+
222
+ def log_drias_interaction(query: str, sql_query: str, data: pd.DataFrame):
223
+ log_drias_to_azure(query, sql_query, data, share_client, user_id)
224
+
225
+
226
+ # Handle example selection
227
+ ui_elements["examples_hidden"].change(
228
+ lambda x: (gr.Accordion(open=False), gr.Textbox(value=x)),
229
+ inputs=[ui_elements["examples_hidden"]],
230
+ outputs=[ui_elements["details_accordion"], ui_elements["drias_direct_question"]]
231
+ ).then(
232
+ ask_drias_query,
233
+ inputs=[ui_elements["examples_hidden"], index_state],
234
+ outputs=[
235
+ ui_elements["drias_sql_query"],
236
+ ui_elements["drias_table"],
237
+ ui_elements["drias_display"],
238
+ sql_queries_state,
239
+ dataframes_state,
240
+ plots_state,
241
+ index_state,
242
+ table_names_list,
243
+ ui_elements["result_text"],
244
+ ],
245
+ ).then(
246
+ log_drias_interaction,
247
+ inputs=[ui_elements["examples_hidden"], ui_elements["drias_sql_query"], ui_elements["drias_table"]],
248
+ outputs=[],
249
+ ).then(
250
+ show_results,
251
+ inputs=[sql_queries_state, dataframes_state, plots_state],
252
+ outputs=[
253
+ ui_elements["result_text"],
254
+ ui_elements["query_accordion"],
255
+ ui_elements["table_accordion"],
256
+ ui_elements["chart_accordion"],
257
+ ui_elements["prev_button"],
258
+ ui_elements["next_button"],
259
+ ui_elements["pagination_display"],
260
+ ui_elements["table_names_display"],
261
+ ],
262
+ ).then(
263
+ update_pagination,
264
+ inputs=[index_state, sql_queries_state],
265
+ outputs=[ui_elements["pagination_display"]],
266
+ ).then(
267
+ display_table_names,
268
+ inputs=[table_names_list],
269
+ outputs=[ui_elements["table_names_display"]],
270
+ )
271
+
272
+ # Handle direct question submission
273
+ ui_elements["drias_direct_question"].submit(
274
+ lambda: gr.Accordion(open=False),
275
+ inputs=None,
276
+ outputs=[ui_elements["details_accordion"]]
277
+ ).then(
278
+ ask_drias_query,
279
+ inputs=[ui_elements["drias_direct_question"], index_state],
280
+ outputs=[
281
+ ui_elements["drias_sql_query"],
282
+ ui_elements["drias_table"],
283
+ ui_elements["drias_display"],
284
+ sql_queries_state,
285
+ dataframes_state,
286
+ plots_state,
287
+ index_state,
288
+ table_names_list,
289
+ ui_elements["result_text"],
290
+ ],
291
+ ).then(
292
+ log_drias_interaction,
293
+ inputs=[ui_elements["drias_direct_question"], ui_elements["drias_sql_query"], ui_elements["drias_table"]],
294
+ outputs=[],
295
+ ).then(
296
+ show_results,
297
+ inputs=[sql_queries_state, dataframes_state, plots_state],
298
+ outputs=[
299
+ ui_elements["result_text"],
300
+ ui_elements["query_accordion"],
301
+ ui_elements["table_accordion"],
302
+ ui_elements["chart_accordion"],
303
+ ui_elements["prev_button"],
304
+ ui_elements["next_button"],
305
+ ui_elements["pagination_display"],
306
+ ui_elements["table_names_display"],
307
+ ],
308
+ ).then(
309
+ update_pagination,
310
+ inputs=[index_state, sql_queries_state],
311
+ outputs=[ui_elements["pagination_display"]],
312
+ ).then(
313
+ display_table_names,
314
+ inputs=[table_names_list],
315
+ outputs=[ui_elements["table_names_display"]],
316
+ )
317
+
318
+ # Handle model selection change
319
+ ui_elements["model_selection"].change(
320
+ filter_by_model,
321
+ inputs=[dataframes_state, plots_state, index_state, ui_elements["model_selection"]],
322
+ outputs=[ui_elements["drias_table"], ui_elements["drias_display"]],
323
+ )
324
+
325
+ # Handle pagination buttons
326
+ ui_elements["prev_button"].click(
327
+ show_previous,
328
+ inputs=[index_state, sql_queries_state, dataframes_state, plots_state],
329
+ outputs=[ui_elements["drias_sql_query"], ui_elements["drias_table"], ui_elements["drias_display"], index_state],
330
+ ).then(
331
+ update_pagination,
332
+ inputs=[index_state, sql_queries_state],
333
+ outputs=[ui_elements["pagination_display"]],
334
+ )
335
+
336
+ ui_elements["next_button"].click(
337
+ show_next,
338
+ inputs=[index_state, sql_queries_state, dataframes_state, plots_state],
339
+ outputs=[ui_elements["drias_sql_query"], ui_elements["drias_table"], ui_elements["drias_display"], index_state],
340
+ ).then(
341
+ update_pagination,
342
+ inputs=[index_state, sql_queries_state],
343
+ outputs=[ui_elements["pagination_display"]],
344
+ )
345
+
346
+ # Handle table selection
347
+ ui_elements["table_names_display"].select(
348
+ fn=on_table_click,
349
+ inputs=[table_names_list, sql_queries_state, dataframes_state, plots_state],
350
+ outputs=[ui_elements["drias_sql_query"], ui_elements["drias_table"], ui_elements["drias_display"], index_state],
351
+ ).then(
352
+ update_pagination,
353
+ inputs=[index_state, sql_queries_state],
354
+ outputs=[ui_elements["pagination_display"]],
355
+ )
356
+
357
+ def create_drias_tab(share_client=None, user_id=None):
358
+ """Create the DRIAS tab with all its components and event handlers."""
359
+ ui_elements = create_drias_ui()
360
+ setup_drias_events(ui_elements, share_client=share_client, user_id=user_id)
361
+
362
+
style.css CHANGED
@@ -29,8 +29,6 @@ main.flex.flex-1.flex-col {
29
  }
30
 
31
 
32
- }
33
-
34
  .tab-nav {
35
  border: none !important;
36
  }
@@ -111,10 +109,18 @@ main.flex.flex-1.flex-col {
111
  border: none;
112
  }
113
 
114
- #input-textbox > label > textarea {
115
  border-radius: 40px;
116
  padding-left: 30px;
117
  resize: none;
 
 
 
 
 
 
 
 
118
  }
119
 
120
  #input-message > div {
@@ -474,6 +480,33 @@ a {
474
  text-decoration: none !important;
475
  }
476
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
477
  /* Media Queries */
478
  /* Desktop Media Query */
479
  @media screen and (min-width: 1024px) {
@@ -487,7 +520,6 @@ a {
487
  height: calc(100vh - 190px) !important;
488
  overflow-y: scroll !important;
489
  }
490
- div#tab-vanna,
491
  div#sources-figures,
492
  div#graphs-container,
493
  div#tab-citations {
@@ -496,6 +528,15 @@ a {
496
  overflow-y: scroll !important;
497
  }
498
 
 
 
 
 
 
 
 
 
 
499
  div#chatbot-row {
500
  max-height: calc(100vh - 90px) !important;
501
  }
@@ -514,7 +555,11 @@ a {
514
  /* Mobile Media Query */
515
  @media screen and (max-width: 767px) {
516
  div#chatbot {
517
- height: 500px !important;
 
 
 
 
518
  }
519
 
520
  #submit-button {
@@ -607,14 +652,61 @@ a {
607
  }
608
 
609
  #vanna-display {
610
- max-height: 300px;
611
  /* overflow-y: scroll; */
612
  }
613
  #sql-query{
614
- max-height: 100px;
615
  overflow-y:scroll;
616
  }
617
- #vanna-details{
618
- max-height: 500px;
619
- overflow-y:scroll;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
620
  }
 
29
  }
30
 
31
 
 
 
32
  .tab-nav {
33
  border: none !important;
34
  }
 
109
  border: none;
110
  }
111
 
112
+ #input-textbox > label > div > textarea {
113
  border-radius: 40px;
114
  padding-left: 30px;
115
  resize: none;
116
+ background-color: #d7e2ed; /* Light blue background */
117
+ border: 2px solid #4b8ec3; /* Blue border */
118
+ font-size: 16px; /* Increase font size */
119
+ color: #333; /* Text color */
120
+ box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Add shadow */
121
+ ::placeholder {
122
+ color: #4b4747; /* Darker placeholder color */
123
+ }
124
  }
125
 
126
  #input-message > div {
 
480
  text-decoration: none !important;
481
  }
482
 
483
+ /* Follow-up Examples Styles */
484
+ #follow-up-examples {
485
+ max-height: 20vh;
486
+ overflow-y: auto;
487
+ gap: 8px;
488
+ display: flex;
489
+ flex-direction: column;
490
+ overflow-y: hidden;
491
+ background: rgb(229, 235, 237);
492
+ }
493
+
494
+ #follow-up-button {
495
+ overflow-y: visible;
496
+ display: block;
497
+ padding: 8px 12px;
498
+ margin: 4px 0;
499
+ border-radius: 8px;
500
+ background-color: #f0f8ff;
501
+ transition: background-color 0.2s;
502
+ background: rgb(240, 240, 236);
503
+
504
+ }
505
+
506
+ #follow-up-button:hover {
507
+ background-color: #e0f0ff;
508
+ }
509
+
510
  /* Media Queries */
511
  /* Desktop Media Query */
512
  @media screen and (min-width: 1024px) {
 
520
  height: calc(100vh - 190px) !important;
521
  overflow-y: scroll !important;
522
  }
 
523
  div#sources-figures,
524
  div#graphs-container,
525
  div#tab-citations {
 
528
  overflow-y: scroll !important;
529
  }
530
 
531
+ div#chatbot-row {
532
+ max-height: calc(100vh - 200px) !important;
533
+ }
534
+
535
+ div#chatbot {
536
+ height: 70vh !important;
537
+ max-height: 70vh !important;
538
+ }
539
+
540
  div#chatbot-row {
541
  max-height: calc(100vh - 90px) !important;
542
  }
 
555
  /* Mobile Media Query */
556
  @media screen and (max-width: 767px) {
557
  div#chatbot {
558
+ height: 400px !important; /* Reduced from 500px */
559
+ }
560
+
561
+ #follow-up-examples {
562
+ max-height: 150px;
563
  }
564
 
565
  #submit-button {
 
652
  }
653
 
654
  #vanna-display {
655
+ max-height: 200px;
656
  /* overflow-y: scroll; */
657
  }
658
  #sql-query{
659
+ max-height: 300px;
660
  overflow-y:scroll;
661
  }
662
+
663
+ #sql-query textarea{
664
+ min-height: 100px !important;
665
+ }
666
+
667
+ #sql-query span{
668
+ display: none;
669
+ }
670
+ div#tab-vanna{
671
+ max-height: 100¨vh;
672
+ overflow-y: hidden;
673
+ }
674
+ #vanna-plot{
675
+ max-height:500px
676
+ }
677
+
678
+ #pagination-display{
679
+ text-align: center;
680
+ font-weight: bold;
681
+ font-size: 16px;
682
+ }
683
+
684
+ #table-names table{
685
+ overflow: hidden;
686
+ }
687
+ #table-names thead{
688
+ display: none;
689
+ }
690
+
691
+ /* DRIAS Data Table Styles */
692
+ #vanna-table {
693
+ height: 400px !important;
694
+ overflow-y: auto !important;
695
+ }
696
+
697
+ #vanna-table > div[class*="table"] {
698
+ height: 400px !important;
699
+ overflow-y: None !important;
700
+ }
701
+
702
+ #vanna-table .table-wrap {
703
+ height: 400px !important;
704
+ overflow-y: None !important;
705
+ }
706
+
707
+ #vanna-table thead {
708
+ position: sticky;
709
+ top: 0;
710
+ background: white;
711
+ z-index: 1;
712
  }