armanddemasson commited on
Commit
7c38528
·
2 Parent(s): 8c7a7fe a967134

Merged in feat/improve_drias_exeuction_time (pull request #6)

Browse files
climateqa/engine/talk_to_data/sql_query.py CHANGED
@@ -22,9 +22,10 @@ async def execute_sql_query(sql_query: str) -> pd.DataFrame:
22
  """
23
  def _execute_query():
24
  # Execute the query
25
- results = duckdb.sql(sql_query)
 
26
  # return fetched data
27
- return results.fetchdf()
28
 
29
  # Run the query in a thread pool to avoid blocking
30
  loop = asyncio.get_event_loop()
 
22
  """
23
  def _execute_query():
24
  # Execute the query
25
+ con = duckdb.connect()
26
+ results = con.sql(sql_query).fetchdf()
27
  # return fetched data
28
+ return results
29
 
30
  # Run the query in a thread pool to avoid blocking
31
  loop = asyncio.get_event_loop()
climateqa/engine/talk_to_data/{workflow.py → talk_to_drias.py} RENAMED
@@ -1,10 +1,12 @@
1
  import os
2
 
3
  from typing import Any, Callable, TypedDict, Optional
 
4
  import pandas as pd
5
-
6
  from plotly.graph_objects import Figure
7
  from climateqa.engine.llm import get_llm
 
8
  from climateqa.engine.talk_to_data.config import INDICATOR_COLUMNS_PER_TABLE
9
  from climateqa.engine.talk_to_data.plot import PLOTS, Plot
10
  from climateqa.engine.talk_to_data.sql_query import execute_sql_query
@@ -17,6 +19,7 @@ from climateqa.engine.talk_to_data.utils import (
17
  detect_relevant_tables,
18
  )
19
 
 
20
  ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))
21
 
22
  class TableState(TypedDict):
@@ -61,101 +64,6 @@ class State(TypedDict):
61
  plot_states: dict[str, PlotState]
62
  error: Optional[str]
63
 
64
- async def drias_workflow(user_input: str) -> State:
65
- """Performs the complete workflow of Talk To Drias : from user input to sql queries, dataframes and figures generated
66
-
67
- Args:
68
- user_input (str): initial user input
69
-
70
- Returns:
71
- State: Final state with all the results
72
- """
73
- state: State = {
74
- 'user_input': user_input,
75
- 'plots': [],
76
- 'plot_states': {}
77
- }
78
-
79
- llm = get_llm(provider="openai")
80
-
81
- plots = await find_relevant_plots(state, llm)
82
- state['plots'] = plots
83
-
84
- if not state['plots']:
85
- state['error'] = 'There is no plot to answer to the question'
86
- return state
87
-
88
- have_relevant_table = False
89
- have_sql_query = False
90
- have_dataframe = False
91
- for plot_name in state['plots']:
92
-
93
- plot = next((p for p in PLOTS if p['name'] == plot_name), None) # Find the associated plot object
94
- if plot is None:
95
- continue
96
-
97
- plot_state: PlotState = {
98
- 'plot_name': plot_name,
99
- 'tables': [],
100
- 'table_states': {}
101
- }
102
-
103
- plot_state['plot_name'] = plot_name
104
-
105
- relevant_tables = await find_relevant_tables_per_plot(state, plot, llm)
106
- if len(relevant_tables) > 0 :
107
- have_relevant_table = True
108
-
109
- plot_state['tables'] = relevant_tables
110
-
111
- params = {}
112
- for param_name in plot['params']:
113
- param = await find_param(state, param_name, relevant_tables[0])
114
- if param:
115
- params.update(param)
116
-
117
- for n, table in enumerate(plot_state['tables']):
118
- if n > 2:
119
- break
120
-
121
- table_state: TableState = {
122
- 'table_name': table,
123
- 'params': params,
124
- 'status': 'OK'
125
- }
126
-
127
- table_state["params"]['indicator_column'] = find_indicator_column(table)
128
-
129
- sql_query = plot['sql_query'](table, table_state['params'])
130
-
131
- if sql_query == "":
132
- table_state['status'] = 'ERROR'
133
- continue
134
- else :
135
- have_sql_query = True
136
-
137
- table_state['sql_query'] = sql_query
138
- df = await execute_sql_query(sql_query)
139
-
140
- if len(df) > 0:
141
- have_dataframe = True
142
-
143
- figure = plot['plot_function'](table_state['params'])
144
- table_state['dataframe'] = df
145
- table_state['figure'] = figure
146
- plot_state['table_states'][table] = table_state
147
-
148
- state['plot_states'][plot_name] = plot_state
149
-
150
- if not have_relevant_table:
151
- state['error'] = "There is no relevant table in the our database to answer your question"
152
- elif not have_sql_query:
153
- state['error'] = "There is no relevant sql query on our database that can help to answer your question"
154
- elif not have_dataframe:
155
- state['error'] = "There is no data in our table that can answer to your question"
156
-
157
- return state
158
-
159
  async def find_relevant_plots(state: State, llm) -> list[str]:
160
  print("---- Find relevant plots ----")
161
  relevant_plots = await detect_relevant_plots(state['user_input'], llm)
@@ -238,6 +146,128 @@ def find_indicator_column(table: str) -> str:
238
  return INDICATOR_COLUMNS_PER_TABLE[table]
239
 
240
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  # def make_write_query_node():
242
 
243
  # def write_query(state):
 
