armanddemasson commited on
Commit
46c1e34
Β·
1 Parent(s): 14d7085

refactor: modularized talk to data

Browse files
climateqa/engine/talk_to_data/{config.py β†’ drias/config.py} RENAMED
@@ -1,3 +1,4 @@
 
1
  DRIAS_TABLES = [
2
  "total_winter_precipitation",
3
  "total_summer_precipiation",
@@ -15,7 +16,7 @@ DRIAS_TABLES = [
15
  "number_of_days_with_a_dry_ground",
16
  ]
17
 
18
- INDICATOR_COLUMNS_PER_TABLE = {
19
  "total_winter_precipitation": "total_winter_precipitation",
20
  "total_summer_precipiation": "total_summer_precipitation",
21
  "total_annual_precipitation": "total_annual_precipitation",
@@ -52,7 +53,7 @@ DRIAS_MODELS = [
52
  'CCLM4-8-17_HadGEM2-ES'
53
  ]
54
  # Mapping between indicator columns and their units
55
- INDICATOR_TO_UNIT = {
56
  "total_winter_precipitation": "mm",
57
  "total_summer_precipitation": "mm",
58
  "total_annual_precipitation": "mm",
 
1
+
2
  DRIAS_TABLES = [
3
  "total_winter_precipitation",
4
  "total_summer_precipiation",
 
16
  "number_of_days_with_a_dry_ground",
17
  ]
18
 
19
+ DRIAS_INDICATOR_COLUMNS_PER_TABLE = {
20
  "total_winter_precipitation": "total_winter_precipitation",
21
  "total_summer_precipiation": "total_summer_precipitation",
22
  "total_annual_precipitation": "total_annual_precipitation",
 
53
  'CCLM4-8-17_HadGEM2-ES'
54
  ]
55
  # Mapping between indicator columns and their units
56
+ DRIAS_INDICATOR_TO_UNIT = {
57
  "total_winter_precipitation": "mm",
58
  "total_summer_precipitation": "mm",
59
  "total_annual_precipitation": "mm",
climateqa/engine/talk_to_data/{plot.py β†’ drias/plots.py} RENAMED
@@ -1,38 +1,15 @@
1
- from typing import Callable, TypedDict
2
- from matplotlib.figure import figaspect
 
3
  import pandas as pd
4
  from plotly.graph_objects import Figure
5
  import plotly.graph_objects as go
6
- import plotly.express as px
7
-
8
- from climateqa.engine.talk_to_data.sql_query import (
9
  indicator_for_given_year_query,
10
  indicator_per_year_at_location_query,
11
  )
12
- from climateqa.engine.talk_to_data.config import INDICATOR_TO_UNIT
13
-
14
-
15
-
16
-
17
- class Plot(TypedDict):
18
- """Represents a plot configuration in the DRIAS system.
19
-
20
- This class defines the structure for configuring different types of plots
21
- that can be generated from climate data.
22
-
23
- Attributes:
24
- name (str): The name of the plot type
25
- description (str): A description of what the plot shows
26
- params (list[str]): List of required parameters for the plot
27
- plot_function (Callable[..., Callable[..., Figure]]): Function to generate the plot
28
- sql_query (Callable[..., str]): Function to generate the SQL query for the plot
29
- """
30
- name: str
31
- description: str
32
- params: list[str]
33
- plot_function: Callable[..., Callable[..., Figure]]
34
- sql_query: Callable[..., str]
35
-
36
 
37
  def plot_indicator_evolution_at_location(params: dict) -> Callable[..., Figure]:
38
  """Generates a function to plot indicator evolution over time at a location.
@@ -61,7 +38,7 @@ def plot_indicator_evolution_at_location(params: dict) -> Callable[..., Figure]:
61
  indicator = params["indicator_column"]
62
  location = params["location"]
63
  indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
64
- unit = INDICATOR_TO_UNIT.get(indicator, "")
65
 
66
  def plot_data(df: pd.DataFrame) -> Figure:
67
  """Generates the actual plot from the data.
@@ -184,7 +161,7 @@ def plot_indicator_number_of_days_per_year_at_location(
184
  indicator = params["indicator_column"]
185
  location = params["location"]
186
  indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
187
- unit = INDICATOR_TO_UNIT.get(indicator, "")
188
 
189
  def plot_data(df: pd.DataFrame) -> Figure:
190
  """Generate the figure thanks to the dataframe
@@ -266,7 +243,7 @@ def plot_distribution_of_indicator_for_given_year(
266
  indicator = params["indicator_column"]
267
  year = params["year"]
268
  indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
269
- unit = INDICATOR_TO_UNIT.get(indicator, "")
270
 
271
  def plot_data(df: pd.DataFrame) -> Figure:
272
  """Generate the figure thanks to the dataframe
@@ -347,7 +324,7 @@ def plot_map_of_france_of_indicator_for_given_year(
347
  indicator = params["indicator_column"]
348
  year = params["year"]
349
  indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
350
- unit = INDICATOR_TO_UNIT.get(indicator, "")
351
 
352
  def plot_data(df: pd.DataFrame) -> Figure:
353
  fig = go.Figure()
@@ -409,10 +386,9 @@ map_of_france_of_indicator_for_given_year: Plot = {
409
  "sql_query": indicator_for_given_year_query,
410
  }
411
 
412
-
413
- PLOTS = [
414
  indicator_evolution_at_location,
415
  indicator_number_of_days_per_year_at_location,
416
  distribution_of_indicator_for_given_year,
417
  map_of_france_of_indicator_for_given_year,
418
- ]
 
1
+ import os
2
+
3
+ from typing import Callable
4
  import pandas as pd
5
  from plotly.graph_objects import Figure
6
  import plotly.graph_objects as go
7
+ from climateqa.engine.talk_to_data.objects.plot import Plot
8
+ from climateqa.engine.talk_to_data.drias.queries import (
 
9
  indicator_for_given_year_query,
10
  indicator_per_year_at_location_query,
11
  )
12
+ from climateqa.engine.talk_to_data.drias.config import DRIAS_INDICATOR_TO_UNIT
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  def plot_indicator_evolution_at_location(params: dict) -> Callable[..., Figure]:
15
  """Generates a function to plot indicator evolution over time at a location.
 
38
  indicator = params["indicator_column"]
39
  location = params["location"]
40
  indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
41
+ unit = DRIAS_INDICATOR_TO_UNIT.get(indicator, "")
42
 
43
  def plot_data(df: pd.DataFrame) -> Figure:
44
  """Generates the actual plot from the data.
 
161
  indicator = params["indicator_column"]
162
  location = params["location"]
163
  indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
164
+ unit = DRIAS_INDICATOR_TO_UNIT.get(indicator, "")
165
 
166
  def plot_data(df: pd.DataFrame) -> Figure:
167
  """Generate the figure thanks to the dataframe
 
243
  indicator = params["indicator_column"]
244
  year = params["year"]
245
  indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
246
+ unit = DRIAS_INDICATOR_TO_UNIT.get(indicator, "")
247
 
248
  def plot_data(df: pd.DataFrame) -> Figure:
249
  """Generate the figure thanks to the dataframe
 
324
  indicator = params["indicator_column"]
325
  year = params["year"]
326
  indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
327
+ unit = DRIAS_INDICATOR_TO_UNIT.get(indicator, "")
328
 
329
  def plot_data(df: pd.DataFrame) -> Figure:
330
  fig = go.Figure()
 
386
  "sql_query": indicator_for_given_year_query,
387
  }
388
 
389
+ DRIAS_PLOTS = [
 
390
  indicator_evolution_at_location,
391
  indicator_number_of_days_per_year_at_location,
392
  distribution_of_indicator_for_given_year,
393
  map_of_france_of_indicator_for_given_year,
394
+ ]
climateqa/engine/talk_to_data/{sql_query.py β†’ drias/queries.py} RENAMED
@@ -1,37 +1,4 @@
1
- import asyncio
2
- from concurrent.futures import ThreadPoolExecutor
3
  from typing import TypedDict
4
- import duckdb
5
- import pandas as pd
6
-
7
- async def execute_sql_query(sql_query: str) -> pd.DataFrame:
8
- """Executes a SQL query on the DRIAS database and returns the results.
9
-
10
- This function connects to the DuckDB database containing DRIAS climate data
11
- and executes the provided SQL query. It handles the database connection and
12
- returns the results as a pandas DataFrame.
13
-
14
- Args:
15
- sql_query (str): The SQL query to execute
16
-
17
- Returns:
18
- pd.DataFrame: A DataFrame containing the query results
19
-
20
- Raises:
21
- duckdb.Error: If there is an error executing the SQL query
22
- """
23
- def _execute_query():
24
- # Execute the query
25
- con = duckdb.connect()
26
- results = con.sql(sql_query).fetchdf()
27
- # return fetched data
28
- return results
29
-
30
- # Run the query in a thread pool to avoid blocking
31
- loop = asyncio.get_event_loop()
32
- with ThreadPoolExecutor() as executor:
33
- return await loop.run_in_executor(executor, _execute_query)
34
-
35
 
36
  class IndicatorPerYearAtLocationQueryParams(TypedDict, total=False):
37
  """Parameters for querying an indicator's values over time at a location.
@@ -50,7 +17,6 @@ class IndicatorPerYearAtLocationQueryParams(TypedDict, total=False):
50
  longitude: str
51
  model: str
52
 
53
-
54
  def indicator_per_year_at_location_query(
55
  table: str, params: IndicatorPerYearAtLocationQueryParams
56
  ) -> str:
 
 
 
1
  from typing import TypedDict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  class IndicatorPerYearAtLocationQueryParams(TypedDict, total=False):
4
  """Parameters for querying an indicator's values over time at a location.
 
17
  longitude: str
18
  model: str
19
 
 
20
  def indicator_per_year_at_location_query(
21
  table: str, params: IndicatorPerYearAtLocationQueryParams
22
  ) -> str:
climateqa/engine/talk_to_data/{utils.py β†’ input_processing.py} RENAMED
@@ -1,13 +1,14 @@
1
- import re
2
- from typing import Annotated, TypedDict
3
- import duckdb
4
- from geopy.geocoders import Nominatim
5
  import ast
6
- from climateqa.engine.llm import get_llm
7
- from climateqa.engine.talk_to_data.config import DRIAS_TABLES
8
- from climateqa.engine.talk_to_data.plot import PLOTS, Plot
9
  from langchain_core.prompts import ChatPromptTemplate
 
 
 
10
 
 
 
 
 
11
 
12
  async def detect_location_with_openai(sentence):
13
  """
@@ -29,63 +30,7 @@ async def detect_location_with_openai(sentence):
29
  else:
30
  return ""
31
 
32
- class ArrayOutput(TypedDict):
33
- """Represents the output of a function that returns an array.
34
-
35
- This class is used to type-hint functions that return arrays,
36
- ensuring consistent return types across the codebase.
37
-
38
- Attributes:
39
- array (str): A syntactically valid Python array string
40
- """
41
- array: Annotated[str, "Syntactically valid python array."]
42
-
43
- async def detect_year_with_openai(sentence: str) -> str:
44
- """
45
- Detects years in a sentence using OpenAI's API via LangChain.
46
- """
47
- llm = get_llm()
48
-
49
- prompt = """
50
- Extract all years mentioned in the following sentence.
51
- Return the result as a Python list. If no year are mentioned, return an empty list.
52
-
53
- Sentence: "{sentence}"
54
- """
55
-
56
- prompt = ChatPromptTemplate.from_template(prompt)
57
- structured_llm = llm.with_structured_output(ArrayOutput)
58
- chain = prompt | structured_llm
59
- response: ArrayOutput = await chain.ainvoke({"sentence": sentence})
60
- years_list = eval(response['array'])
61
- if len(years_list) > 0:
62
- return years_list[0]
63
- else:
64
- return ""
65
-
66
-
67
- def detectTable(sql_query: str) -> list[str]:
68
- """Extracts table names from a SQL query.
69
-
70
- This function uses regular expressions to find all table names
71
- referenced in a SQL query's FROM clause.
72
-
73
- Args:
74
- sql_query (str): The SQL query to analyze
75
-
76
- Returns:
77
- list[str]: A list of table names found in the query
78
-
79
- Example:
80
- >>> detectTable("SELECT * FROM temperature_data WHERE year > 2000")
81
- ['temperature_data']
82
- """
83
- pattern = r'(?i)\bFROM\s+((?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+)(?:\.(?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+))*)'
84
- matches = re.findall(pattern, sql_query)
85
- return matches
86
-
87
-
88
- def loc2coords(location: str) -> tuple[float, float]:
89
  """Converts a location name to geographic coordinates.
90
 
91
  This function uses the Nominatim geocoding service to convert
@@ -105,32 +50,7 @@ def loc2coords(location: str) -> tuple[float, float]:
105
  return (coords.latitude, coords.longitude)
106
 
107
 
108
- def coords2loc(coords: tuple[float, float]) -> str:
109
- """Converts geographic coordinates to a location name.
110
-
111
- This function uses the Nominatim reverse geocoding service to convert
112
- latitude and longitude coordinates to a human-readable location name.
113
-
114
- Args:
115
- coords (tuple[float, float]): A tuple containing (latitude, longitude)
116
-
117
- Returns:
118
- str: The address of the location, or "Unknown Location" if not found
119
-
120
- Example:
121
- >>> coords2loc((48.8566, 2.3522))
122
- 'Paris, France'
123
- """
124
- geolocator = Nominatim(user_agent="coords_to_city")
125
- try:
126
- location = geolocator.reverse(coords)
127
- return location.address
128
- except Exception as e:
129
- print(f"Error: {e}")
130
- return "Unknown Location"
131
-
132
-
133
- def nearestNeighbourSQL(location: tuple, table: str) -> tuple[str, str]:
134
  long = round(location[1], 3)
135
  lat = round(location[0], 3)
136
 
@@ -145,8 +65,31 @@ def nearestNeighbourSQL(location: tuple, table: str) -> tuple[str, str]:
145
  # cursor.execute(f"SELECT latitude, longitude FROM {table} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}")
146
  return results['latitude'].iloc[0], results['longitude'].iloc[0]
147
 
 
 
 
 
 
148
 
149
- async def detect_relevant_tables(user_question: str, plot: Plot, llm) -> list[str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  """Identifies relevant tables for a plot based on user input.
151
 
152
  This function uses an LLM to analyze the user's question and the plot
@@ -170,7 +113,6 @@ async def detect_relevant_tables(user_question: str, plot: Plot, llm) -> list[st
170
  ['mean_annual_temperature', 'mean_summer_temperature']
171
  """
172
  # Get all table names
173
- table_names_list = DRIAS_TABLES
174
 
175
  prompt = (
176
  f"You are helping to build a plot following this description : {plot['description']}."
@@ -187,19 +129,9 @@ async def detect_relevant_tables(user_question: str, plot: Plot, llm) -> list[st
187
  )
188
  return table_names
189
 
190
-
191
- def replace_coordonates(coords, query, coords_tables):
192
- n = query.count(str(coords[0]))
193
-
194
- for i in range(n):
195
- query = query.replace(str(coords[0]), str(coords_tables[i][0]), 1)
196
- query = query.replace(str(coords[1]), str(coords_tables[i][1]), 1)
197
- return query
198
-
199
-
200
- async def detect_relevant_plots(user_question: str, llm):
201
  plots_description = ""
202
- for plot in PLOTS:
203
  plots_description += "Name: " + plot["name"]
204
  plots_description += " - Description: " + plot["description"] + "\n"
205
 
@@ -227,55 +159,60 @@ async def detect_relevant_plots(user_question: str, llm):
227
  )
228
  return plot_names
229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
- # Next Version
232
- # class QueryOutput(TypedDict):
233
- # """Generated SQL query."""
234
-
235
- # query: Annotated[str, ..., "Syntactically valid SQL query."]
236
-
237
-
238
- # class PlotlyCodeOutput(TypedDict):
239
- # """Generated Plotly code"""
240
-
241
- # code: Annotated[str, ..., "Synatically valid Plotly python code."]
242
- # def write_sql_query(user_input: str, db: SQLDatabase, relevant_tables: list[str], llm):
243
- # """Generate SQL query to fetch information."""
244
- # prompt_params = {
245
- # "dialect": db.dialect,
246
- # "table_info": db.get_table_info(),
247
- # "input": user_input,
248
- # "relevant_tables": relevant_tables,
249
- # "model": "ALADIN63_CNRM-CM5",
250
- # }
251
-
252
- # prompt = ChatPromptTemplate.from_template(query_prompt_template)
253
- # structured_llm = llm.with_structured_output(QueryOutput)
254
- # chain = prompt | structured_llm
255
- # result = chain.invoke(prompt_params)
256
-
257
- # return result["query"]
258
-
259
-
260
- # def fetch_data_from_sql_query(db: str, sql_query: str):
261
- # conn = sqlite3.connect(db)
262
- # cursor = conn.cursor()
263
- # cursor.execute(sql_query)
264
- # column_names = [desc[0] for desc in cursor.description]
265
- # values = cursor.fetchall()
266
- # return {"column_names": column_names, "data": values}
267
-
268
 
269
- # def generate_chart_code(user_input: str, sql_query: list[str], llm):
270
- # """ "Generate plotly python code for the chart based on the sql query and the user question"""
 
 
271
 
272
- # class PlotlyCodeOutput(TypedDict):
273
- # """Generated Plotly code"""
274
 
275
- # code: Annotated[str, ..., "Synatically valid Plotly python code."]
 
 
 
276
 
277
- # prompt = ChatPromptTemplate.from_template(plot_prompt_template)
278
- # structured_llm = llm.with_structured_output(PlotlyCodeOutput)
279
- # chain = prompt | structured_llm
280
- # result = chain.invoke({"input": user_input, "sql_query": sql_query})
281
- # return result["code"]
 
 
 
 
 
 
1
+ from typing import Any
 
 
 
2
  import ast
 
 
 
3
  from langchain_core.prompts import ChatPromptTemplate
4
+ from geopy.geocoders import Nominatim
5
+ from climateqa.engine.llm import get_llm
6
+ import duckdb
7
 
8
+ from climateqa.engine.talk_to_data.objects.llm_outputs import ArrayOutput
9
+ from climateqa.engine.talk_to_data.objects.location import Location
10
+ from climateqa.engine.talk_to_data.objects.plot import Plot
11
+ from climateqa.engine.talk_to_data.objects.states import State
12
 
13
  async def detect_location_with_openai(sentence):
14
  """
 
30
  else:
31
  return ""
32
 
33
+ def loc_to_coords(location: str) -> tuple[float, float]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  """Converts a location name to geographic coordinates.
35
 
36
  This function uses the Nominatim geocoding service to convert
 
50
  return (coords.latitude, coords.longitude)
51
 
52
 
53
+ def nearest_neighbour_sql(location: tuple, table: str) -> tuple[str, str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  long = round(location[1], 3)
55
  lat = round(location[0], 3)
56
 
 
65
  # cursor.execute(f"SELECT latitude, longitude FROM {table} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}")
66
  return results['latitude'].iloc[0], results['longitude'].iloc[0]
67
 
68
+ async def detect_year_with_openai(sentence: str) -> str:
69
+ """
70
+ Detects years in a sentence using OpenAI's API via LangChain.
71
+ """
72
+ llm = get_llm()
73
 
74
+ prompt = """
75
+ Extract all years mentioned in the following sentence.
76
+ Return the result as a Python list. If no year are mentioned, return an empty list.
77
+
78
+ Sentence: "{sentence}"
79
+ """
80
+
81
+ prompt = ChatPromptTemplate.from_template(prompt)
82
+ structured_llm = llm.with_structured_output(ArrayOutput)
83
+ chain = prompt | structured_llm
84
+ response: ArrayOutput = await chain.ainvoke({"sentence": sentence})
85
+ years_list = eval(response['array'])
86
+ if len(years_list) > 0:
87
+ return years_list[0]
88
+ else:
89
+ return ""
90
+
91
+
92
+ async def detect_relevant_tables(user_question: str, plot: Plot, llm, table_names_list: list[str]) -> list[str]:
93
  """Identifies relevant tables for a plot based on user input.
94
 
95
  This function uses an LLM to analyze the user's question and the plot
 
113
  ['mean_annual_temperature', 'mean_summer_temperature']
114
  """
115
  # Get all table names
 
116
 
117
  prompt = (
118
  f"You are helping to build a plot following this description : {plot['description']}."
 
129
  )
130
  return table_names
131
 
132
+ async def detect_relevant_plots(user_question: str, llm, plot_list: list[Plot]) -> list[str]:
 
 
 
 
 
 
 
 
 
 
133
  plots_description = ""
134
+ for plot in plot_list:
135
  plots_description += "Name: " + plot["name"]
136
  plots_description += " - Description: " + plot["description"] + "\n"
137
 
 
159
  )
160
  return plot_names
161
 
162
+ async def find_location(user_input: str, table: str) -> Location:
163
+ print(f"---- Find location in table {table} ----")
164
+ location = await detect_location_with_openai(user_input)
165
+ output: Location = {'location' : location}
166
+ if location:
167
+ coords = loc_to_coords(location)
168
+ neighbour = nearest_neighbour_sql(coords, table)
169
+ output.update({
170
+ "latitude": neighbour[0],
171
+ "longitude": neighbour[1],
172
+ })
173
+ return output
174
+
175
+ async def find_year(user_input: str) -> str:
176
+ """Extracts year information from user input using LLM.
177
+
178
+ This function uses an LLM to identify and extract year information from the
179
+ user's query, which is used to filter data in subsequent queries.
180
+
181
+ Args:
182
+ user_input (str): The user's query text
183
+
184
+ Returns:
185
+ str: The extracted year, or empty string if no year found
186
+ """
187
+ print(f"---- Find year ---")
188
+ year = await detect_year_with_openai(user_input)
189
+ return year
190
 
191
+ async def find_relevant_plots(state: State, llm, plots: list[Plot]) -> list[str]:
192
+ print("---- Find relevant plots ----")
193
+ relevant_plots = await detect_relevant_plots(state['user_input'], llm, plots)
194
+ return relevant_plots
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
+ async def find_relevant_tables_per_plot(state: State, plot: Plot, llm, tables: list[str]) -> list[str]:
197
+ print(f"---- Find relevant tables for {plot['name']} ----")
198
+ relevant_tables = await detect_relevant_tables(state['user_input'], plot, llm, tables)
199
+ return relevant_tables
200
 
201
+ async def find_param(state: State, param_name:str, table: str) -> dict[str, Any] | None:
202
+ """Perform the good method to retrieve the desired parameter
203
 
204
+ Args:
205
+ state (State): state of the workflow
206
+ param_name (str): name of the desired parameter
207
+ table (str): name of the table
208
 
209
+ Returns:
210
+ dict[str, Any] | None:
211
+ """
212
+ if param_name == 'location':
213
+ location = await find_location(state['user_input'], table)
214
+ return location
215
+ if param_name == 'year':
216
+ year = await find_year(state['user_input'])
217
+ return {'year': year}
218
+ return None
climateqa/engine/talk_to_data/main.py CHANGED
@@ -1,43 +1,8 @@
1
- from climateqa.engine.talk_to_data.talk_to_drias import drias_workflow
2
  from climateqa.engine.llm import get_llm
3
  from climateqa.logging import log_drias_interaction_to_huggingface
4
  import ast
5
 
6
- llm = get_llm(provider="openai")
7
-
8
- def ask_llm_to_add_table_names(sql_query: str, llm) -> str:
9
- """Adds table names to the SQL query result rows using LLM.
10
-
11
- This function modifies the SQL query to include the source table name in each row
12
- of the result set, making it easier to track which data comes from which table.
13
-
14
- Args:
15
- sql_query (str): The original SQL query to modify
16
- llm: The language model instance to use for generating the modified query
17
-
18
- Returns:
19
- str: The modified SQL query with table names included in the result rows
20
- """
21
- sql_with_table_names = llm.invoke(f"Make the following sql query display the source table in the rows {sql_query}. Just answer the query. The answer should not include ```sql\n").content
22
- return sql_with_table_names
23
-
24
- def ask_llm_column_names(sql_query: str, llm) -> list[str]:
25
- """Extracts column names from a SQL query using LLM.
26
-
27
- This function analyzes a SQL query to identify which columns are being selected
28
- in the result set.
29
-
30
- Args:
31
- sql_query (str): The SQL query to analyze
32
- llm: The language model instance to use for column extraction
33
-
34
- Returns:
35
- list[str]: A list of column names being selected in the query
36
- """
37
- columns = llm.invoke(f"From the given sql query, list the columns that are being selected. The answer should only be a python list. Just answer the list. The SQL query : {sql_query}").content
38
- columns_list = ast.literal_eval(columns.strip("```python\n").strip())
39
- return columns_list
40
-
41
  async def ask_drias(query: str, index_state: int = 0, user_id: str = None) -> tuple:
42
  """Main function to process a DRIAS query and return results.
43
 
@@ -85,34 +50,8 @@ async def ask_drias(query: str, index_state: int = 0, user_id: str = None) -> tu
85
 
86
  sql_query = sql_queries[index_state]
87
  dataframe = result_dataframes[index_state]
88
- figure = figures[index_state](dataframe)
89
 
90
  log_drias_interaction_to_huggingface(query, sql_query, user_id)
91
 
92
- return sql_query, dataframe, figure, sql_queries, result_dataframes, figures, index_state, table_list, ""
93
-
94
- # def ask_vanna(vn,db_vanna_path, query):
95
-
96
- # try :
97
- # location = detect_location_with_openai(query)
98
- # if location:
99
-
100
- # coords = loc2coords(location)
101
- # user_input = query.lower().replace(location.lower(), f"lat, long : {coords}")
102
-
103
- # relevant_tables = detect_relevant_tables(db_vanna_path, user_input, llm)
104
- # coords_tables = [nearestNeighbourSQL(db_vanna_path, coords, relevant_tables[i]) for i in range(len(relevant_tables))]
105
- # user_input_with_coords = replace_coordonates(coords, user_input, coords_tables)
106
-
107
- # sql_query, result_dataframe, figure = vn.ask(user_input_with_coords, print_results=False, allow_llm_to_see_data=True, auto_train=False)
108
-
109
- # return sql_query, result_dataframe, figure
110
- # else :
111
- # empty_df = pd.DataFrame()
112
- # empty_fig = None
113
- # return "", empty_df, empty_fig
114
- # except Exception as e:
115
- # print(f"Error: {e}")
116
- # empty_df = pd.DataFrame()
117
- # empty_fig = None
118
- # return "", empty_df, empty_fig
 
1
+ from climateqa.engine.talk_to_data.workflow.drias import drias_workflow
2
  from climateqa.engine.llm import get_llm
3
  from climateqa.logging import log_drias_interaction_to_huggingface
4
  import ast
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  async def ask_drias(query: str, index_state: int = 0, user_id: str = None) -> tuple:
7
  """Main function to process a DRIAS query and return results.
8
 
 
50
 
51
  sql_query = sql_queries[index_state]
52
  dataframe = result_dataframes[index_state]
53
+ figure = figures[index_state](dataframe)
54
 
55
  log_drias_interaction_to_huggingface(query, sql_query, user_id)
56
 
57
+ return sql_query, dataframe, figure, sql_queries, result_dataframes, figures, index_state, table_list, ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/engine/talk_to_data/objects/llm_outputs.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Annotated, TypedDict
2
+
3
+
4
+ class ArrayOutput(TypedDict):
5
+ """Represents the output of a function that returns an array.
6
+
7
+ This class is used to type-hint functions that return arrays,
8
+ ensuring consistent return types across the codebase.
9
+
10
+ Attributes:
11
+ array (str): A syntactically valid Python array string
12
+ """
13
+ array: Annotated[str, "Syntactically valid python array."]
climateqa/engine/talk_to_data/objects/location.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from typing import Optional, TypedDict
2
+
3
+
4
+ class Location(TypedDict):
5
+ location: str
6
+ latitude: Optional[str]
7
+ longitude: Optional[str]
climateqa/engine/talk_to_data/objects/plot.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, TypedDict
2
+ from plotly.graph_objects import Figure
3
+
4
+ class Plot(TypedDict):
5
+ """Represents a plot configuration in the DRIAS system.
6
+
7
+ This class defines the structure for configuring different types of plots
8
+ that can be generated from climate data.
9
+
10
+ Attributes:
11
+ name (str): The name of the plot type
12
+ description (str): A description of what the plot shows
13
+ params (list[str]): List of required parameters for the plot
14
+ plot_function (Callable[..., Callable[..., Figure]]): Function to generate the plot
15
+ sql_query (Callable[..., str]): Function to generate the SQL query for the plot
16
+ """
17
+ name: str
18
+ description: str
19
+ params: list[str]
20
+ plot_function: Callable[..., Callable[..., Figure]]
21
+ sql_query: Callable[..., str]
climateqa/engine/talk_to_data/objects/states.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable, Optional, TypedDict
2
+ from plotly.graph_objects import Figure
3
+ import pandas as pd
4
+
5
+ class TableState(TypedDict):
6
+ """Represents the state of a table in the DRIAS workflow.
7
+
8
+ This class defines the structure for tracking the state of a table during the
9
+ data processing workflow, including its name, parameters, SQL query, and results.
10
+
11
+ Attributes:
12
+ table_name (str): The name of the table in the database
13
+ params (dict[str, Any]): Parameters used for querying the table
14
+ sql_query (str, optional): The SQL query used to fetch data
15
+ dataframe (pd.DataFrame | None, optional): The resulting data
16
+ figure (Callable[..., Figure], optional): Function to generate visualization
17
+ status (str): The current status of the table processing ('OK' or 'ERROR')
18
+ """
19
+ table_name: str
20
+ params: dict[str, Any]
21
+ sql_query: Optional[str]
22
+ dataframe: Optional[pd.DataFrame | None]
23
+ figure: Optional[Callable[..., Figure]]
24
+ status: str
25
+
26
+ class PlotState(TypedDict):
27
+ """Represents the state of a plot in the DRIAS workflow.
28
+
29
+ This class defines the structure for tracking the state of a plot during the
30
+ data processing workflow, including its name and associated tables.
31
+
32
+ Attributes:
33
+ plot_name (str): The name of the plot
34
+ tables (list[str]): List of tables used in the plot
35
+ table_states (dict[str, TableState]): States of the tables used in the plot
36
+ """
37
+ plot_name: str
38
+ tables: list[str]
39
+ table_states: dict[str, TableState]
40
+
41
+ class State(TypedDict):
42
+ user_input: str
43
+ plots: list[str]
44
+ plot_states: dict[str, PlotState]
45
+ error: Optional[str]
46
+
climateqa/engine/talk_to_data/query.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from concurrent.futures import ThreadPoolExecutor
3
+ import duckdb
4
+ import pandas as pd
5
+
6
+
7
+ def find_indicator_column(table: str, indicator_columns_per_table: dict[str,str]) -> str:
8
+ """Retrieves the name of the indicator column within a table.
9
+
10
+ This function maps table names to their corresponding indicator columns
11
+ using the predefined mapping in INDICATOR_COLUMNS_PER_TABLE.
12
+
13
+ Args:
14
+ table (str): Name of the table in the database
15
+
16
+ Returns:
17
+ str: Name of the indicator column for the specified table
18
+
19
+ Raises:
20
+ KeyError: If the table name is not found in the mapping
21
+ """
22
+ print(f"---- Find indicator column in table {table} ----")
23
+ return indicator_columns_per_table[table]
24
+
25
+ async def execute_sql_query(sql_query: str) -> pd.DataFrame:
26
+ """Executes a SQL query on the DRIAS database and returns the results.
27
+
28
+ This function connects to the DuckDB database containing DRIAS climate data
29
+ and executes the provided SQL query. It handles the database connection and
30
+ returns the results as a pandas DataFrame.
31
+
32
+ Args:
33
+ sql_query (str): The SQL query to execute
34
+
35
+ Returns:
36
+ pd.DataFrame: A DataFrame containing the query results
37
+
38
+ Raises:
39
+ duckdb.Error: If there is an error executing the SQL query
40
+ """
41
+ def _execute_query():
42
+ # Execute the query
43
+ con = duckdb.connect()
44
+ results = con.sql(sql_query).fetchdf()
45
+ # return fetched data
46
+ return results
47
+
48
+ # Run the query in a thread pool to avoid blocking
49
+ loop = asyncio.get_event_loop()
50
+ with ThreadPoolExecutor() as executor:
51
+ return await loop.run_in_executor(executor, _execute_query)
52
+
climateqa/engine/talk_to_data/{talk_to_drias.py β†’ workflow/drias.py} RENAMED
@@ -1,151 +1,17 @@
1
  import os
2
 
3
- from typing import Any, Callable, TypedDict, Optional
4
- from numpy import sort
5
- import pandas as pd
6
  import asyncio
7
- from plotly.graph_objects import Figure
8
  from climateqa.engine.llm import get_llm
9
- from climateqa.engine.talk_to_data import sql_query
10
- from climateqa.engine.talk_to_data.config import INDICATOR_COLUMNS_PER_TABLE
11
- from climateqa.engine.talk_to_data.plot import PLOTS, Plot
12
- from climateqa.engine.talk_to_data.sql_query import execute_sql_query
13
- from climateqa.engine.talk_to_data.utils import (
14
- detect_relevant_plots,
15
- detect_year_with_openai,
16
- loc2coords,
17
- detect_location_with_openai,
18
- nearestNeighbourSQL,
19
- detect_relevant_tables,
20
- )
21
-
22
 
23
  ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))
24
 
25
- class TableState(TypedDict):
26
- """Represents the state of a table in the DRIAS workflow.
27
-
28
- This class defines the structure for tracking the state of a table during the
29
- data processing workflow, including its name, parameters, SQL query, and results.
30
-
31
- Attributes:
32
- table_name (str): The name of the table in the database
33
- params (dict[str, Any]): Parameters used for querying the table
34
- sql_query (str, optional): The SQL query used to fetch data
35
- dataframe (pd.DataFrame | None, optional): The resulting data
36
- figure (Callable[..., Figure], optional): Function to generate visualization
37
- status (str): The current status of the table processing ('OK' or 'ERROR')
38
- """
39
- table_name: str
40
- params: dict[str, Any]
41
- sql_query: Optional[str]
42
- dataframe: Optional[pd.DataFrame | None]
43
- figure: Optional[Callable[..., Figure]]
44
- status: str
45
-
46
- class PlotState(TypedDict):
47
- """Represents the state of a plot in the DRIAS workflow.
48
-
49
- This class defines the structure for tracking the state of a plot during the
50
- data processing workflow, including its name and associated tables.
51
-
52
- Attributes:
53
- plot_name (str): The name of the plot
54
- tables (list[str]): List of tables used in the plot
55
- table_states (dict[str, TableState]): States of the tables used in the plot
56
- """
57
- plot_name: str
58
- tables: list[str]
59
- table_states: dict[str, TableState]
60
-
61
- class State(TypedDict):
62
- user_input: str
63
- plots: list[str]
64
- plot_states: dict[str, PlotState]
65
- error: Optional[str]
66
-
67
- async def find_relevant_plots(state: State, llm) -> list[str]:
68
- print("---- Find relevant plots ----")
69
- relevant_plots = await detect_relevant_plots(state['user_input'], llm)
70
- return relevant_plots
71
-
72
- async def find_relevant_tables_per_plot(state: State, plot: Plot, llm) -> list[str]:
73
- print(f"---- Find relevant tables for {plot['name']} ----")
74
- relevant_tables = await detect_relevant_tables(state['user_input'], plot, llm)
75
- return relevant_tables
76
-
77
- async def find_param(state: State, param_name:str, table: str) -> dict[str, Any] | None:
78
- """Perform the good method to retrieve the desired parameter
79
-
80
- Args:
81
- state (State): state of the workflow
82
- param_name (str): name of the desired parameter
83
- table (str): name of the table
84
-
85
- Returns:
86
- dict[str, Any] | None:
87
- """
88
- if param_name == 'location':
89
- location = await find_location(state['user_input'], table)
90
- return location
91
- if param_name == 'year':
92
- year = await find_year(state['user_input'])
93
- return {'year': year}
94
- return None
95
-
96
- class Location(TypedDict):
97
- location: str
98
- latitude: Optional[str]
99
- longitude: Optional[str]
100
-
101
- async def find_location(user_input: str, table: str) -> Location:
102
- print(f"---- Find location in table {table} ----")
103
- location = await detect_location_with_openai(user_input)
104
- output: Location = {'location' : location}
105
- if location:
106
- coords = loc2coords(location)
107
- neighbour = nearestNeighbourSQL(coords, table)
108
- output.update({
109
- "latitude": neighbour[0],
110
- "longitude": neighbour[1],
111
- })
112
- return output
113
-
114
- async def find_year(user_input: str) -> str:
115
- """Extracts year information from user input using LLM.
116
-
117
- This function uses an LLM to identify and extract year information from the
118
- user's query, which is used to filter data in subsequent queries.
119
-
120
- Args:
121
- user_input (str): The user's query text
122
-
123
- Returns:
124
- str: The extracted year, or empty string if no year found
125
- """
126
- print(f"---- Find year ---")
127
- year = await detect_year_with_openai(user_input)
128
- return year
129
-
130
- def find_indicator_column(table: str) -> str:
131
- """Retrieves the name of the indicator column within a table.
132
-
133
- This function maps table names to their corresponding indicator columns
134
- using the predefined mapping in INDICATOR_COLUMNS_PER_TABLE.
135
-
136
- Args:
137
- table (str): Name of the table in the database
138
-
139
- Returns:
140
- str: Name of the indicator column for the specified table
141
-
142
- Raises:
143
- KeyError: If the table name is not found in the mapping
144
- """
145
- print(f"---- Find indicator column in table {table} ----")
146
- return INDICATOR_COLUMNS_PER_TABLE[table]
147
-
148
-
149
  async def process_table(
150
  table: str,
151
  params: dict[str, Any],
@@ -173,7 +39,7 @@ async def process_table(
173
  'figure': None
174
  }
175
 
176
- table_state['params']['indicator_column'] = find_indicator_column(table)
177
  sql_query = plot['sql_query'](table, table_state['params'])
178
 
179
  if sql_query == "":
@@ -187,6 +53,7 @@ async def process_table(
187
 
188
  return table_state
189
 
 
190
  async def drias_workflow(user_input: str) -> State:
191
  """Performs the complete workflow of Talk To Drias : from user input to sql queries, dataframes and figures generated
192
 
@@ -205,7 +72,7 @@ async def drias_workflow(user_input: str) -> State:
205
 
206
  llm = get_llm(provider="openai")
207
 
208
- plots = await find_relevant_plots(state, llm)
209
 
210
  state['plots'] = plots
211
 
@@ -219,7 +86,7 @@ async def drias_workflow(user_input: str) -> State:
219
 
220
  for plot_name in state['plots']:
221
 
222
- plot = next((p for p in PLOTS if p['name'] == plot_name), None) # Find the associated plot object
223
  if plot is None:
224
  continue
225
 
@@ -231,7 +98,7 @@ async def drias_workflow(user_input: str) -> State:
231
 
232
  plot_state['plot_name'] = plot_name
233
 
234
- relevant_tables = await find_relevant_tables_per_plot(state, plot, llm)
235
 
236
  if len(relevant_tables) > 0 :
237
  have_relevant_table = True
@@ -267,51 +134,3 @@ async def drias_workflow(user_input: str) -> State:
267
  state['error'] = "There is no data in our table that can answer to your question"
268
 
269
  return state
270
-
271
- # def make_write_query_node():
272
-
273
- # def write_query(state):
274
- # print("---- Write query ----")
275
- # for table in state["tables"]:
276
- # sql_query = QUERIES[state[table]['query_type']](
277
- # table=table,
278
- # indicator_column=state[table]["columns"],
279
- # longitude=state[table]["longitude"],
280
- # latitude=state[table]["latitude"],
281
- # )
282
- # state[table].update({"sql_query": sql_query})
283
-
284
- # return state
285
-
286
- # return write_query
287
-
288
- # def make_fetch_data_node(db_path):
289
-
290
- # def fetch_data(state):
291
- # print("---- Fetch data ----")
292
- # for table in state["tables"]:
293
- # results = execute_sql_query(db_path, state[table]['sql_query'])
294
- # state[table].update(results)
295
-
296
- # return state
297
-
298
- # return fetch_data
299
-
300
-
301
-
302
- ## V2
303
-
304
-
305
- # def make_fetch_data_node(db_path: str, llm):
306
- # def fetch_data(state):
307
- # print("---- Fetch data ----")
308
- # db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
309
- # output = {}
310
- # sql_query = write_sql_query(state["query"], db, state["tables"], llm)
311
- # # TO DO : Add query checker
312
- # print(f"SQL query : {sql_query}")
313
- # output["sql_query"] = sql_query
314
- # output.update(fetch_data_from_sql_query(db_path, sql_query))
315
- # return output
316
-
317
- # return fetch_data
 
1
  import os
2
 
3
+ from typing import Any
 
 
4
  import asyncio
 
5
  from climateqa.engine.llm import get_llm
6
+ from climateqa.engine.talk_to_data.input_processing import find_param, find_relevant_plots, find_relevant_tables_per_plot
7
+ from climateqa.engine.talk_to_data.query import execute_sql_query, find_indicator_column
8
+ from climateqa.engine.talk_to_data.objects.plot import Plot
9
+ from climateqa.engine.talk_to_data.objects.states import PlotState, State, TableState
10
+ from climateqa.engine.talk_to_data.drias.config import DRIAS_TABLES, DRIAS_INDICATOR_COLUMNS_PER_TABLE
11
+ from climateqa.engine.talk_to_data.drias.plots import DRIAS_PLOTS
 
 
 
 
 
 
 
12
 
13
  ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  async def process_table(
16
  table: str,
17
  params: dict[str, Any],
 
39
  'figure': None
40
  }
41
 
42
+ table_state['params']['indicator_column'] = find_indicator_column(table, DRIAS_INDICATOR_COLUMNS_PER_TABLE)
43
  sql_query = plot['sql_query'](table, table_state['params'])
44
 
45
  if sql_query == "":
 
53
 
54
  return table_state
55
 
56
+
57
  async def drias_workflow(user_input: str) -> State:
58
  """Performs the complete workflow of Talk To Drias : from user input to sql queries, dataframes and figures generated
59
 
 
72
 
73
  llm = get_llm(provider="openai")
74
 
75
+ plots = await find_relevant_plots(state, llm, DRIAS_PLOTS)
76
 
77
  state['plots'] = plots
78
 
 
86
 
87
  for plot_name in state['plots']:
88
 
89
+ plot = next((p for p in DRIAS_PLOTS if p['name'] == plot_name), None) # Find the associated plot object
90
  if plot is None:
91
  continue
92
 
 
98
 
99
  plot_state['plot_name'] = plot_name
100
 
101
+ relevant_tables = await find_relevant_tables_per_plot(state, plot, llm, DRIAS_TABLES)
102
 
103
  if len(relevant_tables) > 0 :
104
  have_relevant_table = True
 
134
  state['error'] = "There is no data in our table that can answer to your question"
135
 
136
  return state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
front/tabs/tab_drias.py CHANGED
@@ -4,7 +4,7 @@ import os
4
  import pandas as pd
5
 
6
  from climateqa.engine.talk_to_data.main import ask_drias
7
- from climateqa.engine.talk_to_data.config import DRIAS_MODELS, DRIAS_UI_TEXT
8
 
9
  class DriasUIElements(TypedDict):
10
  tab: gr.Tab
 
4
  import pandas as pd
5
 
6
  from climateqa.engine.talk_to_data.main import ask_drias
7
+ from climateqa.engine.talk_to_data.drias.config import DRIAS_MODELS, DRIAS_UI_TEXT
8
 
9
  class DriasUIElements(TypedDict):
10
  tab: gr.Tab