armanddemasson commited on
Commit
c3024c3
·
1 Parent(s): ac63459

feat: updated talk to drias based on talk to ipcc

Browse files
climateqa/engine/talk_to_data/drias/config.py CHANGED
@@ -1,7 +1,7 @@
1
 
2
  DRIAS_TABLES = [
3
  "total_winter_precipitation",
4
- "total_summer_precipiation",
5
  "total_annual_precipitation",
6
  "total_remarkable_daily_precipitation",
7
  "frequency_of_remarkable_daily_precipitation",
@@ -18,7 +18,7 @@ DRIAS_TABLES = [
18
 
19
  DRIAS_INDICATOR_COLUMNS_PER_TABLE = {
20
  "total_winter_precipitation": "total_winter_precipitation",
21
- "total_summer_precipiation": "total_summer_precipitation",
22
  "total_annual_precipitation": "total_annual_precipitation",
23
  "total_remarkable_daily_precipitation": "total_remarkable_daily_precipitation",
24
  "frequency_of_remarkable_daily_precipitation": "frequency_of_remarkable_daily_precipitation",
@@ -70,11 +70,14 @@ DRIAS_INDICATOR_TO_UNIT = {
70
  "number_of_days_with_dry_ground": "days"
71
  }
72
 
 
 
 
 
73
  DRIAS_UI_TEXT = """
74
  Hi, I'm **Talk to Drias**, designed to answer your questions using [**DRIAS - TRACC 2023**](https://www.drias-climat.fr/accompagnement/sections/401) data.
75
  I'll answer by displaying a list of SQL queries, graphs and data most relevant to your question.
76
 
77
- ❓ **How to use?**
78
  You can ask me anything about these climate indicators: **temperature**, **precipitation** or **drought**.
79
  You can specify **location** and/or **year**.
80
  You can choose from a list of climate models. By default, we take the **average of each model**.
 
1
 
2
  DRIAS_TABLES = [
3
  "total_winter_precipitation",
4
+ "total_summer_precipitation",
5
  "total_annual_precipitation",
6
  "total_remarkable_daily_precipitation",
7
  "frequency_of_remarkable_daily_precipitation",
 
18
 
19
  DRIAS_INDICATOR_COLUMNS_PER_TABLE = {
20
  "total_winter_precipitation": "total_winter_precipitation",
21
+ "total_summer_precipitation": "total_summer_precipitation",
22
  "total_annual_precipitation": "total_annual_precipitation",
23
  "total_remarkable_daily_precipitation": "total_remarkable_daily_precipitation",
24
  "frequency_of_remarkable_daily_precipitation": "frequency_of_remarkable_daily_precipitation",
 
70
  "number_of_days_with_dry_ground": "days"
71
  }
72
 
73
+ DRIAS_PLOT_PARAMETERS = [
74
+ 'year',
75
+ 'location'
76
+ ]
77
  DRIAS_UI_TEXT = """
78
  Hi, I'm **Talk to Drias**, designed to answer your questions using [**DRIAS - TRACC 2023**](https://www.drias-climat.fr/accompagnement/sections/401) data.
79
  I'll answer by displaying a list of SQL queries, graphs and data most relevant to your question.
80
 
 
81
  You can ask me anything about these climate indicators: **temperature**, **precipitation** or **drought**.
82
  You can specify **location** and/or **year**.
83
  You can choose from a list of climate models. By default, we take the **average of each model**.
climateqa/engine/talk_to_data/drias/plots.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
-
 
3
  from typing import Callable
4
  import pandas as pd
5
  from plotly.graph_objects import Figure
@@ -11,6 +12,7 @@ from climateqa.engine.talk_to_data.drias.queries import (
11
  )
12
  from climateqa.engine.talk_to_data.drias.config import DRIAS_INDICATOR_TO_UNIT
13
 
 
14
  def plot_indicator_evolution_at_location(params: dict) -> Callable[..., Figure]:
15
  """Generates a function to plot indicator evolution over time at a location.
16
 
@@ -122,10 +124,11 @@ def plot_indicator_evolution_at_location(params: dict) -> Callable[..., Figure]:
122
  hovertemplate=f"{indicator_label}: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
123
  )
124
  fig.update_layout(
125
- title=f"Plot of {indicator_label} in {location} ({model_label})",
126
  xaxis_title="Year",
127
  yaxis_title=f"{indicator_label} ({unit})",
128
  template="plotly_white",
 
129
  )
130
  return fig
131
 
@@ -138,6 +141,7 @@ indicator_evolution_at_location: Plot = {
138
  "params": ["indicator_column", "location", "model"],
139
  "plot_function": plot_indicator_evolution_at_location,
140
  "sql_query": indicator_per_year_at_location_query,
 
141
  }
142
 
143
 
@@ -206,6 +210,7 @@ def plot_indicator_number_of_days_per_year_at_location(
206
  yaxis_title=f"{indicator_label} ({unit})",
207
  yaxis=dict(range=[0, max(indicators)]),
208
  bargap=0.5,
 
209
  template="plotly_white",
210
  )
211
 
@@ -220,6 +225,7 @@ indicator_number_of_days_per_year_at_location: Plot = {
220
  "params": ["indicator_column", "location", "model"],
221
  "plot_function": plot_indicator_number_of_days_per_year_at_location,
222
  "sql_query": indicator_per_year_at_location_query,
 
223
  }
224
 
225
 
@@ -242,6 +248,8 @@ def plot_distribution_of_indicator_for_given_year(
242
  """
243
  indicator = params["indicator_column"]
244
  year = params["year"]
 
 
245
  indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
246
  unit = DRIAS_INDICATOR_TO_UNIT.get(indicator, "")
247
 
@@ -288,6 +296,7 @@ def plot_distribution_of_indicator_for_given_year(
288
  yaxis_title="Frequency (%)",
289
  plot_bgcolor="rgba(0, 0, 0, 0)",
290
  showlegend=False,
 
291
  )
292
 
293
  return fig
@@ -301,6 +310,7 @@ distribution_of_indicator_for_given_year: Plot = {
301
  "params": ["indicator_column", "model", "year"],
302
  "plot_function": plot_distribution_of_indicator_for_given_year,
303
  "sql_query": indicator_for_given_year_query,
 
304
  }
305
 
306
 
@@ -323,6 +333,8 @@ def plot_map_of_france_of_indicator_for_given_year(
323
  """
324
  indicator = params["indicator_column"]
325
  year = params["year"]
 
 
326
  indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
327
  unit = DRIAS_INDICATOR_TO_UNIT.get(indicator, "")
328
 
@@ -347,28 +359,60 @@ def plot_map_of_france_of_indicator_for_given_year(
347
  longitudes = df_model["longitude"].astype(float).tolist()
348
  model_label = f"Model : {df['model'].unique()[0]}"
349
 
350
-
351
- fig.add_trace(
352
- go.Scattermapbox(
353
- lat=latitudes,
354
- lon=longitudes,
355
- mode="markers",
356
- marker=dict(
357
- size=10,
358
- color=indicators, # Color mapped to values
359
- colorscale="Turbo", # Color scale (can be 'Plasma', 'Jet', etc.)
360
- cmin=min(indicators), # Minimum color range
361
- cmax=max(indicators), # Maximum color range
362
- showscale=True, # Show colorbar
363
- ),
364
- text=[f"{indicator_label}: {value:.2f} {unit}" for value in indicators], # Add hover text showing the indicator value
365
- hoverinfo="text" # Only show the custom text on hover
366
- )
367
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368
 
369
  fig.update_layout(
370
  mapbox_style="open-street-map", # Use OpenStreetMap
371
- mapbox_zoom=3,
 
372
  mapbox_center={"lat": 46.6, "lon": 2.0},
373
  coloraxis_colorbar=dict(title=f"{indicator_label} ({unit})"), # Add legend
374
  title=f"{indicator_label} in {year} in France ({model_label}) " # Title
@@ -380,10 +424,11 @@ def plot_map_of_france_of_indicator_for_given_year(
380
 
381
  map_of_france_of_indicator_for_given_year: Plot = {
382
  "name": "Map of France of an indicator for a given year",
383
- "description": "Heatmap on the map of France of the values of an in indicator for a given year",
384
  "params": ["indicator_column", "year", "model"],
385
  "plot_function": plot_map_of_france_of_indicator_for_given_year,
386
  "sql_query": indicator_for_given_year_query,
 
387
  }
388
 
389
  DRIAS_PLOTS = [
 
1
  import os
2
+ import geojson
3
+ from math import cos, radians
4
  from typing import Callable
5
  import pandas as pd
6
  from plotly.graph_objects import Figure
 
12
  )
13
  from climateqa.engine.talk_to_data.drias.config import DRIAS_INDICATOR_TO_UNIT
14
 
15
+
16
  def plot_indicator_evolution_at_location(params: dict) -> Callable[..., Figure]:
17
  """Generates a function to plot indicator evolution over time at a location.
18
 
 
124
  hovertemplate=f"{indicator_label}: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
125
  )
126
  fig.update_layout(
127
+ title=f"Evolution of {indicator_label} in {location} ({model_label})",
128
  xaxis_title="Year",
129
  yaxis_title=f"{indicator_label} ({unit})",
130
  template="plotly_white",
131
+ height=900,
132
  )
133
  return fig
134
 
 
141
  "params": ["indicator_column", "location", "model"],
142
  "plot_function": plot_indicator_evolution_at_location,
143
  "sql_query": indicator_per_year_at_location_query,
144
+ 'short_name': 'Indicator Evolution'
145
  }
146
 
147
 
 
210
  yaxis_title=f"{indicator_label} ({unit})",
211
  yaxis=dict(range=[0, max(indicators)]),
212
  bargap=0.5,
213
+ height=900,
214
  template="plotly_white",
215
  )
216
 
 
225
  "params": ["indicator_column", "location", "model"],
226
  "plot_function": plot_indicator_number_of_days_per_year_at_location,
227
  "sql_query": indicator_per_year_at_location_query,
228
+ "short_name": "Indicator Yearly Frequency",
229
  }
