timeki commited on
Commit
1ed4e25
·
2 Parent(s): 11910c7 c6723cc

Merge hf-origin/main into main

Browse files
.gitattributes CHANGED
@@ -45,4 +45,4 @@ documents/climate_gpt_v2.faiss filter=lfs diff=lfs merge=lfs -text
45
  climateqa_v3.db filter=lfs diff=lfs merge=lfs -text
46
  climateqa_v3.faiss filter=lfs diff=lfs merge=lfs -text
47
  data/drias/drias.db filter=lfs diff=lfs merge=lfs -text
48
- front/assets/*.png filter=lfs diff=lfs merge=lfs -text
 
45
  climateqa_v3.db filter=lfs diff=lfs merge=lfs -text
46
  climateqa_v3.faiss filter=lfs diff=lfs merge=lfs -text
47
  data/drias/drias.db filter=lfs diff=lfs merge=lfs -text
48
+ front/assets/*.png filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -16,7 +16,10 @@ from climateqa.chat import start_chat, chat_stream, finish_chat
16
  from front.tabs import create_config_modal, cqa_tab, create_about_tab
17
  from front.tabs import MainTabPanel, ConfigPanel
18
  from front.tabs.tab_drias import create_drias_tab
 
19
  from front.tabs.tab_ipcc import create_ipcc_tab
 
 
20
  from front.utils import process_figures
21
  from gradio_modal import Modal
22
 
@@ -533,8 +536,13 @@ def main_ui():
533
  with gr.Tabs():
534
  cqa_components = cqa_tab(tab_name="ClimateQ&A")
535
  local_cqa_components = cqa_tab(tab_name="France - Local Q&A")
 
536
  drias_components = create_drias_tab(share_client=share_client, user_id=user_id)
537
  ipcc_components = create_ipcc_tab(share_client=share_client, user_id=user_id)
 
 
 
 
538
  create_about_tab()
539
 
540
  event_handling(cqa_components, config_components, tab_name="ClimateQ&A")
 
16
  from front.tabs import create_config_modal, cqa_tab, create_about_tab
17
  from front.tabs import MainTabPanel, ConfigPanel
18
  from front.tabs.tab_drias import create_drias_tab
19
+ <<<<<<< HEAD
20
  from front.tabs.tab_ipcc import create_ipcc_tab
21
+ =======
22
+ >>>>>>> hf-origin/main
23
  from front.utils import process_figures
24
  from gradio_modal import Modal
25
 
 
536
  with gr.Tabs():
537
  cqa_components = cqa_tab(tab_name="ClimateQ&A")
538
  local_cqa_components = cqa_tab(tab_name="France - Local Q&A")
539
+ <<<<<<< HEAD
540
  drias_components = create_drias_tab(share_client=share_client, user_id=user_id)
541
  ipcc_components = create_ipcc_tab(share_client=share_client, user_id=user_id)
542
+ =======
543
+ create_drias_tab(share_client=share_client, user_id=user_id)
544
+
545
+ >>>>>>> hf-origin/main
546
  create_about_tab()
547
 
548
  event_handling(cqa_components, config_components, tab_name="ClimateQ&A")
climateqa/engine/talk_to_data/main.py CHANGED
@@ -121,4 +121,4 @@ async def ask_ipcc(query: str, index_state: int = 0, user_id: str | None = None)
121
 
122
  log_drias_interaction_to_huggingface(query, sql_query, user_id)
123
 
124
- return sql_query, dataframe, figure, plot_information, sql_queries, result_dataframes, figures, plot_informations, index_state, plot_title_list, ""
 
121
 
122
  log_drias_interaction_to_huggingface(query, sql_query, user_id)
123
 
124
+ return sql_query, dataframe, figure, plot_information, sql_queries, result_dataframes, figures, plot_informations, index_state, plot_title_list, ""
climateqa/engine/talk_to_data/myVanna.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dotenv import load_dotenv
2
+ from climateqa.engine.talk_to_data.vanna_class import MyCustomVectorDB
3
+ from vanna.openai import OpenAI_Chat
4
+ import os
5
+
6
+ load_dotenv()
7
+
8
+ OPENAI_API_KEY = os.getenv('THEO_API_KEY')
9
+
10
+ class MyVanna(MyCustomVectorDB, OpenAI_Chat):
11
+ def __init__(self, config=None):
12
+ MyCustomVectorDB.__init__(self, config=config)
13
+ OpenAI_Chat.__init__(self, config=config)
climateqa/engine/talk_to_data/plot.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, TypedDict
2
+ from matplotlib.figure import figaspect
3
+ import pandas as pd
4
+ from plotly.graph_objects import Figure
5
+ import plotly.graph_objects as go
6
+ import plotly.express as px
7
+
8
+ from climateqa.engine.talk_to_data.sql_query import (
9
+ indicator_for_given_year_query,
10
+ indicator_per_year_at_location_query,
11
+ )
12
+ from climateqa.engine.talk_to_data.config import INDICATOR_TO_UNIT
13
+
14
+
15
+
16
+
17
+ class Plot(TypedDict):
18
+ """Represents a plot configuration in the DRIAS system.
19
+
20
+ This class defines the structure for configuring different types of plots
21
+ that can be generated from climate data.
22
+
23
+ Attributes:
24
+ name (str): The name of the plot type
25
+ description (str): A description of what the plot shows
26
+ params (list[str]): List of required parameters for the plot
27
+ plot_function (Callable[..., Callable[..., Figure]]): Function to generate the plot
28
+ sql_query (Callable[..., str]): Function to generate the SQL query for the plot
29
+ """
30
+ name: str
31
+ description: str
32
+ params: list[str]
33
+ plot_function: Callable[..., Callable[..., Figure]]
34
+ sql_query: Callable[..., str]
35
+
36
+
37
+ def plot_indicator_evolution_at_location(params: dict) -> Callable[..., Figure]:
38
+ """Generates a function to plot indicator evolution over time at a location.
39
+
40
+ This function creates a line plot showing how a climate indicator changes
41
+ over time at a specific location. It handles temperature, precipitation,
42
+ and other climate indicators.
43
+
44
+ Args:
45
+ params (dict): Dictionary containing:
46
+ - indicator_column (str): The column name for the indicator
47
+ - location (str): The location to plot
48
+ - model (str): The climate model to use
49
+
50
+ Returns:
51
+ Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
52
+
53
+ Example:
54
+ >>> plot_func = plot_indicator_evolution_at_location({
55
+ ... 'indicator_column': 'mean_temperature',
56
+ ... 'location': 'Paris',
57
+ ... 'model': 'ALL'
58
+ ... })
59
+ >>> fig = plot_func(df)
60
+ """
61
+ indicator = params["indicator_column"]
62
+ location = params["location"]
63
+ indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
64
+ unit = INDICATOR_TO_UNIT.get(indicator, "")
65
+
66
+ def plot_data(df: pd.DataFrame) -> Figure:
67
+ """Generates the actual plot from the data.
68
+
69
+ Args:
70
+ df (pd.DataFrame): DataFrame containing the data to plot
71
+
72
+ Returns:
73
+ Figure: A plotly Figure object showing the indicator evolution
74
+ """
75
+ fig = go.Figure()
76
+ if df['model'].nunique() != 1:
77
+ df_avg = df.groupby("year", as_index=False)[indicator].mean()
78
+
79
+ # Transform to list to avoid pandas encoding
80
+ indicators = df_avg[indicator].astype(float).tolist()
81
+ years = df_avg["year"].astype(int).tolist()
82
+
83
+ # Compute the 10-year rolling average
84
+ rolling_window = 10
85
+ sliding_averages = (
86
+ df_avg[indicator]
87
+ .rolling(window=rolling_window, min_periods=rolling_window)
88
+ .mean()
89
+ .astype(float)
90
+ .tolist()
91
+ )
92
+ model_label = "Model Average"
93
+
94
+ # Only add rolling average if we have enough data points
95
+ if len([x for x in sliding_averages if pd.notna(x)]) > 0:
96
+ # Sliding average dashed line
97
+ fig.add_scatter(
98
+ x=years,
99
+ y=sliding_averages,
100
+ mode="lines",
101
+ name="10 years rolling average",
102
+ line=dict(dash="dash"),
103
+ marker=dict(color="#d62728"),
104
+ hovertemplate=f"10-year average: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
105
+ )
106
+
107
+ else:
108
+ df_model = df
109
+
110
+ # Transform to list to avoid pandas encoding
111
+ indicators = df_model[indicator].astype(float).tolist()
112
+ years = df_model["year"].astype(int).tolist()
113
+
114
+ # Compute the 10-year rolling average
115
+ rolling_window = 10
116
+ sliding_averages = (
117
+ df_model[indicator]
118
+ .rolling(window=rolling_window, min_periods=rolling_window)
119
+ .mean()
120
+ .astype(float)
121
+ .tolist()
122
+ )
123
+ model_label = f"Model : {df['model'].unique()[0]}"
124
+
125
+ # Only add rolling average if we have enough data points
126
+ if len([x for x in sliding_averages if pd.notna(x)]) > 0:
127
+ # Sliding average dashed line
128
+ fig.add_scatter(
129
+ x=years,
130
+ y=sliding_averages,
131
+ mode="lines",
132
+ name="10 years rolling average",
133
+ line=dict(dash="dash"),
134
+ marker=dict(color="#d62728"),
135
+ hovertemplate=f"10-year average: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
136
+ )
137
+
138
+ # Indicator per year plot
139
+ fig.add_scatter(
140
+ x=years,
141
+ y=indicators,
142
+ name=f"Yearly {indicator_label}",
143
+ mode="lines",
144
+ marker=dict(color="#1f77b4"),
145
+ hovertemplate=f"{indicator_label}: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
146
+ )
147
+ fig.update_layout(
148
+ title=f"Plot of {indicator_label} in {location} ({model_label})",
149
+ xaxis_title="Year",
150
+ yaxis_title=f"{indicator_label} ({unit})",
151
+ template="plotly_white",
152
+ )
153
+ return fig
154
+
155
+ return plot_data
156
+
157
+
158
+ indicator_evolution_at_location: Plot = {
159
+ "name": "Indicator evolution at location",
160
+ "description": "Plot an evolution of the indicator at a certain location",
161
+ "params": ["indicator_column", "location", "model"],
162
+ "plot_function": plot_indicator_evolution_at_location,
163
+ "sql_query": indicator_per_year_at_location_query,
164
+ }
165
+
166
+
167
+ def plot_indicator_number_of_days_per_year_at_location(
168
+ params: dict,
169
+ ) -> Callable[..., Figure]:
170
+ """Generates a function to plot the number of days per year for an indicator.
171
+
172
+ This function creates a bar chart showing the frequency of certain climate
173
+ events (like days above a temperature threshold) per year at a specific location.
174
+
175
+ Args:
176
+ params (dict): Dictionary containing:
177
+ - indicator_column (str): The column name for the indicator
178
+ - location (str): The location to plot
179
+ - model (str): The climate model to use
180
+
181
+ Returns:
182
+ Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
183
+ """
184
+ indicator = params["indicator_column"]
185
+ location = params["location"]
186
+ indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
187
+ unit = INDICATOR_TO_UNIT.get(indicator, "")
188
+
189
+ def plot_data(df: pd.DataFrame) -> Figure:
190
+ """Generate the figure thanks to the dataframe
191
+
192
+ Args:
193
+ df (pd.DataFrame): pandas dataframe with the required data
194
+
195
+ Returns:
196
+ Figure: Plotly figure
197
+ """
198
+ fig = go.Figure()
199
+ if df['model'].nunique() != 1:
200
+ df_avg = df.groupby("year", as_index=False)[indicator].mean()
201
+
202
+ # Transform to list to avoid pandas encoding
203
+ indicators = df_avg[indicator].astype(float).tolist()
204
+ years = df_avg["year"].astype(int).tolist()
205
+ model_label = "Model Average"
206
+
207
+ else:
208
+ df_model = df
209
+ # Transform to list to avoid pandas encoding
210
+ indicators = df_model[indicator].astype(float).tolist()
211
+ years = df_model["year"].astype(int).tolist()
212
+ model_label = f"Model : {df['model'].unique()[0]}"
213
+
214
+
215
+ # Bar plot
216
+ fig.add_trace(
217
+ go.Bar(
218
+ x=years,
219
+ y=indicators,
220
+ width=0.5,
221
+ marker=dict(color="#1f77b4"),
222
+ hovertemplate=f"{indicator_label}: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
223
+ )
224
+ )
225
+
226
+ fig.update_layout(
227
+ title=f"{indicator_label} in {location} ({model_label})",
228
+ xaxis_title="Year",
229
+ yaxis_title=f"{indicator_label} ({unit})",
230
+ yaxis=dict(range=[0, max(indicators)]),
231
+ bargap=0.5,
232
+ template="plotly_white",
233
+ )
234
+
235
+ return fig
236
+
237
+ return plot_data
238
+
239
+
240
+ indicator_number_of_days_per_year_at_location: Plot = {
241
+ "name": "Indicator number of days per year at location",
242
+ "description": "Plot a barchart of the number of days per year of a certain indicator at a certain location. It is appropriate for frequency indicator.",
243
+ "params": ["indicator_column", "location", "model"],
244
+ "plot_function": plot_indicator_number_of_days_per_year_at_location,
245
+ "sql_query": indicator_per_year_at_location_query,
246
+ }
247
+
248
+
249
+ def plot_distribution_of_indicator_for_given_year(
250
+ params: dict,
251
+ ) -> Callable[..., Figure]:
252
+ """Generates a function to plot the distribution of an indicator for a year.
253
+
254
+ This function creates a histogram showing the distribution of a climate
255
+ indicator across different locations for a specific year.
256
+
257
+ Args:
258
+ params (dict): Dictionary containing:
259
+ - indicator_column (str): The column name for the indicator
260
+ - year (str): The year to plot
261
+ - model (str): The climate model to use
262
+
263
+ Returns:
264
+ Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
265
+ """
266
+ indicator = params["indicator_column"]
267
+ year = params["year"]
268
+ indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
269
+ unit = INDICATOR_TO_UNIT.get(indicator, "")
270
+
271
+ def plot_data(df: pd.DataFrame) -> Figure:
272
+ """Generate the figure thanks to the dataframe
273
+
274
+ Args:
275
+ df (pd.DataFrame): pandas dataframe with the required data
276
+
277
+ Returns:
278
+ Figure: Plotly figure
279
+ """
280
+ fig = go.Figure()
281
+ if df['model'].nunique() != 1:
282
+ df_avg = df.groupby(["latitude", "longitude"], as_index=False)[
283
+ indicator
284
+ ].mean()
285
+
286
+ # Transform to list to avoid pandas encoding
287
+ indicators = df_avg[indicator].astype(float).tolist()
288
+ model_label = "Model Average"
289
+
290
+ else:
291
+ df_model = df
292
+
293
+ # Transform to list to avoid pandas encoding
294
+ indicators = df_model[indicator].astype(float).tolist()
295
+ model_label = f"Model : {df['model'].unique()[0]}"
296
+
297
+
298
+ fig.add_trace(
299
+ go.Histogram(
300
+ x=indicators,
301
+ opacity=0.8,
302
+ histnorm="percent",
303
+ marker=dict(color="#1f77b4"),
304
+ hovertemplate=f"{indicator_label}: %{{x:.2f}} {unit}<br>Frequency: %{{y:.2f}}%<extra></extra>"
305
+ )
306
+ )
307
+
308
+ fig.update_layout(
309
+ title=f"Distribution of {indicator_label} in {year} ({model_label})",
310
+ xaxis_title=f"{indicator_label} ({unit})",
311
+ yaxis_title="Frequency (%)",
312
+ plot_bgcolor="rgba(0, 0, 0, 0)",
313
+ showlegend=False,
314
+ )
315
+
316
+ return fig
317
+
318
+ return plot_data
319
+
320
+
321
+ distribution_of_indicator_for_given_year: Plot = {
322
+ "name": "Distribution of an indicator for a given year",
323
+ "description": "Plot an histogram of the distribution for a given year of the values of an indicator",
324
+ "params": ["indicator_column", "model", "year"],
325
+ "plot_function": plot_distribution_of_indicator_for_given_year,
326
+ "sql_query": indicator_for_given_year_query,
327
+ }
328
+
329
+
330
+ def plot_map_of_france_of_indicator_for_given_year(
331
+ params: dict,
332
+ ) -> Callable[..., Figure]:
333
+ """Generates a function to plot a map of France for an indicator.
334
+
335
+ This function creates a choropleth map of France showing the spatial
336
+ distribution of a climate indicator for a specific year.
337
+
338
+ Args:
339
+ params (dict): Dictionary containing:
340
+ - indicator_column (str): The column name for the indicator
341
+ - year (str): The year to plot
342
+ - model (str): The climate model to use
343
+
344
+ Returns:
345
+ Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
346
+ """
347
+ indicator = params["indicator_column"]
348
+ year = params["year"]
349
+ indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
350
+ unit = INDICATOR_TO_UNIT.get(indicator, "")
351
+
352
+ def plot_data(df: pd.DataFrame) -> Figure:
353
+ fig = go.Figure()
354
+ if df['model'].nunique() != 1:
355
+ df_avg = df.groupby(["latitude", "longitude"], as_index=False)[
356
+ indicator
357
+ ].mean()
358
+
359
+ indicators = df_avg[indicator].astype(float).tolist()
360
+ latitudes = df_avg["latitude"].astype(float).tolist()
361
+ longitudes = df_avg["longitude"].astype(float).tolist()
362
+ model_label = "Model Average"
363
+
364
+ else:
365
+ df_model = df
366
+
367
+ # Transform to list to avoid pandas encoding
368
+ indicators = df_model[indicator].astype(float).tolist()
369
+ latitudes = df_model["latitude"].astype(float).tolist()
370
+ longitudes = df_model["longitude"].astype(float).tolist()
371
+ model_label = f"Model : {df['model'].unique()[0]}"
372
+
373
+
374
+ fig.add_trace(
375
+ go.Scattermapbox(
376
+ lat=latitudes,
377
+ lon=longitudes,
378
+ mode="markers",
379
+ marker=dict(
380
+ size=10,
381
+ color=indicators, # Color mapped to values
382
+ colorscale="Turbo", # Color scale (can be 'Plasma', 'Jet', etc.)
383
+ cmin=min(indicators), # Minimum color range
384
+ cmax=max(indicators), # Maximum color range
385
+ showscale=True, # Show colorbar
386
+ ),
387
+ text=[f"{indicator_label}: {value:.2f} {unit}" for value in indicators], # Add hover text showing the indicator value
388
+ hoverinfo="text" # Only show the custom text on hover
389
+ )
390
+ )
391
+
392
+ fig.update_layout(
393
+ mapbox_style="open-street-map", # Use OpenStreetMap
394
+ mapbox_zoom=3,
395
+ mapbox_center={"lat": 46.6, "lon": 2.0},
396
+ coloraxis_colorbar=dict(title=f"{indicator_label} ({unit})"), # Add legend
397
+ title=f"{indicator_label} in {year} in France ({model_label}) " # Title
398
+ )
399
+ return fig
400
+
401
+ return plot_data
402
+
403
+
404
+ map_of_france_of_indicator_for_given_year: Plot = {
405
+ "name": "Map of France of an indicator for a given year",
406
+ "description": "Heatmap on the map of France of the values of an in indicator for a given year",
407
+ "params": ["indicator_column", "year", "model"],
408
+ "plot_function": plot_map_of_france_of_indicator_for_given_year,
409
+ "sql_query": indicator_for_given_year_query,
410
+ }
411
+
412
+
413
+ PLOTS = [
414
+ indicator_evolution_at_location,
415
+ indicator_number_of_days_per_year_at_location,
416
+ distribution_of_indicator_for_given_year,
417
+ map_of_france_of_indicator_for_given_year,
418
+ ]
climateqa/engine/talk_to_data/sql_query.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from concurrent.futures import ThreadPoolExecutor
3
+ from typing import TypedDict
4
+ import duckdb
5
+ import pandas as pd
6
+
7
+ async def execute_sql_query(sql_query: str) -> pd.DataFrame:
8
+ """Executes a SQL query on the DRIAS database and returns the results.
9
+
10
+ This function connects to the DuckDB database containing DRIAS climate data
11
+ and executes the provided SQL query. It handles the database connection and
12
+ returns the results as a pandas DataFrame.
13
+
14
+ Args:
15
+ sql_query (str): The SQL query to execute
16
+
17
+ Returns:
18
+ pd.DataFrame: A DataFrame containing the query results
19
+
20
+ Raises:
21
+ duckdb.Error: If there is an error executing the SQL query
22
+ """
23
+ def _execute_query():
24
+ # Execute the query
25
+ con = duckdb.connect()
26
+ results = con.sql(sql_query).fetchdf()
27
+ # return fetched data
28
+ return results
29
+
30
+ # Run the query in a thread pool to avoid blocking
31
+ loop = asyncio.get_event_loop()
32
+ with ThreadPoolExecutor() as executor:
33
+ return await loop.run_in_executor(executor, _execute_query)
34
+
35
+
36
+ class IndicatorPerYearAtLocationQueryParams(TypedDict, total=False):
37
+ """Parameters for querying an indicator's values over time at a location.
38
+
39
+ This class defines the parameters needed to query climate indicator data
40
+ for a specific location over multiple years.
41
+
42
+ Attributes:
43
+ indicator_column (str): The column name for the climate indicator
44
+ latitude (str): The latitude coordinate of the location
45
+ longitude (str): The longitude coordinate of the location
46
+ model (str): The climate model to use (optional)
47
+ """
48
+ indicator_column: str
49
+ latitude: str
50
+ longitude: str
51
+ model: str
52
+
53
+
54
+ def indicator_per_year_at_location_query(
55
+ table: str, params: IndicatorPerYearAtLocationQueryParams
56
+ ) -> str:
57
+ """SQL Query to get the evolution of an indicator per year at a certain location
58
+
59
+ Args:
60
+ table (str): sql table of the indicator
61
+ params (IndicatorPerYearAtLocationQueryParams) : dictionary with the required params for the query
62
+
63
+ Returns:
64
+ str: the sql query
65
+ """
66
+ indicator_column = params.get("indicator_column")
67
+ latitude = params.get("latitude")
68
+ longitude = params.get("longitude")
69
+
70
+ if indicator_column is None or latitude is None or longitude is None: # If one parameter is missing, returns an empty query
71
+ return ""
72
+
73
+ table = f"'hf://datasets/timeki/drias_db/{table.lower()}.parquet'"
74
+
75
+ sql_query = f"SELECT year, {indicator_column}, model\nFROM {table}\nWHERE latitude = {latitude} \nAnd longitude = {longitude} \nOrder by Year"
76
+
77
+ return sql_query
78
+
79
+ class IndicatorForGivenYearQueryParams(TypedDict, total=False):
80
+ """Parameters for querying an indicator's values across locations for a year.
81
+
82
+ This class defines the parameters needed to query climate indicator data
83
+ across different locations for a specific year.
84
+
85
+ Attributes:
86
+ indicator_column (str): The column name for the climate indicator
87
+ year (str): The year to query
88
+ model (str): The climate model to use (optional)
89
+ """
90
+ indicator_column: str
91
+ year: str
92
+ model: str
93
+
94
+ def indicator_for_given_year_query(
95
+ table:str, params: IndicatorForGivenYearQueryParams
96
+ ) -> str:
97
+ """SQL Query to get the values of an indicator with their latitudes, longitudes and models for a given year
98
+
99
+ Args:
100
+ table (str): sql table of the indicator
101
+ params (IndicatorForGivenYearQueryParams): dictionarry with the required params for the query
102
+
103
+ Returns:
104
+ str: the sql query
105
+ """
106
+ indicator_column = params.get("indicator_column")
107
+ year = params.get('year')
108
+ if year is None or indicator_column is None: # If one parameter is missing, returns an empty query
109
+ return ""
110
+
111
+ table = f"'hf://datasets/timeki/drias_db/{table.lower()}.parquet'"
112
+
113
+ sql_query = f"Select {indicator_column}, latitude, longitude, model\nFrom {table}\nWhere year = {year}"
114
+ return sql_query
climateqa/engine/talk_to_data/talk_to_drias.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from typing import Any, Callable, TypedDict, Optional
4
+ from numpy import sort
5
+ import pandas as pd
6
+ import asyncio
7
+ from plotly.graph_objects import Figure
8
+ from climateqa.engine.llm import get_llm
9
+ from climateqa.engine.talk_to_data import sql_query
10
+ from climateqa.engine.talk_to_data.config import INDICATOR_COLUMNS_PER_TABLE
11
+ from climateqa.engine.talk_to_data.plot import PLOTS, Plot
12
+ from climateqa.engine.talk_to_data.sql_query import execute_sql_query
13
+ from climateqa.engine.talk_to_data.utils import (
14
+ detect_relevant_plots,
15
+ detect_year_with_openai,
16
+ loc2coords,
17
+ detect_location_with_openai,
18
+ nearestNeighbourSQL,
19
+ detect_relevant_tables,
20
+ )
21
+
22
+
23
+ ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))
24
+
25
+ class TableState(TypedDict):
26
+ """Represents the state of a table in the DRIAS workflow.
27
+
28
+ This class defines the structure for tracking the state of a table during the
29
+ data processing workflow, including its name, parameters, SQL query, and results.
30
+
31
+ Attributes:
32
+ table_name (str): The name of the table in the database
33
+ params (dict[str, Any]): Parameters used for querying the table
34
+ sql_query (str, optional): The SQL query used to fetch data
35
+ dataframe (pd.DataFrame | None, optional): The resulting data
36
+ figure (Callable[..., Figure], optional): Function to generate visualization
37
+ status (str): The current status of the table processing ('OK' or 'ERROR')
38
+ """
39
+ table_name: str
40
+ params: dict[str, Any]
41
+ sql_query: Optional[str]
42
+ dataframe: Optional[pd.DataFrame | None]
43
+ figure: Optional[Callable[..., Figure]]
44
+ status: str
45
+
46
+ class PlotState(TypedDict):
47
+ """Represents the state of a plot in the DRIAS workflow.
48
+
49
+ This class defines the structure for tracking the state of a plot during the
50
+ data processing workflow, including its name and associated tables.
51
+
52
+ Attributes:
53
+ plot_name (str): The name of the plot
54
+ tables (list[str]): List of tables used in the plot
55
+ table_states (dict[str, TableState]): States of the tables used in the plot
56
+ """
57
+ plot_name: str
58
+ tables: list[str]
59
+ table_states: dict[str, TableState]
60
+
61
+ class State(TypedDict):
62
+ user_input: str
63
+ plots: list[str]
64
+ plot_states: dict[str, PlotState]
65
+ error: Optional[str]
66
+
67
+ async def find_relevant_plots(state: State, llm) -> list[str]:
68
+ print("---- Find relevant plots ----")
69
+ relevant_plots = await detect_relevant_plots(state['user_input'], llm)
70
+ return relevant_plots
71
+
72
+ async def find_relevant_tables_per_plot(state: State, plot: Plot, llm) -> list[str]:
73
+ print(f"---- Find relevant tables for {plot['name']} ----")
74
+ relevant_tables = await detect_relevant_tables(state['user_input'], plot, llm)
75
+ return relevant_tables
76
+
77
+ async def find_param(state: State, param_name:str, table: str) -> dict[str, Any] | None:
78
+ """Perform the good method to retrieve the desired parameter
79
+
80
+ Args:
81
+ state (State): state of the workflow
82
+ param_name (str): name of the desired parameter
83
+ table (str): name of the table
84
+
85
+ Returns:
86
+ dict[str, Any] | None:
87
+ """
88
+ if param_name == 'location':
89
+ location = await find_location(state['user_input'], table)
90
+ return location
91
+ if param_name == 'year':
92
+ year = await find_year(state['user_input'])
93
+ return {'year': year}
94
+ return None
95
+
96
+ class Location(TypedDict):
97
+ location: str
98
+ latitude: Optional[str]
99
+ longitude: Optional[str]
100
+
101
+ async def find_location(user_input: str, table: str) -> Location:
102
+ print(f"---- Find location in table {table} ----")
103
+ location = await detect_location_with_openai(user_input)
104
+ output: Location = {'location' : location}
105
+ if location:
106
+ coords = loc2coords(location)
107
+ neighbour = nearestNeighbourSQL(coords, table)
108
+ output.update({
109
+ "latitude": neighbour[0],
110
+ "longitude": neighbour[1],
111
+ })
112
+ return output
113
+
114
+ async def find_year(user_input: str) -> str:
115
+ """Extracts year information from user input using LLM.
116
+
117
+ This function uses an LLM to identify and extract year information from the
118
+ user's query, which is used to filter data in subsequent queries.
119
+
120
+ Args:
121
+ user_input (str): The user's query text
122
+
123
+ Returns:
124
+ str: The extracted year, or empty string if no year found
125
+ """
126
+ print(f"---- Find year ---")
127
+ year = await detect_year_with_openai(user_input)
128
+ return year
129
+
130
+ def find_indicator_column(table: str) -> str:
131
+ """Retrieves the name of the indicator column within a table.
132
+
133
+ This function maps table names to their corresponding indicator columns
134
+ using the predefined mapping in INDICATOR_COLUMNS_PER_TABLE.
135
+
136
+ Args:
137
+ table (str): Name of the table in the database
138
+
139
+ Returns:
140
+ str: Name of the indicator column for the specified table
141
+
142
+ Raises:
143
+ KeyError: If the table name is not found in the mapping
144
+ """
145
+ print(f"---- Find indicator column in table {table} ----")
146
+ return INDICATOR_COLUMNS_PER_TABLE[table]
147
+
148
+
149
+ async def process_table(
150
+ table: str,
151
+ params: dict[str, Any],
152
+ plot: Plot,
153
+ ) -> TableState:
154
+ """Processes a table to extract relevant data and generate visualizations.
155
+
156
+ This function retrieves the SQL query for the specified table, executes it,
157
+ and generates a visualization based on the results.
158
+
159
+ Args:
160
+ table (str): The name of the table to process
161
+ params (dict[str, Any]): Parameters used for querying the table
162
+ plot (Plot): The plot object containing SQL query and visualization function
163
+
164
+ Returns:
165
+ TableState: The state of the processed table
166
+ """
167
+ table_state: TableState = {
168
+ 'table_name': table,
169
+ 'params': params.copy(),
170
+ 'status': 'OK',
171
+ 'dataframe': None,
172
+ 'sql_query': None,
173
+ 'figure': None
174
+ }
175
+
176
+ table_state['params']['indicator_column'] = find_indicator_column(table)
177
+ sql_query = plot['sql_query'](table, table_state['params'])
178
+
179
+ if sql_query == "":
180
+ table_state['status'] = 'ERROR'
181
+ return table_state
182
+ table_state['sql_query'] = sql_query
183
+ df = await execute_sql_query(sql_query)
184
+
185
+ table_state['dataframe'] = df
186
+ table_state['figure'] = plot['plot_function'](table_state['params'])
187
+
188
+ return table_state
189
+
190
+ async def drias_workflow(user_input: str) -> State:
191
+ """Performs the complete workflow of Talk To Drias : from user input to sql queries, dataframes and figures generated
192
+
193
+ Args:
194
+ user_input (str): initial user input
195
+
196
+ Returns:
197
+ State: Final state with all the results
198
+ """
199
+ state: State = {
200
+ 'user_input': user_input,
201
+ 'plots': [],
202
+ 'plot_states': {},
203
+ 'error': ''
204
+ }
205
+
206
+ llm = get_llm(provider="openai")
207
+
208
+ plots = await find_relevant_plots(state, llm)
209
+
210
+ state['plots'] = plots
211
+
212
+ if len(state['plots']) < 1:
213
+ state['error'] = 'There is no plot to answer to the question'
214
+ return state
215
+
216
+ have_relevant_table = False
217
+ have_sql_query = False
218
+ have_dataframe = False
219
+
220
+ for plot_name in state['plots']:
221
+
222
+ plot = next((p for p in PLOTS if p['name'] == plot_name), None) # Find the associated plot object
223
+ if plot is None:
224
+ continue
225
+
226
+ plot_state: PlotState = {
227
+ 'plot_name': plot_name,
228
+ 'tables': [],
229
+ 'table_states': {}
230
+ }
231
+
232
+ plot_state['plot_name'] = plot_name
233
+
234
+ relevant_tables = await find_relevant_tables_per_plot(state, plot, llm)
235
+
236
+ if len(relevant_tables) > 0 :
237
+ have_relevant_table = True
238
+
239
+ plot_state['tables'] = relevant_tables
240
+
241
+ params = {}
242
+ for param_name in plot['params']:
243
+ param = await find_param(state, param_name, relevant_tables[0])
244
+ if param:
245
+ params.update(param)
246
+
247
+ tasks = [process_table(table, params, plot) for table in plot_state['tables'][:3]]
248
+ results = await asyncio.gather(*tasks)
249
+
250
+ # Store results back in plot_state
251
+ have_dataframe = False
252
+ have_sql_query = False
253
+ for table_state in results:
254
+ if table_state['sql_query']:
255
+ have_sql_query = True
256
+ if table_state['dataframe'] is not None and len(table_state['dataframe']) > 0:
257
+ have_dataframe = True
258
+ plot_state['table_states'][table_state['table_name']] = table_state
259
+
260
+ state['plot_states'][plot_name] = plot_state
261
+
262
+ if not have_relevant_table:
263
+ state['error'] = "There is no relevant table in our database to answer your question"
264
+ elif not have_sql_query:
265
+ state['error'] = "There is no relevant sql query on our database that can help to answer your question"
266
+ elif not have_dataframe:
267
+ state['error'] = "There is no data in our table that can answer to your question"
268
+
269
+ return state
270
+
271
+ # def make_write_query_node():
272
+
273
+ # def write_query(state):
274
+ # print("---- Write query ----")
275
+ # for table in state["tables"]:
276
+ # sql_query = QUERIES[state[table]['query_type']](
277
+ # table=table,
278
+ # indicator_column=state[table]["columns"],
279
+ # longitude=state[table]["longitude"],
280
+ # latitude=state[table]["latitude"],
281
+ # )
282
+ # state[table].update({"sql_query": sql_query})
283
+
284
+ # return state
285
+
286
+ # return write_query
287
+
288
+ # def make_fetch_data_node(db_path):
289
+
290
+ # def fetch_data(state):
291
+ # print("---- Fetch data ----")
292
+ # for table in state["tables"]:
293
+ # results = execute_sql_query(db_path, state[table]['sql_query'])
294
+ # state[table].update(results)
295
+
296
+ # return state
297
+
298
+ # return fetch_data
299
+
300
+
301
+
302
+ ## V2
303
+
304
+
305
+ # def make_fetch_data_node(db_path: str, llm):
306
+ # def fetch_data(state):
307
+ # print("---- Fetch data ----")
308
+ # db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
309
+ # output = {}
310
+ # sql_query = write_sql_query(state["query"], db, state["tables"], llm)
311
+ # # TO DO : Add query checker
312
+ # print(f"SQL query : {sql_query}")
313
+ # output["sql_query"] = sql_query
314
+ # output.update(fetch_data_from_sql_query(db_path, sql_query))
315
+ # return output
316
+
317
+ # return fetch_data
climateqa/engine/talk_to_data/utils.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import Annotated, TypedDict
3
+ import duckdb
4
+ from geopy.geocoders import Nominatim
5
+ import ast
6
+ from climateqa.engine.llm import get_llm
7
+ from climateqa.engine.talk_to_data.config import DRIAS_TABLES
8
+ from climateqa.engine.talk_to_data.plot import PLOTS, Plot
9
+ from langchain_core.prompts import ChatPromptTemplate
10
+
11
+
12
+ async def detect_location_with_openai(sentence):
13
+ """
14
+ Detects locations in a sentence using OpenAI's API via LangChain.
15
+ """
16
+ llm = get_llm()
17
+
18
+ prompt = f"""
19
+ Extract all locations (cities, countries, states, or geographical areas) mentioned in the following sentence.
20
+ Return the result as a Python list. If no locations are mentioned, return an empty list.
21
+
22
+ Sentence: "{sentence}"
23
+ """
24
+
25
+ response = await llm.ainvoke(prompt)
26
+ location_list = ast.literal_eval(response.content.strip("```python\n").strip())
27
+ if location_list:
28
+ return location_list[0]
29
+ else:
30
+ return ""
31
+
32
+ class ArrayOutput(TypedDict):
33
+ """Represents the output of a function that returns an array.
34
+
35
+ This class is used to type-hint functions that return arrays,
36
+ ensuring consistent return types across the codebase.
37
+
38
+ Attributes:
39
+ array (str): A syntactically valid Python array string
40
+ """
41
+ array: Annotated[str, "Syntactically valid python array."]
42
+
43
+ async def detect_year_with_openai(sentence: str) -> str:
44
+ """
45
+ Detects years in a sentence using OpenAI's API via LangChain.
46
+ """
47
+ llm = get_llm()
48
+
49
+ prompt = """
50
+ Extract all years mentioned in the following sentence.
51
+ Return the result as a Python list. If no year are mentioned, return an empty list.
52
+
53
+ Sentence: "{sentence}"
54
+ """
55
+
56
+ prompt = ChatPromptTemplate.from_template(prompt)
57
+ structured_llm = llm.with_structured_output(ArrayOutput)
58
+ chain = prompt | structured_llm
59
+ response: ArrayOutput = await chain.ainvoke({"sentence": sentence})
60
+ years_list = eval(response['array'])
61
+ if len(years_list) > 0:
62
+ return years_list[0]
63
+ else:
64
+ return ""
65
+
66
+
67
+ def detectTable(sql_query: str) -> list[str]:
68
+ """Extracts table names from a SQL query.
69
+
70
+ This function uses regular expressions to find all table names
71
+ referenced in a SQL query's FROM clause.
72
+
73
+ Args:
74
+ sql_query (str): The SQL query to analyze
75
+
76
+ Returns:
77
+ list[str]: A list of table names found in the query
78
+
79
+ Example:
80
+ >>> detectTable("SELECT * FROM temperature_data WHERE year > 2000")
81
+ ['temperature_data']
82
+ """
83
+ pattern = r'(?i)\bFROM\s+((?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+)(?:\.(?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+))*)'
84
+ matches = re.findall(pattern, sql_query)
85
+ return matches
86
+
87
+
88
+ def loc2coords(location: str) -> tuple[float, float]:
89
+ """Converts a location name to geographic coordinates.
90
+
91
+ This function uses the Nominatim geocoding service to convert
92
+ a location name (e.g., city name) to its latitude and longitude.
93
+
94
+ Args:
95
+ location (str): The name of the location to geocode
96
+
97
+ Returns:
98
+ tuple[float, float]: A tuple containing (latitude, longitude)
99
+
100
+ Raises:
101
+ AttributeError: If the location cannot be found
102
+ """
103
+ geolocator = Nominatim(user_agent="city_to_latlong")
104
+ coords = geolocator.geocode(location)
105
+ return (coords.latitude, coords.longitude)
106
+
107
+
108
+ def coords2loc(coords: tuple[float, float]) -> str:
109
+ """Converts geographic coordinates to a location name.
110
+
111
+ This function uses the Nominatim reverse geocoding service to convert
112
+ latitude and longitude coordinates to a human-readable location name.
113
+
114
+ Args:
115
+ coords (tuple[float, float]): A tuple containing (latitude, longitude)
116
+
117
+ Returns:
118
+ str: The address of the location, or "Unknown Location" if not found
119
+
120
+ Example:
121
+ >>> coords2loc((48.8566, 2.3522))
122
+ 'Paris, France'
123
+ """
124
+ geolocator = Nominatim(user_agent="coords_to_city")
125
+ try:
126
+ location = geolocator.reverse(coords)
127
+ return location.address
128
+ except Exception as e:
129
+ print(f"Error: {e}")
130
+ return "Unknown Location"
131
+
132
+
133
+ def nearestNeighbourSQL(location: tuple, table: str) -> tuple[str, str]:
134
+ long = round(location[1], 3)
135
+ lat = round(location[0], 3)
136
+
137
+ table = f"'hf://datasets/timeki/drias_db/{table.lower()}.parquet'"
138
+
139
+ results = duckdb.sql(
140
+ f"SELECT latitude, longitude FROM {table} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}"
141
+ ).fetchdf()
142
+
143
+ if len(results) == 0:
144
+ return "", ""
145
+ # cursor.execute(f"SELECT latitude, longitude FROM {table} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}")
146
+ return results['latitude'].iloc[0], results['longitude'].iloc[0]
147
+
148
+
149
+ async def detect_relevant_tables(user_question: str, plot: Plot, llm) -> list[str]:
150
+ """Identifies relevant tables for a plot based on user input.
151
+
152
+ This function uses an LLM to analyze the user's question and the plot
153
+ description to determine which tables in the DRIAS database would be
154
+ most relevant for generating the requested visualization.
155
+
156
+ Args:
157
+ user_question (str): The user's question about climate data
158
+ plot (Plot): The plot configuration object
159
+ llm: The language model instance to use for analysis
160
+
161
+ Returns:
162
+ list[str]: A list of table names that are relevant for the plot
163
+
164
+ Example:
165
+ >>> detect_relevant_tables(
166
+ ... "What will the temperature be like in Paris?",
167
+ ... indicator_evolution_at_location,
168
+ ... llm
169
+ ... )
170
+ ['mean_annual_temperature', 'mean_summer_temperature']
171
+ """
172
+ # Get all table names
173
+ table_names_list = DRIAS_TABLES
174
+
175
+ prompt = (
176
+ f"You are helping to build a plot following this description : {plot['description']}."
177
+ f"You are given a list of tables and a user question."
178
+ f"Based on the description of the plot, which table are appropriate for that kind of plot."
179
+ f"Write the 3 most relevant tables to use. Answer only a python list of table name."
180
+ f"### List of tables : {table_names_list}"
181
+ f"### User question : {user_question}"
182
+ f"### List of table name : "
183
+ )
184
+
185
+ table_names = ast.literal_eval(
186
+ (await llm.ainvoke(prompt)).content.strip("```python\n").strip()
187
+ )
188
+ return table_names
189
+
190
+
191
+ def replace_coordonates(coords, query, coords_tables):
192
+ n = query.count(str(coords[0]))
193
+
194
+ for i in range(n):
195
+ query = query.replace(str(coords[0]), str(coords_tables[i][0]), 1)
196
+ query = query.replace(str(coords[1]), str(coords_tables[i][1]), 1)
197
+ return query
198
+
199
+
200
+ async def detect_relevant_plots(user_question: str, llm):
201
+ plots_description = ""
202
+ for plot in PLOTS:
203
+ plots_description += "Name: " + plot["name"]
204
+ plots_description += " - Description: " + plot["description"] + "\n"
205
+
206
+ prompt = (
207
+ f"You are helping to answer a quesiton with insightful visualizations."
208
+ f"You are given an user question and a list of plots with their name and description."
209
+ f"Based on the descriptions of the plots, which plot is appropriate to answer to this question."
210
+ f"Write the most relevant tables to use. Answer only a python list of plot name."
211
+ f"### Descriptions of the plots : {plots_description}"
212
+ f"### User question : {user_question}"
213
+ f"### Name of the plot : "
214
+ )
215
+ # prompt = (
216
+ # f"You are helping to answer a question with insightful visualizations. "
217
+ # f"Given a list of plots with their name and description: "
218
+ # f"{plots_description} "
219
+ # f"The user question is: {user_question}. "
220
+ # f"Choose the most relevant plots to answer the question. "
221
+ # f"The answer must be a Python list with the names of the relevant plots, and nothing else. "
222
+ # f"Ensure the response is in the exact format: ['PlotName1', 'PlotName2']."
223
+ # )
224
+
225
+ plot_names = ast.literal_eval(
226
+ (await llm.ainvoke(prompt)).content.strip("```python\n").strip()
227
+ )
228
+ return plot_names
229
+
230
+
231
+ # Next Version
232
+ # class QueryOutput(TypedDict):
233
+ # """Generated SQL query."""
234
+
235
+ # query: Annotated[str, ..., "Syntactically valid SQL query."]
236
+
237
+
238
+ # class PlotlyCodeOutput(TypedDict):
239
+ # """Generated Plotly code"""
240
+
241
+ # code: Annotated[str, ..., "Synatically valid Plotly python code."]
242
+ # def write_sql_query(user_input: str, db: SQLDatabase, relevant_tables: list[str], llm):
243
+ # """Generate SQL query to fetch information."""
244
+ # prompt_params = {
245
+ # "dialect": db.dialect,
246
+ # "table_info": db.get_table_info(),
247
+ # "input": user_input,
248
+ # "relevant_tables": relevant_tables,
249
+ # "model": "ALADIN63_CNRM-CM5",
250
+ # }
251
+
252
+ # prompt = ChatPromptTemplate.from_template(query_prompt_template)
253
+ # structured_llm = llm.with_structured_output(QueryOutput)
254
+ # chain = prompt | structured_llm
255
+ # result = chain.invoke(prompt_params)
256
+
257
+ # return result["query"]
258
+
259
+
260
+ # def fetch_data_from_sql_query(db: str, sql_query: str):
261
+ # conn = sqlite3.connect(db)
262
+ # cursor = conn.cursor()
263
+ # cursor.execute(sql_query)
264
+ # column_names = [desc[0] for desc in cursor.description]
265
+ # values = cursor.fetchall()
266
+ # return {"column_names": column_names, "data": values}
267
+
268
+
269
+ # def generate_chart_code(user_input: str, sql_query: list[str], llm):
270
+ # """ "Generate plotly python code for the chart based on the sql query and the user question"""
271
+
272
+ # class PlotlyCodeOutput(TypedDict):
273
+ # """Generated Plotly code"""
274
+
275
+ # code: Annotated[str, ..., "Synatically valid Plotly python code."]
276
+
277
+ # prompt = ChatPromptTemplate.from_template(plot_prompt_template)
278
+ # structured_llm = llm.with_structured_output(PlotlyCodeOutput)
279
+ # chain = prompt | structured_llm
280
+ # result = chain.invoke({"input": user_input, "sql_query": sql_query})
281
+ # return result["code"]
climateqa/engine/talk_to_data/vanna_class.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from vanna.base import VannaBase
2
+ from pinecone import Pinecone
3
+ from climateqa.engine.embeddings import get_embeddings_function
4
+ import pandas as pd
5
+ import hashlib
6
+
7
+ class MyCustomVectorDB(VannaBase):
8
+
9
+ """
10
+ VectorDB class for storing and retrieving vectors from Pinecone.
11
+
12
+ args :
13
+ config (dict) : Configuration dictionary containing the Pinecone API key and the index name :
14
+ - pc_api_key (str) : Pinecone API key
15
+ - index_name (str) : Pinecone index name
16
+ - top_k (int) : Number of top results to return (default = 2)
17
+
18
+ """
19
+
20
+ def __init__(self,config):
21
+ super().__init__(config = config)
22
+ try :
23
+ self.api_key = config.get('pc_api_key')
24
+ self.index_name = config.get('index_name')
25
+ except :
26
+ raise Exception("Please provide the Pinecone API key and the index name")
27
+
28
+ self.pc = Pinecone(api_key = self.api_key)
29
+ self.index = self.pc.Index(self.index_name)
30
+ self.top_k = config.get('top_k', 2)
31
+ self.embeddings = get_embeddings_function()
32
+
33
+
34
+ def check_embedding(self, id, namespace):
35
+ fetched = self.index.fetch(ids = [id], namespace = namespace)
36
+ if fetched['vectors'] == {}:
37
+ return False
38
+ return True
39
+
40
+ def generate_hash_id(self, data: str) -> str:
41
+ """
42
+ Generate a unique hash ID for the given data.
43
+
44
+ Args:
45
+ data (str): The input data to hash (e.g., a concatenated string of user attributes).
46
+
47
+ Returns:
48
+ str: A unique hash ID as a hexadecimal string.
49
+ """
50
+
51
+ data_bytes = data.encode('utf-8')
52
+ hash_object = hashlib.sha256(data_bytes)
53
+ hash_id = hash_object.hexdigest()
54
+
55
+ return hash_id
56
+
57
+ def add_ddl(self, ddl: str, **kwargs) -> str:
58
+ id = self.generate_hash_id(ddl) + '_ddl'
59
+
60
+ if self.check_embedding(id, 'ddl'):
61
+ print(f"DDL having id {id} already exists")
62
+ return id
63
+
64
+ self.index.upsert(
65
+ vectors = [(id, self.embeddings.embed_query(ddl), {'ddl': ddl})],
66
+ namespace = 'ddl'
67
+ )
68
+
69
+ return id
70
+
71
+ def add_documentation(self, doc: str, **kwargs) -> str:
72
+ id = self.generate_hash_id(doc) + '_doc'
73
+
74
+ if self.check_embedding(id, 'documentation'):
75
+ print(f"Documentation having id {id} already exists")
76
+ return id
77
+
78
+ self.index.upsert(
79
+ vectors = [(id, self.embeddings.embed_query(doc), {'doc': doc})],
80
+ namespace = 'documentation'
81
+ )
82
+
83
+ return id
84
+
85
+ def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
86
+ id = self.generate_hash_id(question) + '_sql'
87
+
88
+ if self.check_embedding(id, 'question_sql'):
89
+ print(f"Question-SQL pair having id {id} already exists")
90
+ return id
91
+
92
+ self.index.upsert(
93
+ vectors = [(id, self.embeddings.embed_query(question + sql), {'question': question, 'sql': sql})],
94
+ namespace = 'question_sql'
95
+ )
96
+
97
+ return id
98
+
99
+ def get_related_ddl(self, question: str, **kwargs) -> list:
100
+ res = self.index.query(
101
+ vector=self.embeddings.embed_query(question),
102
+ top_k=self.top_k,
103
+ namespace='ddl',
104
+ include_metadata=True
105
+ )
106
+
107
+ return [match['metadata']['ddl'] for match in res['matches']]
108
+
109
+ def get_related_documentation(self, question: str, **kwargs) -> list:
110
+ res = self.index.query(
111
+ vector=self.embeddings.embed_query(question),
112
+ top_k=self.top_k,
113
+ namespace='documentation',
114
+ include_metadata=True
115
+ )
116
+
117
+ return [match['metadata']['doc'] for match in res['matches']]
118
+
119
+ def get_similar_question_sql(self, question: str, **kwargs) -> list:
120
+ res = self.index.query(
121
+ vector=self.embeddings.embed_query(question),
122
+ top_k=self.top_k,
123
+ namespace='question_sql',
124
+ include_metadata=True
125
+ )
126
+
127
+ return [(match['metadata']['question'], match['metadata']['sql']) for match in res['matches']]
128
+
129
+ def get_training_data(self, **kwargs) -> pd.DataFrame:
130
+
131
+ list_of_data = []
132
+
133
+ namespaces = ['ddl', 'documentation', 'question_sql']
134
+
135
+ for namespace in namespaces:
136
+
137
+ data = self.index.query(
138
+ top_k=10000,
139
+ namespace=namespace,
140
+ include_metadata=True,
141
+ include_values=False
142
+ )
143
+
144
+ for match in data['matches']:
145
+ list_of_data.append(match['metadata'])
146
+
147
+ return pd.DataFrame(list_of_data)
148
+
149
+
150
+
151
+ def remove_training_data(self, id: str, **kwargs) -> bool:
152
+ if id.endswith("_ddl"):
153
+ self.Index.delete(ids=[id], namespace="_ddl")
154
+ return True
155
+ if id.endswith("_sql"):
156
+ self.index.delete(ids=[id], namespace="_sql")
157
+ return True
158
+
159
+ if id.endswith("_doc"):
160
+ self.Index.delete(ids=[id], namespace="_doc")
161
+ return True
162
+
163
+ return False
164
+
165
+ def generate_embedding(self, text, **kwargs):
166
+ # Implement the method here
167
+ pass
168
+
169
+
170
+ def get_sql_prompt(
171
+ self,
172
+ initial_prompt : str,
173
+ question: str,
174
+ question_sql_list: list,
175
+ ddl_list: list,
176
+ doc_list: list,
177
+ **kwargs,
178
+ ):
179
+ """
180
+ Example:
181
+ ```python
182
+ vn.get_sql_prompt(
183
+ question="What are the top 10 customers by sales?",
184
+ question_sql_list=[{"question": "What are the top 10 customers by sales?", "sql": "SELECT * FROM customers ORDER BY sales DESC LIMIT 10"}],
185
+ ddl_list=["CREATE TABLE customers (id INT, name TEXT, sales DECIMAL)"],
186
+ doc_list=["The customers table contains information about customers and their sales."],
187
+ )
188
+
189
+ ```
190
+
191
+ This method is used to generate a prompt for the LLM to generate SQL.
192
+
193
+ Args:
194
+ question (str): The question to generate SQL for.
195
+ question_sql_list (list): A list of questions and their corresponding SQL statements.
196
+ ddl_list (list): A list of DDL statements.
197
+ doc_list (list): A list of documentation.
198
+
199
+ Returns:
200
+ any: The prompt for the LLM to generate SQL.
201
+ """
202
+
203
+ if initial_prompt is None:
204
+ initial_prompt = f"You are a {self.dialect} expert. " + \
205
+ "Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. "
206
+
207
+ initial_prompt = self.add_ddl_to_prompt(
208
+ initial_prompt, ddl_list, max_tokens=self.max_tokens
209
+ )
210
+
211
+ if self.static_documentation != "":
212
+ doc_list.append(self.static_documentation)
213
+
214
+ initial_prompt = self.add_documentation_to_prompt(
215
+ initial_prompt, doc_list, max_tokens=self.max_tokens
216
+ )
217
+
218
+ # initial_prompt = self.add_sql_to_prompt(
219
+ # initial_prompt, question_sql_list, max_tokens=self.max_tokens
220
+ # )
221
+
222
+
223
+ initial_prompt += (
224
+ "===Response Guidelines \n"
225
+ "1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n"
226
+ "2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql \n"
227
+ "3. If the provided context is insufficient, please give a sql query based on your knowledge and the context provided. \n"
228
+ "4. Please use the most relevant table(s). \n"
229
+ "5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n"
230
+ f"6. Ensure that the output SQL is {self.dialect}-compliant and executable, and free of syntax errors. \n"
231
+ f"7. Add a description of the table in the result of the sql query, if relevant. \n"
232
+ "8 Make sure to include the relevant KPI in the SQL query. The query should return impactfull data \n"
233
+ # f"8. If a set of latitude,longitude is provided, make a intermediate query to find the nearest value in the table and replace the coordinates in the sql query. \n"
234
+ # "7. Add a description of the table in the result of the sql query."
235
+ # "7. If the question is about a specific latitude, longitude, query an interval of 0.3 and keep only the first set of coordinate. \n"
236
+ # "7. Table names should be included in the result of the sql query. Use for example Mean_winter_temperature AS table_name in the query \n"
237
+ )
238
+
239
+
240
+ message_log = [self.system_message(initial_prompt)]
241
+
242
+ for example in question_sql_list:
243
+ if example is None:
244
+ print("example is None")
245
+ else:
246
+ if example is not None and "question" in example and "sql" in example:
247
+ message_log.append(self.user_message(example["question"]))
248
+ message_log.append(self.assistant_message(example["sql"]))
249
+
250
+ message_log.append(self.user_message(question))
251
+
252
+ return message_log
253
+
254
+
255
+ # def get_sql_prompt(
256
+ # self,
257
+ # initial_prompt : str,
258
+ # question: str,
259
+ # question_sql_list: list,
260
+ # ddl_list: list,
261
+ # doc_list: list,
262
+ # **kwargs,
263
+ # ):
264
+ # """
265
+ # Example:
266
+ # ```python
267
+ # vn.get_sql_prompt(
268
+ # question="What are the top 10 customers by sales?",
269
+ # question_sql_list=[{"question": "What are the top 10 customers by sales?", "sql": "SELECT * FROM customers ORDER BY sales DESC LIMIT 10"}],
270
+ # ddl_list=["CREATE TABLE customers (id INT, name TEXT, sales DECIMAL)"],
271
+ # doc_list=["The customers table contains information about customers and their sales."],
272
+ # )
273
+
274
+ # ```
275
+
276
+ # This method is used to generate a prompt for the LLM to generate SQL.
277
+
278
+ # Args:
279
+ # question (str): The question to generate SQL for.
280
+ # question_sql_list (list): A list of questions and their corresponding SQL statements.
281
+ # ddl_list (list): A list of DDL statements.
282
+ # doc_list (list): A list of documentation.
283
+
284
+ # Returns:
285
+ # any: The prompt for the LLM to generate SQL.
286
+ # """
287
+
288
+ # if initial_prompt is None:
289
+ # initial_prompt = f"You are a {self.dialect} expert. " + \
290
+ # "Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. "
291
+
292
+ # initial_prompt = self.add_ddl_to_prompt(
293
+ # initial_prompt, ddl_list, max_tokens=self.max_tokens
294
+ # )
295
+
296
+ # if self.static_documentation != "":
297
+ # doc_list.append(self.static_documentation)
298
+
299
+ # initial_prompt = self.add_documentation_to_prompt(
300
+ # initial_prompt, doc_list, max_tokens=self.max_tokens
301
+ # )
302
+
303
+ # initial_prompt += (
304
+ # "===Response Guidelines \n"
305
+ # "1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n"
306
+ # "2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql \n"
307
+ # "3. If the provided context is insufficient, please explain why it can't be generated. \n"
308
+ # "4. Please use the most relevant table(s). \n"
309
+ # "5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n"
310
+ # f"6. Ensure that the output SQL is {self.dialect}-compliant and executable, and free of syntax errors. \n"
311
+ # )
312
+
313
+ # message_log = [self.system_message(initial_prompt)]
314
+
315
+ # for example in question_sql_list:
316
+ # if example is None:
317
+ # print("example is None")
318
+ # else:
319
+ # if example is not None and "question" in example and "sql" in example:
320
+ # message_log.append(self.user_message(example["question"]))
321
+ # message_log.append(self.assistant_message(example["sql"]))
322
+
323
+ # message_log.append(self.user_message(question))
324
+
325
+ # return message_log
requirements.txt CHANGED
@@ -26,4 +26,4 @@ duckdb==1.2.1
26
  openai==1.61.1
27
  pydantic==2.9.2
28
  pydantic-settings==2.2.1
29
- geojson==3.2.0
 
26
  openai==1.61.1
27
  pydantic==2.9.2
28
  pydantic-settings==2.2.1
29
+ geojson==3.2.0
style.css CHANGED
@@ -656,11 +656,20 @@ a {
656
  /* overflow-y: scroll; */
657
  }
658
  #sql-query{
 
659
  max-height: 100%;
660
  }
661
 
662
  #sql-query textarea{
663
  min-height: 200px !important;
 
 
 
 
 
 
 
 
664
  }
665
 
666
  #sql-query span{
@@ -741,4 +750,4 @@ div#tab-vanna{
741
  #example-img-container {
742
  flex-direction: column;
743
  align-items: left;
744
- }
 
656
  /* overflow-y: scroll; */
657
  }
658
  #sql-query{
659
+ <<<<<<< HEAD
660
  max-height: 100%;
661
  }
662
 
663
  #sql-query textarea{
664
  min-height: 200px !important;
665
+ =======
666
+ max-height: 300px;
667
+ overflow-y:scroll;
668
+ }
669
+
670
+ #sql-query textarea{
671
+ min-height: 100px !important;
672
+ >>>>>>> hf-origin/main
673
  }
674
 
675
  #sql-query span{
 
750
  #example-img-container {
751
  flex-direction: column;
752
  align-items: left;
753
+ }