timeki commited on
Commit
5fe1543
·
1 Parent(s): 45a9320

make ask drias asynchronous

Browse files
climateqa/engine/talk_to_data/main.py CHANGED
@@ -37,7 +37,7 @@ def ask_llm_column_names(sql_query: str, llm) -> list[str]:
37
  columns_list = ast.literal_eval(columns.strip("```python\n").strip())
38
  return columns_list
39
 
40
- def ask_drias(query: str, index_state: int = 0) -> tuple:
41
  """Main function to process a DRIAS query and return results.
42
 
43
  This function orchestrates the DRIAS workflow, processing a user query to generate
@@ -60,7 +60,7 @@ def ask_drias(query: str, index_state: int = 0) -> tuple:
60
  - table_list (list): List of table names used
61
  - error (str): Error message if any
62
  """
63
- final_state = drias_workflow(query)
64
  sql_queries = []
65
  result_dataframes = []
66
  figures = []
 
37
  columns_list = ast.literal_eval(columns.strip("```python\n").strip())
38
  return columns_list
39
 
40
+ async def ask_drias(query: str, index_state: int = 0) -> tuple:
41
  """Main function to process a DRIAS query and return results.
42
 
43
  This function orchestrates the DRIAS workflow, processing a user query to generate
 
60
  - table_list (list): List of table names used
61
  - error (str): Error message if any
62
  """
63
+ final_state = await drias_workflow(query)
64
  sql_queries = []
65
  result_dataframes = []
66
  figures = []
climateqa/engine/talk_to_data/sql_query.py CHANGED
@@ -1,8 +1,10 @@
 
 
1
  from typing import TypedDict
2
  import duckdb
3
  import pandas as pd
4
 
5
- def execute_sql_query(sql_query: str) -> pd.DataFrame:
6
  """Executes a SQL query on the DRIAS database and returns the results.
7
 
8
  This function connects to the DuckDB database containing DRIAS climate data
@@ -18,11 +20,16 @@ def execute_sql_query(sql_query: str) -> pd.DataFrame:
18
  Raises:
19
  duckdb.Error: If there is an error executing the SQL query
20
  """
21
- # Execute the query
22
- results = duckdb.sql(sql_query)
 
 
 
23
 
24
- # return fetched data
25
- return results.fetchdf()
 
 
26
 
27
 
28
  class IndicatorPerYearAtLocationQueryParams(TypedDict, total=False):
 
1
+ import asyncio
2
+ from concurrent.futures import ThreadPoolExecutor
3
  from typing import TypedDict
4
  import duckdb
5
  import pandas as pd
6
 
7
+ async def execute_sql_query(sql_query: str) -> pd.DataFrame:
8
  """Executes a SQL query on the DRIAS database and returns the results.
9
 
10
  This function connects to the DuckDB database containing DRIAS climate data
 
20
  Raises:
21
  duckdb.Error: If there is an error executing the SQL query
22
  """
23
+ def _execute_query():
24
+ # Execute the query
25
+ results = duckdb.sql(sql_query)
26
+ # return fetched data
27
+ return results.fetchdf()
28
 
29
+ # Run the query in a thread pool to avoid blocking
30
+ loop = asyncio.get_event_loop()
31
+ with ThreadPoolExecutor() as executor:
32
+ return await loop.run_in_executor(executor, _execute_query)
33
 
34
 
35
  class IndicatorPerYearAtLocationQueryParams(TypedDict, total=False):
climateqa/engine/talk_to_data/utils.py CHANGED
@@ -9,7 +9,7 @@ from climateqa.engine.talk_to_data.plot import PLOTS, Plot
9
  from langchain_core.prompts import ChatPromptTemplate
10
 
11
 
12
- def detect_location_with_openai(sentence):
13
  """
14
  Detects locations in a sentence using OpenAI's API via LangChain.
15
  """
@@ -22,7 +22,7 @@ def detect_location_with_openai(sentence):
22
  Sentence: "{sentence}"
23
  """
24
 
25
- response = llm.invoke(prompt)
26
  location_list = ast.literal_eval(response.content.strip("```python\n").strip())
27
  if location_list:
28
  return location_list[0]
@@ -40,7 +40,7 @@ class ArrayOutput(TypedDict):
40
  """
