armanddemasson commited on
Commit
705ccec
·
1 Parent(s): 8c7a7fe

feat: added multithreading to run sql queries in talk to drias

Browse files
climateqa/engine/talk_to_data/sql_query.py CHANGED
@@ -22,9 +22,10 @@ async def execute_sql_query(sql_query: str) -> pd.DataFrame:
22
  """
23
  def _execute_query():
24
  # Execute the query
25
- results = duckdb.sql(sql_query)
 
26
  # return fetched data
27
- return results.fetchdf()
28
 
29
  # Run the query in a thread pool to avoid blocking
30
  loop = asyncio.get_event_loop()
 
22
  """
23
  def _execute_query():
24
  # Execute the query
25
+ con = duckdb.connect()
26
+ results = con.sql(sql_query).fetchdf()
27
  # return fetched data
28
+ return results
29
 
30
  # Run the query in a thread pool to avoid blocking
31
  loop = asyncio.get_event_loop()
climateqa/engine/talk_to_data/{workflow.py → talk_to_drias.py} RENAMED
@@ -1,10 +1,12 @@
1
  import os
2
 
3
  from typing import Any, Callable, TypedDict, Optional
 
4
  import pandas as pd
5
-
6
  from plotly.graph_objects import Figure
7
  from climateqa.engine.llm import get_llm
 
8
  from climateqa.engine.talk_to_data.config import INDICATOR_COLUMNS_PER_TABLE
9
  from climateqa.engine.talk_to_data.plot import PLOTS, Plot
10
  from climateqa.engine.talk_to_data.sql_query import execute_sql_query
@@ -17,6 +19,7 @@ from climateqa.engine.talk_to_data.utils import (
17
  detect_relevant_tables,
18
  )
19
 
 
20
  ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))
21
 
22
  class TableState(TypedDict):
@@ -61,101 +64,6 @@ class State(TypedDict):
61
  plot_states: dict[str, PlotState]
62
  error: Optional[str]
63
 
64
- async def drias_workflow(user_input: str) -> State:
65
- """Performs the complete workflow of Talk To Drias : from user input to sql queries, dataframes and figures generated
66
-
67
- Args:
68
- user_input (str): initial user input
69
-
70
- Returns:
71
- State: Final state with all the results
72
- """
73
- state: State = {
74
- 'user_input': user_input,
75
- 'plots': [],
76
- 'plot_states': {}
77
- }
78
-
79
- llm = get_llm(provider="openai")
80
-
81
- plots = await find_relevant_plots(state, llm)
82
- state['plots'] = plots
83
-
84
- if not state['plots']:
85
- state['error'] = 'There is no plot to answer to the question'
86
- return state
87
-
88
- have_relevant_table = False
89
- have_sql_query = False
90
- have_dataframe = False
91
- for plot_name in state['plots']:
92
-
93
- plot = next((p for p in PLOTS if p['name'] == plot_name), None) # Find the associated plot object
94
- if plot is None:
95
- continue
96
-
97
- plot_state: PlotState = {
98
- 'plot_name': plot_name,
99
- 'tables': [],
100
- 'table_states': {}
101
- }
102
-
103
- plot_state['plot_name'] = plot_name
104
-
105
- relevant_tables = await find_relevant_tables_per_plot(state, plot, llm)
106
- if len(relevant_tables) > 0 :
107
- have_relevant_table = True
108
-
109
- plot_state['tables'] = relevant_tables
110
-
111
- params = {}
112
- for param_name in plot['params']:
113
- param = await find_param(state, param_name, relevant_tables[0])
114
- if param:
115
- params.update(param)
116
-
117
- for n, table in enumerate(plot_state['tables']):
118
- if n > 2:
119
- break
120
-
121
- table_state: TableState = {
122
- 'table_name': table,
123
- 'params': params,
124
- 'status': 'OK'
125
- }
126
-
127
- table_state["params"]['indicator_column'] = find_indicator_column(table)
128
-
129
- sql_query = plot['sql_query'](table, table_state['params'])
130
-
131
- if sql_query == "":
132
- table_state['status'] = 'ERROR'
133
- continue
134
- else :
135
- have_sql_query = True
136
-
137
- table_state['sql_query'] = sql_query
138
- df = await execute_sql_query(sql_query)
139
-
140
- if len(df) > 0:
141
- have_dataframe = True
142
-
143
- figure = plot['plot_function'](table_state['params'])
144
- table_state['dataframe'] = df
145
- table_state['figure'] = figure
146
- plot_state['table_states'][table] = table_state
147
-
148
- state['plot_states'][plot_name] = plot_state
149
-
150
- if not have_relevant_table:
151
- state['error'] = "There is no relevant table in the our database to answer your question"
152
- elif not have_sql_query:
153
- state['error'] = "There is no relevant sql query on our database that can help to answer your question"
154
- elif not have_dataframe:
155
- state['error'] = "There is no data in our table that can answer to your question"
156
-
157
- return state
158
-
159
  async def find_relevant_plots(state: State, llm) -> list[str]:
160
  print("---- Find relevant plots ----")
161
  relevant_plots = await detect_relevant_plots(state['user_input'], llm)
@@ -238,6 +146,130 @@ def find_indicator_column(table: str) -> str:
238
  return INDICATOR_COLUMNS_PER_TABLE[table]
239
 
240
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  # def make_write_query_node():
242
 
243
  # def write_query(state):
 
1
  import os
2
 
