armanddemasson commited on
Commit
4df74e4
·
1 Parent(s): 3e75ed8

feat: implemented talk to drias v1

Browse files
app.py CHANGED
@@ -12,7 +12,7 @@ from climateqa.engine.reranker import get_reranker
12
  from climateqa.engine.graph import make_graph_agent,make_graph_agent_poc
13
  from climateqa.engine.chains.retrieve_papers import find_papers
14
  from climateqa.chat import start_chat, chat_stream, finish_chat
15
- from climateqa.engine.talk_to_data.main import ask_vanna
16
  from climateqa.engine.talk_to_data.myVanna import MyVanna
17
 
18
  from front.tabs import (create_config_modal, create_examples_tab, create_papers_tab, create_figures_tab, create_chat_interface, create_about_tab)
@@ -84,8 +84,11 @@ vn = MyVanna(config = {"temperature": 0, "api_key": os.getenv('THEO_API_KEY'), '
84
  db_vanna_path = os.path.join(os.getcwd(), "data/drias/drias.db")
85
  vn.connect_to_sqlite(db_vanna_path)
86
 
87
- def ask_vanna_query(query):
88
- return ask_vanna(vn, db_vanna_path, query)
 
 
 
89
 
90
  async def chat(query, history, audience, sources, reports, relevant_content_sources_selection, search_only):
91
  print("chat cqa - message received")
@@ -122,20 +125,70 @@ def update_sources_number_display(sources_textbox, figures_cards, current_graphs
122
 
123
  return gr.update(label=recommended_content_notif_label), gr.update(label=sources_notif_label), gr.update(label=figures_notif_label), gr.update(label=graphs_notif_label), gr.update(label=papers_notif_label)
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  def create_drias_tab():
126
- with gr.Tab("Beta - Talk to DRIAS", elem_id="tab-vanna", id=6) as tab_vanna:
127
- vanna_direct_question = gr.Textbox(label="Direct Question", placeholder="You can write direct question here",elem_id="direct-question", interactive=True)
128
- with gr.Accordion("Details",elem_id = 'vanna-details', open=False) as vanna_details :
129
- vanna_sql_query = gr.Textbox(label="SQL Query Used", elem_id="sql-query", interactive=False)
130
- show_vanna_table = gr.Button("Show Table", elem_id="show-table")
131
- with Modal(visible=False) as vanna_table_modal:
132
- vanna_table = gr.DataFrame([], elem_id="vanna-table")
133
- close_vanna_modal = gr.Button("Close", elem_id="close-vanna-modal")
134
- close_vanna_modal.click(lambda: Modal(visible=False),None, [vanna_table_modal])
135
- show_vanna_table.click(lambda: Modal(visible=True),None ,[vanna_table_modal])
136
-
137
- vanna_display = gr.Plot()
138
- vanna_direct_question.submit(ask_vanna_query, [vanna_direct_question], [vanna_sql_query ,vanna_table, vanna_display])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
  # # UI Layout Components
141
  def cqa_tab(tab_name):
 
12
  from climateqa.engine.graph import make_graph_agent,make_graph_agent_poc
13
  from climateqa.engine.chains.retrieve_papers import find_papers
14
  from climateqa.chat import start_chat, chat_stream, finish_chat
15
+ from climateqa.engine.talk_to_data.main import ask_drias
16
  from climateqa.engine.talk_to_data.myVanna import MyVanna
17
 
18
  from front.tabs import (create_config_modal, create_examples_tab, create_papers_tab, create_figures_tab, create_chat_interface, create_about_tab)
 
84
  db_vanna_path = os.path.join(os.getcwd(), "data/drias/drias.db")
85
  vn.connect_to_sqlite(db_vanna_path)
86
 
87
+ # def ask_vanna_query(query):
88
+ # return ask_vanna(vn, db_vanna_path, query)
89
+
90
+ def ask_drias_query(query, index_state):
91
+ return ask_drias(db_vanna_path, query, index_state)
92
 
93
  async def chat(query, history, audience, sources, reports, relevant_content_sources_selection, search_only):
94
  print("chat cqa - message received")
 
125
 
126
  return gr.update(label=recommended_content_notif_label), gr.update(label=sources_notif_label), gr.update(label=figures_notif_label), gr.update(label=graphs_notif_label), gr.update(label=papers_notif_label)
127
 
128
+ # def create_drias_tab():
129
+ # with gr.Tab("Beta - Talk to DRIAS", elem_id="tab-vanna", id=6) as tab_vanna:
130
+ # vanna_direct_question = gr.Textbox(label="Direct Question", placeholder="You can write direct question here",elem_id="direct-question", interactive=True)
131
+ # with gr.Accordion("Details",elem_id = 'vanna-details', open=False) as vanna_details :
132
+ # vanna_sql_query = gr.Textbox(label="SQL Query Used", elem_id="sql-query", interactive=False)
133
+ # show_vanna_table = gr.Button("Show Table", elem_id="show-table")
134
+ # with Modal(visible=False) as vanna_table_modal:
135
+ # vanna_table = gr.DataFrame([], elem_id="vanna-table")
136
+ # close_vanna_modal = gr.Button("Close", elem_id="close-vanna-modal")
137
+ # close_vanna_modal.click(lambda: Modal(visible=False),None, [vanna_table_modal])
138
+ # show_vanna_table.click(lambda: Modal(visible=True),None ,[vanna_table_modal])
139
+
140
+ # vanna_display = gr.Plot()
141
+ # vanna_direct_question.submit(ask_drias_query, [vanna_direct_question], [vanna_sql_query ,vanna_table, vanna_display])
142
  def create_drias_tab():
143
+ with gr.Tab("Beta - Talk to DRIAS", elem_id="tab-vanna", id=6):
144
+ drias_direct_question = gr.Textbox(label="Direct Question", placeholder="You can write direct question here", elem_id="direct-question", interactive=True)
145
+
146
+ with gr.Accordion("Details", elem_id="vanna-details", open=False) as drias_details:
147
+ drias_sql_query = gr.Textbox(label="SQL Query Used", elem_id="sql-query", interactive=False)
148
+ drias_table = gr.DataFrame([], elem_id="vanna-table")
149
+ drias_display = gr.Plot()
150
+
151
+ # Navigation buttons
152
+ prev_button = gr.Button("Previous")
153
+ next_button = gr.Button("Next")
154
+
155
+ # Initialisation des données
156
+ sql_queries_state = gr.State([])
157
+ dataframes_state = gr.State([])
158
+ plots_state = gr.State([])
159
+ index_state = gr.State(0) # To track the current position
160
+
161
+ # Action sur la soumission du texte
162
+ drias_direct_question.submit(
163
+ ask_drias_query,
164
+ inputs=[drias_direct_question, index_state],
165
+ outputs=[drias_sql_query, drias_table, drias_display, sql_queries_state, dataframes_state, plots_state, index_state]
166
+ )
167
+
168
+ # Define functions to navigate history
169
+ def show_previous(index, sql_queries, dataframes, plots):
170
+ if index > 0:
171
+ index -= 1
172
+ return sql_queries[index], dataframes[index], plots[index], index
173
+
174
+ def show_next(index, sql_queries, dataframes, plots):
175
+ if index < len(sql_queries) - 1:
176
+ index += 1
177
+ return sql_queries[index], dataframes[index], plots[index], index
178
+
179
+ prev_button.click(
180
+ show_previous,
181
+ inputs=[index_state, sql_queries_state, dataframes_state, plots_state],
182
+ outputs=[drias_sql_query, drias_table, drias_display, index_state]
183
+ )
184
+
185
+ next_button.click(
186
+ show_next,
187
+ inputs=[index_state, sql_queries_state, dataframes_state, plots_state],
188
+ outputs=[drias_sql_query, drias_table, drias_display, index_state]
189
+ )
190
+
191
+
192
 
193
  # # UI Layout Components
194
  def cqa_tab(tab_name):
climateqa/engine/talk_to_data/main.py CHANGED
@@ -1,13 +1,7 @@
1
- from climateqa.engine.talk_to_data.myVanna import MyVanna
2
- from climateqa.engine.talk_to_data.utils import loc2coords, detect_location_with_openai, detectTable, nearestNeighbourSQL, detect_relevant_tables, replace_coordonates
3
- import sqlite3
4
- import os
5
- import pandas as pd
6
  from climateqa.engine.llm import get_llm
7
  import ast
8
 
9
-
10
-
11
  llm = get_llm(provider="openai")
12
 
13
  def ask_llm_to_add_table_names(sql_query, llm):
@@ -19,29 +13,47 @@ def ask_llm_column_names(sql_query, llm):
19
  columns_list = ast.literal_eval(columns.strip("```python\n").strip())
20
  return columns_list
21
 
22
- def ask_vanna(vn,db_vanna_path, query):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- try :
25
- location = detect_location_with_openai(query)
26
- if location:
27
 
28
- coords = loc2coords(location)
29
- user_input = query.lower().replace(location.lower(), f"lat, long : {coords}")
30
 
31
- relevant_tables = detect_relevant_tables(user_input, llm)
32
- coords_tables = [nearestNeighbourSQL(db_vanna_path, coords, relevant_tables[i]) for i in range(len(relevant_tables))]
33
- user_input_with_coords = replace_coordonates(coords, user_input, coords_tables)
34
-
35
- sql_query, result_dataframe, figure = vn.ask(user_input_with_coords, print_results=False, allow_llm_to_see_data=True, auto_train=False)
36
-
37
- return sql_query, result_dataframe, figure
38
-
39
- else :
40
- empty_df = pd.DataFrame()
41
- empty_fig = None
42
- return "", empty_df, empty_fig
43
- except Exception as e:
44
- print(f"Error: {e}")
45
- empty_df = pd.DataFrame()
46
- empty_fig = None
47
- return "", empty_df, empty_fig
 
1
+ from climateqa.engine.talk_to_data.workflow import drias_workflow
 
 
 
 
2
  from climateqa.engine.llm import get_llm
3
  import ast
4
 
 
 
5
  llm = get_llm(provider="openai")
6
 
7
  def ask_llm_to_add_table_names(sql_query, llm):
 
13
  columns_list = ast.literal_eval(columns.strip("```python\n").strip())
14
  return columns_list
15
 
16
+ def ask_drias(db_drias_path:str, query:str , index_state: int):
17
+ final_state = drias_workflow(db_drias_path, query)
18
+ sql_queries = []
19
+ result_dataframes = []
20
+ figures = []
21
+
22
+ for plot_state in final_state['plot_states'].values():
23
+ for table_state in plot_state['table_states'].values():
24
+
25
+ if 'ql_query' in table_state and table_state['sql_query'] is not None:
26
+ sql_queries.append(table_state['sql_query'])
27
+
28
+ if 'dataframe' in table_state and table_state['dataframe'] is not None:
29
+ result_dataframes.append(table_state['dataframe'])
30
+ if 'figure' in table_state and table_state['figure'] is not None:
31
+ figures.append(table_state['figure'](table_state['dataframe']))
32
+
33
+ return sql_queries[index_state], result_dataframes[index_state], figures[index_state], sql_queries, result_dataframes, figures, index_state
34
+
35
+ # def ask_vanna(vn,db_vanna_path, query):
36
 
37
+ # try :
38
+ # location = detect_location_with_openai(query)
39
+ # if location:
40
 
41
+ # coords = loc2coords(location)
42
+ # user_input = query.lower().replace(location.lower(), f"lat, long : {coords}")
43
 
44
+ # relevant_tables = detect_relevant_tables(db_vanna_path, user_input, llm)
45
+ # coords_tables = [nearestNeighbourSQL(db_vanna_path, coords, relevant_tables[i]) for i in range(len(relevant_tables))]
46
+ # user_input_with_coords = replace_coordonates(coords, user_input, coords_tables)
47
+
48
+ # sql_query, result_dataframe, figure = vn.ask(user_input_with_coords, print_results=False, allow_llm_to_see_data=True, auto_train=False)
49
+
50
+ # return sql_query, result_dataframe, figure
51
+ # else :
52
+ # empty_df = pd.DataFrame()
53
+ # empty_fig = None
54
+ # return "", empty_df, empty_fig
55
+ # except Exception as e:
56
+ # print(f"Error: {e}")
57
+ # empty_df = pd.DataFrame()
58
+ # empty_fig = None
59
+ # return "", empty_df, empty_fig
 
climateqa/engine/talk_to_data/plot.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, TypedDict
2
+ import pandas as pd
3
+ from plotly.graph_objects import Figure
4
+ import plotly.graph_objects as go
5
+
6
+ from climateqa.engine.talk_to_data.sql_query import indicator_per_year_at_location_query
7
+
8
+
9
+ class Plot(TypedDict):
10
+ name: str
11
+ description: str
12
+ params: list[str]
13
+ plot_function: Callable[..., Callable[..., Figure]]
14
+ sql_query: Callable[..., str]
15
+
16
+
17
+ def plot_indicator_per_year_at_location(params: dict) -> Callable[..., Figure]:
18
+ """Generate the function to plot a line plot of an indicator per year at a certain location
19
+
20
+ Args:
21
+ params (dict): dictionnary with the required params : model, indicator_column, location
22
+
23
+ Returns:
24
+ Callable[..., Figure]: Function which can be call to create the figure with the associated dataframe
25
+ """
26
+ indicator = params["indicator_column"]
27
+ model = params["model"]
28
+ indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
29
+
30
+ def plot_data(df: pd.DataFrame) -> Figure:
31
+ """Generate the figure thanks to the dataframe
32
+
33
+ Args:
34
+ df (pd.DataFrame): pandas dataframe with the required data
35
+
36
+ Returns:
37
+ Figure: Plotly figure
38
+ """
39
+ fig = go.Figure()
40
+ if model == "ALL":
41
+ df_avg = df.groupby("year", as_index=False)[indicator].mean()
42
+
43
+ # Transform to list to avoid pandas encoding
44
+ indicators = df_avg[indicator].astype(float).tolist()
45
+ years = df_avg["year"].astype(int).tolist()
46
+
47
+ # Compute the 10-year rolling average
48
+ sliding_averages = (
49
+ df_avg[indicator]
50
+ .rolling(window=10, min_periods=5)
51
+ .mean()
52
+ .astype(float)
53
+ .tolist()
54
+ )
55
+ else:
56
+ df_model = df[df["model"] == model]
57
+
58
+ # Transform to list to avoid pandas encoding
59
+ indicators = df_model[indicator].astype(float).tolist()
60
+ years = df_model["year"].astype(int).tolist()
61
+
62
+ # Compute the 10-year rolling average
63
+ sliding_averages = (
64
+ df_model[indicator]
65
+ .rolling(window=10, min_periods=5)
66
+ .mean()
67
+ .astype(float)
68
+ .tolist()
69
+ )
70
+
71
+ # Indicator per year plot
72
+ fig.add_scatter(
73
+ x=years,
74
+ y=indicators,
75
+ name=f"Yearly {indicator_label}",
76
+ mode="lines",
77
+ )
78
+
79
+ # Sliding average dashed line
80
+ fig.add_scatter(
81
+ x=years,
82
+ y=sliding_averages,
83
+ mode="lines",
84
+ name="10 years rolling average",
85
+ line=dict(dash="dash"),
86
+ marker=dict(color="#1f77b4"),
87
+ )
88
+ fig.update_layout(
89
+ title=f"Plot of {indicator_label} in {params['location']} (Model Average)",
90
+ xaxis_title="Year",
91
+ yaxis_title=indicator_label,
92
+ template="plotly_white",
93
+ )
94
+ return fig
95
+
96
+ return plot_data
97
+
98
+
99
+ indicator_per_year_at_location: Plot = {
100
+ "name": "Indicator per year at location",
101
+ "description": "Plot an evolution of the indicator at a certain location over the years",
102
+ "params": ["indicator_column", "location", "model"],
103
+ "plot_function": plot_indicator_per_year_at_location,
104
+ "sql_query": indicator_per_year_at_location_query,
105
+ }
106
+
107
+
108
+ def plot_indicator_number_of_days_per_year_at_location(params) -> Callable[..., Figure]:
109
+ """Generate the function to plot a line plot of an indicator per year at a certain location
110
+
111
+ Args:
112
+ params (dict): dictionnary with the required params : model, indicator_column, location
113
+
114
+ Returns:
115
+ Callable[..., Figure]: Function which can be call to create the figure with the associated dataframe
116
+ """
117
+
118
+ indicator = params["indicator_column"]
119
+ model = params["model"]
120
+
121
+ def plot_data(df) -> Figure:
122
+ fig = go.Figure()
123
+ if params["model"] == "ALL":
124
+ df_avg = df.groupby("year", as_index=False)[indicator].mean()
125
+
126
+ # Transform to list to avoid pandas encoding
127
+ indicators = df_avg[indicator].astype(float).tolist()
128
+ years = df_avg["year"].astype(int).tolist()
129
+
130
+ else:
131
+ df_model = df[df["model"] == model]
132
+
133
+ # Transform to list to avoid pandas encoding
134
+ indicators = df_model[indicator].astype(float).tolist()
135
+ years = df_model["year"].astype(int).tolist()
136
+
137
+ # Bar plot
138
+ fig.add_trace(
139
+ go.Bar(
140
+ x=years,
141
+ y=indicators,
142
+ width=0.5,
143
+ marker=dict(color="#1f77b4"),
144
+ )
145
+ )
146
+
147
+ indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
148
+
149
+ fig.update_layout(
150
+ title=f"{indicator_label} in {params['location']} (Model Average)",
151
+ xaxis_title="Year",
152
+ yaxis_title=indicator,
153
+ yaxis=dict(range=[0, 366]),
154
+ bargap=0.5,
155
+ template="plotly_white",
156
+ )
157
+
158
+ return fig
159
+
160
+ return plot_data
161
+
162
+
163
+ indicator_number_of_days_per_year_at_location: Plot = {
164
+ "name": "Indicator number of days per year at location",
165
+ "description": "Plot a barchart of the number of days per year of a certain indicator at a certain location. It is appropriate for frequency indicator.",
166
+ "params": ["indicator_column", "location", "model"],
167
+ "plot_function": plot_indicator_number_of_days_per_year_at_location,
168
+ "sql_query": indicator_per_year_at_location_query,
169
+ }
170
+
171
+
172
+ PLOTS = [indicator_per_year_at_location, indicator_number_of_days_per_year_at_location]
climateqa/engine/talk_to_data/sql_query.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ from typing import Any, TypedDict
3
+
4
+
5
+ class SqlQueryOutput(TypedDict):
6
+ labels: list[str]
7
+ data: list[list[Any]]
8
+
9
+
10
+ def execute_sql_query(db_path: str, sql_query: str) -> SqlQueryOutput:
11
+ """Execute the SQL Query on the sqlite database
12
+
13
+ Args:
14
+ db_ (str): path to the sqlite database
15
+ sql_query (str): sql query to execute
16
+
17
+ Returns:
18
+ SqlQueryOutput: labels of the selected column and fetched data
19
+ """
20
+
21
+ # Connect to sqlite3 database
22
+ conn = sqlite3.connect(db_path)
23
+ cursor = conn.cursor()
24
+
25
+ # Execute the query
26
+ cursor.execute(sql_query)
27
+
28
+ # Fetch labels of selected columns
29
+ labels = [desc[0] for desc in cursor.description]
30
+
31
+ # Fetch data
32
+ data = cursor.fetchall()
33
+ conn.close()
34
+
35
+ return {
36
+ "labels": labels,
37
+ "data": data,
38
+ }
39
+
40
+
41
+ class IndicatorPerYearAtLocationQueryParams(TypedDict, total=False):
42
+ table: str
43
+ indicator_column: list[str]
44
+ latitude: str
45
+ longitude: str
46
+
47
+
48
+ def indicator_per_year_at_location_query(
49
+ table: str, params: IndicatorPerYearAtLocationQueryParams
50
+ ) -> str:
51
+ """SQL Query to get the evolution of an indicator per year at a certain location
52
+
53
+ Args:
54
+ table (str): sql table of the indicator
55
+ params (IndicatorPerYearAtLocationQueryParams) : dictionary with the required params for the query
56
+
57
+ Returns:
58
+ str: the sql query
59
+ """
60
+ indicator_column = params.get("indicator_column")
61
+ latitude = params.get("latitude")
62
+ longitude = params.get("longitude")
63
+ sql_query = f"SELECT year, {indicator_column}, model FROM {table} WHERE latitude = {latitude} and longitude={longitude} Order by Year"
64
+ return sql_query
climateqa/engine/talk_to_data/utils.py CHANGED
@@ -1,10 +1,12 @@
1
  import re
2
- import openai
3
- import pandas as pd
4
  from geopy.geocoders import Nominatim
5
  import sqlite3
6
  import ast
7
  from climateqa.engine.llm import get_llm
 
 
8
 
9
  def detect_location_with_openai(sentence):
10
  """
@@ -26,67 +28,139 @@ def detect_location_with_openai(sentence):
26
  else:
27
  return ""
28
 
 
29
  def detectTable(sql_query):
30
  pattern = r'(?i)\bFROM\s+((?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+)(?:\.(?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+))*)'
31
  matches = re.findall(pattern, sql_query)
32
  return matches
33
 
34
 
35
-
36
- def loc2coords(location : str):
37
  geolocator = Nominatim(user_agent="city_to_latlong")
38
- location = geolocator.geocode(location)
39
- return (location.latitude, location.longitude)
40
 
41
 
42
- def coords2loc(coords : tuple):
43
  geolocator = Nominatim(user_agent="coords_to_city")
44
  try:
45
  location = geolocator.reverse(coords)
46
  return location.address
47
  except Exception as e:
48
  print(f"Error: {e}")
49
- return "Unknown Location"
50
 
51
 
52
- def nearestNeighbourSQL(db: str, location: tuple, table : str):
53
  conn = sqlite3.connect(db)
54
  long = round(location[1], 3)
55
  lat = round(location[0], 3)
56
- cursor = conn.cursor()
57
- cursor.execute(f"SELECT lat, lon FROM {table} WHERE lat BETWEEN {lat - 0.3} AND {lat + 0.3} AND lon BETWEEN {long - 0.3} AND {long + 0.3}")
 
 
 
58
  results = cursor.fetchall()
59
  return results[0]
60
 
61
- def detect_relevant_tables(user_question, llm):
62
- table_names_list = [
63
- "Frequency_of_rainy_days_index",
64
- "Winter_precipitation_total",
65
- "Summer_precipitation_total",
66
- "Annual_precipitation_total",
67
- # "Remarkable_daily_precipitation_total_(Q99)",
68
- "Frequency_of_remarkable_daily_precipitation",
69
- "Extreme_precipitation_intensity",
70
- "Mean_winter_temperature",
71
- "Mean_summer_temperature",
72
- "Number_of_tropical_nights",
73
- "Maximum_summer_temperature",
74
- "Number_of_days_with_Tx_above_30C",
75
- "Number_of_days_with_Tx_above_35C",
76
- "Drought_index"
77
- ]
78
  prompt = (
79
- f"You are helping to build a sql query to retrieve relevant data for a user question."
 
80
  f"The different tables are {table_names_list}."
81
  f"The user question is {user_question}. Write the relevant tables to use. Answer only a python list of table name."
82
  )
83
- table_names = ast.literal_eval(llm.invoke(prompt).content.strip("```python\n").strip())
 
 
84
  return table_names
85
 
 
86
  def replace_coordonates(coords, query, coords_tables):
87
  n = query.count(str(coords[0]))
88
 
89
  for i in range(n):
90
- query = query.replace(str(coords[0]), str(coords_tables[i][0]),1)
91
- query = query.replace(str(coords[1]), str(coords_tables[i][1]),1)
92
- return query
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import re
2
+
3
+ from sympy import use
4
  from geopy.geocoders import Nominatim
5
  import sqlite3
6
  import ast
7
  from climateqa.engine.llm import get_llm
8
+ from climateqa.engine.talk_to_data.plot import PLOTS, Plot
9
+
10
 
11
  def detect_location_with_openai(sentence):
12
  """
 
28
  else:
29
  return ""
30
 
31
+
32
  def detectTable(sql_query):
33
  pattern = r'(?i)\bFROM\s+((?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+)(?:\.(?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+))*)'
34
  matches = re.findall(pattern, sql_query)
35
  return matches
36
 
37
 
38
+ def loc2coords(location: str):
 
39
  geolocator = Nominatim(user_agent="city_to_latlong")
40
+ coords = geolocator.geocode(location)
41
+ return (coords.latitude, coords.longitude)
42
 
43
 
44
+ def coords2loc(coords: tuple):
45
  geolocator = Nominatim(user_agent="coords_to_city")
46
  try:
47
  location = geolocator.reverse(coords)
48
  return location.address
49
  except Exception as e:
50
  print(f"Error: {e}")
51
+ return "Unknown Location"
52
 
53
 
54
+ def nearestNeighbourSQL(db: str, location: tuple, table: str):
55
  conn = sqlite3.connect(db)
56
  long = round(location[1], 3)
57
  lat = round(location[0], 3)
58
+ cursor = conn.cursor()
59
+ cursor.execute(
60
+ f"SELECT latitude, longitude FROM {table} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}"
61
+ )
62
+ # cursor.execute(f"SELECT latitude, longitude FROM {table} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}")
63
  results = cursor.fetchall()
64
  return results[0]
65
 
66
+
67
+ def detect_relevant_tables(db: str, user_question: str, plot: Plot, llm) -> list[str]:
68
+ conn = sqlite3.connect(db)
69
+ cursor = conn.cursor()
70
+
71
+ # Get all table names
72
+ cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
73
+ table_names_list = cursor.fetchall()
74
+
 
 
 
 
 
 
 
 
75
  prompt = (
76
+ f"You are helping to build a plot following this description : {plot['description']}."
77
+ f"Based on the description of the plot, which table are appropriate for that kind of plot."
78
  f"The different tables are {table_names_list}."
79
  f"The user question is {user_question}. Write the relevant tables to use. Answer only a python list of table name."
80
  )
81
+ table_names = ast.literal_eval(
82
+ llm.invoke(prompt).content.strip("```python\n").strip()
83
+ )
84
  return table_names
85
 
86
+
87
  def replace_coordonates(coords, query, coords_tables):
88
  n = query.count(str(coords[0]))
89
 
90
  for i in range(n):
91
+ query = query.replace(str(coords[0]), str(coords_tables[i][0]), 1)
92
+ query = query.replace(str(coords[1]), str(coords_tables[i][1]), 1)
93
+ return query
94
+
95
+
96
+ def detect_relevant_plots(user_question: str, llm):
97
+ plots_description = ""
98
+ for plot in PLOTS:
99
+ plots_description += "Name: " + plot["name"]
100
+ plots_description += " - Description: " + plot["description"] + "\n"
101
+
102
+ prompt = (
103
+ f"You are helping to answer a question with insightful visualizations. "
104
+ f"Given a list of plots with their name and description: "
105
+ f"{plots_description} "
106
+ f"The user question is: {user_question}. "
107
+ f"Choose the most relevant plots to answer the question. "
108
+ f"The answer must be a Python list with the names of the relevant plots, and nothing else. "
109
+ f"Ensure the response is in the exact format: ['PlotName1', 'PlotName2']."
110
+ )
111
+
112
+ response = llm.invoke(prompt).content
113
+ return eval(response)
114
+
115
+
116
+ # Next Version
117
+ # class QueryOutput(TypedDict):
118
+ # """Generated SQL query."""
119
+
120
+ # query: Annotated[str, ..., "Syntactically valid SQL query."]
121
+
122
+
123
+ # class PlotlyCodeOutput(TypedDict):
124
+ # """Generated Plotly code"""
125
+
126
+ # code: Annotated[str, ..., "Synatically valid Plotly python code."]
127
+ # def write_sql_query(user_input: str, db: SQLDatabase, relevant_tables: list[str], llm):
128
+ # """Generate SQL query to fetch information."""
129
+ # prompt_params = {
130
+ # "dialect": db.dialect,
131
+ # "table_info": db.get_table_info(),
132
+ # "input": user_input,
133
+ # "relevant_tables": relevant_tables,
134
+ # "model": "ALADIN63_CNRM-CM5",
135
+ # }
136
+
137
+ # prompt = ChatPromptTemplate.from_template(query_prompt_template)
138
+ # structured_llm = llm.with_structured_output(QueryOutput)
139
+ # chain = prompt | structured_llm
140
+ # result = chain.invoke(prompt_params)
141
+
142
+ # return result["query"]
143
+
144
+
145
+ # def fetch_data_from_sql_query(db: str, sql_query: str):
146
+ # conn = sqlite3.connect(db)
147
+ # cursor = conn.cursor()
148
+ # cursor.execute(sql_query)
149
+ # column_names = [desc[0] for desc in cursor.description]
150
+ # values = cursor.fetchall()
151
+ # return {"column_names": column_names, "data": values}
152
+
153
+
154
+ # def generate_chart_code(user_input: str, sql_query: list[str], llm):
155
+ # """ "Generate plotly python code for the chart based on the sql query and the user question"""
156
+
157
+ # class PlotlyCodeOutput(TypedDict):
158
+ # """Generated Plotly code"""
159
+
160
+ # code: Annotated[str, ..., "Synatically valid Plotly python code."]
161
+
162
+ # prompt = ChatPromptTemplate.from_template(plot_prompt_template)
163
+ # structured_llm = llm.with_structured_output(PlotlyCodeOutput)
164
+ # chain = prompt | structured_llm
165
+ # result = chain.invoke({"input": user_input, "sql_query": sql_query})
166
+ # return result["code"]
climateqa/engine/talk_to_data/workflow.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from typing import Any, Callable, NotRequired, TypedDict
4
+ import pandas as pd
5
+
6
+ from plotly.graph_objects import Figure
7
+ from climateqa.engine.llm import get_llm
8
+ from climateqa.engine.talk_to_data.plot import PLOTS, Plot
9
+ from climateqa.engine.talk_to_data.sql_query import execute_sql_query
10
+ from climateqa.engine.talk_to_data.utils import (
11
+ detect_relevant_plots,
12
+ loc2coords,
13
+ detect_location_with_openai,
14
+ nearestNeighbourSQL,
15
+ detect_relevant_tables,
16
+ )
17
+
18
+ ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))
19
+
20
+ DRIAS_DB_PATH = ROOT_PATH + "/data/drias/drias.db"
21
+
22
+ class TableState(TypedDict):
23
+ table_name: str
24
+ params: dict[str, Any]
25
+ sql_query: NotRequired[str]
26
+ dataframe: NotRequired[pd.DataFrame | None]
27
+ figure: NotRequired[Callable[..., Figure]]
28
+
29
+ class PlotState(TypedDict):
30
+ plot_name: str
31
+ tables: list[str]
32
+ table_states: dict[str, TableState]
33
+
34
+ class State(TypedDict):
35
+ user_input: str
36
+ plots: list[str]
37
+ plot_states: dict[str, PlotState]
38
+
39
+ def drias_workflow(db_drias_path: str, user_input: str) -> State:
40
+ """Performs the complete workflow of Talk To Drias : from user input to sql queries, dataframes and figures generated
41
+
42
+ Args:
43
+ db_drias_path (str): path to the drias database
44
+ user_input (str): initial user input
45
+
46
+ Returns:
47
+ State: Final state with all the results
48
+ """
49
+ state: State = {
50
+ 'user_input': user_input,
51
+ 'plots': [],
52
+ 'plot_states': {}
53
+ }
54
+
55
+ llm = get_llm(provider="openai")
56
+
57
+ plots = find_relevant_plots(state, llm)
58
+ state['plots'] = plots
59
+
60
+ if not state['plots']:
61
+ return state
62
+
63
+ for plot_name in state['plots']:
64
+
65
+ plot = next((p for p in PLOTS if p['name'] == plot_name), None) # Find the associated plot object
66
+ if plot is None:
67
+ continue
68
+
69
+ plot_state: PlotState = {
70
+ 'plot_name': plot_name,
71
+ 'tables': [],
72
+ 'table_states': {}
73
+ }
74
+
75
+ plot_state['plot_name'] = plot_name
76
+
77
+ relevant_tables = find_relevant_tables_per_plot(state, plot, db_drias_path, llm)
78
+
79
+ plot_state['tables'] = relevant_tables
80
+
81
+ for table in plot_state['tables']:
82
+ table_state: TableState = {
83
+ 'table_name': table,
84
+ 'params': {},
85
+ }
86
+ table_state['params'] = {
87
+ 'model': 'ALL'
88
+ }
89
+ for param_name in plot['params']:
90
+ param = find_param(state, param_name, table, db_drias_path)
91
+ if param:
92
+ table_state['params'].update(param)
93
+
94
+ sql_query = plot['sql_query'](table, table_state['params'])
95
+ table_state['sql_query'] = sql_query
96
+ results = execute_sql_query(db_drias_path, sql_query)
97
+
98
+ df = pd.DataFrame(results['data'], columns=results['labels'])
99
+ figure = plot['plot_function'](table_state['params'])
100
+ table_state['dataframe'] = df
101
+ table_state['figure'] = figure
102
+ plot_state['table_states'][table] = table_state
103
+
104
+ state['plot_states'][plot_name] = plot_state
105
+ return state
106
+
107
+
108
+ def find_relevant_plots(state: State, llm) -> list[str]:
109
+ print("---- Find relevant plots ----")
110
+ relevant_plots = detect_relevant_plots(state['user_input'], llm)
111
+ return relevant_plots
112
+
113
+ def find_relevant_tables_per_plot(state: State, plot: Plot, db_path: str, llm) -> list[str]:
114
+ print(f"---- Find relevant tables for {plot['name']} ----")
115
+ relevant_tables = detect_relevant_tables(db_path, state['user_input'], plot, llm)
116
+ return relevant_tables
117
+
118
+
119
+ def find_param(state: State, param_name:str, table: str, db_path: str) -> dict[str, Any] | None:
120
+ """Perform the good method to retrieve the desired parameter
121
+
122
+ Args:
123
+ state (State): state of the workflow
124
+ param_name (str): name of the desired parameter
125
+ table (str): name of the table
126
+ db_path (str): path to the databse
127
+
128
+ Returns:
129
+ dict[str, Any] | None:
130
+ """
131
+ if param_name == 'location':
132
+ location = find_location(state['user_input'], table, db_path)
133
+ return location
134
+ if param_name == 'indicator_column':
135
+ indicator_column = find_indicator_column(table)
136
+ return {'indicator_column': indicator_column}
137
+ return None
138
+
139
+
140
+ class Location(TypedDict):
141
+ location: str
142
+ latitude: NotRequired[str]
143
+ longitude: NotRequired[str]
144
+
145
+ def find_location(user_input: str, table: str, db_path: str) -> Location:
146
+ print(f"---- Find location in table {table} ----")
147
+ location = detect_location_with_openai(user_input)
148
+ output: Location = {'location' : location}
149
+ if location:
150
+ coords = loc2coords(location)
151
+ neighbour = nearestNeighbourSQL(db_path, coords, table)
152
+ output.update({
153
+ "latitude": neighbour[0],
154
+ "longitude": neighbour[1],
155
+ })
156
+ return output
157
+
158
+ def find_indicator_column(table: str) -> str:
159
+ """Retrieve the name of the indicator column within the table in the database
160
+
161
+ Args:
162
+ table (str): name of the table
163
+
164
+ Returns:
165
+ str: name of the indicator column
166
+ """
167
+
168
+ print(f"---- Find indicator column in table {table} ----")
169
+ indicator_columns_per_table = {
170
+ "total_winter_precipitation": "total_winter_precipitation",
171
+ "total_summer_precipiation": "total_summer_precipitation",
172
+ "total_annual_precipitation": "total_annual_precipitation",
173
+ "total_remarkable_daily_precipitation": "total_remarkable_daily_precipitation",
174
+ "frequency_of_remarkable_daily_precipitation": "frequency_of_remarkable_daily_precipitation",
175
+ "extreme_precipitation_intensity": "extreme_precipitation_intensity",
176
+ "mean_winter_temperature": "mean_winter_temperature",
177
+ "mean_summer_temperature": "mean_summer_temperature",
178
+ "mean_annual_temperature": "mean_annual_temperature",
179
+ "number_of_tropical_nights": "number_tropical_nights",
180
+ "maximum_summer_temperature": "maximum_summer_temperature",
181
+ "number_of_days_with_TX_above_30": "number_of_days_with_tx_above_30",
182
+ "number_of_days_with_TX_above_35": "number_of_days_with_tx_above_35",
183
+ "number_of_days_with_a_dry_ground": "number_of_days_with_dry_ground"
184
+ }
185
+ return indicator_columns_per_table[table]
186
+
187
+ # def make_write_query_node():
188
+
189
+ # def write_query(state):
190
+ # print("---- Write query ----")
191
+ # for table in state["tables"]:
192
+ # sql_query = QUERIES[state[table]['query_type']](
193
+ # table=table,
194
+ # indicator_column=state[table]["columns"],
195
+ # longitude=state[table]["longitude"],
196
+ # latitude=state[table]["latitude"],
197
+ # )
198
+ # state[table].update({"sql_query": sql_query})
199
+
200
+ # return state
201
+
202
+ # return write_query
203
+
204
+ # def make_fetch_data_node(db_path):
205
+
206
+ # def fetch_data(state):
207
+ # print("---- Fetch data ----")
208
+ # for table in state["tables"]:
209
+ # results = execute_sql_query(db_path, state[table]['sql_query'])
210
+ # state[table].update(results)
211
+
212
+ # return state
213
+
214
+ # return fetch_data
215
+
216
+
217
+
218
+ ## V2
219
+
220
+
221
+ # def make_fetch_data_node(db_path: str, llm):
222
+ # def fetch_data(state):
223
+ # print("---- Fetch data ----")
224
+ # db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
225
+ # output = {}
226
+ # sql_query = write_sql_query(state["query"], db, state["tables"], llm)
227
+ # # TO DO : Add query checker
228
+ # print(f"SQL query : {sql_query}")
229
+ # output["sql_query"] = sql_query
230
+ # output.update(fetch_data_from_sql_query(db_path, sql_query))
231
+ # return output
232
+
233
+ # return fetch_data