230
 
231
 
 
248
  """
249
  indicator = params["indicator_column"]
250
  year = params["year"]
251
+ if year is None:
252
+ year = 2030
253
  indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
254
  unit = DRIAS_INDICATOR_TO_UNIT.get(indicator, "")
255
 
 
296
  yaxis_title="Frequency (%)",
297
  plot_bgcolor="rgba(0, 0, 0, 0)",
298
  showlegend=False,
299
+ height=900,
300
  )
301
 
302
  return fig
 
310
  "params": ["indicator_column", "model", "year"],
311
  "plot_function": plot_distribution_of_indicator_for_given_year,
312
  "sql_query": indicator_for_given_year_query,
313
+ 'short_name': 'Indicator Distribution'
314
  }
315
 
316
 
 
333
  """
334
  indicator = params["indicator_column"]
335
  year = params["year"]
336
+ if year is None:
337
+ year = 2030
338
  indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
339
  unit = DRIAS_INDICATOR_TO_UNIT.get(indicator, "")
340
 
 
359
  longitudes = df_model["longitude"].astype(float).tolist()
360
  model_label = f"Model : {df['model'].unique()[0]}"
361
 
362
+ side_km = 8
363
+ delta_lat = side_km / 111
364
+ features = []
365
+ for idx, (lat, lon, val) in enumerate(zip(latitudes, longitudes, indicators)):
366
+ delta_lon = side_km / (111 * cos(radians(lat)))
367
+ half_lat = delta_lat / 2
368
+ half_lon = delta_lon / 2
369
+ features.append(geojson.Feature(
370
+ geometry=geojson.Polygon([[
371
+ [lon - half_lon, lat - half_lat],
372
+ [lon + half_lon, lat - half_lat],
373
+ [lon + half_lon, lat + half_lat],
374
+ [lon - half_lon, lat + half_lat],
375
+ [lon - half_lon, lat - half_lat]
376
+ ]]),
377
+ properties={"value": val},
378
+ id=str(idx)
379
+ ))
380
+
381
+ geojson_data = geojson.FeatureCollection(features)
382
+
383
+ custom_colorscale = [
384
+ [0.0, "rgb(5, 48, 97)"],
385
+ [0.10, "rgb(33, 102, 172)"],
386
+ [0.20, "rgb(67, 147, 195)"],
387
+ [0.30, "rgb(146, 197, 222)"],
388
+ [0.40, "rgb(209, 229, 240)"],
389
+ [0.50, "rgb(247, 247, 247)"],
390
+ [0.60, "rgb(253, 219, 199)"],
391
+ [0.75, "rgb(244, 165, 130)"],
392
+ [0.85, "rgb(214, 96, 77)"],
393
+ [0.90, "rgb(178, 24, 43)"],
394
+ [1.0, "rgb(103, 0, 31)"]
395
+ ]
396
+
397
+ fig = go.Figure(go.Choroplethmapbox(
398
+ geojson=geojson_data,
399
+ locations=[str(i) for i in range(len(indicators))],
400
+ featureidkey="id",
401
+ z=indicators,
402
+ colorscale=custom_colorscale,
403
+ zmin=min(indicators),
404
+ zmax=max(indicators),
405
+ marker_opacity=0.7,
406
+ marker_line_width=0,
407
+ colorbar_title=f"{indicator_label} ({unit})",
408
+ text=[f"{indicator_label}: {value:.2f} {unit}" for value in indicators], # Add hover text showing the indicator value
409
+ hoverinfo="text"
410
+ ))
411
 