3
  from typing import Any, Callable, TypedDict, Optional
4
+ from numpy import sort
5
  import pandas as pd
6
+ import asyncio
7
  from plotly.graph_objects import Figure
8
  from climateqa.engine.llm import get_llm
9
+ from climateqa.engine.talk_to_data import sql_query
10
  from climateqa.engine.talk_to_data.config import INDICATOR_COLUMNS_PER_TABLE
11
  from climateqa.engine.talk_to_data.plot import PLOTS, Plot
12
  from climateqa.engine.talk_to_data.sql_query import execute_sql_query
 
19
  detect_relevant_tables,
20
  )
21
 
22
+
23
  ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))
24
 
25
  class TableState(TypedDict):
 
64
  plot_states: dict[str, PlotState]
65
  error: Optional[str]
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  async def find_relevant_plots(state: State, llm) -> list[str]:
68
  print("---- Find relevant plots ----")
69
  relevant_plots = await detect_relevant_plots(state['user_input'], llm)
 
146
  return INDICATOR_COLUMNS_PER_TABLE[table]
147
 
148
 
149
+ async def process_table(
150
+ table: str,
151
+ params: dict[str, Any],
152
+ plot: Plot,
153
+ ) -> TableState:
154
+ """Processes a table to extract relevant data and generate visualizations.
155
+
156
+ This function retrieves the SQL query for the specified table, executes it,
157
+ and generates a visualization based on the results.
158
+
159
+ Args:
160
+ table (str): The name of the table to process
161
+ params (dict[str, Any]): Parameters used for querying the table
162
+ plot (Plot): The plot object containing SQL query and visualization function
163
+
164
+ Returns:
165
+ TableState: The state of the processed table
166
+ """
167
+ table_state: TableState = {
168
+ 'table_name': table,
169
+ 'params': params.copy(),
170
+ 'status': 'OK',
171
+ 'dataframe': None,
172
+ 'sql_query': None,
173
+ 'figure': None
174
+ }
175
+
176
+ table_state['params']['indicator_column'] = find_indicator_column(table)
177
+ sql_query = plot['sql_query'](table, table_state['params'])
178
+
179
+ if sql_query == "":
180
+ table_state['status'] = 'ERROR'
181
+ return table_state
182
+ table_state['sql_query'] = sql_query
183
+ print(sql_query)
184
+ df = await execute_sql_query(sql_query)
185
+
186
+ table_state['dataframe'] = df
187
+ table_state['figure'] = plot['plot_function'](table_state['params'])
188
+
189
+ return table_state
190
+
191
+ async def drias_workflow(user_input: str) -> State:
192
+ """Performs the complete workflow of Talk To Drias : from user input to sql queries, dataframes and figures generated
193
+
194
+ Args:
195
+ user_input (str): initial user input
196
+
197
+ Returns:
198
+ State: Final state with all the results
199
+ """
200
+ state: State = {
201
+ 'user_input': user_input,
202
+ 'plots': [],
203
+ 'plot_states': {},
204
+ 'error': ''
205
+ }
206
+
207
+ llm = get_llm(provider="openai")
208
+
209
+ plots = await find_relevant_plots(state, llm)
210
+
211
+ state['plots'] = plots
212
+
213
+ if len(state['plots']) < 1:
214
+ state['error'] = 'There is no plot to answer to the question'
215
+ return state
216
+
217
+ have_relevant_table = False
218
+ have_sql_query = False
219
+ have_dataframe = False
220
+
221
+ for plot_name in state['plots']:
222
+
223
+ plot = next((p for p in PLOTS if p['name'] == plot_name), None) # Find the associated plot object
224
+ if plot is None:
225
+ continue
226
+
227
+ plot_state: PlotState = {
228
+ 'plot_name': plot_name,
229
+ 'tables': [],
230
+ 'table_states': {}
231
+ }
232
+
233
+ plot_state['plot_name'] = plot_name
234
+
235
+ relevant_tables = await find_relevant_tables_per_plot(state, plot, llm)
236
+
237
+ if len(relevant_tables) > 0 :
238
+ have_relevant_table = True
239
+
240
+ plot_state['tables'] = relevant_tables
241
+
242
+ params = {}
243
+ for param_name in plot['params']:
244
+ param = await find_param(state, param_name, relevant_tables[0])
245
+ if param:
246
+ params.update(param)
247
+
248
+ tasks = [process_table(table, params, plot) for table in plot_state['tables'][:3]]
249
+ results = await asyncio.gather(*tasks)
250
+
251
+ # Store results back in plot_state
252
+ have_dataframe = False
253
+ have_sql_query = False
254
+ for table_state in results:
255
+ print(table_state)
256
+ if table_state['sql_query']:
257
+ have_sql_query = True
258
+ if table_state['dataframe'] is not None and len(table_state['dataframe']) > 0:
259
+ have_dataframe = True
260
+ plot_state['table_states'][table_state['table_name']] = table_state
261
+
262
+ state['plot_states'][plot_name] = plot_state
263
+
264
+ if not have_relevant_table:
265
+ state['error'] = "There is no relevant table in our database to answer your question"
266
+ elif not have_sql_query:
267
+ state['error'] = "There is no relevant sql query on our database that can help to answer your question"
268
+ elif not have_dataframe:
269
+ state['error'] = "There is no data in our table that can answer to your question"
270
+
271
+ return state
272
+
273
  # def make_write_query_node():
274
 
275
  # def write_query(state):