Merge hf-origin/main into main
Browse files- .gitattributes +1 -1
- app.py +8 -0
- climateqa/engine/talk_to_data/main.py +1 -1
- climateqa/engine/talk_to_data/myVanna.py +13 -0
- climateqa/engine/talk_to_data/plot.py +418 -0
- climateqa/engine/talk_to_data/sql_query.py +114 -0
- climateqa/engine/talk_to_data/talk_to_drias.py +317 -0
- climateqa/engine/talk_to_data/utils.py +281 -0
- climateqa/engine/talk_to_data/vanna_class.py +325 -0
- requirements.txt +1 -1
- style.css +10 -1
.gitattributes
CHANGED
@@ -45,4 +45,4 @@ documents/climate_gpt_v2.faiss filter=lfs diff=lfs merge=lfs -text
|
|
45 |
climateqa_v3.db filter=lfs diff=lfs merge=lfs -text
|
46 |
climateqa_v3.faiss filter=lfs diff=lfs merge=lfs -text
|
47 |
data/drias/drias.db filter=lfs diff=lfs merge=lfs -text
|
48 |
-
front/assets/*.png filter=lfs diff=lfs merge=lfs -text
|
|
|
45 |
climateqa_v3.db filter=lfs diff=lfs merge=lfs -text
|
46 |
climateqa_v3.faiss filter=lfs diff=lfs merge=lfs -text
|
47 |
data/drias/drias.db filter=lfs diff=lfs merge=lfs -text
|
48 |
+
front/assets/*.png filter=lfs diff=lfs merge=lfs -text
|
app.py
CHANGED
@@ -16,7 +16,10 @@ from climateqa.chat import start_chat, chat_stream, finish_chat
|
|
16 |
from front.tabs import create_config_modal, cqa_tab, create_about_tab
|
17 |
from front.tabs import MainTabPanel, ConfigPanel
|
18 |
from front.tabs.tab_drias import create_drias_tab
|
|
|
19 |
from front.tabs.tab_ipcc import create_ipcc_tab
|
|
|
|
|
20 |
from front.utils import process_figures
|
21 |
from gradio_modal import Modal
|
22 |
|
@@ -533,8 +536,13 @@ def main_ui():
|
|
533 |
with gr.Tabs():
|
534 |
cqa_components = cqa_tab(tab_name="ClimateQ&A")
|
535 |
local_cqa_components = cqa_tab(tab_name="France - Local Q&A")
|
|
|
536 |
drias_components = create_drias_tab(share_client=share_client, user_id=user_id)
|
537 |
ipcc_components = create_ipcc_tab(share_client=share_client, user_id=user_id)
|
|
|
|
|
|
|
|
|
538 |
create_about_tab()
|
539 |
|
540 |
event_handling(cqa_components, config_components, tab_name="ClimateQ&A")
|
|
|
16 |
from front.tabs import create_config_modal, cqa_tab, create_about_tab
|
17 |
from front.tabs import MainTabPanel, ConfigPanel
|
18 |
from front.tabs.tab_drias import create_drias_tab
|
19 |
+
<<<<<<< HEAD
|
20 |
from front.tabs.tab_ipcc import create_ipcc_tab
|
21 |
+
=======
|
22 |
+
>>>>>>> hf-origin/main
|
23 |
from front.utils import process_figures
|
24 |
from gradio_modal import Modal
|
25 |
|
|
|
536 |
with gr.Tabs():
|
537 |
cqa_components = cqa_tab(tab_name="ClimateQ&A")
|
538 |
local_cqa_components = cqa_tab(tab_name="France - Local Q&A")
|
539 |
+
<<<<<<< HEAD
|
540 |
drias_components = create_drias_tab(share_client=share_client, user_id=user_id)
|
541 |
ipcc_components = create_ipcc_tab(share_client=share_client, user_id=user_id)
|
542 |
+
=======
|
543 |
+
create_drias_tab(share_client=share_client, user_id=user_id)
|
544 |
+
|
545 |
+
>>>>>>> hf-origin/main
|
546 |
create_about_tab()
|
547 |
|
548 |
event_handling(cqa_components, config_components, tab_name="ClimateQ&A")
|
climateqa/engine/talk_to_data/main.py
CHANGED
@@ -121,4 +121,4 @@ async def ask_ipcc(query: str, index_state: int = 0, user_id: str | None = None)
|
|
121 |
|
122 |
log_drias_interaction_to_huggingface(query, sql_query, user_id)
|
123 |
|
124 |
-
return sql_query, dataframe, figure, plot_information, sql_queries, result_dataframes, figures, plot_informations, index_state, plot_title_list, ""
|
|
|
121 |
|
122 |
log_drias_interaction_to_huggingface(query, sql_query, user_id)
|
123 |
|
124 |
+
return sql_query, dataframe, figure, plot_information, sql_queries, result_dataframes, figures, plot_informations, index_state, plot_title_list, ""
|
climateqa/engine/talk_to_data/myVanna.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dotenv import load_dotenv
|
2 |
+
from climateqa.engine.talk_to_data.vanna_class import MyCustomVectorDB
|
3 |
+
from vanna.openai import OpenAI_Chat
|
4 |
+
import os
|
5 |
+
|
6 |
+
load_dotenv()
|
7 |
+
|
8 |
+
OPENAI_API_KEY = os.getenv('THEO_API_KEY')
|
9 |
+
|
10 |
+
class MyVanna(MyCustomVectorDB, OpenAI_Chat):
|
11 |
+
def __init__(self, config=None):
|
12 |
+
MyCustomVectorDB.__init__(self, config=config)
|
13 |
+
OpenAI_Chat.__init__(self, config=config)
|
climateqa/engine/talk_to_data/plot.py
ADDED
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, TypedDict
|
2 |
+
from matplotlib.figure import figaspect
|
3 |
+
import pandas as pd
|
4 |
+
from plotly.graph_objects import Figure
|
5 |
+
import plotly.graph_objects as go
|
6 |
+
import plotly.express as px
|
7 |
+
|
8 |
+
from climateqa.engine.talk_to_data.sql_query import (
|
9 |
+
indicator_for_given_year_query,
|
10 |
+
indicator_per_year_at_location_query,
|
11 |
+
)
|
12 |
+
from climateqa.engine.talk_to_data.config import INDICATOR_TO_UNIT
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
class Plot(TypedDict):
|
18 |
+
"""Represents a plot configuration in the DRIAS system.
|
19 |
+
|
20 |
+
This class defines the structure for configuring different types of plots
|
21 |
+
that can be generated from climate data.
|
22 |
+
|
23 |
+
Attributes:
|
24 |
+
name (str): The name of the plot type
|
25 |
+
description (str): A description of what the plot shows
|
26 |
+
params (list[str]): List of required parameters for the plot
|
27 |
+
plot_function (Callable[..., Callable[..., Figure]]): Function to generate the plot
|
28 |
+
sql_query (Callable[..., str]): Function to generate the SQL query for the plot
|
29 |
+
"""
|
30 |
+
name: str
|
31 |
+
description: str
|
32 |
+
params: list[str]
|
33 |
+
plot_function: Callable[..., Callable[..., Figure]]
|
34 |
+
sql_query: Callable[..., str]
|
35 |
+
|
36 |
+
|
37 |
+
def plot_indicator_evolution_at_location(params: dict) -> Callable[..., Figure]:
|
38 |
+
"""Generates a function to plot indicator evolution over time at a location.
|
39 |
+
|
40 |
+
This function creates a line plot showing how a climate indicator changes
|
41 |
+
over time at a specific location. It handles temperature, precipitation,
|
42 |
+
and other climate indicators.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
params (dict): Dictionary containing:
|
46 |
+
- indicator_column (str): The column name for the indicator
|
47 |
+
- location (str): The location to plot
|
48 |
+
- model (str): The climate model to use
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
|
52 |
+
|
53 |
+
Example:
|
54 |
+
>>> plot_func = plot_indicator_evolution_at_location({
|
55 |
+
... 'indicator_column': 'mean_temperature',
|
56 |
+
... 'location': 'Paris',
|
57 |
+
... 'model': 'ALL'
|
58 |
+
... })
|
59 |
+
>>> fig = plot_func(df)
|
60 |
+
"""
|
61 |
+
indicator = params["indicator_column"]
|
62 |
+
location = params["location"]
|
63 |
+
indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
|
64 |
+
unit = INDICATOR_TO_UNIT.get(indicator, "")
|
65 |
+
|
66 |
+
def plot_data(df: pd.DataFrame) -> Figure:
|
67 |
+
"""Generates the actual plot from the data.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
df (pd.DataFrame): DataFrame containing the data to plot
|
71 |
+
|
72 |
+
Returns:
|
73 |
+
Figure: A plotly Figure object showing the indicator evolution
|
74 |
+
"""
|
75 |
+
fig = go.Figure()
|
76 |
+
if df['model'].nunique() != 1:
|
77 |
+
df_avg = df.groupby("year", as_index=False)[indicator].mean()
|
78 |
+
|
79 |
+
# Transform to list to avoid pandas encoding
|
80 |
+
indicators = df_avg[indicator].astype(float).tolist()
|
81 |
+
years = df_avg["year"].astype(int).tolist()
|
82 |
+
|
83 |
+
# Compute the 10-year rolling average
|
84 |
+
rolling_window = 10
|
85 |
+
sliding_averages = (
|
86 |
+
df_avg[indicator]
|
87 |
+
.rolling(window=rolling_window, min_periods=rolling_window)
|
88 |
+
.mean()
|
89 |
+
.astype(float)
|
90 |
+
.tolist()
|
91 |
+
)
|
92 |
+
model_label = "Model Average"
|
93 |
+
|
94 |
+
# Only add rolling average if we have enough data points
|
95 |
+
if len([x for x in sliding_averages if pd.notna(x)]) > 0:
|
96 |
+
# Sliding average dashed line
|
97 |
+
fig.add_scatter(
|
98 |
+
x=years,
|
99 |
+
y=sliding_averages,
|
100 |
+
mode="lines",
|
101 |
+
name="10 years rolling average",
|
102 |
+
line=dict(dash="dash"),
|
103 |
+
marker=dict(color="#d62728"),
|
104 |
+
hovertemplate=f"10-year average: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
|
105 |
+
)
|
106 |
+
|
107 |
+
else:
|
108 |
+
df_model = df
|
109 |
+
|
110 |
+
# Transform to list to avoid pandas encoding
|
111 |
+
indicators = df_model[indicator].astype(float).tolist()
|
112 |
+
years = df_model["year"].astype(int).tolist()
|
113 |
+
|
114 |
+
# Compute the 10-year rolling average
|
115 |
+
rolling_window = 10
|
116 |
+
sliding_averages = (
|
117 |
+
df_model[indicator]
|
118 |
+
.rolling(window=rolling_window, min_periods=rolling_window)
|
119 |
+
.mean()
|
120 |
+
.astype(float)
|
121 |
+
.tolist()
|
122 |
+
)
|
123 |
+
model_label = f"Model : {df['model'].unique()[0]}"
|
124 |
+
|
125 |
+
# Only add rolling average if we have enough data points
|
126 |
+
if len([x for x in sliding_averages if pd.notna(x)]) > 0:
|
127 |
+
# Sliding average dashed line
|
128 |
+
fig.add_scatter(
|
129 |
+
x=years,
|
130 |
+
y=sliding_averages,
|
131 |
+
mode="lines",
|
132 |
+
name="10 years rolling average",
|
133 |
+
line=dict(dash="dash"),
|
134 |
+
marker=dict(color="#d62728"),
|
135 |
+
hovertemplate=f"10-year average: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
|
136 |
+
)
|
137 |
+
|
138 |
+
# Indicator per year plot
|
139 |
+
fig.add_scatter(
|
140 |
+
x=years,
|
141 |
+
y=indicators,
|
142 |
+
name=f"Yearly {indicator_label}",
|
143 |
+
mode="lines",
|
144 |
+
marker=dict(color="#1f77b4"),
|
145 |
+
hovertemplate=f"{indicator_label}: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
|
146 |
+
)
|
147 |
+
fig.update_layout(
|
148 |
+
title=f"Plot of {indicator_label} in {location} ({model_label})",
|
149 |
+
xaxis_title="Year",
|
150 |
+
yaxis_title=f"{indicator_label} ({unit})",
|
151 |
+
template="plotly_white",
|
152 |
+
)
|
153 |
+
return fig
|
154 |
+
|
155 |
+
return plot_data
|
156 |
+
|
157 |
+
|
158 |
+
indicator_evolution_at_location: Plot = {
|
159 |
+
"name": "Indicator evolution at location",
|
160 |
+
"description": "Plot an evolution of the indicator at a certain location",
|
161 |
+
"params": ["indicator_column", "location", "model"],
|
162 |
+
"plot_function": plot_indicator_evolution_at_location,
|
163 |
+
"sql_query": indicator_per_year_at_location_query,
|
164 |
+
}
|
165 |
+
|
166 |
+
|
167 |
+
def plot_indicator_number_of_days_per_year_at_location(
|
168 |
+
params: dict,
|
169 |
+
) -> Callable[..., Figure]:
|
170 |
+
"""Generates a function to plot the number of days per year for an indicator.
|
171 |
+
|
172 |
+
This function creates a bar chart showing the frequency of certain climate
|
173 |
+
events (like days above a temperature threshold) per year at a specific location.
|
174 |
+
|
175 |
+
Args:
|
176 |
+
params (dict): Dictionary containing:
|
177 |
+
- indicator_column (str): The column name for the indicator
|
178 |
+
- location (str): The location to plot
|
179 |
+
- model (str): The climate model to use
|
180 |
+
|
181 |
+
Returns:
|
182 |
+
Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
|
183 |
+
"""
|
184 |
+
indicator = params["indicator_column"]
|
185 |
+
location = params["location"]
|
186 |
+
indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
|
187 |
+
unit = INDICATOR_TO_UNIT.get(indicator, "")
|
188 |
+
|
189 |
+
def plot_data(df: pd.DataFrame) -> Figure:
|
190 |
+
"""Generate the figure thanks to the dataframe
|
191 |
+
|
192 |
+
Args:
|
193 |
+
df (pd.DataFrame): pandas dataframe with the required data
|
194 |
+
|
195 |
+
Returns:
|
196 |
+
Figure: Plotly figure
|
197 |
+
"""
|
198 |
+
fig = go.Figure()
|
199 |
+
if df['model'].nunique() != 1:
|
200 |
+
df_avg = df.groupby("year", as_index=False)[indicator].mean()
|
201 |
+
|
202 |
+
# Transform to list to avoid pandas encoding
|
203 |
+
indicators = df_avg[indicator].astype(float).tolist()
|
204 |
+
years = df_avg["year"].astype(int).tolist()
|
205 |
+
model_label = "Model Average"
|
206 |
+
|
207 |
+
else:
|
208 |
+
df_model = df
|
209 |
+
# Transform to list to avoid pandas encoding
|
210 |
+
indicators = df_model[indicator].astype(float).tolist()
|
211 |
+
years = df_model["year"].astype(int).tolist()
|
212 |
+
model_label = f"Model : {df['model'].unique()[0]}"
|
213 |
+
|
214 |
+
|
215 |
+
# Bar plot
|
216 |
+
fig.add_trace(
|
217 |
+
go.Bar(
|
218 |
+
x=years,
|
219 |
+
y=indicators,
|
220 |
+
width=0.5,
|
221 |
+
marker=dict(color="#1f77b4"),
|
222 |
+
hovertemplate=f"{indicator_label}: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
|
223 |
+
)
|
224 |
+
)
|
225 |
+
|
226 |
+
fig.update_layout(
|
227 |
+
title=f"{indicator_label} in {location} ({model_label})",
|
228 |
+
xaxis_title="Year",
|
229 |
+
yaxis_title=f"{indicator_label} ({unit})",
|
230 |
+
yaxis=dict(range=[0, max(indicators)]),
|
231 |
+
bargap=0.5,
|
232 |
+
template="plotly_white",
|
233 |
+
)
|
234 |
+
|
235 |
+
return fig
|
236 |
+
|
237 |
+
return plot_data
|
238 |
+
|
239 |
+
|
240 |
+
indicator_number_of_days_per_year_at_location: Plot = {
|
241 |
+
"name": "Indicator number of days per year at location",
|
242 |
+
"description": "Plot a barchart of the number of days per year of a certain indicator at a certain location. It is appropriate for frequency indicator.",
|
243 |
+
"params": ["indicator_column", "location", "model"],
|
244 |
+
"plot_function": plot_indicator_number_of_days_per_year_at_location,
|
245 |
+
"sql_query": indicator_per_year_at_location_query,
|
246 |
+
}
|
247 |
+
|
248 |
+
|
249 |
+
def plot_distribution_of_indicator_for_given_year(
|
250 |
+
params: dict,
|
251 |
+
) -> Callable[..., Figure]:
|
252 |
+
"""Generates a function to plot the distribution of an indicator for a year.
|
253 |
+
|
254 |
+
This function creates a histogram showing the distribution of a climate
|
255 |
+
indicator across different locations for a specific year.
|
256 |
+
|
257 |
+
Args:
|
258 |
+
params (dict): Dictionary containing:
|
259 |
+
- indicator_column (str): The column name for the indicator
|
260 |
+
- year (str): The year to plot
|
261 |
+
- model (str): The climate model to use
|
262 |
+
|
263 |
+
Returns:
|
264 |
+
Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
|
265 |
+
"""
|
266 |
+
indicator = params["indicator_column"]
|
267 |
+
year = params["year"]
|
268 |
+
indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
|
269 |
+
unit = INDICATOR_TO_UNIT.get(indicator, "")
|
270 |
+
|
271 |
+
def plot_data(df: pd.DataFrame) -> Figure:
|
272 |
+
"""Generate the figure thanks to the dataframe
|
273 |
+
|
274 |
+
Args:
|
275 |
+
df (pd.DataFrame): pandas dataframe with the required data
|
276 |
+
|
277 |
+
Returns:
|
278 |
+
Figure: Plotly figure
|
279 |
+
"""
|
280 |
+
fig = go.Figure()
|
281 |
+
if df['model'].nunique() != 1:
|
282 |
+
df_avg = df.groupby(["latitude", "longitude"], as_index=False)[
|
283 |
+
indicator
|
284 |
+
].mean()
|
285 |
+
|
286 |
+
# Transform to list to avoid pandas encoding
|
287 |
+
indicators = df_avg[indicator].astype(float).tolist()
|
288 |
+
model_label = "Model Average"
|
289 |
+
|
290 |
+
else:
|
291 |
+
df_model = df
|
292 |
+
|
293 |
+
# Transform to list to avoid pandas encoding
|
294 |
+
indicators = df_model[indicator].astype(float).tolist()
|
295 |
+
model_label = f"Model : {df['model'].unique()[0]}"
|
296 |
+
|
297 |
+
|
298 |
+
fig.add_trace(
|
299 |
+
go.Histogram(
|
300 |
+
x=indicators,
|
301 |
+
opacity=0.8,
|
302 |
+
histnorm="percent",
|
303 |
+
marker=dict(color="#1f77b4"),
|
304 |
+
hovertemplate=f"{indicator_label}: %{{x:.2f}} {unit}<br>Frequency: %{{y:.2f}}%<extra></extra>"
|
305 |
+
)
|
306 |
+
)
|
307 |
+
|
308 |
+
fig.update_layout(
|
309 |
+
title=f"Distribution of {indicator_label} in {year} ({model_label})",
|
310 |
+
xaxis_title=f"{indicator_label} ({unit})",
|
311 |
+
yaxis_title="Frequency (%)",
|
312 |
+
plot_bgcolor="rgba(0, 0, 0, 0)",
|
313 |
+
showlegend=False,
|
314 |
+
)
|
315 |
+
|
316 |
+
return fig
|
317 |
+
|
318 |
+
return plot_data
|
319 |
+
|
320 |
+
|
321 |
+
distribution_of_indicator_for_given_year: Plot = {
|
322 |
+
"name": "Distribution of an indicator for a given year",
|
323 |
+
"description": "Plot an histogram of the distribution for a given year of the values of an indicator",
|
324 |
+
"params": ["indicator_column", "model", "year"],
|
325 |
+
"plot_function": plot_distribution_of_indicator_for_given_year,
|
326 |
+
"sql_query": indicator_for_given_year_query,
|
327 |
+
}
|
328 |
+
|
329 |
+
|
330 |
+
def plot_map_of_france_of_indicator_for_given_year(
|
331 |
+
params: dict,
|
332 |
+
) -> Callable[..., Figure]:
|
333 |
+
"""Generates a function to plot a map of France for an indicator.
|
334 |
+
|
335 |
+
This function creates a choropleth map of France showing the spatial
|
336 |
+
distribution of a climate indicator for a specific year.
|
337 |
+
|
338 |
+
Args:
|
339 |
+
params (dict): Dictionary containing:
|
340 |
+
- indicator_column (str): The column name for the indicator
|
341 |
+
- year (str): The year to plot
|
342 |
+
- model (str): The climate model to use
|
343 |
+
|
344 |
+
Returns:
|
345 |
+
Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
|
346 |
+
"""
|
347 |
+
indicator = params["indicator_column"]
|
348 |
+
year = params["year"]
|
349 |
+
indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
|
350 |
+
unit = INDICATOR_TO_UNIT.get(indicator, "")
|
351 |
+
|
352 |
+
def plot_data(df: pd.DataFrame) -> Figure:
|
353 |
+
fig = go.Figure()
|
354 |
+
if df['model'].nunique() != 1:
|
355 |
+
df_avg = df.groupby(["latitude", "longitude"], as_index=False)[
|
356 |
+
indicator
|
357 |
+
].mean()
|
358 |
+
|
359 |
+
indicators = df_avg[indicator].astype(float).tolist()
|
360 |
+
latitudes = df_avg["latitude"].astype(float).tolist()
|
361 |
+
longitudes = df_avg["longitude"].astype(float).tolist()
|
362 |
+
model_label = "Model Average"
|
363 |
+
|
364 |
+
else:
|
365 |
+
df_model = df
|
366 |
+
|
367 |
+
# Transform to list to avoid pandas encoding
|
368 |
+
indicators = df_model[indicator].astype(float).tolist()
|
369 |
+
latitudes = df_model["latitude"].astype(float).tolist()
|
370 |
+
longitudes = df_model["longitude"].astype(float).tolist()
|
371 |
+
model_label = f"Model : {df['model'].unique()[0]}"
|
372 |
+
|
373 |
+
|
374 |
+
fig.add_trace(
|
375 |
+
go.Scattermapbox(
|
376 |
+
lat=latitudes,
|
377 |
+
lon=longitudes,
|
378 |
+
mode="markers",
|
379 |
+
marker=dict(
|
380 |
+
size=10,
|
381 |
+
color=indicators, # Color mapped to values
|
382 |
+
colorscale="Turbo", # Color scale (can be 'Plasma', 'Jet', etc.)
|
383 |
+
cmin=min(indicators), # Minimum color range
|
384 |
+
cmax=max(indicators), # Maximum color range
|
385 |
+
showscale=True, # Show colorbar
|
386 |
+
),
|
387 |
+
text=[f"{indicator_label}: {value:.2f} {unit}" for value in indicators], # Add hover text showing the indicator value
|
388 |
+
hoverinfo="text" # Only show the custom text on hover
|
389 |
+
)
|
390 |
+
)
|
391 |
+
|
392 |
+
fig.update_layout(
|
393 |
+
mapbox_style="open-street-map", # Use OpenStreetMap
|
394 |
+
mapbox_zoom=3,
|
395 |
+
mapbox_center={"lat": 46.6, "lon": 2.0},
|
396 |
+
coloraxis_colorbar=dict(title=f"{indicator_label} ({unit})"), # Add legend
|
397 |
+
title=f"{indicator_label} in {year} in France ({model_label}) " # Title
|
398 |
+
)
|
399 |
+
return fig
|
400 |
+
|
401 |
+
return plot_data
|
402 |
+
|
403 |
+
|
404 |
+
map_of_france_of_indicator_for_given_year: Plot = {
|
405 |
+
"name": "Map of France of an indicator for a given year",
|
406 |
+
"description": "Heatmap on the map of France of the values of an in indicator for a given year",
|
407 |
+
"params": ["indicator_column", "year", "model"],
|
408 |
+
"plot_function": plot_map_of_france_of_indicator_for_given_year,
|
409 |
+
"sql_query": indicator_for_given_year_query,
|
410 |
+
}
|
411 |
+
|
412 |
+
|
413 |
+
PLOTS = [
|
414 |
+
indicator_evolution_at_location,
|
415 |
+
indicator_number_of_days_per_year_at_location,
|
416 |
+
distribution_of_indicator_for_given_year,
|
417 |
+
map_of_france_of_indicator_for_given_year,
|
418 |
+
]
|
climateqa/engine/talk_to_data/sql_query.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
from concurrent.futures import ThreadPoolExecutor
|
3 |
+
from typing import TypedDict
|
4 |
+
import duckdb
|
5 |
+
import pandas as pd
|
6 |
+
|
7 |
+
async def execute_sql_query(sql_query: str) -> pd.DataFrame:
|
8 |
+
"""Executes a SQL query on the DRIAS database and returns the results.
|
9 |
+
|
10 |
+
This function connects to the DuckDB database containing DRIAS climate data
|
11 |
+
and executes the provided SQL query. It handles the database connection and
|
12 |
+
returns the results as a pandas DataFrame.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
sql_query (str): The SQL query to execute
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
pd.DataFrame: A DataFrame containing the query results
|
19 |
+
|
20 |
+
Raises:
|
21 |
+
duckdb.Error: If there is an error executing the SQL query
|
22 |
+
"""
|
23 |
+
def _execute_query():
|
24 |
+
# Execute the query
|
25 |
+
con = duckdb.connect()
|
26 |
+
results = con.sql(sql_query).fetchdf()
|
27 |
+
# return fetched data
|
28 |
+
return results
|
29 |
+
|
30 |
+
# Run the query in a thread pool to avoid blocking
|
31 |
+
loop = asyncio.get_event_loop()
|
32 |
+
with ThreadPoolExecutor() as executor:
|
33 |
+
return await loop.run_in_executor(executor, _execute_query)
|
34 |
+
|
35 |
+
|
36 |
+
class IndicatorPerYearAtLocationQueryParams(TypedDict, total=False):
|
37 |
+
"""Parameters for querying an indicator's values over time at a location.
|
38 |
+
|
39 |
+
This class defines the parameters needed to query climate indicator data
|
40 |
+
for a specific location over multiple years.
|
41 |
+
|
42 |
+
Attributes:
|
43 |
+
indicator_column (str): The column name for the climate indicator
|
44 |
+
latitude (str): The latitude coordinate of the location
|
45 |
+
longitude (str): The longitude coordinate of the location
|
46 |
+
model (str): The climate model to use (optional)
|
47 |
+
"""
|
48 |
+
indicator_column: str
|
49 |
+
latitude: str
|
50 |
+
longitude: str
|
51 |
+
model: str
|
52 |
+
|
53 |
+
|
54 |
+
def indicator_per_year_at_location_query(
|
55 |
+
table: str, params: IndicatorPerYearAtLocationQueryParams
|
56 |
+
) -> str:
|
57 |
+
"""SQL Query to get the evolution of an indicator per year at a certain location
|
58 |
+
|
59 |
+
Args:
|
60 |
+
table (str): sql table of the indicator
|
61 |
+
params (IndicatorPerYearAtLocationQueryParams) : dictionary with the required params for the query
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
str: the sql query
|
65 |
+
"""
|
66 |
+
indicator_column = params.get("indicator_column")
|
67 |
+
latitude = params.get("latitude")
|
68 |
+
longitude = params.get("longitude")
|
69 |
+
|
70 |
+
if indicator_column is None or latitude is None or longitude is None: # If one parameter is missing, returns an empty query
|
71 |
+
return ""
|
72 |
+
|
73 |
+
table = f"'hf://datasets/timeki/drias_db/{table.lower()}.parquet'"
|
74 |
+
|
75 |
+
sql_query = f"SELECT year, {indicator_column}, model\nFROM {table}\nWHERE latitude = {latitude} \nAnd longitude = {longitude} \nOrder by Year"
|
76 |
+
|
77 |
+
return sql_query
|
78 |
+
|
79 |
+
class IndicatorForGivenYearQueryParams(TypedDict, total=False):
|
80 |
+
"""Parameters for querying an indicator's values across locations for a year.
|
81 |
+
|
82 |
+
This class defines the parameters needed to query climate indicator data
|
83 |
+
across different locations for a specific year.
|
84 |
+
|
85 |
+
Attributes:
|
86 |
+
indicator_column (str): The column name for the climate indicator
|
87 |
+
year (str): The year to query
|
88 |
+
model (str): The climate model to use (optional)
|
89 |
+
"""
|
90 |
+
indicator_column: str
|
91 |
+
year: str
|
92 |
+
model: str
|
93 |
+
|
94 |
+
def indicator_for_given_year_query(
|
95 |
+
table:str, params: IndicatorForGivenYearQueryParams
|
96 |
+
) -> str:
|
97 |
+
"""SQL Query to get the values of an indicator with their latitudes, longitudes and models for a given year
|
98 |
+
|
99 |
+
Args:
|
100 |
+
table (str): sql table of the indicator
|
101 |
+
params (IndicatorForGivenYearQueryParams): dictionarry with the required params for the query
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
str: the sql query
|
105 |
+
"""
|
106 |
+
indicator_column = params.get("indicator_column")
|
107 |
+
year = params.get('year')
|
108 |
+
if year is None or indicator_column is None: # If one parameter is missing, returns an empty query
|
109 |
+
return ""
|
110 |
+
|
111 |
+
table = f"'hf://datasets/timeki/drias_db/{table.lower()}.parquet'"
|
112 |
+
|
113 |
+
sql_query = f"Select {indicator_column}, latitude, longitude, model\nFrom {table}\nWhere year = {year}"
|
114 |
+
return sql_query
|
climateqa/engine/talk_to_data/talk_to_drias.py
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from typing import Any, Callable, TypedDict, Optional
|
4 |
+
from numpy import sort
|
5 |
+
import pandas as pd
|
6 |
+
import asyncio
|
7 |
+
from plotly.graph_objects import Figure
|
8 |
+
from climateqa.engine.llm import get_llm
|
9 |
+
from climateqa.engine.talk_to_data import sql_query
|
10 |
+
from climateqa.engine.talk_to_data.config import INDICATOR_COLUMNS_PER_TABLE
|
11 |
+
from climateqa.engine.talk_to_data.plot import PLOTS, Plot
|
12 |
+
from climateqa.engine.talk_to_data.sql_query import execute_sql_query
|
13 |
+
from climateqa.engine.talk_to_data.utils import (
|
14 |
+
detect_relevant_plots,
|
15 |
+
detect_year_with_openai,
|
16 |
+
loc2coords,
|
17 |
+
detect_location_with_openai,
|
18 |
+
nearestNeighbourSQL,
|
19 |
+
detect_relevant_tables,
|
20 |
+
)
|
21 |
+
|
22 |
+
|
23 |
+
ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))
|
24 |
+
|
25 |
+
class TableState(TypedDict):
|
26 |
+
"""Represents the state of a table in the DRIAS workflow.
|
27 |
+
|
28 |
+
This class defines the structure for tracking the state of a table during the
|
29 |
+
data processing workflow, including its name, parameters, SQL query, and results.
|
30 |
+
|
31 |
+
Attributes:
|
32 |
+
table_name (str): The name of the table in the database
|
33 |
+
params (dict[str, Any]): Parameters used for querying the table
|
34 |
+
sql_query (str, optional): The SQL query used to fetch data
|
35 |
+
dataframe (pd.DataFrame | None, optional): The resulting data
|
36 |
+
figure (Callable[..., Figure], optional): Function to generate visualization
|
37 |
+
status (str): The current status of the table processing ('OK' or 'ERROR')
|
38 |
+
"""
|
39 |
+
table_name: str
|
40 |
+
params: dict[str, Any]
|
41 |
+
sql_query: Optional[str]
|
42 |
+
dataframe: Optional[pd.DataFrame | None]
|
43 |
+
figure: Optional[Callable[..., Figure]]
|
44 |
+
status: str
|
45 |
+
|
46 |
+
class PlotState(TypedDict):
|
47 |
+
"""Represents the state of a plot in the DRIAS workflow.
|
48 |
+
|
49 |
+
This class defines the structure for tracking the state of a plot during the
|
50 |
+
data processing workflow, including its name and associated tables.
|
51 |
+
|
52 |
+
Attributes:
|
53 |
+
plot_name (str): The name of the plot
|
54 |
+
tables (list[str]): List of tables used in the plot
|
55 |
+
table_states (dict[str, TableState]): States of the tables used in the plot
|
56 |
+
"""
|
57 |
+
plot_name: str
|
58 |
+
tables: list[str]
|
59 |
+
table_states: dict[str, TableState]
|
60 |
+
|
61 |
+
class State(TypedDict):
|
62 |
+
user_input: str
|
63 |
+
plots: list[str]
|
64 |
+
plot_states: dict[str, PlotState]
|
65 |
+
error: Optional[str]
|
66 |
+
|
67 |
+
async def find_relevant_plots(state: State, llm) -> list[str]:
|
68 |
+
print("---- Find relevant plots ----")
|
69 |
+
relevant_plots = await detect_relevant_plots(state['user_input'], llm)
|
70 |
+
return relevant_plots
|
71 |
+
|
72 |
+
async def find_relevant_tables_per_plot(state: State, plot: Plot, llm) -> list[str]:
|
73 |
+
print(f"---- Find relevant tables for {plot['name']} ----")
|
74 |
+
relevant_tables = await detect_relevant_tables(state['user_input'], plot, llm)
|
75 |
+
return relevant_tables
|
76 |
+
|
77 |
+
async def find_param(state: State, param_name:str, table: str) -> dict[str, Any] | None:
|
78 |
+
"""Perform the good method to retrieve the desired parameter
|
79 |
+
|
80 |
+
Args:
|
81 |
+
state (State): state of the workflow
|
82 |
+
param_name (str): name of the desired parameter
|
83 |
+
table (str): name of the table
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
dict[str, Any] | None:
|
87 |
+
"""
|
88 |
+
if param_name == 'location':
|
89 |
+
location = await find_location(state['user_input'], table)
|
90 |
+
return location
|
91 |
+
if param_name == 'year':
|
92 |
+
year = await find_year(state['user_input'])
|
93 |
+
return {'year': year}
|
94 |
+
return None
|
95 |
+
|
96 |
+
class Location(TypedDict):
|
97 |
+
location: str
|
98 |
+
latitude: Optional[str]
|
99 |
+
longitude: Optional[str]
|
100 |
+
|
101 |
+
async def find_location(user_input: str, table: str) -> Location:
|
102 |
+
print(f"---- Find location in table {table} ----")
|
103 |
+
location = await detect_location_with_openai(user_input)
|
104 |
+
output: Location = {'location' : location}
|
105 |
+
if location:
|
106 |
+
coords = loc2coords(location)
|
107 |
+
neighbour = nearestNeighbourSQL(coords, table)
|
108 |
+
output.update({
|
109 |
+
"latitude": neighbour[0],
|
110 |
+
"longitude": neighbour[1],
|
111 |
+
})
|
112 |
+
return output
|
113 |
+
|
114 |
+
async def find_year(user_input: str) -> str:
|
115 |
+
"""Extracts year information from user input using LLM.
|
116 |
+
|
117 |
+
This function uses an LLM to identify and extract year information from the
|
118 |
+
user's query, which is used to filter data in subsequent queries.
|
119 |
+
|
120 |
+
Args:
|
121 |
+
user_input (str): The user's query text
|
122 |
+
|
123 |
+
Returns:
|
124 |
+
str: The extracted year, or empty string if no year found
|
125 |
+
"""
|
126 |
+
print(f"---- Find year ---")
|
127 |
+
year = await detect_year_with_openai(user_input)
|
128 |
+
return year
|
129 |
+
|
130 |
+
def find_indicator_column(table: str) -> str:
|
131 |
+
"""Retrieves the name of the indicator column within a table.
|
132 |
+
|
133 |
+
This function maps table names to their corresponding indicator columns
|
134 |
+
using the predefined mapping in INDICATOR_COLUMNS_PER_TABLE.
|
135 |
+
|
136 |
+
Args:
|
137 |
+
table (str): Name of the table in the database
|
138 |
+
|
139 |
+
Returns:
|
140 |
+
str: Name of the indicator column for the specified table
|
141 |
+
|
142 |
+
Raises:
|
143 |
+
KeyError: If the table name is not found in the mapping
|
144 |
+
"""
|
145 |
+
print(f"---- Find indicator column in table {table} ----")
|
146 |
+
return INDICATOR_COLUMNS_PER_TABLE[table]
|
147 |
+
|
148 |
+
|
149 |
+
async def process_table(
|
150 |
+
table: str,
|
151 |
+
params: dict[str, Any],
|
152 |
+
plot: Plot,
|
153 |
+
) -> TableState:
|
154 |
+
"""Processes a table to extract relevant data and generate visualizations.
|
155 |
+
|
156 |
+
This function retrieves the SQL query for the specified table, executes it,
|
157 |
+
and generates a visualization based on the results.
|
158 |
+
|
159 |
+
Args:
|
160 |
+
table (str): The name of the table to process
|
161 |
+
params (dict[str, Any]): Parameters used for querying the table
|
162 |
+
plot (Plot): The plot object containing SQL query and visualization function
|
163 |
+
|
164 |
+
Returns:
|
165 |
+
TableState: The state of the processed table
|
166 |
+
"""
|
167 |
+
table_state: TableState = {
|
168 |
+
'table_name': table,
|
169 |
+
'params': params.copy(),
|
170 |
+
'status': 'OK',
|
171 |
+
'dataframe': None,
|
172 |
+
'sql_query': None,
|
173 |
+
'figure': None
|
174 |
+
}
|
175 |
+
|
176 |
+
table_state['params']['indicator_column'] = find_indicator_column(table)
|
177 |
+
sql_query = plot['sql_query'](table, table_state['params'])
|
178 |
+
|
179 |
+
if sql_query == "":
|
180 |
+
table_state['status'] = 'ERROR'
|
181 |
+
return table_state
|
182 |
+
table_state['sql_query'] = sql_query
|
183 |
+
df = await execute_sql_query(sql_query)
|
184 |
+
|
185 |
+
table_state['dataframe'] = df
|
186 |
+
table_state['figure'] = plot['plot_function'](table_state['params'])
|
187 |
+
|
188 |
+
return table_state
|
189 |
+
|
190 |
+
async def drias_workflow(user_input: str) -> State:
|
191 |
+
"""Performs the complete workflow of Talk To Drias : from user input to sql queries, dataframes and figures generated
|
192 |
+
|
193 |
+
Args:
|
194 |
+
user_input (str): initial user input
|
195 |
+
|
196 |
+
Returns:
|
197 |
+
State: Final state with all the results
|
198 |
+
"""
|
199 |
+
state: State = {
|
200 |
+
'user_input': user_input,
|
201 |
+
'plots': [],
|
202 |
+
'plot_states': {},
|
203 |
+
'error': ''
|
204 |
+
}
|
205 |
+
|
206 |
+
llm = get_llm(provider="openai")
|
207 |
+
|
208 |
+
plots = await find_relevant_plots(state, llm)
|
209 |
+
|
210 |
+
state['plots'] = plots
|
211 |
+
|
212 |
+
if len(state['plots']) < 1:
|
213 |
+
state['error'] = 'There is no plot to answer to the question'
|
214 |
+
return state
|
215 |
+
|
216 |
+
have_relevant_table = False
|
217 |
+
have_sql_query = False
|
218 |
+
have_dataframe = False
|
219 |
+
|
220 |
+
for plot_name in state['plots']:
|
221 |
+
|
222 |
+
plot = next((p for p in PLOTS if p['name'] == plot_name), None) # Find the associated plot object
|
223 |
+
if plot is None:
|
224 |
+
continue
|
225 |
+
|
226 |
+
plot_state: PlotState = {
|
227 |
+
'plot_name': plot_name,
|
228 |
+
'tables': [],
|
229 |
+
'table_states': {}
|
230 |
+
}
|
231 |
+
|
232 |
+
plot_state['plot_name'] = plot_name
|
233 |
+
|
234 |
+
relevant_tables = await find_relevant_tables_per_plot(state, plot, llm)
|
235 |
+
|
236 |
+
if len(relevant_tables) > 0 :
|
237 |
+
have_relevant_table = True
|
238 |
+
|
239 |
+
plot_state['tables'] = relevant_tables
|
240 |
+
|
241 |
+
params = {}
|
242 |
+
for param_name in plot['params']:
|
243 |
+
param = await find_param(state, param_name, relevant_tables[0])
|
244 |
+
if param:
|
245 |
+
params.update(param)
|
246 |
+
|
247 |
+
tasks = [process_table(table, params, plot) for table in plot_state['tables'][:3]]
|
248 |
+
results = await asyncio.gather(*tasks)
|
249 |
+
|
250 |
+
# Store results back in plot_state
|
251 |
+
have_dataframe = False
|
252 |
+
have_sql_query = False
|
253 |
+
for table_state in results:
|
254 |
+
if table_state['sql_query']:
|
255 |
+
have_sql_query = True
|
256 |
+
if table_state['dataframe'] is not None and len(table_state['dataframe']) > 0:
|
257 |
+
have_dataframe = True
|
258 |
+
plot_state['table_states'][table_state['table_name']] = table_state
|
259 |
+
|
260 |
+
state['plot_states'][plot_name] = plot_state
|
261 |
+
|
262 |
+
if not have_relevant_table:
|
263 |
+
state['error'] = "There is no relevant table in our database to answer your question"
|
264 |
+
elif not have_sql_query:
|
265 |
+
state['error'] = "There is no relevant sql query on our database that can help to answer your question"
|
266 |
+
elif not have_dataframe:
|
267 |
+
state['error'] = "There is no data in our table that can answer to your question"
|
268 |
+
|
269 |
+
return state
|
270 |
+
|
271 |
+
# def make_write_query_node():
|
272 |
+
|
273 |
+
# def write_query(state):
|
274 |
+
# print("---- Write query ----")
|
275 |
+
# for table in state["tables"]:
|
276 |
+
# sql_query = QUERIES[state[table]['query_type']](
|
277 |
+
# table=table,
|
278 |
+
# indicator_column=state[table]["columns"],
|
279 |
+
# longitude=state[table]["longitude"],
|
280 |
+
# latitude=state[table]["latitude"],
|
281 |
+
# )
|
282 |
+
# state[table].update({"sql_query": sql_query})
|
283 |
+
|
284 |
+
# return state
|
285 |
+
|
286 |
+
# return write_query
|
287 |
+
|
288 |
+
# def make_fetch_data_node(db_path):
|
289 |
+
|
290 |
+
# def fetch_data(state):
|
291 |
+
# print("---- Fetch data ----")
|
292 |
+
# for table in state["tables"]:
|
293 |
+
# results = execute_sql_query(db_path, state[table]['sql_query'])
|
294 |
+
# state[table].update(results)
|
295 |
+
|
296 |
+
# return state
|
297 |
+
|
298 |
+
# return fetch_data
|
299 |
+
|
300 |
+
|
301 |
+
|
302 |
+
## V2
|
303 |
+
|
304 |
+
|
305 |
+
# def make_fetch_data_node(db_path: str, llm):
|
306 |
+
# def fetch_data(state):
|
307 |
+
# print("---- Fetch data ----")
|
308 |
+
# db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
|
309 |
+
# output = {}
|
310 |
+
# sql_query = write_sql_query(state["query"], db, state["tables"], llm)
|
311 |
+
# # TO DO : Add query checker
|
312 |
+
# print(f"SQL query : {sql_query}")
|
313 |
+
# output["sql_query"] = sql_query
|
314 |
+
# output.update(fetch_data_from_sql_query(db_path, sql_query))
|
315 |
+
# return output
|
316 |
+
|
317 |
+
# return fetch_data
|
climateqa/engine/talk_to_data/utils.py
ADDED
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from typing import Annotated, TypedDict
|
3 |
+
import duckdb
|
4 |
+
from geopy.geocoders import Nominatim
|
5 |
+
import ast
|
6 |
+
from climateqa.engine.llm import get_llm
|
7 |
+
from climateqa.engine.talk_to_data.config import DRIAS_TABLES
|
8 |
+
from climateqa.engine.talk_to_data.plot import PLOTS, Plot
|
9 |
+
from langchain_core.prompts import ChatPromptTemplate
|
10 |
+
|
11 |
+
|
12 |
+
async def detect_location_with_openai(sentence):
|
13 |
+
"""
|
14 |
+
Detects locations in a sentence using OpenAI's API via LangChain.
|
15 |
+
"""
|
16 |
+
llm = get_llm()
|
17 |
+
|
18 |
+
prompt = f"""
|
19 |
+
Extract all locations (cities, countries, states, or geographical areas) mentioned in the following sentence.
|
20 |
+
Return the result as a Python list. If no locations are mentioned, return an empty list.
|
21 |
+
|
22 |
+
Sentence: "{sentence}"
|
23 |
+
"""
|
24 |
+
|
25 |
+
response = await llm.ainvoke(prompt)
|
26 |
+
location_list = ast.literal_eval(response.content.strip("```python\n").strip())
|
27 |
+
if location_list:
|
28 |
+
return location_list[0]
|
29 |
+
else:
|
30 |
+
return ""
|
31 |
+
|
32 |
+
class ArrayOutput(TypedDict):
|
33 |
+
"""Represents the output of a function that returns an array.
|
34 |
+
|
35 |
+
This class is used to type-hint functions that return arrays,
|
36 |
+
ensuring consistent return types across the codebase.
|
37 |
+
|
38 |
+
Attributes:
|
39 |
+
array (str): A syntactically valid Python array string
|
40 |
+
"""
|
41 |
+
array: Annotated[str, "Syntactically valid python array."]
|
42 |
+
|
43 |
+
async def detect_year_with_openai(sentence: str) -> str:
|
44 |
+
"""
|
45 |
+
Detects years in a sentence using OpenAI's API via LangChain.
|
46 |
+
"""
|
47 |
+
llm = get_llm()
|
48 |
+
|
49 |
+
prompt = """
|
50 |
+
Extract all years mentioned in the following sentence.
|
51 |
+
Return the result as a Python list. If no year are mentioned, return an empty list.
|
52 |
+
|
53 |
+
Sentence: "{sentence}"
|
54 |
+
"""
|
55 |
+
|
56 |
+
prompt = ChatPromptTemplate.from_template(prompt)
|
57 |
+
structured_llm = llm.with_structured_output(ArrayOutput)
|
58 |
+
chain = prompt | structured_llm
|
59 |
+
response: ArrayOutput = await chain.ainvoke({"sentence": sentence})
|
60 |
+
years_list = eval(response['array'])
|
61 |
+
if len(years_list) > 0:
|
62 |
+
return years_list[0]
|
63 |
+
else:
|
64 |
+
return ""
|
65 |
+
|
66 |
+
|
67 |
+
def detectTable(sql_query: str) -> list[str]:
|
68 |
+
"""Extracts table names from a SQL query.
|
69 |
+
|
70 |
+
This function uses regular expressions to find all table names
|
71 |
+
referenced in a SQL query's FROM clause.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
sql_query (str): The SQL query to analyze
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
list[str]: A list of table names found in the query
|
78 |
+
|
79 |
+
Example:
|
80 |
+
>>> detectTable("SELECT * FROM temperature_data WHERE year > 2000")
|
81 |
+
['temperature_data']
|
82 |
+
"""
|
83 |
+
pattern = r'(?i)\bFROM\s+((?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+)(?:\.(?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+))*)'
|
84 |
+
matches = re.findall(pattern, sql_query)
|
85 |
+
return matches
|
86 |
+
|
87 |
+
|
88 |
+
def loc2coords(location: str) -> tuple[float, float]:
|
89 |
+
"""Converts a location name to geographic coordinates.
|
90 |
+
|
91 |
+
This function uses the Nominatim geocoding service to convert
|
92 |
+
a location name (e.g., city name) to its latitude and longitude.
|
93 |
+
|
94 |
+
Args:
|
95 |
+
location (str): The name of the location to geocode
|
96 |
+
|
97 |
+
Returns:
|
98 |
+
tuple[float, float]: A tuple containing (latitude, longitude)
|
99 |
+
|
100 |
+
Raises:
|
101 |
+
AttributeError: If the location cannot be found
|
102 |
+
"""
|
103 |
+
geolocator = Nominatim(user_agent="city_to_latlong")
|
104 |
+
coords = geolocator.geocode(location)
|
105 |
+
return (coords.latitude, coords.longitude)
|
106 |
+
|
107 |
+
|
108 |
+
def coords2loc(coords: tuple[float, float]) -> str:
|
109 |
+
"""Converts geographic coordinates to a location name.
|
110 |
+
|
111 |
+
This function uses the Nominatim reverse geocoding service to convert
|
112 |
+
latitude and longitude coordinates to a human-readable location name.
|
113 |
+
|
114 |
+
Args:
|
115 |
+
coords (tuple[float, float]): A tuple containing (latitude, longitude)
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
str: The address of the location, or "Unknown Location" if not found
|
119 |
+
|
120 |
+
Example:
|
121 |
+
>>> coords2loc((48.8566, 2.3522))
|
122 |
+
'Paris, France'
|
123 |
+
"""
|
124 |
+
geolocator = Nominatim(user_agent="coords_to_city")
|
125 |
+
try:
|
126 |
+
location = geolocator.reverse(coords)
|
127 |
+
return location.address
|
128 |
+
except Exception as e:
|
129 |
+
print(f"Error: {e}")
|
130 |
+
return "Unknown Location"
|
131 |
+
|
132 |
+
|
133 |
+
def nearestNeighbourSQL(location: tuple, table: str) -> tuple[str, str]:
|
134 |
+
long = round(location[1], 3)
|
135 |
+
lat = round(location[0], 3)
|
136 |
+
|
137 |
+
table = f"'hf://datasets/timeki/drias_db/{table.lower()}.parquet'"
|
138 |
+
|
139 |
+
results = duckdb.sql(
|
140 |
+
f"SELECT latitude, longitude FROM {table} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}"
|
141 |
+
).fetchdf()
|
142 |
+
|
143 |
+
if len(results) == 0:
|
144 |
+
return "", ""
|
145 |
+
# cursor.execute(f"SELECT latitude, longitude FROM {table} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}")
|
146 |
+
return results['latitude'].iloc[0], results['longitude'].iloc[0]
|
147 |
+
|
148 |
+
|
149 |
+
async def detect_relevant_tables(user_question: str, plot: Plot, llm) -> list[str]:
|
150 |
+
"""Identifies relevant tables for a plot based on user input.
|
151 |
+
|
152 |
+
This function uses an LLM to analyze the user's question and the plot
|
153 |
+
description to determine which tables in the DRIAS database would be
|
154 |
+
most relevant for generating the requested visualization.
|
155 |
+
|
156 |
+
Args:
|
157 |
+
user_question (str): The user's question about climate data
|
158 |
+
plot (Plot): The plot configuration object
|
159 |
+
llm: The language model instance to use for analysis
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
list[str]: A list of table names that are relevant for the plot
|
163 |
+
|
164 |
+
Example:
|
165 |
+
>>> detect_relevant_tables(
|
166 |
+
... "What will the temperature be like in Paris?",
|
167 |
+
... indicator_evolution_at_location,
|
168 |
+
... llm
|
169 |
+
... )
|
170 |
+
['mean_annual_temperature', 'mean_summer_temperature']
|
171 |
+
"""
|
172 |
+
# Get all table names
|
173 |
+
table_names_list = DRIAS_TABLES
|
174 |
+
|
175 |
+
prompt = (
|
176 |
+
f"You are helping to build a plot following this description : {plot['description']}."
|
177 |
+
f"You are given a list of tables and a user question."
|
178 |
+
f"Based on the description of the plot, which table are appropriate for that kind of plot."
|
179 |
+
f"Write the 3 most relevant tables to use. Answer only a python list of table name."
|
180 |
+
f"### List of tables : {table_names_list}"
|
181 |
+
f"### User question : {user_question}"
|
182 |
+
f"### List of table name : "
|
183 |
+
)
|
184 |
+
|
185 |
+
table_names = ast.literal_eval(
|
186 |
+
(await llm.ainvoke(prompt)).content.strip("```python\n").strip()
|
187 |
+
)
|
188 |
+
return table_names
|
189 |
+
|
190 |
+
|
191 |
+
def replace_coordonates(coords, query, coords_tables):
|
192 |
+
n = query.count(str(coords[0]))
|
193 |
+
|
194 |
+
for i in range(n):
|
195 |
+
query = query.replace(str(coords[0]), str(coords_tables[i][0]), 1)
|
196 |
+
query = query.replace(str(coords[1]), str(coords_tables[i][1]), 1)
|
197 |
+
return query
|
198 |
+
|
199 |
+
|
200 |
+
async def detect_relevant_plots(user_question: str, llm):
|
201 |
+
plots_description = ""
|
202 |
+
for plot in PLOTS:
|
203 |
+
plots_description += "Name: " + plot["name"]
|
204 |
+
plots_description += " - Description: " + plot["description"] + "\n"
|
205 |
+
|
206 |
+
prompt = (
|
207 |
+
f"You are helping to answer a quesiton with insightful visualizations."
|
208 |
+
f"You are given an user question and a list of plots with their name and description."
|
209 |
+
f"Based on the descriptions of the plots, which plot is appropriate to answer to this question."
|
210 |
+
f"Write the most relevant tables to use. Answer only a python list of plot name."
|
211 |
+
f"### Descriptions of the plots : {plots_description}"
|
212 |
+
f"### User question : {user_question}"
|
213 |
+
f"### Name of the plot : "
|
214 |
+
)
|
215 |
+
# prompt = (
|
216 |
+
# f"You are helping to answer a question with insightful visualizations. "
|
217 |
+
# f"Given a list of plots with their name and description: "
|
218 |
+
# f"{plots_description} "
|
219 |
+
# f"The user question is: {user_question}. "
|
220 |
+
# f"Choose the most relevant plots to answer the question. "
|
221 |
+
# f"The answer must be a Python list with the names of the relevant plots, and nothing else. "
|
222 |
+
# f"Ensure the response is in the exact format: ['PlotName1', 'PlotName2']."
|
223 |
+
# )
|
224 |
+
|
225 |
+
plot_names = ast.literal_eval(
|
226 |
+
(await llm.ainvoke(prompt)).content.strip("```python\n").strip()
|
227 |
+
)
|
228 |
+
return plot_names
|
229 |
+
|
230 |
+
|
231 |
+
# Next Version
|
232 |
+
# class QueryOutput(TypedDict):
|
233 |
+
# """Generated SQL query."""
|
234 |
+
|
235 |
+
# query: Annotated[str, ..., "Syntactically valid SQL query."]
|
236 |
+
|
237 |
+
|
238 |
+
# class PlotlyCodeOutput(TypedDict):
|
239 |
+
# """Generated Plotly code"""
|
240 |
+
|
241 |
+
# code: Annotated[str, ..., "Synatically valid Plotly python code."]
|
242 |
+
# def write_sql_query(user_input: str, db: SQLDatabase, relevant_tables: list[str], llm):
|
243 |
+
# """Generate SQL query to fetch information."""
|
244 |
+
# prompt_params = {
|
245 |
+
# "dialect": db.dialect,
|
246 |
+
# "table_info": db.get_table_info(),
|
247 |
+
# "input": user_input,
|
248 |
+
# "relevant_tables": relevant_tables,
|
249 |
+
# "model": "ALADIN63_CNRM-CM5",
|
250 |
+
# }
|
251 |
+
|
252 |
+
# prompt = ChatPromptTemplate.from_template(query_prompt_template)
|
253 |
+
# structured_llm = llm.with_structured_output(QueryOutput)
|
254 |
+
# chain = prompt | structured_llm
|
255 |
+
# result = chain.invoke(prompt_params)
|
256 |
+
|
257 |
+
# return result["query"]
|
258 |
+
|
259 |
+
|
260 |
+
# def fetch_data_from_sql_query(db: str, sql_query: str):
|
261 |
+
# conn = sqlite3.connect(db)
|
262 |
+
# cursor = conn.cursor()
|
263 |
+
# cursor.execute(sql_query)
|
264 |
+
# column_names = [desc[0] for desc in cursor.description]
|
265 |
+
# values = cursor.fetchall()
|
266 |
+
# return {"column_names": column_names, "data": values}
|
267 |
+
|
268 |
+
|
269 |
+
# def generate_chart_code(user_input: str, sql_query: list[str], llm):
|
270 |
+
# """ "Generate plotly python code for the chart based on the sql query and the user question"""
|
271 |
+
|
272 |
+
# class PlotlyCodeOutput(TypedDict):
|
273 |
+
# """Generated Plotly code"""
|
274 |
+
|
275 |
+
# code: Annotated[str, ..., "Synatically valid Plotly python code."]
|
276 |
+
|
277 |
+
# prompt = ChatPromptTemplate.from_template(plot_prompt_template)
|
278 |
+
# structured_llm = llm.with_structured_output(PlotlyCodeOutput)
|
279 |
+
# chain = prompt | structured_llm
|
280 |
+
# result = chain.invoke({"input": user_input, "sql_query": sql_query})
|
281 |
+
# return result["code"]
|
climateqa/engine/talk_to_data/vanna_class.py
ADDED
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from vanna.base import VannaBase
|
2 |
+
from pinecone import Pinecone
|
3 |
+
from climateqa.engine.embeddings import get_embeddings_function
|
4 |
+
import pandas as pd
|
5 |
+
import hashlib
|
6 |
+
|
7 |
+
class MyCustomVectorDB(VannaBase):
|
8 |
+
|
9 |
+
"""
|
10 |
+
VectorDB class for storing and retrieving vectors from Pinecone.
|
11 |
+
|
12 |
+
args :
|
13 |
+
config (dict) : Configuration dictionary containing the Pinecone API key and the index name :
|
14 |
+
- pc_api_key (str) : Pinecone API key
|
15 |
+
- index_name (str) : Pinecone index name
|
16 |
+
- top_k (int) : Number of top results to return (default = 2)
|
17 |
+
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self,config):
|
21 |
+
super().__init__(config = config)
|
22 |
+
try :
|
23 |
+
self.api_key = config.get('pc_api_key')
|
24 |
+
self.index_name = config.get('index_name')
|
25 |
+
except :
|
26 |
+
raise Exception("Please provide the Pinecone API key and the index name")
|
27 |
+
|
28 |
+
self.pc = Pinecone(api_key = self.api_key)
|
29 |
+
self.index = self.pc.Index(self.index_name)
|
30 |
+
self.top_k = config.get('top_k', 2)
|
31 |
+
self.embeddings = get_embeddings_function()
|
32 |
+
|
33 |
+
|
34 |
+
def check_embedding(self, id, namespace):
|
35 |
+
fetched = self.index.fetch(ids = [id], namespace = namespace)
|
36 |
+
if fetched['vectors'] == {}:
|
37 |
+
return False
|
38 |
+
return True
|
39 |
+
|
40 |
+
def generate_hash_id(self, data: str) -> str:
|
41 |
+
"""
|
42 |
+
Generate a unique hash ID for the given data.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
data (str): The input data to hash (e.g., a concatenated string of user attributes).
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
str: A unique hash ID as a hexadecimal string.
|
49 |
+
"""
|
50 |
+
|
51 |
+
data_bytes = data.encode('utf-8')
|
52 |
+
hash_object = hashlib.sha256(data_bytes)
|
53 |
+
hash_id = hash_object.hexdigest()
|
54 |
+
|
55 |
+
return hash_id
|
56 |
+
|
57 |
+
def add_ddl(self, ddl: str, **kwargs) -> str:
|
58 |
+
id = self.generate_hash_id(ddl) + '_ddl'
|
59 |
+
|
60 |
+
if self.check_embedding(id, 'ddl'):
|
61 |
+
print(f"DDL having id {id} already exists")
|
62 |
+
return id
|
63 |
+
|
64 |
+
self.index.upsert(
|
65 |
+
vectors = [(id, self.embeddings.embed_query(ddl), {'ddl': ddl})],
|
66 |
+
namespace = 'ddl'
|
67 |
+
)
|
68 |
+
|
69 |
+
return id
|
70 |
+
|
71 |
+
def add_documentation(self, doc: str, **kwargs) -> str:
|
72 |
+
id = self.generate_hash_id(doc) + '_doc'
|
73 |
+
|
74 |
+
if self.check_embedding(id, 'documentation'):
|
75 |
+
print(f"Documentation having id {id} already exists")
|
76 |
+
return id
|
77 |
+
|
78 |
+
self.index.upsert(
|
79 |
+
vectors = [(id, self.embeddings.embed_query(doc), {'doc': doc})],
|
80 |
+
namespace = 'documentation'
|
81 |
+
)
|
82 |
+
|
83 |
+
return id
|
84 |
+
|
85 |
+
def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
|
86 |
+
id = self.generate_hash_id(question) + '_sql'
|
87 |
+
|
88 |
+
if self.check_embedding(id, 'question_sql'):
|
89 |
+
print(f"Question-SQL pair having id {id} already exists")
|
90 |
+
return id
|
91 |
+
|
92 |
+
self.index.upsert(
|
93 |
+
vectors = [(id, self.embeddings.embed_query(question + sql), {'question': question, 'sql': sql})],
|
94 |
+
namespace = 'question_sql'
|
95 |
+
)
|
96 |
+
|
97 |
+
return id
|
98 |
+
|
99 |
+
def get_related_ddl(self, question: str, **kwargs) -> list:
|
100 |
+
res = self.index.query(
|
101 |
+
vector=self.embeddings.embed_query(question),
|
102 |
+
top_k=self.top_k,
|
103 |
+
namespace='ddl',
|
104 |
+
include_metadata=True
|
105 |
+
)
|
106 |
+
|
107 |
+
return [match['metadata']['ddl'] for match in res['matches']]
|
108 |
+
|
109 |
+
def get_related_documentation(self, question: str, **kwargs) -> list:
|
110 |
+
res = self.index.query(
|
111 |
+
vector=self.embeddings.embed_query(question),
|
112 |
+
top_k=self.top_k,
|
113 |
+
namespace='documentation',
|
114 |
+
include_metadata=True
|
115 |
+
)
|
116 |
+
|
117 |
+
return [match['metadata']['doc'] for match in res['matches']]
|
118 |
+
|
119 |
+
def get_similar_question_sql(self, question: str, **kwargs) -> list:
|
120 |
+
res = self.index.query(
|
121 |
+
vector=self.embeddings.embed_query(question),
|
122 |
+
top_k=self.top_k,
|
123 |
+
namespace='question_sql',
|
124 |
+
include_metadata=True
|
125 |
+
)
|
126 |
+
|
127 |
+
return [(match['metadata']['question'], match['metadata']['sql']) for match in res['matches']]
|
128 |
+
|
129 |
+
def get_training_data(self, **kwargs) -> pd.DataFrame:
|
130 |
+
|
131 |
+
list_of_data = []
|
132 |
+
|
133 |
+
namespaces = ['ddl', 'documentation', 'question_sql']
|
134 |
+
|
135 |
+
for namespace in namespaces:
|
136 |
+
|
137 |
+
data = self.index.query(
|
138 |
+
top_k=10000,
|
139 |
+
namespace=namespace,
|
140 |
+
include_metadata=True,
|
141 |
+
include_values=False
|
142 |
+
)
|
143 |
+
|
144 |
+
for match in data['matches']:
|
145 |
+
list_of_data.append(match['metadata'])
|
146 |
+
|
147 |
+
return pd.DataFrame(list_of_data)
|
148 |
+
|
149 |
+
|
150 |
+
|
151 |
+
def remove_training_data(self, id: str, **kwargs) -> bool:
|
152 |
+
if id.endswith("_ddl"):
|
153 |
+
self.Index.delete(ids=[id], namespace="_ddl")
|
154 |
+
return True
|
155 |
+
if id.endswith("_sql"):
|
156 |
+
self.index.delete(ids=[id], namespace="_sql")
|
157 |
+
return True
|
158 |
+
|
159 |
+
if id.endswith("_doc"):
|
160 |
+
self.Index.delete(ids=[id], namespace="_doc")
|
161 |
+
return True
|
162 |
+
|
163 |
+
return False
|
164 |
+
|
165 |
+
def generate_embedding(self, text, **kwargs):
|
166 |
+
# Implement the method here
|
167 |
+
pass
|
168 |
+
|
169 |
+
|
170 |
+
def get_sql_prompt(
|
171 |
+
self,
|
172 |
+
initial_prompt : str,
|
173 |
+
question: str,
|
174 |
+
question_sql_list: list,
|
175 |
+
ddl_list: list,
|
176 |
+
doc_list: list,
|
177 |
+
**kwargs,
|
178 |
+
):
|
179 |
+
"""
|
180 |
+
Example:
|
181 |
+
```python
|
182 |
+
vn.get_sql_prompt(
|
183 |
+
question="What are the top 10 customers by sales?",
|
184 |
+
question_sql_list=[{"question": "What are the top 10 customers by sales?", "sql": "SELECT * FROM customers ORDER BY sales DESC LIMIT 10"}],
|
185 |
+
ddl_list=["CREATE TABLE customers (id INT, name TEXT, sales DECIMAL)"],
|
186 |
+
doc_list=["The customers table contains information about customers and their sales."],
|
187 |
+
)
|
188 |
+
|
189 |
+
```
|
190 |
+
|
191 |
+
This method is used to generate a prompt for the LLM to generate SQL.
|
192 |
+
|
193 |
+
Args:
|
194 |
+
question (str): The question to generate SQL for.
|
195 |
+
question_sql_list (list): A list of questions and their corresponding SQL statements.
|
196 |
+
ddl_list (list): A list of DDL statements.
|
197 |
+
doc_list (list): A list of documentation.
|
198 |
+
|
199 |
+
Returns:
|
200 |
+
any: The prompt for the LLM to generate SQL.
|
201 |
+
"""
|
202 |
+
|
203 |
+
if initial_prompt is None:
|
204 |
+
initial_prompt = f"You are a {self.dialect} expert. " + \
|
205 |
+
"Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. "
|
206 |
+
|
207 |
+
initial_prompt = self.add_ddl_to_prompt(
|
208 |
+
initial_prompt, ddl_list, max_tokens=self.max_tokens
|
209 |
+
)
|
210 |
+
|
211 |
+
if self.static_documentation != "":
|
212 |
+
doc_list.append(self.static_documentation)
|
213 |
+
|
214 |
+
initial_prompt = self.add_documentation_to_prompt(
|
215 |
+
initial_prompt, doc_list, max_tokens=self.max_tokens
|
216 |
+
)
|
217 |
+
|
218 |
+
# initial_prompt = self.add_sql_to_prompt(
|
219 |
+
# initial_prompt, question_sql_list, max_tokens=self.max_tokens
|
220 |
+
# )
|
221 |
+
|
222 |
+
|
223 |
+
initial_prompt += (
|
224 |
+
"===Response Guidelines \n"
|
225 |
+
"1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n"
|
226 |
+
"2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql \n"
|
227 |
+
"3. If the provided context is insufficient, please give a sql query based on your knowledge and the context provided. \n"
|
228 |
+
"4. Please use the most relevant table(s). \n"
|
229 |
+
"5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n"
|
230 |
+
f"6. Ensure that the output SQL is {self.dialect}-compliant and executable, and free of syntax errors. \n"
|
231 |
+
f"7. Add a description of the table in the result of the sql query, if relevant. \n"
|
232 |
+
"8 Make sure to include the relevant KPI in the SQL query. The query should return impactfull data \n"
|
233 |
+
# f"8. If a set of latitude,longitude is provided, make a intermediate query to find the nearest value in the table and replace the coordinates in the sql query. \n"
|
234 |
+
# "7. Add a description of the table in the result of the sql query."
|
235 |
+
# "7. If the question is about a specific latitude, longitude, query an interval of 0.3 and keep only the first set of coordinate. \n"
|
236 |
+
# "7. Table names should be included in the result of the sql query. Use for example Mean_winter_temperature AS table_name in the query \n"
|
237 |
+
)
|
238 |
+
|
239 |
+
|
240 |
+
message_log = [self.system_message(initial_prompt)]
|
241 |
+
|
242 |
+
for example in question_sql_list:
|
243 |
+
if example is None:
|
244 |
+
print("example is None")
|
245 |
+
else:
|
246 |
+
if example is not None and "question" in example and "sql" in example:
|
247 |
+
message_log.append(self.user_message(example["question"]))
|
248 |
+
message_log.append(self.assistant_message(example["sql"]))
|
249 |
+
|
250 |
+
message_log.append(self.user_message(question))
|
251 |
+
|
252 |
+
return message_log
|
253 |
+
|
254 |
+
|
255 |
+
# def get_sql_prompt(
|
256 |
+
# self,
|
257 |
+
# initial_prompt : str,
|
258 |
+
# question: str,
|
259 |
+
# question_sql_list: list,
|
260 |
+
# ddl_list: list,
|
261 |
+
# doc_list: list,
|
262 |
+
# **kwargs,
|
263 |
+
# ):
|
264 |
+
# """
|
265 |
+
# Example:
|
266 |
+
# ```python
|
267 |
+
# vn.get_sql_prompt(
|
268 |
+
# question="What are the top 10 customers by sales?",
|
269 |
+
# question_sql_list=[{"question": "What are the top 10 customers by sales?", "sql": "SELECT * FROM customers ORDER BY sales DESC LIMIT 10"}],
|
270 |
+
# ddl_list=["CREATE TABLE customers (id INT, name TEXT, sales DECIMAL)"],
|
271 |
+
# doc_list=["The customers table contains information about customers and their sales."],
|
272 |
+
# )
|
273 |
+
|
274 |
+
# ```
|
275 |
+
|
276 |
+
# This method is used to generate a prompt for the LLM to generate SQL.
|
277 |
+
|
278 |
+
# Args:
|
279 |
+
# question (str): The question to generate SQL for.
|
280 |
+
# question_sql_list (list): A list of questions and their corresponding SQL statements.
|
281 |
+
# ddl_list (list): A list of DDL statements.
|
282 |
+
# doc_list (list): A list of documentation.
|
283 |
+
|
284 |
+
# Returns:
|
285 |
+
# any: The prompt for the LLM to generate SQL.
|
286 |
+
# """
|
287 |
+
|
288 |
+
# if initial_prompt is None:
|
289 |
+
# initial_prompt = f"You are a {self.dialect} expert. " + \
|
290 |
+
# "Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. "
|
291 |
+
|
292 |
+
# initial_prompt = self.add_ddl_to_prompt(
|
293 |
+
# initial_prompt, ddl_list, max_tokens=self.max_tokens
|
294 |
+
# )
|
295 |
+
|
296 |
+
# if self.static_documentation != "":
|
297 |
+
# doc_list.append(self.static_documentation)
|
298 |
+
|
299 |
+
# initial_prompt = self.add_documentation_to_prompt(
|
300 |
+
# initial_prompt, doc_list, max_tokens=self.max_tokens
|
301 |
+
# )
|
302 |
+
|
303 |
+
# initial_prompt += (
|
304 |
+
# "===Response Guidelines \n"
|
305 |
+
# "1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n"
|
306 |
+
# "2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql \n"
|
307 |
+
# "3. If the provided context is insufficient, please explain why it can't be generated. \n"
|
308 |
+
# "4. Please use the most relevant table(s). \n"
|
309 |
+
# "5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n"
|
310 |
+
# f"6. Ensure that the output SQL is {self.dialect}-compliant and executable, and free of syntax errors. \n"
|
311 |
+
# )
|
312 |
+
|
313 |
+
# message_log = [self.system_message(initial_prompt)]
|
314 |
+
|
315 |
+
# for example in question_sql_list:
|
316 |
+
# if example is None:
|
317 |
+
# print("example is None")
|
318 |
+
# else:
|
319 |
+
# if example is not None and "question" in example and "sql" in example:
|
320 |
+
# message_log.append(self.user_message(example["question"]))
|
321 |
+
# message_log.append(self.assistant_message(example["sql"]))
|
322 |
+
|
323 |
+
# message_log.append(self.user_message(question))
|
324 |
+
|
325 |
+
# return message_log
|
requirements.txt
CHANGED
@@ -26,4 +26,4 @@ duckdb==1.2.1
|
|
26 |
openai==1.61.1
|
27 |
pydantic==2.9.2
|
28 |
pydantic-settings==2.2.1
|
29 |
-
geojson==3.2.0
|
|
|
26 |
openai==1.61.1
|
27 |
pydantic==2.9.2
|
28 |
pydantic-settings==2.2.1
|
29 |
+
geojson==3.2.0
|
style.css
CHANGED
@@ -656,11 +656,20 @@ a {
|
|
656 |
/* overflow-y: scroll; */
|
657 |
}
|
658 |
#sql-query{
|
|
|
659 |
max-height: 100%;
|
660 |
}
|
661 |
|
662 |
#sql-query textarea{
|
663 |
min-height: 200px !important;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
664 |
}
|
665 |
|
666 |
#sql-query span{
|
@@ -741,4 +750,4 @@ div#tab-vanna{
|
|
741 |
#example-img-container {
|
742 |
flex-direction: column;
|
743 |
align-items: left;
|
744 |
-
}
|
|
|
656 |
/* overflow-y: scroll; */
|
657 |
}
|
658 |
#sql-query{
|
659 |
+
<<<<<<< HEAD
|
660 |
max-height: 100%;
|
661 |
}
|
662 |
|
663 |
#sql-query textarea{
|
664 |
min-height: 200px !important;
|
665 |
+
=======
|
666 |
+
max-height: 300px;
|
667 |
+
overflow-y:scroll;
|
668 |
+
}
|
669 |
+
|
670 |
+
#sql-query textarea{
|
671 |
+
min-height: 100px !important;
|
672 |
+
>>>>>>> hf-origin/main
|
673 |
}
|
674 |
|
675 |
#sql-query span{
|
|
|
750 |
#example-img-container {
|
751 |
flex-direction: column;
|
752 |
align-items: left;
|
753 |
+
}
|