armanddemasson commited on
Commit
c0fd277
·
1 Parent(s): c25f6b1

feat: updated common talk to data for talk to ipcc and drias

Browse files
climateqa/engine/talk_to_data/input_processing.py CHANGED
@@ -1,16 +1,18 @@
1
- from typing import Any
2
  import ast
3
  from langchain_core.prompts import ChatPromptTemplate
4
  from geopy.geocoders import Nominatim
5
  from climateqa.engine.llm import get_llm
6
  import duckdb
7
-
 
8
  from climateqa.engine.talk_to_data.objects.llm_outputs import ArrayOutput
9
  from climateqa.engine.talk_to_data.objects.location import Location
10
  from climateqa.engine.talk_to_data.objects.plot import Plot
11
  from climateqa.engine.talk_to_data.objects.states import State
 
12
 
13
- async def detect_location_with_openai(sentence):
14
  """
15
  Detects locations in a sentence using OpenAI's API via LangChain.
16
  """
@@ -49,21 +51,51 @@ def loc_to_coords(location: str) -> tuple[float, float]:
49
  coords = geolocator.geocode(location)
50
  return (coords.latitude, coords.longitude)
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- def nearest_neighbour_sql(location: tuple, table: str) -> tuple[str, str]:
54
  long = round(location[1], 3)
55
  lat = round(location[0], 3)
 
56
 
57
- table = f"'hf://datasets/timeki/drias_db/{table.lower()}.parquet'"
58
-
59
- results = duckdb.sql(
60
- f"SELECT latitude, longitude FROM {table} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}"
61
- ).fetchdf()
 
 
 
 
 
 
62
 
63
  if len(results) == 0:
64
- return "", ""
65
- # cursor.execute(f"SELECT latitude, longitude FROM {table} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}")
66
- return results['latitude'].iloc[0], results['longitude'].iloc[0]
 
 
 
 
67
 
68
  async def detect_year_with_openai(sentence: str) -> str:
69
  """
@@ -136,43 +168,49 @@ async def detect_relevant_plots(user_question: str, llm, plot_list: list[Plot])
136
  plots_description += " - Description: " + plot["description"] + "\n"
137
 
138
  prompt = (
139
- f"You are helping to answer a quesiton with insightful visualizations."
140
- f"You are given an user question and a list of plots with their name and description."
141
- f"Based on the descriptions of the plots, which plot is appropriate to answer to this question."
142
- f"Write the most relevant tables to use. Answer only a python list of plot name."
 
 
143
  f"### Descriptions of the plots : {plots_description}"
144
- f"### User question : {user_question}"
145
- f"### Name of the plot : "
146
  )
147
- # prompt = (
148
- # f"You are helping to answer a question with insightful visualizations. "
149
- # f"Given a list of plots with their name and description: "
150
- # f"{plots_description} "
151
- # f"The user question is: {user_question}. "
152
- # f"Choose the most relevant plots to answer the question. "
153
- # f"The answer must be a Python list with the names of the relevant plots, and nothing else. "
154
- # f"Ensure the response is in the exact format: ['PlotName1', 'PlotName2']."
155
- # )
156
 
157
  plot_names = ast.literal_eval(
158
  (await llm.ainvoke(prompt)).content.strip("```python\n").strip()
159
  )
160
  return plot_names
161
 
162
- async def find_location(user_input: str, table: str) -> Location:
163
- print(f"---- Find location in table {table} ----")
164
  location = await detect_location_with_openai(user_input)
165
- output: Location = {'location' : location}
 
 
 
 
 
 
 
 
166
  if location:
167
  coords = loc_to_coords(location)
168
- neighbour = nearest_neighbour_sql(coords, table)
 
169
  output.update({
170
  "latitude": neighbour[0],
171
  "longitude": neighbour[1],
 
 
 
172
  })
 
173
  return output
174
 
175
- async def find_year(user_input: str) -> str:
176
  """Extracts year information from user input using LLM.
177
 
178
  This function uses an LLM to identify and extract year information from the