412
  fig.update_layout(
413
  mapbox_style="open-street-map", # Use OpenStreetMap
414
+ mapbox_zoom=5,
415
+ height=900,
416
  mapbox_center={"lat": 46.6, "lon": 2.0},
417
  coloraxis_colorbar=dict(title=f"{indicator_label} ({unit})"), # Add legend
418
  title=f"{indicator_label} in {year} in France ({model_label}) " # Title
 
424
 
425
  map_of_france_of_indicator_for_given_year: Plot = {
426
  "name": "Map of France of an indicator for a given year",
427
+ "description": "Heatmap on the map of France of the values of an indicator for a given year",
428
  "params": ["indicator_column", "year", "model"],
429
  "plot_function": plot_map_of_france_of_indicator_for_given_year,
430
  "sql_query": indicator_for_given_year_query,
431
+ 'short_name': 'Map of France'
432
  }
433
 
434
  DRIAS_PLOTS = [
climateqa/engine/talk_to_data/drias/queries.py CHANGED
@@ -71,6 +71,8 @@ def indicator_for_given_year_query(
71
  """
72
  indicator_column = params.get("indicator_column")
73
  year = params.get('year')
 
 
74
  if year is None or indicator_column is None: # If one parameter is missing, returns an empty query
75
  return ""
76
 
 
71
  """
72
  indicator_column = params.get("indicator_column")
73
  year = params.get('year')
74
+ if year is None:
75
+ year = 2050
76
  if year is None or indicator_column is None: # If one parameter is missing, returns an empty query
77
  return ""
78
 
climateqa/engine/talk_to_data/workflow/drias.py CHANGED
@@ -6,131 +6,154 @@ from climateqa.engine.llm import get_llm
6
  from climateqa.engine.talk_to_data.input_processing import find_param, find_relevant_plots, find_relevant_tables_per_plot
7
  from climateqa.engine.talk_to_data.query import execute_sql_query, find_indicator_column
8
  from climateqa.engine.talk_to_data.objects.plot import Plot
9
- from climateqa.engine.talk_to_data.objects.states import PlotState, State, TableState
10
- from climateqa.engine.talk_to_data.drias.config import DRIAS_TABLES, DRIAS_INDICATOR_COLUMNS_PER_TABLE
11
  from climateqa.engine.talk_to_data.drias.plots import DRIAS_PLOTS
12
 
13
  ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))
14
 
15
- async def process_table(
 
16
  table: str,
17
- params: dict[str, Any],
18
  plot: Plot,
19
- ) -> TableState:
20
- """Processes a table to extract relevant data and generate visualizations.
21
-
22
- This function retrieves the SQL query for the specified table, executes it,
23
- and generates a visualization based on the results.
24
-
25
  Args:
26
- table (str): The name of the table to process
27
- params (dict[str, Any]): Parameters used for querying the table
28
- plot (Plot): The plot object containing SQL query and visualization function
29
-
 
30
  Returns:
31
- TableState: The state of the processed table
32
  """
33
- table_state: TableState = {
34
- 'table_name': table,
35
- 'params': params.copy(),
36
  'status': 'OK',
37
- 'dataframe': None,
 
38
  'sql_query': None,
 
39
  'figure': None
40
  }
 
 
 
 
41
 
42
- table_state['params']['indicator_column'] = find_indicator_column(table, DRIAS_INDICATOR_COLUMNS_PER_TABLE)
43
- sql_query = plot['sql_query'](table, table_state['params'])
 
 
44
 
45
- if sql_query == "":
46
- table_state['status'] = 'ERROR'
47
- return table_state
48
- table_state['sql_query'] = sql_query
49
- df = await execute_sql_query(sql_query)
50
 
51
- table_state['dataframe'] = df
52
- table_state['figure'] = plot['plot_function'](table_state['params'])
53
 
54
- return table_state
 
 
 
 
 
 
55
 
 
 
 
 
56
 
57
  async def drias_workflow(user_input: str) -> State:
58
- """Performs the complete workflow of Talk To Drias : from user input to sql queries, dataframes and figures generated
 
59
 
60
  Args:
61
- user_input (str): initial user input
62
 
63
  Returns:
64
- State: Final state with all the results
65
  """
66
  state: State = {
67
  'user_input': user_input,
68
  'plots': [],
69
- 'plot_states': {},
70
  'error': ''
71
  }
72
 
73
  llm = get_llm(provider="openai")
74
-
75
  plots = await find_relevant_plots(state, llm, DRIAS_PLOTS)
76
 
77
- state['plots'] = plots
78
-
79
- if len(state['plots']) < 1:
80
  state['error'] = 'There is no plot to answer to the question'
81
  return state
82
 
83
- have_relevant_table = False
84
- have_sql_query = False
85
- have_dataframe = False
86
 
87
- for plot_name in state['plots']:
88
-
89
- plot = next((p for p in DRIAS_PLOTS if p['name'] == plot_name), None) # Find the associated plot object
 
 
 
 
 
 
 
90
  if plot is None:
91
  continue
92
-
93
- plot_state: PlotState = {
94
- 'plot_name': plot_name,
95
- 'tables': [],
96
- 'table_states': {}
97
- }
98
-
99
- plot_state['plot_name'] = plot_name
100
 
101
  relevant_tables = await find_relevant_tables_per_plot(state, plot, llm, DRIAS_TABLES)
102
-
103
- if len(relevant_tables) > 0 :
104
- have_relevant_table = True
105
-
106
- plot_state['tables'] = relevant_tables
107
-
108
- params = {}
109
- for param_name in plot['params']:
110
- param = await find_param(state, param_name, relevant_tables[0])
111
- if param:
112
- params.update(param)
113
-
114
- tasks = [process_table(table, params, plot) for table in plot_state['tables'][:3]]
115
- results = await asyncio.gather(*tasks)
116
-
117
- # Store results back in plot_state
118
- have_dataframe = False
119
- have_sql_query = False
120
- for table_state in results:
121
- if table_state['sql_query']:
122
- have_sql_query = True
123
- if table_state['dataframe'] is not None and len(table_state['dataframe']) > 0:
124
- have_dataframe = True
125
- plot_state['table_states'][table_state['table_name']] = table_state
126
-
127
- state['plot_states'][plot_name] = plot_state
128
-
129
- if not have_relevant_table:
 
 
 
 
 
 
 
 
 
 
130
  state['error'] = "There is no relevant table in our database to answer your question"
131
- elif not have_sql_query:
132
  state['error'] = "There is no relevant sql query on our database that can help to answer your question"
133
- elif not have_dataframe:
134
  state['error'] = "There is no data in our table that can answer to your question"
135
-
136
- return state
 
6
  from climateqa.engine.talk_to_data.input_processing import find_param, find_relevant_plots, find_relevant_tables_per_plot
7
  from climateqa.engine.talk_to_data.query import execute_sql_query, find_indicator_column
8
  from climateqa.engine.talk_to_data.objects.plot import Plot
9
+ from climateqa.engine.talk_to_data.objects.states import State, TTDOutput
10
+ from climateqa.engine.talk_to_data.drias.config import DRIAS_TABLES, DRIAS_INDICATOR_COLUMNS_PER_TABLE, DRIAS_PLOT_PARAMETERS
11
  from climateqa.engine.talk_to_data.drias.plots import DRIAS_PLOTS
12
 
13
  ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))
