armanddemasson commited on
Commit
1eae86b
·
1 Parent(s): 1700186

feat: added drias model choice and changed TTD UI

Browse files
app.py CHANGED
@@ -12,7 +12,7 @@ from climateqa.engine.reranker import get_reranker
12
  from climateqa.engine.graph import make_graph_agent,make_graph_agent_poc
13
  from climateqa.engine.chains.retrieve_papers import find_papers
14
  from climateqa.chat import start_chat, chat_stream, finish_chat
15
- from climateqa.engine.talk_to_data.main import ask_drias
16
  from climateqa.engine.talk_to_data.myVanna import MyVanna
17
 
18
  from front.tabs import (create_config_modal, create_examples_tab, create_papers_tab, create_figures_tab, create_chat_interface, create_about_tab)
@@ -87,8 +87,8 @@ vn.connect_to_sqlite(db_vanna_path)
87
  # def ask_vanna_query(query):
88
  # return ask_vanna(vn, db_vanna_path, query)
89
 
90
- def ask_drias_query(query, index_state):
91
- return ask_drias(db_vanna_path, query, index_state)
92
 
93
  async def chat(query, history, audience, sources, reports, relevant_content_sources_selection, search_only):
94
  print("chat cqa - message received")
@@ -139,27 +139,40 @@ def update_sources_number_display(sources_textbox, figures_cards, current_graphs
139
 
140
  # vanna_display = gr.Plot()
141
  # vanna_direct_question.submit(ask_drias_query, [vanna_direct_question], [vanna_sql_query ,vanna_table, vanna_display])
 
142
  def create_drias_tab():
143
  with gr.Tab("Beta - Talk to DRIAS", elem_id="tab-vanna", id=6):
144
- drias_direct_question = gr.Textbox(label="Direct Question", placeholder="You can write direct question here", elem_id="direct-question", interactive=True)
 
 
145
 
146
- with gr.Accordion("Details", elem_id="vanna-details", open=False) as drias_details:
147
- drias_sql_query = gr.Textbox(label="SQL Query Used", elem_id="sql-query", interactive=False)
 
 
148
  drias_table = gr.DataFrame([], elem_id="vanna-table")
149
- drias_display = gr.Plot()
150
 
151
- # Navigation buttons
 
 
 
152
  prev_button = gr.Button("Previous")
153
  next_button = gr.Button("Next")
154
 
155
- sql_queries_state = gr.State([])
156
- dataframes_state = gr.State([])
157
- plots_state = gr.State([])
158
- index_state = gr.State(0) # To track the current position
159
 
160
  drias_direct_question.submit(
161
  ask_drias_query,
162
- inputs=[drias_direct_question, index_state],
 
 
 
 
 
 
163
  outputs=[drias_sql_query, drias_table, drias_display, sql_queries_state, dataframes_state, plots_state, index_state]
164
  )
165
 
@@ -184,8 +197,6 @@ def create_drias_tab():
184
  inputs=[index_state, sql_queries_state, dataframes_state, plots_state],
185
  outputs=[drias_sql_query, drias_table, drias_display, index_state]
186
  )
187
-
188
-
189
 
190
  # # UI Layout Components
191
  def cqa_tab(tab_name):
 
12
  from climateqa.engine.graph import make_graph_agent,make_graph_agent_poc
13
  from climateqa.engine.chains.retrieve_papers import find_papers
14
  from climateqa.chat import start_chat, chat_stream, finish_chat
15
+ from climateqa.engine.talk_to_data.main import ask_drias, DRIAS_MODELS
16
  from climateqa.engine.talk_to_data.myVanna import MyVanna
17
 
18
  from front.tabs import (create_config_modal, create_examples_tab, create_papers_tab, create_figures_tab, create_chat_interface, create_about_tab)
 
87
  # def ask_vanna_query(query):
88
  # return ask_vanna(vn, db_vanna_path, query)
89
 
90
+ def ask_drias_query(query: str, index_state: int, drias_model: str):
91
+ return ask_drias(db_vanna_path, query, index_state, drias_model)
92
 
93
  async def chat(query, history, audience, sources, reports, relevant_content_sources_selection, search_only):