@@ -186,6 +224,8 @@ async def find_year(user_input: str) -> str:
186
  """
187
  print(f"---- Find year ---")
188
  year = await detect_year_with_openai(user_input)
 
 
189
  return year
190
 
191
  async def find_relevant_plots(state: State, llm, plots: list[Plot]) -> list[str]:
@@ -198,7 +238,7 @@ async def find_relevant_tables_per_plot(state: State, plot: Plot, llm, tables: l
198
  relevant_tables = await detect_relevant_tables(state['user_input'], plot, llm, tables)
199
  return relevant_tables
200
 
201
- async def find_param(state: State, param_name:str, table: str) -> dict[str, Any] | None:
202
  """Perform the good method to retrieve the desired parameter
203
 
204
  Args:
@@ -210,7 +250,7 @@ async def find_param(state: State, param_name:str, table: str) -> dict[str, Any]
210
  dict[str, Any] | None:
211
  """
212
  if param_name == 'location':
213
- location = await find_location(state['user_input'], table)
214
  return location
215
  if param_name == 'year':
216
  year = await find_year(state['user_input'])
 
1
+ from typing import Any, Literal, Optional, cast
2
  import ast
3
  from langchain_core.prompts import ChatPromptTemplate
4
  from geopy.geocoders import Nominatim
5
  from climateqa.engine.llm import get_llm
6
  import duckdb
7
+ import os
8
+ from climateqa.engine.talk_to_data.ipcc.config import IPCC_DATASET_URL
9
  from climateqa.engine.talk_to_data.objects.llm_outputs import ArrayOutput
10
  from climateqa.engine.talk_to_data.objects.location import Location
11
  from climateqa.engine.talk_to_data.objects.plot import Plot
12
  from climateqa.engine.talk_to_data.objects.states import State
13
+ import time
14
 
15
+ async def detect_location_with_openai(sentence: str) -> str:
16
  """
17
  Detects locations in a sentence using OpenAI's API via LangChain.
18
  """
 
51
  coords = geolocator.geocode(location)
52
  return (coords.latitude, coords.longitude)
53
 
54
+ def coords_to_country(coords: tuple[float, float]) -> tuple[str,str]:
55
+ """Converts geographic coordinates to a country name.
56
+
57
+ This function uses the Nominatim reverse geocoding service to convert
58
+ latitude and longitude coordinates to a country name.
59
+
60
+ Args:
61
+ coords (tuple[float, float]): A tuple containing (latitude, longitude)
62
+
63
+ Returns:
64
+ tuple[str,str]: A tuple containg (country_code, country_name, admin1)
65
+
66
+ Raises:
67
+ AttributeError: If the coordinates cannot be found
68
+ """
69
+ geolocator = Nominatim(user_agent="latlong_to_country")
70
+ location = geolocator.reverse(coords)
71
+ address = location.raw['address']
72
+ return address['country_code'].upper(), address['country']
73
 
74
+ def nearest_neighbour_sql(location: tuple, mode: Literal['DRIAS', 'IPCC']) -> tuple[str, str, Optional[str]]:
75
  long = round(location[1], 3)
76
  lat = round(location[0], 3)
77
+ conn = duckdb.connect()
78
 
79
+ if mode == 'DRIAS':
80
+ table_path = f"'hf://datasets/timeki/drias_db/mean_annual_temperature.parquet'"
81
+ results = conn.sql(
82
+ f"SELECT latitude, longitude FROM {table_path} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}"
83
+ ).fetchdf()
84
+ else:
85
+ table_path = f"'{IPCC_DATASET_URL}/coordinates.parquet'"
86
+ results = conn.sql(
87
+ f"SELECT latitude, longitude, admin1 FROM {table_path} WHERE latitude BETWEEN {lat - 0.5} AND {lat + 0.5} AND longitude BETWEEN {long - 0.5} AND {long + 0.5}"
88
+ ).fetchdf()
89
+
90
 
91
  if len(results) == 0:
92
+ return "", "", ""
93
+
94
+ if 'admin1' in results.columns:
95
+ admin1 = results['admin1'].iloc[0]
96
+ else:
97
+ admin1 = None
98
+ return results['latitude'].iloc[0], results['longitude'].iloc[0], admin1
99
 
100
  async def detect_year_with_openai(sentence: str) -> str:
101
  """
 