41
  array: Annotated[str, "Syntactically valid python array."]
42
 
43
- def detect_year_with_openai(sentence: str) -> str:
44
  """
45
  Detects years in a sentence using OpenAI's API via LangChain.
46
  """
@@ -56,7 +56,7 @@ def detect_year_with_openai(sentence: str) -> str:
56
  prompt = ChatPromptTemplate.from_template(prompt)
57
  structured_llm = llm.with_structured_output(ArrayOutput)
58
  chain = prompt | structured_llm
59
- response: ArrayOutput = chain.invoke({"sentence": sentence})
60
  years_list = eval(response['array'])
61
  if len(years_list) > 0:
62
  return years_list[0]
@@ -146,7 +146,7 @@ def nearestNeighbourSQL(location: tuple, table: str) -> tuple[str, str]:
146
  return results['latitude'].iloc[0], results['longitude'].iloc[0]
147
 
148
 
149
- def detect_relevant_tables(user_question: str, plot: Plot, llm) -> list[str]:
150
  """Identifies relevant tables for a plot based on user input.
151
 
152
  This function uses an LLM to analyze the user's question and the plot
@@ -183,7 +183,7 @@ def detect_relevant_tables(user_question: str, plot: Plot, llm) -> list[str]:
183
  )
184
 
185
  table_names = ast.literal_eval(
186
- llm.invoke(prompt).content.strip("```python\n").strip()
187
  )
188
  return table_names
189
 
@@ -197,7 +197,7 @@ def replace_coordonates(coords, query, coords_tables):
197
  return query
198
 
199
 
200
- def detect_relevant_plots(user_question: str, llm):
201
  plots_description = ""
202
  for plot in PLOTS:
203
  plots_description += "Name: " + plot["name"]
@@ -223,7 +223,7 @@ def detect_relevant_plots(user_question: str, llm):
223
  # )
224
 
225
  plot_names = ast.literal_eval(
226
- llm.invoke(prompt).content.strip("```python\n").strip()
227
  )
228
  return plot_names
229
 
 
9
  from langchain_core.prompts import ChatPromptTemplate
10
 
11
 
12
+ async def detect_location_with_openai(sentence):
13
  """
14
  Detects locations in a sentence using OpenAI's API via LangChain.
15
  """
 
22
  Sentence: "{sentence}"
23
  """
24
 
25
+ response = await llm.ainvoke(prompt)
26
  location_list = ast.literal_eval(response.content.strip("```python\n").strip())
27
  if location_list:
28
  return location_list[0]
 
40
  """
41
  array: Annotated[str, "Syntactically valid python array."]
42
 
43
+ async def detect_year_with_openai(sentence: str) -> str:
44
  """
45
  Detects years in a sentence using OpenAI's API via LangChain.
46
  """
 
56
  prompt = ChatPromptTemplate.from_template(prompt)
57
  structured_llm = llm.with_structured_output(ArrayOutput)
58
  chain = prompt | structured_llm
59
+ response: ArrayOutput = await chain.ainvoke({"sentence": sentence})
60
  years_list = eval(response['array'])
61
  if len(years_list) > 0:
62
  return years_list[0]
 
146
  return results['latitude'].iloc[0], results['longitude'].iloc[0]
147
 
148
 
149
+ async def detect_relevant_tables(user_question: str, plot: Plot, llm) -> list[str]:
150
  """Identifies relevant tables for a plot based on user input.
151
 
152
  This function uses an LLM to analyze the user's question and the plot
 
183
  )
184
 
185
  table_names = ast.literal_eval(
186
+ (await llm.ainvoke(prompt)).content.strip("```python\n").strip()
187
  )
188
  return table_names
189
 
 
197
  return query
198
 
199
 
200
+ async def detect_relevant_plots(user_question: str, llm):
201
  plots_description = ""
202
  for plot in PLOTS:
203
  plots_description += "Name: " + plot["name"]
 
223
  # )
224
 
225
  plot_names = ast.literal_eval(
226
+ (await llm.ainvoke(prompt)).content.strip("```python\n").strip()
227
  )