14
 
15
+ async def process_output(
16
+ output_title: str,
17
  table: str,
 
18
  plot: Plot,
19
+ params: dict[str, Any]
20
+ ) -> tuple[str, TTDOutput, dict[str, bool]]:
21
+ """
22
+ Processes a table for a given plot and parameters: builds the SQL query, executes it,
23
+ and generates the corresponding figure.
24
+
25
  Args:
26
+ output_title (str): Title for the output (used as key in outputs dict).
27
+ table (str): The name of the table to process.
28
+ plot (Plot): The plot object containing SQL query and visualization function.
29
+ params (dict[str, Any]): Parameters used for querying the table.
30
+
31
  Returns:
32
+ tuple: (output_title, results dict, errors dict)
33
  """
34
+ results: TTDOutput = {
 
 
35
  'status': 'OK',
36
+ 'plot': plot,
37
+ 'table': table,
38
  'sql_query': None,
39
+ 'dataframe': None,
40
  'figure': None
41
  }
42
+ errors = {
43
+ 'have_sql_query': False,
44
+ 'have_dataframe': False
45
+ }
46
 
47
+ # Find the indicator column for this table
48
+ indicator_column = find_indicator_column(table, DRIAS_INDICATOR_COLUMNS_PER_TABLE)
49
+ if indicator_column:
50
+ params['indicator_column'] = indicator_column
51
 