168
  plots_description += " - Description: " + plot["description"] + "\n"
169
 
170
  prompt = (
171
+ "You are helping to answer a question with insightful visualizations.\n"
172
+ "You are given a user question and a list of plots with their name and description.\n"
173
+ "Based on the descriptions of the plots, select ALL plots that could provide a useful answer to this question. "
174
+ "Include any plot that could show relevant information, even if their perspectives (such as time series or spatial distribution) are different.\n"
175
+ "For example, for a question like 'What will be the total rainfall in China in 2050?', both a time series plot and a spatial map plot could be relevant.\n"
176
+ "Return only a Python list of plot names sorted from the most relevant one to the less relevant one.\n"
177
  f"### Descriptions of the plots : {plots_description}"
178
+ f"### User question : {user_question}\n"
179
+ f"### Names of the plots : "
180
  )
 
 
 
 
 
 
 
 
 
181
 
182
  plot_names = ast.literal_eval(
183
  (await llm.ainvoke(prompt)).content.strip("```python\n").strip()
184
  )
185
  return plot_names
186
 
187
+ async def find_location(user_input: str, mode: Literal['DRIAS', 'IPCC'] = 'DRIAS') -> Location:
188
+ print(f"---- Find location in user input ----")
189
  location = await detect_location_with_openai(user_input)
190
+ output: Location = {
191
+ 'location' : location,
192
+ 'longitude' : None,
193
+ 'latitude' : None,
194
+ 'country_code' : None,
195
+ 'country_name' : None,
196
+ 'admin1' : None
197
+ }
198
+
199
  if location:
200
  coords = loc_to_coords(location)
201
+ country_code, country_name = coords_to_country(coords)
202
+ neighbour = nearest_neighbour_sql(coords, mode)
203
  output.update({
204
  "latitude": neighbour[0],
205
  "longitude": neighbour[1],
206
+ "country_code": country_code,
207
+ "country_name": country_name,
208
+ "admin1": neighbour[2]
209
  })
210
+ output = cast(Location, output)
211
  return output
212
 
213
+ async def find_year(user_input: str) -> str| None:
214
  """Extracts year information from user input using LLM.
215
 
216
  This function uses an LLM to identify and extract year information from the
 
224
  """
225
  print(f"---- Find year ---")
226
  year = await detect_year_with_openai(user_input)
227
+ if year == "":
228
+ return None
229
  return year
230
 
231
  async def find_relevant_plots(state: State, llm, plots: list[Plot]) -> list[str]:
 
238
  relevant_tables = await detect_relevant_tables(state['user_input'], plot, llm, tables)
239
  return relevant_tables
240
 
241
+ async def find_param(state: State, param_name:str, mode: Literal['DRIAS', 'IPCC'] = 'DRIAS') -> dict[str, Optional[str]] | Location | None:
242
  """Perform the good method to retrieve the desired parameter
243
 
244
  Args:
 
250
  dict[str, Any] | None:
251
  """
252
  if param_name == 'location':
253
+ location = await find_location(state['user_input'], mode)
254
  return location
255
  if param_name == 'year':
256
  year = await find_year(state['user_input'])
climateqa/engine/talk_to_data/objects/location.py CHANGED
@@ -1,7 +1,12 @@
 
1
  from typing import Optional, TypedDict
2
 
3
 
 
4
  class Location(TypedDict):
5
  location: str
6
  latitude: Optional[str]
7
  longitude: Optional[str]
 
 
 
 
1
+ from token import OP
2
  from typing import Optional, TypedDict
3
 
4
 
5
+
6
  class Location(TypedDict):
7
  location: str
8
  latitude: Optional[str]
9
  longitude: Optional[str]
10
+ country_code: Optional[str]
11
+ country_name: Optional[str]
12
+ admin1: Optional[str]
climateqa/engine/talk_to_data/objects/plot.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Callable, TypedDict
2
  from plotly.graph_objects import Figure