94
  print("chat cqa - message received")
 
139
 
140
  # vanna_display = gr.Plot()
141
  # vanna_direct_question.submit(ask_drias_query, [vanna_direct_question], [vanna_sql_query ,vanna_table, vanna_display])
142
+
143
  def create_drias_tab():
144
  with gr.Tab("Beta - Talk to DRIAS", elem_id="tab-vanna", id=6):
145
+ with gr.Row():
146
+ drias_direct_question = gr.Textbox(label="Direct Question", placeholder="You can write direct question here", elem_id="direct-question", interactive=True)
147
+ model_selection = gr.Dropdown(label="Model", choices=DRIAS_MODELS ,elem_id="drias-model", value="ALL", interactive=True)
148
 
149
+ with gr.Accordion(label="SQL Query Used"):
150
+ drias_sql_query = gr.Textbox(label="", elem_id="sql-query", interactive=False)
151
+
152
+ with gr.Accordion(label='Data used', open=False):
153
  drias_table = gr.DataFrame([], elem_id="vanna-table")
 
154
 
155
+ with gr.Accordion(label="Chart"):
156
+ drias_display = gr.Plot(elem_id="vanna-plot")
157
+
158
+ with gr.Row():
159
  prev_button = gr.Button("Previous")
160
  next_button = gr.Button("Next")
161
 
162
+ sql_queries_state = gr.State([])
163
+ dataframes_state = gr.State([])
164
+ plots_state = gr.State([])
165
+ index_state = gr.State(0)
166
 
167
  drias_direct_question.submit(
168
  ask_drias_query,
169
+ inputs=[drias_direct_question, index_state, model_selection],
170
+ outputs=[drias_sql_query, drias_table, drias_display, sql_queries_state, dataframes_state, plots_state, index_state]
171
+ )
172
+
173
+ model_selection.change(
174
+ ask_drias_query,
175
+ inputs=[drias_direct_question, index_state, model_selection],
176
  outputs=[drias_sql_query, drias_table, drias_display, sql_queries_state, dataframes_state, plots_state, index_state]
177
  )
178
 
 
197
  inputs=[index_state, sql_queries_state, dataframes_state, plots_state],
198
  outputs=[drias_sql_query, drias_table, drias_display, index_state]
199
  )
 
 
200
 
201
  # # UI Layout Components
202
  def cqa_tab(tab_name):
climateqa/engine/talk_to_data/main.py CHANGED
@@ -13,13 +13,12 @@ def ask_llm_column_names(sql_query, llm):
13
  columns_list = ast.literal_eval(columns.strip("```python\n").strip())
14
  return columns_list
15
 
16
- def ask_drias(db_drias_path:str, query:str , index_state: int):
17
- final_state = drias_workflow(db_drias_path, query)
18
  sql_queries = []
19
  result_dataframes = []
20
  figures = []
21
 
22
-
23
  for plot_state in final_state['plot_states'].values():
24
  for table_state in plot_state['table_states'].values():
25
  if table_state['status'] == 'OK':
@@ -30,9 +29,29 @@ def ask_drias(db_drias_path:str, query:str , index_state: int):
30
  result_dataframes.append(table_state['dataframe'])
31
  if 'figure' in table_state and table_state['figure'] is not None:
32
  figures.append(table_state['figure'](table_state['dataframe']))
 
33
 
34
  return sql_queries[index_state], result_dataframes[index_state], figures[index_state], sql_queries, result_dataframes, figures, index_state
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  # def ask_vanna(vn,db_vanna_path, query):
37
 
38
  # try :
 
13
  columns_list = ast.literal_eval(columns.strip("```python\n").strip())
14
  return columns_list
15
 
16
+ def ask_drias(db_drias_path:str, query:str, index_state: int = 0, drias_model: str = "ALL"):
17
+ final_state = drias_workflow(db_drias_path, query, drias_model)
18
  sql_queries = []
19
  result_dataframes = []
20
  figures = []
21
 
 
22
  for plot_state in final_state['plot_states'].values():
23
  for table_state in plot_state['table_states'].values():
24
  if table_state['status'] == 'OK':
 