52
+ # Build the SQL query
53
+ sql_query = plot['sql_query'](table, params)
54
+ if not sql_query:
55
+ results['status'] = 'ERROR'
56
+ return output_title, results, errors
57
 
58
+ results['sql_query'] = sql_query
59
+ errors['have_sql_query'] = True
60
 
61
+ # Execute the SQL query
62
+ df = await execute_sql_query(sql_query)
63
+ if df is not None and len(df) > 0:
64
+ results['dataframe'] = df
65
+ errors['have_dataframe'] = True
66
+ else:
67
+ results['status'] = 'NO_DATA'
68
 
69
+ # Generate the figure (always, even if df is empty, for consistency)
70
+ results['figure'] = plot['plot_function'](params)
71
+
72
+ return output_title, results, errors
73
 
74
  async def drias_workflow(user_input: str) -> State:
75
+ """
76
+ Orchestrates the DRIAS workflow: from user input to SQL queries, dataframes, and figures.
77
 
78
  Args:
79
+ user_input (str): The user's question.
80
 
81
  Returns:
82
+ State: Final state with all results and error messages if any.
83
  """
84
  state: State = {
85
  'user_input': user_input,
86
  'plots': [],
87
+ 'outputs': {},
88
  'error': ''
89
  }
90
 
91
  llm = get_llm(provider="openai")
 
92
  plots = await find_relevant_plots(state, llm, DRIAS_PLOTS)
93
 
94
+ if not plots:
 
 
95
  state['error'] = 'There is no plot to answer to the question'
96
  return state
97
 
98
+ plots = plots[:2] # limit to 2 types of plots
99
+ state['plots'] = plots
 
100
 
101
+ errors = {
102
+ 'have_relevant_table': False,
103
+ 'have_sql_query': False,
104
+ 'have_dataframe': False
105
+ }
106
+ outputs = {}
107
+
108
+ # Find relevant tables for each plot and prepare outputs
109
+ for plot_name in plots:
110
+ plot = next((p for p in DRIAS_PLOTS if p['name'] == plot_name), None)
111
  if plot is None:
112
  continue
 
 
 
 
 
 
 
 
113
 
114
  relevant_tables = await find_relevant_tables_per_plot(state, plot, llm, DRIAS_TABLES)
115
+ if relevant_tables:
116
+ errors['have_relevant_table'] = True
117
+
118
+ for table in relevant_tables:
119
+ output_title = f"{plot['short_name']} - {' '.join(table.capitalize().split('_'))}"
120
+ outputs[output_title] = {
121
+ 'table': table,
122
+ 'plot': plot,
123
+ 'status': 'OK'
124
+ }
125
+
126
+ # Gather all required parameters
127
+ params = {}
128
+ for param_name in DRIAS_PLOT_PARAMETERS:
129
+ param = await find_param(state, param_name, mode='DRIAS')
130
+ if param:
131
+ params.update(param)
132
+
133
+ # Process all outputs in parallel using process_output
134
+ tasks = [
135
+ process_output(output_title, output['table'], output['plot'], params.copy())
136
+ for output_title, output in outputs.items()
137
+ ]
138
+ results = await asyncio.gather(*tasks)
139
+
140
+ # Update outputs with results and error flags
141
+ for output_title, task_results, task_errors in results:
142
+ outputs[output_title]['sql_query'] = task_results['sql_query']
143
+ outputs[output_title]['dataframe'] = task_results['dataframe']
144
+ outputs[output_title]['figure'] = task_results['figure']
145
+ outputs[output_title]['status'] = task_results['status']
146
+ errors['have_sql_query'] |= task_errors['have_sql_query']
147
+ errors['have_dataframe'] |= task_errors['have_dataframe']
148
+
149
+ state['outputs'] = outputs
150
+
151
+ # Set error messages if needed
152
+ if not errors['have_relevant_table']:
153
  state['error'] = "There is no relevant table in our database to answer your question"
154
+ elif not errors['have_sql_query']:
155
  state['error'] = "There is no relevant sql query on our database that can help to answer your question"
156
+ elif not errors['have_dataframe']:
157
  state['error'] = "There is no data in our table that can answer to your question"
158
+
159
+ return state
front/tabs/tab_drias.py CHANGED
@@ -11,9 +11,10 @@ class DriasUIElements(TypedDict):
11
  details_accordion: gr.Accordion
12
  examples_hidden: gr.Textbox
13
  examples: gr.Examples
 
14
  drias_direct_question: gr.Textbox
15
  result_text: gr.Textbox
16
- table_names_display: gr.DataFrame
17
  query_accordion: gr.Accordion
18
  drias_sql_query: gr.Textbox
19
  chart_accordion: gr.Accordion
@@ -21,9 +22,6 @@ class DriasUIElements(TypedDict):
21
  drias_display: gr.Plot
22
  table_accordion: gr.Accordion
23
  drias_table: gr.DataFrame
24
- pagination_display: gr.Markdown
25
- prev_button: gr.Button
26
- next_button: gr.Button
27
 
28
 
29
  async def ask_drias_query(query: str, index_state: int, user_id: str):
@@ -31,7 +29,7 @@ async def ask_drias_query(query: str, index_state: int, user_id: str):
31
  return result