1
  import os
2
 
3
  from typing import Any, Callable, TypedDict, Optional
4
+ from numpy import sort
5
  import pandas as pd
6
+ import asyncio
7
  from plotly.graph_objects import Figure
8
  from climateqa.engine.llm import get_llm
9
+ from climateqa.engine.talk_to_data import sql_query
10
  from climateqa.engine.talk_to_data.config import INDICATOR_COLUMNS_PER_TABLE
11
  from climateqa.engine.talk_to_data.plot import PLOTS, Plot
12
  from climateqa.engine.talk_to_data.sql_query import execute_sql_query
 
19
  detect_relevant_tables,
20
  )
21
 
22
+
23
  ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))
24
 
25
  class TableState(TypedDict):
 
64
  plot_states: dict[str, PlotState]
65
  error: Optional[str]
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  async def find_relevant_plots(state: State, llm) -> list[str]:
68
  print("---- Find relevant plots ----")
69
  relevant_plots = await detect_relevant_plots(state['user_input'], llm)
 
146
  return INDICATOR_COLUMNS_PER_TABLE[table]
147
 
148
 
149
+ async def process_table(
150
+ table: str,
151
+ params: dict[str, Any],
152
+ plot: Plot,
153
+ ) -> TableState:
154
+ """Processes a table to extract relevant data and generate visualizations.
155
+
156
+ This function retrieves the SQL query for the specified table, executes it,
157
+ and generates a visualization based on the results.
158
+
159
+ Args:
160
+ table (str): The name of the table to process
161
+ params (dict[str, Any]): Parameters used for querying the table
162
+ plot (Plot): The plot object containing SQL query and visualization function
163
+
164
+ Returns:
165
+ TableState: The state of the processed table
166
+ """
167
+ table_state: TableState = {
168
+ 'table_name': table,
169
+ 'params': params.copy(),
170
+ 'status': 'OK',
171
+ 'dataframe': None,
172
+ 'sql_query': None,
173
+ 'figure': None
174
+ }
175
+
176
+ table_state['params']['indicator_column'] = find_indicator_column(table)
177
+ sql_query = plot['sql_query'](table, table_state['params'])
178
+
179
+ if sql_query == "":
180
+ table_state['status'] = 'ERROR'
181
+ return table_state
182
+ table_state['sql_query'] = sql_query
183
+ df = await execute_sql_query(sql_query)
184
+
185
+ table_state['dataframe'] = df
186
+ table_state['figure'] = plot['plot_function'](table_state['params'])
187
+
188
+ return table_state
189
+
190
+ async def drias_workflow(user_input: str) -> State:
191
+ """Performs the complete workflow of Talk To Drias : from user input to sql queries, dataframes and figures generated
192
+
193
+ Args:
194
+ user_input (str): initial user input
195
+
196
+ Returns:
197
+ State: Final state with all the results
198
+ """
199
+ state: State = {
200
+ 'user_input': user_input,
201
+ 'plots': [],
202
+ 'plot_states': {},
203
+ 'error': ''
204
+ }
205
+
206
+ llm = get_llm(provider="openai")
207
+
208
+ plots = await find_relevant_plots(state, llm)
209
+
210
+ state['plots'] = plots
211
+
212
+ if len(state['plots']) < 1:
213
+ state['error'] = 'There is no plot to answer to the question'
214
+ return state
215
+
216
+ have_relevant_table = False
217
+ have_sql_query = False
218
+ have_dataframe = False
219
+
220
+ for plot_name in state['plots']:
221
+
222
+ plot = next((p for p in PLOTS if p['name'] == plot_name), None) # Find the associated plot object
223
+ if plot is None:
224
+ continue
225
+
226
+ plot_state: PlotState = {
227
+ 'plot_name': plot_name,
228
+ 'tables': [],
229
+ 'table_states': {}
230
+ }
231
+
232
+ plot_state['plot_name'] = plot_name
233
+
234
+ relevant_tables = await find_relevant_tables_per_plot(state, plot, llm)
235
+
236
+ if len(relevant_tables) > 0 :
237
+ have_relevant_table = True
238
+
239
+ plot_state['tables'] = relevant_tables
240
+
241
+ params = {}
242
+ for param_name in plot['params']:
243
+ param = await find_param(state, param_name, relevant_tables[0])
244
+ if param:
245
+ params.update(param)
246
+
247
+ tasks = [process_table(table, params, plot) for table in plot_state['tables'][:3]]
248
+ results = await asyncio.gather(*tasks)
249
+
250
+ # Store results back in plot_state
251
+ have_dataframe = False
252
+ have_sql_query = False
253
+ for table_state in results:
254
+ if table_state['sql_query']:
255
+ have_sql_query = True
256
+ if table_state['dataframe'] is not None and len(table_state['dataframe']) > 0:
257
+ have_dataframe = True
258
+ plot_state['table_states'][table_state['table_name']] = table_state
259
+
260
+ state['plot_states'][plot_name] = plot_state
261
+
262
+ if not have_relevant_table:
263
+ state['error'] = "There is no relevant table in our database to answer your question"
264
+ elif not have_sql_query:
265
+ state['error'] = "There is no relevant sql query on our database that can help to answer your question"
266
+ elif not have_dataframe:
267
+ state['error'] = "There is no data in our table that can answer to your question"
268
+
269
+ return state
270
+
271
  # def make_write_query_node():
272
 
273
  # def write_query(state):