29
  result_dataframes.append(table_state['dataframe'])
30
  if 'figure' in table_state and table_state['figure'] is not None:
31
  figures.append(table_state['figure'](table_state['dataframe']))
32
+
33
 
34
  return sql_queries[index_state], result_dataframes[index_state], figures[index_state], sql_queries, result_dataframes, figures, index_state
35
 
36
+ DRIAS_MODELS = [
37
+ 'ALL',
38
+ 'RegCM4-6_MPI-ESM-LR',
39
+ 'RACMO22E_EC-EARTH',
40
+ 'RegCM4-6_HadGEM2-ES',
41
+ 'HadREM3-GA7_EC-EARTH',
42
+ 'HadREM3-GA7_CNRM-CM5',
43
+ 'REMO2015_NorESM1-M',
44
+ 'SMHI-RCA4_EC-EARTH',
45
+ 'WRF381P_NorESM1-M',
46
+ 'ALADIN63_CNRM-CM5',
47
+ 'CCLM4-8-17_MPI-ESM-LR',
48
+ 'HIRHAM5_IPSL-CM5A-MR',
49
+ 'HadREM3-GA7_HadGEM2-ES',
50
+ 'SMHI-RCA4_IPSL-CM5A-MR',
51
+ 'HIRHAM5_NorESM1-M',
52
+ 'REMO2009_MPI-ESM-LR',
53
+ 'CCLM4-8-17_HadGEM2-ES'
54
+ ]
55
  # def ask_vanna(vn,db_vanna_path, query):
56
 
57
  # try :
climateqa/engine/talk_to_data/plot.py CHANGED
@@ -53,7 +53,7 @@ def plot_indicator_evolution_at_location(params: dict) -> Callable[..., Figure]:
53
  # Compute the 10-year rolling average