32
 
33
 
34
- def show_results(sql_queries_state, dataframes_state, plots_state):
35
  if not sql_queries_state or not dataframes_state or not plots_state:
36
  # If all results are empty, show "No result"
37
  return (
@@ -40,9 +38,6 @@ def show_results(sql_queries_state, dataframes_state, plots_state):
40
  gr.update(visible=False),
41
  gr.update(visible=False),
42
  gr.update(visible=False),
43
- gr.update(visible=False),
44
- gr.update(visible=False),
45
- gr.update(visible=False),
46
  )
47
  else:
48
  # Show the appropriate components with their data
@@ -51,10 +46,7 @@ def show_results(sql_queries_state, dataframes_state, plots_state):
51
  gr.update(visible=True),
52
  gr.update(visible=True),
53
  gr.update(visible=True),
54
- gr.update(visible=True),
55
- gr.update(visible=True),
56
- gr.update(visible=True),
57
- gr.update(visible=True),
58
  )
59
 
60
 
@@ -72,39 +64,8 @@ def filter_by_model(dataframes, figures, index_state, model_selection):
72
  return df, figure
73
 
74
 
75
- def update_pagination(index, sql_queries):
76
- pagination = f"{index + 1}/{len(sql_queries)}" if sql_queries else ""
77
- return pagination
78
-
79
-
80
- def show_previous(index, sql_queries, dataframes, plots):
81
- if index > 0:
82
- index -= 1
83
- return (
84
- sql_queries[index],
85
- dataframes[index],
86
- plots[index](dataframes[index]),
87
- index,
88
- )
89
-
90
-
91
- def show_next(index, sql_queries, dataframes, plots):
92
- if index < len(sql_queries) - 1:
93
- index += 1
94
- return (
95
- sql_queries[index],
96
- dataframes[index],
97
- plots[index](dataframes[index]),
98
- index,
99
- )
100
-
101
-
102
- def display_table_names(table_names):
103
- return [[table_name] for table_name in table_names]
104
-
105
-
106
- def on_table_click(evt: gr.SelectData, table_names, sql_queries, dataframes, plots):
107
- index = evt.index[1]
108
  figure = plots[index](dataframes[index])