3
 
4
  class Plot(TypedDict):
@@ -18,4 +18,5 @@ class Plot(TypedDict):
18
  description: str
19
  params: list[str]
20
  plot_function: Callable[..., Callable[..., Figure]]
21
- sql_query: Callable[..., str]
 
 
1
+ from typing import Callable, TypedDict, Optional
2
  from plotly.graph_objects import Figure
3
 
4
  class Plot(TypedDict):
 
18
  description: str
19
  params: list[str]
20
  plot_function: Callable[..., Callable[..., Figure]]
21
+ sql_query: Callable[..., str]
22
+ short_name: str
climateqa/engine/talk_to_data/prompt.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ query_prompt_template = """You are an expert SQL query generator. Given an input question, database schema, SQL dialect and relevant tables to answer the question, generate an optimized and syntactically correct SQL query which can provide useful insights to the question.
2
+
3
+ ### Instructions:
4
+ 1. **Use only relevant tables**: The following tables are relevant to answering the question: {relevant_tables}. Do not use any other tables.
5
+ 2. **Relevant columns only**: Never select `*`. Only include necessary columns based on the input question.
6
+ 3. **Schema Awareness**:
7
+ - Use only columns present in the given schema.
8
+ - **If a column name appears in multiple tables, always use the format `table_name.column_name` to avoid ambiguity.**
9
+ - Select only the column which are insightful for the question.
10
+ 4. **Dialect Compliance**: Follow `{dialect}` syntax rules.
11
+ 5. **Ordering**: Order the results by a relevant column if applicable (e.g., timestamp for recent records).
12
+ 6. **Valid query**: Make sure the query is syntactically and functionally correct.
13
+ 7. **Conditions** : For the common columns, the same condition should be applied to all the tables (e.g. latitude, longitude, model, year...)
14
+ 9. **Join tables** : If you need to join table, you should join them with year feature.
15
+ 10. **Model** : For each table, you need to add a condition on the model to be equal to {model}
16
+
17
+ ### Provided Database Schema:
18
+ {table_info}
19
+
20
+ ### Relevant Tables:
21
+ {relevant_tables}
22
+
23
+ **Question:** {input}
24
+
25
+ **SQL Query:**"""
26
+
27
+ plot_prompt_template = """You are a data visualization expert. Given an input question and an SQL Query, generate an insightful plot according to the question.
28
+
29
+ ### Instructions
30
+ 1. **Use only the column names provided**. The data will be provided as a Pandas DataFrame `df` with the columns present in the SELECT.
31
+ 2. Generate the Python Plotly code to chart the results using `df` and the column names.
32
+ 3. Make as complete a graph as possible to answer the question, and make it as easy to understand as possible.
33
+ 4. **Response with only Python code**. Do not answer with any explanations -- just the code.
34
+ 5. **Specific cases** :
35
+ - For a question about the evolution of something, it is also relevant to plot the data with also the sliding average for a period of 20 years for example.
36
+
37
+ ### SQL Query:
38
+ {sql_query}
39
+
40
+ **Question:** {input}
41
+
42
+ **Python code:**
43
+ """
44
+
climateqa/engine/talk_to_data/query.py CHANGED
@@ -2,7 +2,7 @@ import asyncio
2
  from concurrent.futures import ThreadPoolExecutor
3
  import duckdb
4
  import pandas as pd
5
-
6
 
7
  def find_indicator_column(table: str, indicator_columns_per_table: dict[str,str]) -> str:
8
  """Retrieves the name of the indicator column within a table.
@@ -41,7 +41,12 @@ async def execute_sql_query(sql_query: str) -> pd.DataFrame:
41
  def _execute_query():
42
  # Execute the query
43
  con = duckdb.connect()
44
- results = con.sql(sql_query).fetchdf()
 
 
 
 
 
45
  # return fetched data
46
  return results
47
 
 
2
  from concurrent.futures import ThreadPoolExecutor
3
  import duckdb
4
  import pandas as pd
5
+ import os
6
 
7
  def find_indicator_column(table: str, indicator_columns_per_table: dict[str,str]) -> str:
8
  """Retrieves the name of the indicator column within a table.
 