54
  sliding_averages = (
55
  df_avg[indicator]
56
- .rolling(window=10, min_periods=5)
57
  .mean()
58
  .astype(float)
59
  .tolist()
@@ -68,7 +68,7 @@ def plot_indicator_evolution_at_location(params: dict) -> Callable[..., Figure]:
68
  # Compute the 10-year rolling average
69
  sliding_averages = (
70
  df_model[indicator]
71
- .rolling(window=10, min_periods=5)
72
  .mean()
73
  .astype(float)
74
  .tolist()
@@ -241,6 +241,7 @@ def plot_distribution_of_indicator_for_given_year(
241
  yaxis_title="Frequency",
242
  plot_bgcolor="rgba(0, 0, 0, 0)",
243
  showlegend=False,
 
244
  )
245
 
246
  return fig
@@ -313,8 +314,8 @@ def plot_map_of_france_of_indicator_for_given_year(
313
  mapbox_style="open-street-map", # Use OpenStreetMap
314
  mapbox_zoom=3,
315
  mapbox_center={"lat": 46.6, "lon": 2.0},
316
- coloraxis_colorbar=dict(title=f"{indicator_label}"), # Add legend
317
- title=f"{indicator_label} in {year} in France", # Title
318
  )
319
  return fig
320
 
 
53
  # Compute the 10-year rolling average
54
  sliding_averages = (
55
  df_avg[indicator]
56
+ .rolling(window=10, min_periods=1)
57
  .mean()
58
  .astype(float)
59
  .tolist()
 
68
  # Compute the 10-year rolling average
69
  sliding_averages = (
70
  df_model[indicator]
71
+ .rolling(window=10, min_periods=1)
72
  .mean()
73
  .astype(float)
74
  .tolist()
 
241
  yaxis_title="Frequency",
242
  plot_bgcolor="rgba(0, 0, 0, 0)",
243
  showlegend=False,
244
+ pan=False
245
  )
246
 
247
  return fig
 
314
  mapbox_style="open-street-map", # Use OpenStreetMap
315
  mapbox_zoom=3,
316
  mapbox_center={"lat": 46.6, "lon": 2.0},
317
+ coloraxis_colorbar=dict(title=f"{indicator_label} {'(Model Average)' if model == 'ALL' else '(Model : ' + model + ')'}"), # Add legend
318
+ title=f"{indicator_label} in {year} in France ", # Title
319
  )
320
  return fig
321
 
climateqa/engine/talk_to_data/sql_query.py CHANGED
@@ -60,10 +60,16 @@ def indicator_per_year_at_location_query(
60
  indicator_column = params.get("indicator_column")
61
  latitude = params.get("latitude")
62
  longitude = params.get("longitude")
 
63
 
64
  if indicator_column is None or latitude is None or longitude is None: # If one parameter is missing, returns an empty query
65
  return ""
66
- sql_query = f"SELECT year, {indicator_column}, model\nFROM {table}\nWHERE latitude = {latitude} \nand longitude={longitude} \nOrder by Year"
 
 
 
 
 
67
  return sql_query
68
 
69
  class IndicatorForGivenYearQueryParams(TypedDict, total=False):
@@ -85,9 +91,12 @@ def indicator_for_given_year_query(
85
  """
86
  indicator_column = params.get("indicator_column")
87
  year = params.get('year')
88
-
89
  if year is None or indicator_column is None: # If one parameter is missing, returns an empty query
90
  return ""
91
-
92
- sql_query = f"Select {indicator_column}, latitude, longitude, model\nFrom {table}\nWhere year = {year}"
 
 
 
93
  return sql_query
 
60
  indicator_column = params.get("indicator_column")
61
  latitude = params.get("latitude")
62
  longitude = params.get("longitude")
63
+ model = params.get('model')
64
 
65
  if indicator_column is None or latitude is None or longitude is None: # If one parameter is missing, returns an empty query
66
  return ""
67
+
68
+ if model == 'ALL':
69
+ sql_query = f"SELECT year, {indicator_column}, model\nFROM {table}\nWHERE latitude = {latitude} \nand longitude={longitude} \nOrder by Year"
70
+ else:
71
+ sql_query = f"SELECT year, {indicator_column}, model\nFROM {table}\nWHERE latitude = {latitude} \nand longitude={longitude} \nand model='{model}' \nOrder by Year"
72
+
73
  return sql_query
74
 
75
  class IndicatorForGivenYearQueryParams(TypedDict, total=False):
 
91
  """
92
  indicator_column = params.get("indicator_column")
93
  year = params.get('year')
94
+ model = params.get('model')
95
  if year is None or indicator_column is None: # If one parameter is missing, returns an empty query
96
  return ""
97
+
98
+ if model == 'ALL':
99
+ sql_query = f"Select {indicator_column}, latitude, longitude, model\nFrom {table}\nWhere year = {year}"
100
+ else:
101
+ sql_query = f"Select {indicator_column}, latitude, longitude, model\nFrom {table}\nWhere year = {year}\nand model = '{model}'"
102
  return sql_query
climateqa/engine/talk_to_data/utils.py CHANGED
@@ -27,6 +27,31 @@ def detect_location_with_openai(sentence):
27
  return location_list[0]
28
  else:
29
  return ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
 
32
  def detectTable(sql_query):
@@ -65,6 +90,17 @@ def nearestNeighbourSQL(db: str, location: tuple, table: str):
65
 
66
 
67
  def detect_relevant_tables(db: str, user_question: str, plot: Plot, llm) -> list[str]:
 
 
 
 
 
 
 
 
 
 
 
68
  conn = sqlite3.connect(db)
69
  cursor = conn.cursor()
70
 
 
27
  return location_list[0]
28
  else:
29
  return ""
30
+
31
+ def detect_year_with_openai(sentence: str):
32
+ """
33
+ Detects years in a sentence using OpenAI's API via LangChain.
34
+ """
35
+ llm = get_llm()
36
+
37
+ prompt = f"""
38
+ Extract all years mentioned in the following sentence.
39
+ Return the result as a Python list. If no year are mentioned, return an empty list.
40
+
41
+ Sentence: "{sentence}"
42
+ """
43
+
44
+ response = llm.invoke(prompt)
45
+ if response is None:
46
+ return None
47
+ response_split = response.content.strip("```python\n").split('=')
48
+ years_list = []
49
+ if len(response_split) > 1:
50
+ years_list = ast.literal_eval(response_split[1])
51
+ if years_list and len(years_list) > 0:
52
+ return years_list[0]
53
+ else:
54
+ return None
55
 
56
 
57
  def detectTable(sql_query):
 
90
 
91
 
92
  def detect_relevant_tables(db: str, user_question: str, plot: Plot, llm) -> list[str]:
93
+ """Detect relevant tables regarding the plot and the user input
94
+
95
+ Args:
96
+ db (str): database path
97
+ user_question (str): initial user input
98
+ plot (Plot): plot object for which we wanna plot
99
+ llm (_type_): LLM
100
+
101
+ Returns:
102
+ list[str]: list of table names
103
+ """
104
  conn = sqlite3.connect(db)
105
  cursor = conn.cursor()
106
 
climateqa/engine/talk_to_data/workflow.py CHANGED
@@ -38,7 +38,7 @@ class State(TypedDict):
38
  plots: list[str]
39
  plot_states: dict[str, PlotState]
40
 
41
- def drias_workflow(db_drias_path: str, user_input: str) -> State:
42
  """Performs the complete workflow of Talk To Drias : from user input to sql queries, dataframes and figures generated
43
 
44
  Args:
@@ -87,7 +87,7 @@ def drias_workflow(db_drias_path: str, user_input: str) -> State:
87
  'status': 'OK'
88
  }
89
  table_state['params'] = {
90
- 'model': 'ALL'
91
  }
92
  for param_name in plot['params']:
93
  param = find_param(state, param_name, table, db_drias_path)
@@ -99,6 +99,7 @@ def drias_workflow(db_drias_path: str, user_input: str) -> State:
99
  if sql_query == "":
100
  table_state['status'] = 'ERROR'
101
  continue
 
102
 
103
  table_state['sql_query'] = sql_query
104
  results = execute_sql_query(db_drias_path, sql_query)
 
38
  plots: list[str]
39
  plot_states: dict[str, PlotState]
40
 
41
+ def drias_workflow(db_drias_path: str, user_input: str, model: str) -> State:
42
  """Performs the complete workflow of Talk To Drias : from user input to sql queries, dataframes and figures generated
43
 
44
  Args:
 
87
  'status': 'OK'
88
  }
89
  table_state['params'] = {
90
+ 'model': model
91
  }
92
  for param_name in plot['params']:
93
  param = find_param(state, param_name, table, db_drias_path)
 
99
  if sql_query == "":
100
  table_state['status'] = 'ERROR'
101
  continue
102
+ print(sql_query)
103
 
104
  table_state['sql_query'] = sql_query
105
  results = execute_sql_query(db_drias_path, sql_query)
style.css CHANGED
@@ -487,7 +487,6 @@ a {
487
  height: calc(100vh - 190px) !important;
488
  overflow-y: scroll !important;
489
  }
490
- div#tab-vanna,
491
  div#sources-figures,
492
  div#graphs-container,
493
  div#tab-citations {
@@ -607,14 +606,25 @@ a {
607
  }
608
 
609
  #vanna-display {
610
- max-height: 300px;
611
  /* overflow-y: scroll; */
612
  }
613
  #sql-query{
614
- max-height: 100px;
615
  overflow-y:scroll;
616
  }
617
- #vanna-details{
618
- max-height: 500px;
 
 
 
 
619
  overflow-y:scroll;
 
 
 
620
  }
 
 
 
 
 
487
  height: calc(100vh - 190px) !important;
488
  overflow-y: scroll !important;
489
  }
 
490
  div#sources-figures,
491
  div#graphs-container,
492
  div#tab-citations {
 
606
  }
607
 
608
  #vanna-display {
609
+ max-height: 200px;
610
  /* overflow-y: scroll; */
611
  }
612
  #sql-query{
613
+ max-height: 300px;
614
  overflow-y:scroll;
615
  }
616
+
617
+ #sql-query span{
618
+ display: none;
619
+ }
620
+ div#tab-vanna{
621
+ max-height: 100¨vh;
622
  overflow-y:scroll;
623
+ }
624
+ #vanna-plot{
625
+ max-height:500px
626
  }
627
+
628
+ #drias-model{
629
+ max-width: 25%;
630
+ }