109
  return (
110
  sql_queries[index],
@@ -117,7 +78,7 @@ def on_table_click(evt: gr.SelectData, table_names, sql_queries, dataframes, plo
117
  def create_drias_ui() -> DriasUIElements:
118
  """Create and return all UI elements for the DRIAS tab."""
119
  with gr.Tab("France - Talk to DRIAS", elem_id="tab-vanna", id=6) as tab:
120
- with gr.Accordion(label="Details") as details_accordion:
121
  gr.Markdown(DRIAS_UI_TEXT)
122
 
123
  # Add examples for common questions
@@ -141,19 +102,35 @@ def create_drias_ui() -> DriasUIElements:
141
  elem_id="direct-question",
142
  interactive=True,
143
  )
 
 
 
 
 
 
 
 
 
144
 
145
  result_text = gr.Textbox(
146
  label="", elem_id="no-result-label", interactive=False, visible=True
147
  )
148
- table_names_display = gr.DataFrame(
149
- [], label="List of relevant indicators", headers=["Indicator Name"], interactive=False, elem_id="table-names", visible=False
150
- )
151
-
152
- with gr.Accordion(label="SQL Query Used", visible=False) as query_accordion:
153
- drias_sql_query = gr.Textbox(
154
- label="", elem_id="sql-query", interactive=False
 
155
  )
156
 
 
 
 
 
 
 
157
  with gr.Accordion(label="Chart", visible=False) as chart_accordion:
158
  model_selection = gr.Dropdown(
159
  label="Model", choices=DRIAS_MODELS, value="ALL", interactive=True
@@ -165,19 +142,12 @@ def create_drias_ui() -> DriasUIElements:
165
  ) as table_accordion:
166
  drias_table = gr.DataFrame([], elem_id="vanna-table")
167
 
168
- pagination_display = gr.Markdown(
169
- value="", visible=False, elem_id="pagination-display"
170
- )
171
-
172
- with gr.Row():
173
- prev_button = gr.Button("Previous", visible=False)
174
- next_button = gr.Button("Next", visible=False)
175
-
176
  return DriasUIElements(
177
  tab=tab,
178
  details_accordion=details_accordion,
179
  examples_hidden=examples_hidden,
180
  examples=examples,
 
181
  drias_direct_question=drias_direct_question,
182
  result_text=result_text,
183
  table_names_display=table_names_display,
@@ -188,9 +158,6 @@ def create_drias_ui() -> DriasUIElements:
188
  drias_display=drias_display,
189
  table_accordion=table_accordion,
190
  drias_table=drias_table,
191
- pagination_display=pagination_display,
192
- prev_button=prev_button,
193
- next_button=next_button
194
  )
195
 
196
 
@@ -210,6 +177,10 @@ def setup_drias_events(ui_elements: DriasUIElements, share_client=None, user_id=
210
  lambda x: (gr.Accordion(open=False), gr.Textbox(value=x)),
211
  inputs=[ui_elements["examples_hidden"]],
212
  outputs=[ui_elements["details_accordion"], ui_elements["drias_direct_question"]]
 
 
 
 
213
  ).then(
214
  ask_drias_query,
215
  inputs=[ui_elements["examples_hidden"], index_state, user_id],
@@ -226,25 +197,14 @@ def setup_drias_events(ui_elements: DriasUIElements, share_client=None, user_id=
226
  ],
227
  ).then(
228
  show_results,
229
- inputs=[sql_queries_state, dataframes_state, plots_state],
230
  outputs=[
231
  ui_elements["result_text"],
232
  ui_elements["query_accordion"],
233
  ui_elements["table_accordion"],
234
  ui_elements["chart_accordion"],
235
- ui_elements["prev_button"],
236
- ui_elements["next_button"],
237
- ui_elements["pagination_display"],
238
  ui_elements["table_names_display"],
239
  ],
240
- ).then(
241
- update_pagination,
242
- inputs=[index_state, sql_queries_state],
243
- outputs=[ui_elements["pagination_display"]],
244
- ).then(
245
- display_table_names,
246
- inputs=[table_names_list],
247
- outputs=[ui_elements["table_names_display"]],
248
  )
249
 
250
  # Handle direct question submission
@@ -252,6 +212,10 @@ def setup_drias_events(ui_elements: DriasUIElements, share_client=None, user_id=
252
  lambda: gr.Accordion(open=False),
253
  inputs=None,
254
  outputs=[ui_elements["details_accordion"]]
 
 
 
 
255
  ).then(
256
  ask_drias_query,
257
  inputs=[ui_elements["drias_direct_question"], index_state, user_id],
@@ -268,27 +232,15 @@ def setup_drias_events(ui_elements: DriasUIElements, share_client=None, user_id=
268
  ],
269
  ).then(
270
  show_results,
271
- inputs=[sql_queries_state, dataframes_state, plots_state],
272
  outputs=[
273
  ui_elements["result_text"],
274
  ui_elements["query_accordion"],
275
  ui_elements["table_accordion"],
276
  ui_elements["chart_accordion"],
277
- ui_elements["prev_button"],
278
- ui_elements["next_button"],
279
- ui_elements["pagination_display"],
280
  ui_elements["table_names_display"],
281
  ],
282
- ).then(
283
- update_pagination,
284
- inputs=[index_state, sql_queries_state],
285
- outputs=[ui_elements["pagination_display"]],
286
- ).then(
287
- display_table_names,
288
- inputs=[table_names_list],
289
- outputs=[ui_elements["table_names_display"]],
290
  )
291
-
292
  # Handle model selection change
293
  ui_elements["model_selection"].change(
294
  filter_by_model,
@@ -296,36 +248,12 @@ def setup_drias_events(ui_elements: DriasUIElements, share_client=None, user_id=
296
  outputs=[ui_elements["drias_table"], ui_elements["drias_display"]],
297
  )
298
 
299
- # Handle pagination buttons
300
- ui_elements["prev_button"].click(
301
- show_previous,
302
- inputs=[index_state, sql_queries_state, dataframes_state, plots_state],
303
- outputs=[ui_elements["drias_sql_query"], ui_elements["drias_table"], ui_elements["drias_display"], index_state],
304
- ).then(
305
- update_pagination,
306
- inputs=[index_state, sql_queries_state],
307
- outputs=[ui_elements["pagination_display"]],
308
- )
309
-
310
- ui_elements["next_button"].click(
311
- show_next,
312
- inputs=[index_state, sql_queries_state, dataframes_state, plots_state],
313
- outputs=[ui_elements["drias_sql_query"], ui_elements["drias_table"], ui_elements["drias_display"], index_state],
314
- ).then(
315
- update_pagination,
316
- inputs=[index_state, sql_queries_state],
317
- outputs=[ui_elements["pagination_display"]],
318
- )
319
 
320
  # Handle table selection
321
- ui_elements["table_names_display"].select(
322
  fn=on_table_click,
323
- inputs=[table_names_list, sql_queries_state, dataframes_state, plots_state],
324
  outputs=[ui_elements["drias_sql_query"], ui_elements["drias_table"], ui_elements["drias_display"], index_state],
325
- ).then(
326
- update_pagination,
327
- inputs=[index_state, sql_queries_state],
328
- outputs=[ui_elements["pagination_display"]],
329
  )
330
 
331
  def create_drias_tab(share_client=None, user_id=None):
 
11
  details_accordion: gr.Accordion
12
  examples_hidden: gr.Textbox
13
  examples: gr.Examples
14
+ image_examples: gr.Row
15
  drias_direct_question: gr.Textbox
16
  result_text: gr.Textbox
17
+ table_names_display: gr.Radio
18
  query_accordion: gr.Accordion
19
  drias_sql_query: gr.Textbox
20
  chart_accordion: gr.Accordion
 
22
  drias_display: gr.Plot
23
  table_accordion: gr.Accordion
24
  drias_table: gr.DataFrame
 
 
 
25
 
26
 
27
  async def ask_drias_query(query: str, index_state: int, user_id: str):
 
29
  return result
30
 
31
 
32
+ def show_results(sql_queries_state, dataframes_state, plots_state, table_names):
33
  if not sql_queries_state or not dataframes_state or not plots_state:
34
  # If all results are empty, show "No result"
35
  return (
 
38
  gr.update(visible=False),
39
  gr.update(visible=False),
40
  gr.update(visible=False),
 
 
 
41
  )
42
  else:
43
  # Show the appropriate components with their data
 
46
  gr.update(visible=True),
47
  gr.update(visible=True),
48
  gr.update(visible=True),
49
+ gr.update(choices=table_names, value=table_names[0], visible=True),
 
 
 
50
  )
51
 
52
 
 
64
  return df, figure
65
 
66
 
67
+ def on_table_click(selected_label, table_names, sql_queries, dataframes, plots):
68
+ index = table_names.index(selected_label)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  figure = plots[index](dataframes[index])
70
  return (
71
  sql_queries[index],
 
78
  def create_drias_ui() -> DriasUIElements:
79
  """Create and return all UI elements for the DRIAS tab."""
80
  with gr.Tab("France - Talk to DRIAS", elem_id="tab-vanna", id=6) as tab:
81
+ with gr.Accordion(label="❓ How to use?", elem_id="details") as details_accordion:
82
  gr.Markdown(DRIAS_UI_TEXT)
83
 
84
  # Add examples for common questions
 
102
  elem_id="direct-question",
103
  interactive=True,
104
  )
105
+
106
+
107
+ with gr.Row(visible=True, elem_id="example-img-container") as image_examples:
108
+ gr.Markdown("### Examples of possible visualizations")
109
+
110
+ with gr.Row():
111
+ gr.Image("./front/assets/talk_to_drias_winter_temp_paris_example.png", label="Evolution of Mean Winter Temperature in Paris", elem_classes=["example-img"])
112
+ gr.Image("./front/assets/talk_to_drias_annual_temperature_france_example.png", label="Mean Annual Temperature in 2030 in France", elem_classes=["example-img"])
113
+ gr.Image("./front/assets/talk_to_drias_frequency_remarkable_precipitation_lyon_example.png", label="Frequency of Remarkable Daily Precipitation in Lyon", elem_classes=["example-img"])
114
 
115
  result_text = gr.Textbox(
116
  label="", elem_id="no-result-label", interactive=False, visible=True
117
  )
118
+
119
+ with gr.Row():
120
+ table_names_display = gr.Radio(
121
+ choices=[],
122
+ label="Relevant figures created",
123
+ interactive=True,
124
+ elem_id="table-names",
125
+ visible=False
126
  )
127
 
128
+ with gr.Accordion(label="SQL Query Used", visible=False) as query_accordion:
129
+ drias_sql_query = gr.Textbox(
130
+ label="", elem_id="sql-query", interactive=False
131
+ )
132
+
133
+
134
  with gr.Accordion(label="Chart", visible=False) as chart_accordion:
135
  model_selection = gr.Dropdown(
136
  label="Model", choices=DRIAS_MODELS, value="ALL", interactive=True
 
142
  ) as table_accordion:
143
  drias_table = gr.DataFrame([], elem_id="vanna-table")
144
 
 
 
 
 
 
 
 
 
145
  return DriasUIElements(
146
  tab=tab,
147
  details_accordion=details_accordion,
148
  examples_hidden=examples_hidden,
149
  examples=examples,
150
+ image_examples=image_examples,
151
  drias_direct_question=drias_direct_question,
152
  result_text=result_text,
153
  table_names_display=table_names_display,
 
158
  drias_display=drias_display,
159
  table_accordion=table_accordion,
160
  drias_table=drias_table,
 
 
 
161
  )
162
 
163
 
 
177
  lambda x: (gr.Accordion(open=False), gr.Textbox(value=x)),
178
  inputs=[ui_elements["examples_hidden"]],
179
  outputs=[ui_elements["details_accordion"], ui_elements["drias_direct_question"]]
180
+ ).then(
181
+ lambda : gr.update(visible=False),
182
+ inputs=None,
183
+ outputs=ui_elements["image_examples"]
184
  ).then(
185
  ask_drias_query,
186
  inputs=[ui_elements["examples_hidden"], index_state, user_id],
 
197
  ],