228
  return plot_names
229
 
climateqa/engine/talk_to_data/workflow.py CHANGED
@@ -61,7 +61,7 @@ class State(TypedDict):
61
  plot_states: dict[str, PlotState]
62
  error: NotRequired[str]
63
 
64
- def drias_workflow(user_input: str) -> State:
65
  """Performs the complete workflow of Talk To Drias : from user input to sql queries, dataframes and figures generated
66
 
67
  Args:
@@ -78,7 +78,7 @@ def drias_workflow(user_input: str) -> State:
78
 
79
  llm = get_llm(provider="openai")
80
 
81
- plots = find_relevant_plots(state, llm)
82
  state['plots'] = plots
83
 
84
  if not state['plots']:
@@ -102,7 +102,7 @@ def drias_workflow(user_input: str) -> State:
102
 
103
  plot_state['plot_name'] = plot_name
104
 
105
- relevant_tables = find_relevant_tables_per_plot(state, plot, llm)
106
  if len(relevant_tables) > 0 :
107
  have_relevant_table = True
108
 
@@ -110,7 +110,7 @@ def drias_workflow(user_input: str) -> State:
110
 
111
  params = {}
112
  for param_name in plot['params']:
113
- param = find_param(state, param_name, relevant_tables[0])
114
  if param:
115
  params.update(param)
116
 
@@ -135,7 +135,7 @@ def drias_workflow(user_input: str) -> State:
135
  have_sql_query = True
136
 
137
  table_state['sql_query'] = sql_query
138
- df = execute_sql_query(sql_query)
139
 
140
  if len(df) > 0:
141
  have_dataframe = True
@@ -154,22 +154,19 @@ def drias_workflow(user_input: str) -> State:
154
  elif not have_dataframe:
155
  state['error'] = "There is no data in our table that can answer to your question"
156
 
157
-
158
  return state
159
 
160
-
161
- def find_relevant_plots(state: State, llm) -> list[str]:
162
  print("---- Find relevant plots ----")
163
- relevant_plots = detect_relevant_plots(state['user_input'], llm)
164
  return relevant_plots
165
 
166
- def find_relevant_tables_per_plot(state: State, plot: Plot, llm) -> list[str]:
167
  print(f"---- Find relevant tables for {plot['name']} ----")
168
- relevant_tables = detect_relevant_tables(state['user_input'], plot, llm)
169
  return relevant_tables
170
 
171
-
172
- def find_param(state: State, param_name:str, table: str) -> dict[str, Any] | None:
173
  """Perform the good method to retrieve the desired parameter
174
 
175
  Args:
@@ -181,25 +178,21 @@ def find_param(state: State, param_name:str, table: str) -> dict[str, Any] | Non
181
  dict[str, Any] | None:
182
  """
183
  if param_name == 'location':
184
- location = find_location(state['user_input'], table)
185
  return location
186
- # if param_name == 'indicator_column':
187
- # indicator_column = find_indicator_column(table)
188
- # return {'indicator_column': indicator_column}
189
  if param_name == 'year':
190
- year = find_year(state['user_input'])
191
  return {'year': year}
192
  return None
193
 
194
-
195
  class Location(TypedDict):
196
  location: str
197
  latitude: NotRequired[str]
198
  longitude: NotRequired[str]
199
 
200
- def find_location(user_input: str, table: str) -> Location:
201
  print(f"---- Find location in table {table} ----")
202
- location = detect_location_with_openai(user_input)
203
  output: Location = {'location' : location}
204
  if location:
205
  coords = loc2coords(location)
@@ -210,7 +203,7 @@ def find_location(user_input: str, table: str) -> Location:
210
  })
211
  return output
212
 
213
- def find_year(user_input: str) -> str:
214
  """Extracts year information from user input using LLM.
215
 
216
  This function uses an LLM to identify and extract year information from the
@@ -223,7 +216,7 @@ def find_year(user_input: str) -> str:
223
  str: The extracted year, or empty string if no year found
