Commit
·
c0fd277
1
Parent(s):
c25f6b1
feat: updated common talk to data for talk to ipcc and drias
Browse files
climateqa/engine/talk_to_data/input_processing.py
CHANGED
@@ -1,16 +1,18 @@
|
|
1 |
-
from typing import Any
|
2 |
import ast
|
3 |
from langchain_core.prompts import ChatPromptTemplate
|
4 |
from geopy.geocoders import Nominatim
|
5 |
from climateqa.engine.llm import get_llm
|
6 |
import duckdb
|
7 |
-
|
|
|
8 |
from climateqa.engine.talk_to_data.objects.llm_outputs import ArrayOutput
|
9 |
from climateqa.engine.talk_to_data.objects.location import Location
|
10 |
from climateqa.engine.talk_to_data.objects.plot import Plot
|
11 |
from climateqa.engine.talk_to_data.objects.states import State
|
|
|
12 |
|
13 |
-
async def detect_location_with_openai(sentence):
|
14 |
"""
|
15 |
Detects locations in a sentence using OpenAI's API via LangChain.
|
16 |
"""
|
@@ -49,21 +51,51 @@ def loc_to_coords(location: str) -> tuple[float, float]:
|
|
49 |
coords = geolocator.geocode(location)
|
50 |
return (coords.latitude, coords.longitude)
|
51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
-
def nearest_neighbour_sql(location: tuple,
|
54 |
long = round(location[1], 3)
|
55 |
lat = round(location[0], 3)
|
|
|
56 |
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
if len(results) == 0:
|
64 |
-
return "", ""
|
65 |
-
|
66 |
-
|
|
|
|
|
|
|
|
|
67 |
|
68 |
async def detect_year_with_openai(sentence: str) -> str:
|
69 |
"""
|
@@ -136,43 +168,49 @@ async def detect_relevant_plots(user_question: str, llm, plot_list: list[Plot])
|
|
136 |
plots_description += " - Description: " + plot["description"] + "\n"
|
137 |
|
138 |
prompt = (
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
|
|
|
|
143 |
f"### Descriptions of the plots : {plots_description}"
|
144 |
-
f"### User question : {user_question}"
|
145 |
-
f"###
|
146 |
)
|
147 |
-
# prompt = (
|
148 |
-
# f"You are helping to answer a question with insightful visualizations. "
|
149 |
-
# f"Given a list of plots with their name and description: "
|
150 |
-
# f"{plots_description} "
|
151 |
-
# f"The user question is: {user_question}. "
|
152 |
-
# f"Choose the most relevant plots to answer the question. "
|
153 |
-
# f"The answer must be a Python list with the names of the relevant plots, and nothing else. "
|
154 |
-
# f"Ensure the response is in the exact format: ['PlotName1', 'PlotName2']."
|
155 |
-
# )
|
156 |
|
157 |
plot_names = ast.literal_eval(
|
158 |
(await llm.ainvoke(prompt)).content.strip("```python\n").strip()
|
159 |
)
|
160 |
return plot_names
|
161 |
|
162 |
-
async def find_location(user_input: str,
|
163 |
-
print(f"---- Find location in
|
164 |
location = await detect_location_with_openai(user_input)
|
165 |
-
output: Location = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
if location:
|
167 |
coords = loc_to_coords(location)
|
168 |
-
|
|
|
169 |
output.update({
|
170 |
"latitude": neighbour[0],
|
171 |
"longitude": neighbour[1],
|
|
|
|
|
|
|
172 |
})
|
|
|
173 |
return output
|
174 |
|
175 |
-
async def find_year(user_input: str) -> str:
|
176 |
"""Extracts year information from user input using LLM.
|
177 |
|
178 |
This function uses an LLM to identify and extract year information from the
|
@@ -186,6 +224,8 @@ async def find_year(user_input: str) -> str:
|
|
186 |
"""
|
187 |
print(f"---- Find year ---")
|
188 |
year = await detect_year_with_openai(user_input)
|
|
|
|
|
189 |
return year
|
190 |
|
191 |
async def find_relevant_plots(state: State, llm, plots: list[Plot]) -> list[str]:
|
@@ -198,7 +238,7 @@ async def find_relevant_tables_per_plot(state: State, plot: Plot, llm, tables: l
|
|
198 |
relevant_tables = await detect_relevant_tables(state['user_input'], plot, llm, tables)
|
199 |
return relevant_tables
|
200 |
|
201 |
-
async def find_param(state: State, param_name:str,
|
202 |
"""Perform the good method to retrieve the desired parameter
|
203 |
|
204 |
Args:
|
@@ -210,7 +250,7 @@ async def find_param(state: State, param_name:str, table: str) -> dict[str, Any]
|
|
210 |
dict[str, Any] | None:
|
211 |
"""
|
212 |
if param_name == 'location':
|
213 |
-
location = await find_location(state['user_input'],
|
214 |
return location
|
215 |
if param_name == 'year':
|
216 |
year = await find_year(state['user_input'])
|
|
|
1 |
+
from typing import Any, Literal, Optional, cast
|
2 |
import ast
|
3 |
from langchain_core.prompts import ChatPromptTemplate
|
4 |
from geopy.geocoders import Nominatim
|
5 |
from climateqa.engine.llm import get_llm
|
6 |
import duckdb
|
7 |
+
import os
|
8 |
+
from climateqa.engine.talk_to_data.ipcc.config import IPCC_DATASET_URL
|
9 |
from climateqa.engine.talk_to_data.objects.llm_outputs import ArrayOutput
|
10 |
from climateqa.engine.talk_to_data.objects.location import Location
|
11 |
from climateqa.engine.talk_to_data.objects.plot import Plot
|
12 |
from climateqa.engine.talk_to_data.objects.states import State
|
13 |
+
import time
|
14 |
|
15 |
+
async def detect_location_with_openai(sentence: str) -> str:
|
16 |
"""
|
17 |
Detects locations in a sentence using OpenAI's API via LangChain.
|
18 |
"""
|
|
|
51 |
coords = geolocator.geocode(location)
|
52 |
return (coords.latitude, coords.longitude)
|
53 |
|
54 |
+
def coords_to_country(coords: tuple[float, float]) -> tuple[str,str]:
|
55 |
+
"""Converts geographic coordinates to a country name.
|
56 |
+
|
57 |
+
This function uses the Nominatim reverse geocoding service to convert
|
58 |
+
latitude and longitude coordinates to a country name.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
coords (tuple[float, float]): A tuple containing (latitude, longitude)
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
tuple[str,str]: A tuple containg (country_code, country_name, admin1)
|
65 |
+
|
66 |
+
Raises:
|
67 |
+
AttributeError: If the coordinates cannot be found
|
68 |
+
"""
|
69 |
+
geolocator = Nominatim(user_agent="latlong_to_country")
|
70 |
+
location = geolocator.reverse(coords)
|
71 |
+
address = location.raw['address']
|
72 |
+
return address['country_code'].upper(), address['country']
|
73 |
|
74 |
+
def nearest_neighbour_sql(location: tuple, mode: Literal['DRIAS', 'IPCC']) -> tuple[str, str, Optional[str]]:
|
75 |
long = round(location[1], 3)
|
76 |
lat = round(location[0], 3)
|
77 |
+
conn = duckdb.connect()
|
78 |
|
79 |
+
if mode == 'DRIAS':
|
80 |
+
table_path = f"'hf://datasets/timeki/drias_db/mean_annual_temperature.parquet'"
|
81 |
+
results = conn.sql(
|
82 |
+
f"SELECT latitude, longitude FROM {table_path} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}"
|
83 |
+
).fetchdf()
|
84 |
+
else:
|
85 |
+
table_path = f"'{IPCC_DATASET_URL}/coordinates.parquet'"
|
86 |
+
results = conn.sql(
|
87 |
+
f"SELECT latitude, longitude, admin1 FROM {table_path} WHERE latitude BETWEEN {lat - 0.5} AND {lat + 0.5} AND longitude BETWEEN {long - 0.5} AND {long + 0.5}"
|
88 |
+
).fetchdf()
|
89 |
+
|
90 |
|
91 |
if len(results) == 0:
|
92 |
+
return "", "", ""
|
93 |
+
|
94 |
+
if 'admin1' in results.columns:
|
95 |
+
admin1 = results['admin1'].iloc[0]
|
96 |
+
else:
|
97 |
+
admin1 = None
|
98 |
+
return results['latitude'].iloc[0], results['longitude'].iloc[0], admin1
|
99 |
|
100 |
async def detect_year_with_openai(sentence: str) -> str:
|
101 |
"""
|
|
|
168 |
plots_description += " - Description: " + plot["description"] + "\n"
|
169 |
|
170 |
prompt = (
|
171 |
+
"You are helping to answer a question with insightful visualizations.\n"
|
172 |
+
"You are given a user question and a list of plots with their name and description.\n"
|
173 |
+
"Based on the descriptions of the plots, select ALL plots that could provide a useful answer to this question. "
|
174 |
+
"Include any plot that could show relevant information, even if their perspectives (such as time series or spatial distribution) are different.\n"
|
175 |
+
"For example, for a question like 'What will be the total rainfall in China in 2050?', both a time series plot and a spatial map plot could be relevant.\n"
|
176 |
+
"Return only a Python list of plot names sorted from the most relevant one to the less relevant one.\n"
|
177 |
f"### Descriptions of the plots : {plots_description}"
|
178 |
+
f"### User question : {user_question}\n"
|
179 |
+
f"### Names of the plots : "
|
180 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
|
182 |
plot_names = ast.literal_eval(
|
183 |
(await llm.ainvoke(prompt)).content.strip("```python\n").strip()
|
184 |
)
|
185 |
return plot_names
|
186 |
|
187 |
+
async def find_location(user_input: str, mode: Literal['DRIAS', 'IPCC'] = 'DRIAS') -> Location:
|
188 |
+
print(f"---- Find location in user input ----")
|
189 |
location = await detect_location_with_openai(user_input)
|
190 |
+
output: Location = {
|
191 |
+
'location' : location,
|
192 |
+
'longitude' : None,
|
193 |
+
'latitude' : None,
|
194 |
+
'country_code' : None,
|
195 |
+
'country_name' : None,
|
196 |
+
'admin1' : None
|
197 |
+
}
|
198 |
+
|
199 |
if location:
|
200 |
coords = loc_to_coords(location)
|
201 |
+
country_code, country_name = coords_to_country(coords)
|
202 |
+
neighbour = nearest_neighbour_sql(coords, mode)
|
203 |
output.update({
|
204 |
"latitude": neighbour[0],
|
205 |
"longitude": neighbour[1],
|
206 |
+
"country_code": country_code,
|
207 |
+
"country_name": country_name,
|
208 |
+
"admin1": neighbour[2]
|
209 |
})
|
210 |
+
output = cast(Location, output)
|
211 |
return output
|
212 |
|
213 |
+
async def find_year(user_input: str) -> str| None:
|
214 |
"""Extracts year information from user input using LLM.
|
215 |
|
216 |
This function uses an LLM to identify and extract year information from the
|
|
|
224 |
"""
|
225 |
print(f"---- Find year ---")
|
226 |
year = await detect_year_with_openai(user_input)
|
227 |
+
if year == "":
|
228 |
+
return None
|
229 |
return year
|
230 |
|
231 |
async def find_relevant_plots(state: State, llm, plots: list[Plot]) -> list[str]:
|
|
|
238 |
relevant_tables = await detect_relevant_tables(state['user_input'], plot, llm, tables)
|
239 |
return relevant_tables
|
240 |
|
241 |
+
async def find_param(state: State, param_name:str, mode: Literal['DRIAS', 'IPCC'] = 'DRIAS') -> dict[str, Optional[str]] | Location | None:
|
242 |
"""Perform the good method to retrieve the desired parameter
|
243 |
|
244 |
Args:
|
|
|
250 |
dict[str, Any] | None:
|
251 |
"""
|
252 |
if param_name == 'location':
|
253 |
+
location = await find_location(state['user_input'], mode)
|
254 |
return location
|
255 |
if param_name == 'year':
|
256 |
year = await find_year(state['user_input'])
|
climateqa/engine/talk_to_data/objects/location.py
CHANGED
@@ -1,7 +1,12 @@
|
|
|
|
1 |
from typing import Optional, TypedDict
|
2 |
|
3 |
|
|
|
4 |
class Location(TypedDict):
|
5 |
location: str
|
6 |
latitude: Optional[str]
|
7 |
longitude: Optional[str]
|
|
|
|
|
|
|
|
1 |
+
from token import OP
|
2 |
from typing import Optional, TypedDict
|
3 |
|
4 |
|
5 |
+
|
6 |
class Location(TypedDict):
|
7 |
location: str
|
8 |
latitude: Optional[str]
|
9 |
longitude: Optional[str]
|
10 |
+
country_code: Optional[str]
|
11 |
+
country_name: Optional[str]
|
12 |
+
admin1: Optional[str]
|
climateqa/engine/talk_to_data/objects/plot.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from typing import Callable, TypedDict
|
2 |
from plotly.graph_objects import Figure
|
3 |
|
4 |
class Plot(TypedDict):
|
@@ -18,4 +18,5 @@ class Plot(TypedDict):
|
|
18 |
description: str
|
19 |
params: list[str]
|
20 |
plot_function: Callable[..., Callable[..., Figure]]
|
21 |
-
sql_query: Callable[..., str]
|
|
|
|
1 |
+
from typing import Callable, TypedDict, Optional
|
2 |
from plotly.graph_objects import Figure
|
3 |
|
4 |
class Plot(TypedDict):
|
|
|
18 |
description: str
|
19 |
params: list[str]
|
20 |
plot_function: Callable[..., Callable[..., Figure]]
|
21 |
+
sql_query: Callable[..., str]
|
22 |
+
short_name: str
|
climateqa/engine/talk_to_data/prompt.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
query_prompt_template = """You are an expert SQL query generator. Given an input question, database schema, SQL dialect and relevant tables to answer the question, generate an optimized and syntactically correct SQL query which can provide useful insights to the question.
|
2 |
+
|
3 |
+
### Instructions:
|
4 |
+
1. **Use only relevant tables**: The following tables are relevant to answering the question: {relevant_tables}. Do not use any other tables.
|
5 |
+
2. **Relevant columns only**: Never select `*`. Only include necessary columns based on the input question.
|
6 |
+
3. **Schema Awareness**:
|
7 |
+
- Use only columns present in the given schema.
|
8 |
+
- **If a column name appears in multiple tables, always use the format `table_name.column_name` to avoid ambiguity.**
|
9 |
+
- Select only the column which are insightful for the question.
|
10 |
+
4. **Dialect Compliance**: Follow `{dialect}` syntax rules.
|
11 |
+
5. **Ordering**: Order the results by a relevant column if applicable (e.g., timestamp for recent records).
|
12 |
+
6. **Valid query**: Make sure the query is syntactically and functionally correct.
|
13 |
+
7. **Conditions** : For the common columns, the same condition should be applied to all the tables (e.g. latitude, longitude, model, year...)
|
14 |
+
9. **Join tables** : If you need to join table, you should join them with year feature.
|
15 |
+
10. **Model** : For each table, you need to add a condition on the model to be equal to {model}
|
16 |
+
|
17 |
+
### Provided Database Schema:
|
18 |
+
{table_info}
|
19 |
+
|
20 |
+
### Relevant Tables:
|
21 |
+
{relevant_tables}
|
22 |
+
|
23 |
+
**Question:** {input}
|
24 |
+
|
25 |
+
**SQL Query:**"""
|
26 |
+
|
27 |
+
plot_prompt_template = """You are a data visualization expert. Given an input question and an SQL Query, generate an insightful plot according to the question.
|
28 |
+
|
29 |
+
### Instructions
|
30 |
+
1. **Use only the column names provided**. The data will be provided as a Pandas DataFrame `df` with the columns present in the SELECT.
|
31 |
+
2. Generate the Python Plotly code to chart the results using `df` and the column names.
|
32 |
+
3. Make as complete a graph as possible to answer the question, and make it as easy to understand as possible.
|
33 |
+
4. **Response with only Python code**. Do not answer with any explanations -- just the code.
|
34 |
+
5. **Specific cases** :
|
35 |
+
- For a question about the evolution of something, it is also relevant to plot the data with also the sliding average for a period of 20 years for example.
|
36 |
+
|
37 |
+
### SQL Query:
|
38 |
+
{sql_query}
|
39 |
+
|
40 |
+
**Question:** {input}
|
41 |
+
|
42 |
+
**Python code:**
|
43 |
+
"""
|
44 |
+
|
climateqa/engine/talk_to_data/query.py
CHANGED
@@ -2,7 +2,7 @@ import asyncio
|
|
2 |
from concurrent.futures import ThreadPoolExecutor
|
3 |
import duckdb
|
4 |
import pandas as pd
|
5 |
-
|
6 |
|
7 |
def find_indicator_column(table: str, indicator_columns_per_table: dict[str,str]) -> str:
|
8 |
"""Retrieves the name of the indicator column within a table.
|
@@ -41,7 +41,12 @@ async def execute_sql_query(sql_query: str) -> pd.DataFrame:
|
|
41 |
def _execute_query():
|
42 |
# Execute the query
|
43 |
con = duckdb.connect()
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
45 |
# return fetched data
|
46 |
return results
|
47 |
|
|
|
2 |
from concurrent.futures import ThreadPoolExecutor
|
3 |
import duckdb
|
4 |
import pandas as pd
|
5 |
+
import os
|
6 |
|
7 |
def find_indicator_column(table: str, indicator_columns_per_table: dict[str,str]) -> str:
|
8 |
"""Retrieves the name of the indicator column within a table.
|
|
|
41 |
def _execute_query():
|
42 |
# Execute the query
|
43 |
con = duckdb.connect()
|
44 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
45 |
+
con.execute(f"""CREATE SECRET hf_token (
|
46 |
+
TYPE huggingface,
|
47 |
+
TOKEN '{HF_TOKEN}'
|
48 |
+
);""")
|
49 |
+
results = con.execute(sql_query).fetchdf()
|
50 |
# return fetched data
|
51 |
return results
|
52 |
|
style.css
CHANGED
@@ -656,12 +656,11 @@ a {
|
|
656 |
/* overflow-y: scroll; */
|
657 |
}
|
658 |
#sql-query{
|
659 |
-
max-height:
|
660 |
-
overflow-y:scroll;
|
661 |
}
|
662 |
|
663 |
#sql-query textarea{
|
664 |
-
min-height:
|
665 |
}
|
666 |
|
667 |
#sql-query span{
|
@@ -671,8 +670,11 @@ div#tab-vanna{
|
|
671 |
max-height: 100¨vh;
|
672 |
overflow-y: hidden;
|
673 |
}
|
|
|
|
|
|
|
674 |
#vanna-plot{
|
675 |
-
max-height:
|
676 |
}
|
677 |
|
678 |
#pagination-display{
|
@@ -681,20 +683,40 @@ div#tab-vanna{
|
|
681 |
font-size: 16px;
|
682 |
}
|
683 |
|
684 |
-
#table-names
|
685 |
-
|
686 |
}
|
687 |
-
|
688 |
-
|
|
|
689 |
}
|
690 |
|
691 |
-
#table-names
|
692 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
693 |
}
|
694 |
-
|
|
|
695 |
background-color: #f0f8ff;
|
696 |
}
|
697 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
698 |
/* DRIAS Data Table Styles */
|
699 |
#vanna-table {
|
700 |
height: 400px !important;
|
@@ -717,3 +739,13 @@ div#tab-vanna{
|
|
717 |
background: white;
|
718 |
z-index: 1;
|
719 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
656 |
/* overflow-y: scroll; */
|
657 |
}
|
658 |
#sql-query{
|
659 |
+
max-height: 100%;
|
|
|
660 |
}
|
661 |
|
662 |
#sql-query textarea{
|
663 |
+
min-height: 200px !important;
|
664 |
}
|
665 |
|
666 |
#sql-query span{
|
|
|
670 |
max-height: 100¨vh;
|
671 |
overflow-y: hidden;
|
672 |
}
|
673 |
+
#details button span{
|
674 |
+
font-weight: bold;
|
675 |
+
}
|
676 |
#vanna-plot{
|
677 |
+
max-height:1000px
|
678 |
}
|
679 |
|
680 |
#pagination-display{
|
|
|
683 |
font-size: 16px;
|
684 |
}
|
685 |
|
686 |
+
#table-names label:nth-child(odd) {
|
687 |
+
background-color: #f9f9f9;
|
688 |
}
|
689 |
+
|
690 |
+
#table-names label:nth-child(even) {
|
691 |
+
background-color: #e6f0ff;
|
692 |
}
|
693 |
|
694 |
+
#table-names label {
|
695 |
+
display: block; /* Chaque option prend toute la ligne */
|
696 |
+
width: 100%; /* Chaque option remplit l'espace horizontal */
|
697 |
+
box-sizing: border-box;
|
698 |
+
padding: 8px 12px;
|
699 |
+
margin-bottom: 4px;
|
700 |
+
border: 1px solid #ccc;
|
701 |
+
border-radius: 6px;
|
702 |
+
background-color: white;
|
703 |
+
cursor: pointer;
|
704 |
+
text-align: center;
|
705 |
}
|
706 |
+
|
707 |
+
#table-names label:hover {
|
708 |
background-color: #f0f8ff;
|
709 |
}
|
710 |
|
711 |
+
#table-names input[type="radio"] {
|
712 |
+
display: none;
|
713 |
+
}
|
714 |
+
|
715 |
+
#table-names input[type="radio"]:checked + label {
|
716 |
+
background-color: #d0eaff;
|
717 |
+
border-color: #2196f3;
|
718 |
+
}
|
719 |
+
|
720 |
/* DRIAS Data Table Styles */
|
721 |
#vanna-table {
|
722 |
height: 400px !important;
|
|
|
739 |
background: white;
|
740 |
z-index: 1;
|
741 |
}
|
742 |
+
|
743 |
+
.example-img{
|
744 |
+
height: 250px;
|
745 |
+
object-fit: contain;
|
746 |
+
}
|
747 |
+
|
748 |
+
#example-img-container {
|
749 |
+
flex-direction: column;
|
750 |
+
align-items: left;
|
751 |
+
}
|