198
  ).then(
199
  show_results,
200
+ inputs=[sql_queries_state, dataframes_state, plots_state, table_names_list],
201
  outputs=[
202
  ui_elements["result_text"],
203
  ui_elements["query_accordion"],
204
  ui_elements["table_accordion"],
205
  ui_elements["chart_accordion"],
 
 
 
206
  ui_elements["table_names_display"],
207
  ],
 
 
 
 
 
 
 
 
208
  )
209
 
210
  # Handle direct question submission
 
212
  lambda: gr.Accordion(open=False),
213
  inputs=None,
214
  outputs=[ui_elements["details_accordion"]]
215
+ ).then(
216
+ lambda : gr.update(visible=False),
217
+ inputs=None,
218
+ outputs=ui_elements["image_examples"]
219
  ).then(
220
  ask_drias_query,
221
  inputs=[ui_elements["drias_direct_question"], index_state, user_id],
 
232
  ],
233
  ).then(
234
  show_results,
235
+ inputs=[sql_queries_state, dataframes_state, plots_state, table_names_list],
236
  outputs=[
237
  ui_elements["result_text"],
238
  ui_elements["query_accordion"],
239
  ui_elements["table_accordion"],
240
  ui_elements["chart_accordion"],
 
 
 
241
  ui_elements["table_names_display"],
242
  ],
 
 
 
 
 
 
 
 
243
  )
 
244
  # Handle model selection change
245
  ui_elements["model_selection"].change(
246
  filter_by_model,
 
248
  outputs=[ui_elements["drias_table"], ui_elements["drias_display"]],
249
  )
250
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
 
252
  # Handle table selection
253
+ ui_elements["table_names_display"].change(
254
  fn=on_table_click,
255
+ inputs=[ui_elements["table_names_display"], table_names_list, sql_queries_state, dataframes_state, plots_state],
256
  outputs=[ui_elements["drias_sql_query"], ui_elements["drias_table"], ui_elements["drias_display"], index_state],
 
 
 
 
257
  )
258
 
259
  def create_drias_tab(share_client=None, user_id=None):