224
  """
225
  print(f"---- Find year ---")
226
- year = detect_year_with_openai(user_input)
227
  return year
228
 
229
  def find_indicator_column(table: str) -> str:
 
61
  plot_states: dict[str, PlotState]
62
  error: NotRequired[str]
63
 
64
+ async def drias_workflow(user_input: str) -> State:
65
  """Performs the complete workflow of Talk To Drias : from user input to sql queries, dataframes and figures generated
66
 
67
  Args:
 
78
 
79
  llm = get_llm(provider="openai")
80
 
81
+ plots = await find_relevant_plots(state, llm)
82
  state['plots'] = plots
83
 
84
  if not state['plots']:
 
102
 
103
  plot_state['plot_name'] = plot_name
104
 
105
+ relevant_tables = await find_relevant_tables_per_plot(state, plot, llm)
106
  if len(relevant_tables) > 0 :
107
  have_relevant_table = True
108
 
 
110
 
111
  params = {}
112
  for param_name in plot['params']:
113
+ param = await find_param(state, param_name, relevant_tables[0])
114
  if param:
115
  params.update(param)
116
 
 
135
  have_sql_query = True
136
 
137
  table_state['sql_query'] = sql_query
138
+ df = await execute_sql_query(sql_query)
139
 
140
  if len(df) > 0:
141
  have_dataframe = True
 
154
  elif not have_dataframe:
155
  state['error'] = "There is no data in our table that can answer to your question"
156
 
 
157
  return state
158
 
159
+ async def find_relevant_plots(state: State, llm) -> list[str]:
 
160
  print("---- Find relevant plots ----")
161
+ relevant_plots = await detect_relevant_plots(state['user_input'], llm)
162
  return relevant_plots
163
 
164
+ async def find_relevant_tables_per_plot(state: State, plot: Plot, llm) -> list[str]:
165
  print(f"---- Find relevant tables for {plot['name']} ----")
166
+ relevant_tables = await detect_relevant_tables(state['user_input'], plot, llm)
167
  return relevant_tables
168
 
169
+ async def find_param(state: State, param_name:str, table: str) -> dict[str, Any] | None:
 
170
  """Perform the good method to retrieve the desired parameter
171
 
172
  Args:
 
178
  dict[str, Any] | None:
179
  """
180
  if param_name == 'location':
181
+ location = await find_location(state['user_input'], table)
182
  return location
 
 
 
183
  if param_name == 'year':
184
+ year = await find_year(state['user_input'])
185
  return {'year': year}
186
  return None
187
 
 
188
  class Location(TypedDict):
189
  location: str
190
  latitude: NotRequired[str]
191
  longitude: NotRequired[str]
192
 
193
+ async def find_location(user_input: str, table: str) -> Location:
194
  print(f"---- Find location in table {table} ----")
195
+ location = await detect_location_with_openai(user_input)
196
  output: Location = {'location' : location}
197
  if location:
198
  coords = loc2coords(location)
 
203
  })
204
  return output
205
 
206
+ async def find_year(user_input: str) -> str:
207
  """Extracts year information from user input using LLM.
208
 
209
  This function uses an LLM to identify and extract year information from the
 
216
  str: The extracted year, or empty string if no year found
217
  """
218
  print(f"---- Find year ---")
219
+ year = await detect_year_with_openai(user_input)
220
  return year
221
 
222
  def find_indicator_column(table: str) -> str:
front/tabs/tab_drias.py CHANGED
@@ -4,8 +4,8 @@ from climateqa.engine.talk_to_data.main import ask_drias
4
  from climateqa.engine.talk_to_data.config import DRIAS_MODELS, DRIAS_UI_TEXT
5
 
6
 
7
- def ask_drias_query(query: str, index_state: int):
8
- return ask_drias(query, index_state)
9
 
10
 
11
  def show_results(sql_queries_state, dataframes_state, plots_state):
 
4
  from climateqa.engine.talk_to_data.config import DRIAS_MODELS, DRIAS_UI_TEXT
5
 
6
 
7
+ async def ask_drias_query(query: str, index_state: int):
8
+ return await ask_drias(query, index_state)
9
 
10
 
11
  def show_results(sql_queries_state, dataframes_state, plots_state):