41
  def _execute_query():
42
  # Execute the query
43
  con = duckdb.connect()
44
+ HF_TOKEN = os.getenv("HF_TOKEN")
45
+ con.execute(f"""CREATE SECRET hf_token (
46
+ TYPE huggingface,
47
+ TOKEN '{HF_TOKEN}'
48
+ );""")
49
+ results = con.execute(sql_query).fetchdf()
50
  # return fetched data
51
  return results
52
 
style.css CHANGED
@@ -656,12 +656,11 @@ a {
656
  /* overflow-y: scroll; */
657
  }
658
  #sql-query{
659
- max-height: 300px;
660
- overflow-y:scroll;
661
  }
662
 
663
  #sql-query textarea{
664
- min-height: 100px !important;
665
  }
666
 
667
  #sql-query span{
@@ -671,8 +670,11 @@ div#tab-vanna{
671
  max-height: 100¨vh;
672
  overflow-y: hidden;
673
  }
 
 
 
674
  #vanna-plot{
675
- max-height:500px
676
  }
677
 
678
  #pagination-display{
@@ -681,20 +683,40 @@ div#tab-vanna{
681
  font-size: 16px;
682
  }
683
 
684
- #table-names table{
685
- overflow: hidden;
686
  }
687
- #table-names thead{
688
- display: none;
 
689
  }
690
 
691
- #table-names tr{
692
- cursor:pointer
 
 
 
 
 
 
 
 
 
693
  }
694
- #table-names tr:hover{
 
695
  background-color: #f0f8ff;
696
  }
697
 
 
 
 
 
 
 
 
 
 
698
  /* DRIAS Data Table Styles */
699
  #vanna-table {
700
  height: 400px !important;
@@ -717,3 +739,13 @@ div#tab-vanna{
717
  background: white;
718
  z-index: 1;
719
  }
 
 
 
 
 
 
 
 
 
 
 
656
  /* overflow-y: scroll; */
657
  }
658
  #sql-query{
659
+ max-height: 100%;
 
660
  }
661
 
662
  #sql-query textarea{
663
+ min-height: 200px !important;
664
  }
665
 
666
  #sql-query span{
 
670
  max-height: 100¨vh;
671
  overflow-y: hidden;
672
  }
673
+ #details button span{
674
+ font-weight: bold;
675
+ }
676
  #vanna-plot{
677
+ max-height:1000px
678
  }
679
 
680
  #pagination-display{
 
683
  font-size: 16px;
684
  }
685
 
686
+ #table-names label:nth-child(odd) {
687
+ background-color: #f9f9f9;
688
  }
689
+
690
+ #table-names label:nth-child(even) {
691
+ background-color: #e6f0ff;
692
  }
693
 
694
+ #table-names label {
695
+ display: block; /* Chaque option prend toute la ligne */
696
+ width: 100%; /* Chaque option remplit l'espace horizontal */
697
+ box-sizing: border-box;
698
+ padding: 8px 12px;
699
+ margin-bottom: 4px;
700
+ border: 1px solid #ccc;
701
+ border-radius: 6px;
702
+ background-color: white;
703
+ cursor: pointer;
704
+ text-align: center;
705
  }
706
+
707
+ #table-names label:hover {
708
  background-color: #f0f8ff;
709
  }
710
 
711
+ #table-names input[type="radio"] {
712
+ display: none;
713
+ }
714
+
715
+ #table-names input[type="radio"]:checked + label {
716
+ background-color: #d0eaff;
717
+ border-color: #2196f3;
718
+ }
719
+
720
  /* DRIAS Data Table Styles */
721
  #vanna-table {
722
  height: 400px !important;
 
739
  background: white;
740
  z-index: 1;
741
  }
742
+
743
+ .example-img{
744
+ height: 250px;
745
+ object-fit: contain;
746
+ }
747
+
748
+ #example-img-container {
749
+ flex-direction: column;
750
+ align-items: left;
751
+ }