Merged in feature/talk_to_ipcc (pull request #17)
Browse files- app.py +3 -2
- climateqa/engine/talk_to_data/ipcc/config.py +83 -0
- climateqa/engine/talk_to_data/ipcc/plots.py +193 -0
- climateqa/engine/talk_to_data/ipcc/queries.py +116 -0
- climateqa/engine/talk_to_data/main.py +77 -17
- climateqa/engine/talk_to_data/objects/states.py +7 -35
- climateqa/engine/talk_to_data/workflow/ipcc.py +157 -0
- front/tabs/tab_ipcc.py +289 -0
app.py
CHANGED
@@ -16,6 +16,7 @@ from climateqa.chat import start_chat, chat_stream, finish_chat
|
|
16 |
from front.tabs import create_config_modal, cqa_tab, create_about_tab
|
17 |
from front.tabs import MainTabPanel, ConfigPanel
|
18 |
from front.tabs.tab_drias import create_drias_tab
|
|
|
19 |
from front.utils import process_figures
|
20 |
from gradio_modal import Modal
|
21 |
|
@@ -532,8 +533,8 @@ def main_ui():
|
|
532 |
with gr.Tabs():
|
533 |
cqa_components = cqa_tab(tab_name="ClimateQ&A")
|
534 |
local_cqa_components = cqa_tab(tab_name="France - Local Q&A")
|
535 |
-
create_drias_tab(share_client=share_client, user_id=user_id)
|
536 |
-
|
537 |
create_about_tab()
|
538 |
|
539 |
event_handling(cqa_components, config_components, tab_name="ClimateQ&A")
|
|
|
16 |
from front.tabs import create_config_modal, cqa_tab, create_about_tab
|
17 |
from front.tabs import MainTabPanel, ConfigPanel
|
18 |
from front.tabs.tab_drias import create_drias_tab
|
19 |
+
from front.tabs.tab_ipcc import create_ipcc_tab
|
20 |
from front.utils import process_figures
|
21 |
from gradio_modal import Modal
|
22 |
|
|
|
533 |
with gr.Tabs():
|
534 |
cqa_components = cqa_tab(tab_name="ClimateQ&A")
|
535 |
local_cqa_components = cqa_tab(tab_name="France - Local Q&A")
|
536 |
+
drias_components = create_drias_tab(share_client=share_client, user_id=user_id)
|
537 |
+
ipcc_components = create_ipcc_tab(share_client=share_client, user_id=user_id)
|
538 |
create_about_tab()
|
539 |
|
540 |
event_handling(cqa_components, config_components, tab_name="ClimateQ&A")
|
climateqa/engine/talk_to_data/ipcc/config.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
IPCC_DATASET_URL = "hf://datasets/ekimetrics/ipcc-atlas"
|
2 |
+
IPCC_TABLES = [
|
3 |
+
"mean_temperature",
|
4 |
+
"total_precipitation",
|
5 |
+
]
|
6 |
+
|
7 |
+
IPCC_INDICATOR_COLUMNS_PER_TABLE = {
|
8 |
+
"mean_temperature": "mean_temperature",
|
9 |
+
"total_precipitation": "total_precipitation"
|
10 |
+
}
|
11 |
+
|
12 |
+
IPCC_INDICATOR_TO_UNIT = {
|
13 |
+
"mean_temperature": "°C",
|
14 |
+
"total_precipitation": "mm/day"
|
15 |
+
}
|
16 |
+
|
17 |
+
IPCC_SCENARIO = [
|
18 |
+
"historical",
|
19 |
+
"ssp126",
|
20 |
+
"ssp245",
|
21 |
+
"ssp370",
|
22 |
+
"ssp585",
|
23 |
+
]
|
24 |
+
|
25 |
+
IPCC_MODELS = []
|
26 |
+
|
27 |
+
IPCC_PLOT_PARAMETERS = [
|
28 |
+
'year',
|
29 |
+
'location'
|
30 |
+
]
|
31 |
+
|
32 |
+
MACRO_COUNTRIES = ['JP',
|
33 |
+
'IN',
|
34 |
+
'MH',
|
35 |
+
'PT',
|
36 |
+
'ID',
|
37 |
+
'SJ',
|
38 |
+
'MX',
|
39 |
+
'CN',
|
40 |
+
'GL',
|
41 |
+
'PN',
|
42 |
+
'AR',
|
43 |
+
'AQ',
|
44 |
+
'PF',
|
45 |
+
'BR',
|
46 |
+
'SH',
|
47 |
+
'GS',
|
48 |
+
'ZA',
|
49 |
+
'NZ',
|
50 |
+
'TF',
|
51 |
+
'CL',
|
52 |
+
'CA',
|
53 |
+
'AU',
|
54 |
+
'US',
|
55 |
+
'RU'
|
56 |
+
]
|
57 |
+
|
58 |
+
IPCC_UI_TEXT = """
|
59 |
+
Hi, I'm **Talk to IPCC**, designed to answer your questions using [**IPCC - ATLAS**](https://interactive-atlas.ipcc.ch/regional-information#eyJ0eXBlIjoiQVRMQVMiLCJjb21tb25zIjp7ImxhdCI6OTc3MiwibG5nIjo0MDA2OTIsInpvb20iOjQsInByb2oiOiJFUFNHOjU0MDMwIiwibW9kZSI6ImNvbXBsZXRlX2F0bGFzIn0sInByaW1hcnkiOnsic2NlbmFyaW8iOiJzc3A1ODUiLCJwZXJpb2QiOiIyIiwic2Vhc29uIjoieWVhciIsImRhdGFzZXQiOiJDTUlQNiIsInZhcmlhYmxlIjoidGFzIiwidmFsdWVUeXBlIjoiQU5PTUFMWSIsImhhdGNoaW5nIjoiU0lNUExFIiwicmVnaW9uU2V0IjoiYXI2IiwiYmFzZWxpbmUiOiJwcmVJbmR1c3RyaWFsIiwicmVnaW9uc1NlbGVjdGVkIjpbXX0sInBsb3QiOnsiYWN0aXZlVGFiIjoicGx1bWUiLCJtYXNrIjoibm9uZSIsInNjYXR0ZXJZTWFnIjpudWxsLCJzY2F0dGVyWVZhciI6bnVsbCwic2hvd2luZyI6ZmFsc2V9fQ==) data.
|
60 |
+
I'll answer by displaying a list of SQL queries, graphs and data most relevant to your question.
|
61 |
+
|
62 |
+
You can ask me anything about these climate indicators: **temperature** or **precipitation**.
|
63 |
+
You can specify **location** and/or **year**.
|
64 |
+
By default, we take the **mediane of each climate model**.
|
65 |
+
|
66 |
+
For example, you can ask:
|
67 |
+
- What will the temperature be like in Paris?
|
68 |
+
- What will be the total rainfall in the USA in 2030?
|
69 |
+
- How will the average temperature evolve in China ?
|
70 |
+
|
71 |
+
**Example of indicators in the data**:
|
72 |
+
- Mean temperature
|
73 |
+
- Total precipitation
|
74 |
+
|
75 |
+
⚠️ **Limitations**:
|
76 |
+
- You can't ask anything that isn't related to *IPCC - ATLAS** data.
|
77 |
+
- You can not ask about **several locations at the same time**.
|
78 |
+
- If you specify a year **before 1850 or over 2100**, there will be **no data**.
|
79 |
+
- You **cannot compare two models**.
|
80 |
+
|
81 |
+
🛈 **Information**
|
82 |
+
Please note that we **log your questions for meta-analysis purposes**, so avoid sharing any sensitive or personal information.
|
83 |
+
"""
|
climateqa/engine/talk_to_data/ipcc/plots.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable
|
2 |
+
from plotly.graph_objects import Figure
|
3 |
+
import plotly.graph_objects as go
|
4 |
+
import pandas as pd
|
5 |
+
import geojson
|
6 |
+
|
7 |
+
from climateqa.engine.talk_to_data.ipcc.config import IPCC_INDICATOR_TO_UNIT, IPCC_SCENARIO
|
8 |
+
from climateqa.engine.talk_to_data.ipcc.queries import indicator_for_given_year_query, indicator_per_year_at_location_query
|
9 |
+
from climateqa.engine.talk_to_data.objects.plot import Plot
|
10 |
+
|
11 |
+
def plot_indicator_evolution_at_location_historical_and_projections(
|
12 |
+
params: dict,
|
13 |
+
) -> Callable[[pd.DataFrame], Figure]:
|
14 |
+
"""
|
15 |
+
Returns a function that generates a line plot showing the evolution of a climate indicator
|
16 |
+
(e.g., temperature, rainfall) over time at a specific location, including both historical data
|
17 |
+
and future projections for different climate scenarios.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
params (dict): Dictionary with:
|
21 |
+
- indicator_column (str): Name of the climate indicator column to plot.
|
22 |
+
- location (str): Location (e.g., country, city) for which to plot the indicator.
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
Callable[[pd.DataFrame], Figure]: Function that takes a DataFrame and returns a Plotly Figure
|
26 |
+
showing the indicator's evolution over time, with scenario lines and historical data.
|
27 |
+
"""
|
28 |
+
indicator = params["indicator_column"]
|
29 |
+
location = params["location"]
|
30 |
+
indicator_label = " ".join(word.capitalize() for word in indicator.split("_"))
|
31 |
+
unit = IPCC_INDICATOR_TO_UNIT.get(indicator, "")
|
32 |
+
|
33 |
+
def plot_data(df: pd.DataFrame) -> Figure:
|
34 |
+
df = df.sort_values(by='year')
|
35 |
+
years = df['year'].astype(int).tolist()
|
36 |
+
indicators = df[indicator].astype(float).tolist()
|
37 |
+
scenarios = df['scenario'].astype(str).tolist()
|
38 |
+
|
39 |
+
# Find last historical value for continuity
|
40 |
+
last_historical = [(y, v) for y, v, s in zip(years, indicators, scenarios) if s == 'historical']
|
41 |
+
last_historical_year, last_historical_indicator = last_historical[-1] if last_historical else (None, None)
|
42 |
+
|
43 |
+
fig = go.Figure()
|
44 |
+
for scenario in IPCC_SCENARIO:
|
45 |
+
x = [y for y, s in zip(years, scenarios) if s == scenario]
|
46 |
+
y = [v for v, s in zip(indicators, scenarios) if s == scenario]
|
47 |
+
# Connect historical to scenario
|
48 |
+
if scenario != 'historical' and last_historical_indicator is not None:
|
49 |
+
x = [last_historical_year] + x
|
50 |
+
y = [last_historical_indicator] + y
|
51 |
+
fig.add_trace(go.Scatter(
|
52 |
+
x=x,
|
53 |
+
y=y,
|
54 |
+
mode='lines',
|
55 |
+
name=scenario
|
56 |
+
))
|
57 |
+
|
58 |
+
fig.update_layout(
|
59 |
+
title=f'Yearly Evolution of {indicator_label} in {location} (Historical + SSP Scenarios)',
|
60 |
+
xaxis_title='Year',
|
61 |
+
yaxis_title=f'{indicator_label} ({unit})',
|
62 |
+
legend_title='Scenario',
|
63 |
+
height=800,
|
64 |
+
)
|
65 |
+
return fig
|
66 |
+
|
67 |
+
return plot_data
|
68 |
+
|
69 |
+
indicator_evolution_at_location_historical_and_projections: Plot = {
|
70 |
+
"name": "Indicator Evolution at Location (Historical + Projections)",
|
71 |
+
"description": (
|
72 |
+
"Shows how a climate indicator (e.g., rainfall, temperature) changes over time at a specific location, "
|
73 |
+
"including historical data and future projections. "
|
74 |
+
"Useful for questions about the value or trend of an indicator at a location for any year, "
|
75 |
+
"such as 'What will be the total rainfall in China in 2050?' or 'How does rainfall evolve in China over time?'. "
|
76 |
+
"Parameters: indicator_column (the climate variable), location (e.g., country, city)."
|
77 |
+
),
|
78 |
+
"params": ["indicator_column", "location"],
|
79 |
+
"plot_function": plot_indicator_evolution_at_location_historical_and_projections,
|
80 |
+
"sql_query": indicator_per_year_at_location_query,
|
81 |
+
"short_name": "Indicator Evolution"
|
82 |
+
}
|
83 |
+
|
84 |
+
def plot_choropleth_map_of_country_indicator_for_specific_year(
|
85 |
+
params: dict,
|
86 |
+
) -> Callable[[pd.DataFrame], Figure]:
|
87 |
+
"""
|
88 |
+
Returns a function that generates a choropleth map (heatmap) showing the spatial distribution
|
89 |
+
of a climate indicator (e.g., temperature, rainfall) across all regions of a country for a specific year.
|
90 |
+
|
91 |
+
Args:
|
92 |
+
params (dict): Dictionary with:
|
93 |
+
- indicator_column (str): Name of the climate indicator column to plot.
|
94 |
+
- year (str or int, optional): Year for which to plot the indicator (default: 2050).
|
95 |
+
- country_name (str): Name of the country.
|
96 |
+
- location (str): Location (country or region) for the map.
|
97 |
+
|
98 |
+
Returns:
|
99 |
+
Callable[[pd.DataFrame], Figure]: Function that takes a DataFrame and returns a Plotly Figure
|
100 |
+
showing the indicator's spatial distribution as a choropleth map for the specified year.
|
101 |
+
"""
|
102 |
+
indicator = params["indicator_column"]
|
103 |
+
year = params.get('year', 2050)
|
104 |
+
country_name = params['country_name']
|
105 |
+
location = params['location']
|
106 |
+
indicator_label = " ".join(word.capitalize() for word in indicator.split("_"))
|
107 |
+
unit = IPCC_INDICATOR_TO_UNIT.get(indicator, "")
|
108 |
+
|
109 |
+
def plot_data(df: pd.DataFrame) -> Figure:
|
110 |
+
custom_colorscale = [
|
111 |
+
[0.0, "rgb(5, 48, 97)"],
|
112 |
+
[0.10, "rgb(33, 102, 172)"],
|
113 |
+
[0.20, "rgb(67, 147, 195)"],
|
114 |
+
[0.30, "rgb(146, 197, 222)"],
|
115 |
+
[0.40, "rgb(209, 229, 240)"],
|
116 |
+
[0.50, "rgb(247, 247, 247)"],
|
117 |
+
[0.60, "rgb(253, 219, 199)"],
|
118 |
+
[0.75, "rgb(244, 165, 130)"],
|
119 |
+
[0.85, "rgb(214, 96, 77)"],
|
120 |
+
[0.90, "rgb(178, 24, 43)"],
|
121 |
+
[1.0, "rgb(103, 0, 31)"]
|
122 |
+
]
|
123 |
+
|
124 |
+
indicators = df[indicator].astype(float).tolist()
|
125 |
+
latitudes = df["latitude"].astype(float).tolist()
|
126 |
+
longitudes = df["longitude"].astype(float).tolist()
|
127 |
+
|
128 |
+
features = [
|
129 |
+
geojson.Feature(
|
130 |
+
geometry=geojson.Polygon([[
|
131 |
+
[lon - 0.5, lat - 0.5],
|
132 |
+
[lon + 0.5, lat - 0.5],
|
133 |
+
[lon + 0.5, lat + 0.5],
|
134 |
+
[lon - 0.5, lat + 0.5],
|
135 |
+
[lon - 0.5, lat - 0.5]
|
136 |
+
]]),
|
137 |
+
properties={"value": val},
|
138 |
+
id=str(idx)
|
139 |
+
)
|
140 |
+
for idx, (lat, lon, val) in enumerate(zip(latitudes, longitudes, indicators))
|
141 |
+
]
|
142 |
+
|
143 |
+
geojson_data = geojson.FeatureCollection(features)
|
144 |
+
|
145 |
+
fig = go.Figure(go.Choroplethmapbox(
|
146 |
+
geojson=geojson_data,
|
147 |
+
locations=[str(i) for i in range(len(indicators))],
|
148 |
+
featureidkey="id",
|
149 |
+
z=indicators,
|
150 |
+
colorscale=custom_colorscale,
|
151 |
+
zmin=min(indicators),
|
152 |
+
zmax=max(indicators),
|
153 |
+
marker_opacity=0.7,
|
154 |
+
marker_line_width=0,
|
155 |
+
colorbar_title=f"{indicator_label} ({unit})",
|
156 |
+
text=[f"{indicator_label}: {value:.2f} {unit}" for value in indicators], # Add hover text showing the indicator value
|
157 |
+
hoverinfo="text"
|
158 |
+
))
|
159 |
+
|
160 |
+
fig.update_layout(
|
161 |
+
mapbox_style="open-street-map",
|
162 |
+
mapbox_zoom=3,
|
163 |
+
height=800,
|
164 |
+
mapbox_center={
|
165 |
+
"lat": latitudes[len(latitudes)//2] if latitudes else 0,
|
166 |
+
"lon": longitudes[len(longitudes)//2] if longitudes else 0
|
167 |
+
},
|
168 |
+
coloraxis_colorbar=dict(title=f"{indicator_label} ({unit})"),
|
169 |
+
title=f"{indicator_label} in {year} in {location} ({country_name})"
|
170 |
+
)
|
171 |
+
return fig
|
172 |
+
|
173 |
+
return plot_data
|
174 |
+
|
175 |
+
choropleth_map_of_country_indicator_for_specific_year: Plot = {
|
176 |
+
"name": "Choropleth Map of a Country's Indicator Distribution for a Specific Year",
|
177 |
+
"description": (
|
178 |
+
"Displays a map showing the spatial distribution of a climate indicator (e.g., rainfall, temperature) "
|
179 |
+
"across all regions of a country for a specific year. "
|
180 |
+
"Can answer questions about the value of an indicator in a country or region for a given year, "
|
181 |
+
"such as 'What will be the total rainfall in China in 2050?' or 'How is rainfall distributed across China in 2050?'. "
|
182 |
+
"Parameters: indicator_column (the climate variable), year, location (country name)."
|
183 |
+
),
|
184 |
+
"params": ["indicator_column", "year", "location"],
|
185 |
+
"plot_function": plot_choropleth_map_of_country_indicator_for_specific_year,
|
186 |
+
"sql_query": indicator_for_given_year_query,
|
187 |
+
"short_name": "Choropleth Map"
|
188 |
+
}
|
189 |
+
|
190 |
+
IPCC_PLOTS = [
|
191 |
+
indicator_evolution_at_location_historical_and_projections,
|
192 |
+
choropleth_map_of_country_indicator_for_specific_year
|
193 |
+
]
|
climateqa/engine/talk_to_data/ipcc/queries.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import TypedDict, Optional
|
2 |
+
|
3 |
+
from climateqa.engine.talk_to_data.ipcc.config import IPCC_DATASET_URL, MACRO_COUNTRIES
|
4 |
+
|
5 |
+
class IndicatorPerYearAtLocationQueryParams(TypedDict, total=False):
|
6 |
+
"""
|
7 |
+
Parameters for querying the evolution of an indicator per year at a specific location.
|
8 |
+
|
9 |
+
Attributes:
|
10 |
+
indicator_column (str): Name of the climate indicator column.
|
11 |
+
latitude (str): Latitude of the location.
|
12 |
+
longitude (str): Longitude of the location.
|
13 |
+
country_code (str): Country code.
|
14 |
+
admin1 (str): Administrative region (optional).
|
15 |
+
"""
|
16 |
+
indicator_column: str
|
17 |
+
latitude: str
|
18 |
+
longitude: str
|
19 |
+
country_code: str
|
20 |
+
admin1: Optional[str]
|
21 |
+
|
22 |
+
def indicator_per_year_at_location_query(
|
23 |
+
table: str, params: IndicatorPerYearAtLocationQueryParams
|
24 |
+
) -> str:
|
25 |
+
"""
|
26 |
+
Builds an SQL query to get the evolution of an indicator per year at a specific location.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
table (str): SQL table of the indicator.
|
30 |
+
params (IndicatorPerYearAtLocationQueryParams): Dictionary with the required params for the query.
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
str: The SQL query string, or an empty string if required parameters are missing.
|
34 |
+
"""
|
35 |
+
indicator_column = params.get("indicator_column")
|
36 |
+
latitude = params.get("latitude")
|
37 |
+
longitude = params.get("longitude")
|
38 |
+
country_code = params.get("country_code")
|
39 |
+
admin1 = params.get("admin1")
|
40 |
+
|
41 |
+
if not all([indicator_column, latitude, longitude, country_code]):
|
42 |
+
return ""
|
43 |
+
|
44 |
+
if country_code in MACRO_COUNTRIES:
|
45 |
+
table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}_macro.parquet'"
|
46 |
+
sql_query = f"""
|
47 |
+
SELECT year, scenario, {indicator_column}
|
48 |
+
FROM {table_path}
|
49 |
+
WHERE admin1 = '{admin1}' AND year >= 1950
|
50 |
+
ORDER BY year, scenario
|
51 |
+
"""
|
52 |
+
else:
|
53 |
+
table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}.parquet'"
|
54 |
+
sql_query = f"""
|
55 |
+
SELECT year, scenario, AVG({indicator_column}) AS {indicator_column}
|
56 |
+
FROM {table_path}
|
57 |
+
WHERE latitude = {latitude} AND longitude = {longitude} AND year >= 1950
|
58 |
+
GROUP BY scenario, year
|
59 |
+
ORDER BY year, scenario
|
60 |
+
"""
|
61 |
+
return sql_query.strip()
|
62 |
+
|
63 |
+
class IndicatorForGivenYearQueryParams(TypedDict, total=False):
|
64 |
+
"""
|
65 |
+
Parameters for querying an indicator's values across locations for a specific year.
|
66 |
+
|
67 |
+
Attributes:
|
68 |
+
indicator_column (str): The column name for the climate indicator.
|
69 |
+
year (str): The year to query.
|
70 |
+
country_code (str): The country code.
|
71 |
+
"""
|
72 |
+
indicator_column: str
|
73 |
+
year: str
|
74 |
+
country_code: str
|
75 |
+
|
76 |
+
def indicator_for_given_year_query(
|
77 |
+
table: str, params: IndicatorForGivenYearQueryParams
|
78 |
+
) -> str:
|
79 |
+
"""
|
80 |
+
Builds an SQL query to get the values of an indicator with their latitudes, longitudes,
|
81 |
+
and scenarios for a given year.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
table (str): SQL table of the indicator.
|
85 |
+
params (IndicatorForGivenYearQueryParams): Dictionary with the required params for the query.
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
str: The SQL query string, or an empty string if required parameters are missing.
|
89 |
+
"""
|
90 |
+
indicator_column = params.get("indicator_column")
|
91 |
+
year = params.get("year") or 2050
|
92 |
+
country_code = params.get("country_code")
|
93 |
+
|
94 |
+
if not all([indicator_column, year, country_code]):
|
95 |
+
return ""
|
96 |
+
|
97 |
+
if country_code in MACRO_COUNTRIES:
|
98 |
+
table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}_macro.parquet'"
|
99 |
+
sql_query = f"""
|
100 |
+
SELECT {indicator_column}, c.latitude, c.longitude, c.admin1, scenario
|
101 |
+
FROM {table_path} AS t
|
102 |
+
RIGHT JOIN '{IPCC_DATASET_URL}/coordinates.parquet' AS c
|
103 |
+
ON c.admin1 = t.admin1 AND c.country_code = t.country_code
|
104 |
+
WHERE year = {year}
|
105 |
+
ORDER BY latitude, longitude, scenario
|
106 |
+
"""
|
107 |
+
else:
|
108 |
+
table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}.parquet'"
|
109 |
+
sql_query = f"""
|
110 |
+
SELECT AVG({indicator_column}) AS {indicator_column}, latitude, longitude, scenario
|
111 |
+
FROM {table_path}
|
112 |
+
WHERE year = {year}
|
113 |
+
GROUP BY latitude, longitude, scenario
|
114 |
+
ORDER BY latitude, longitude, scenario
|
115 |
+
"""
|
116 |
+
return sql_query.strip()
|
climateqa/engine/talk_to_data/main.py
CHANGED
@@ -1,10 +1,13 @@
|
|
|
|
|
|
1 |
from climateqa.engine.talk_to_data.workflow.drias import drias_workflow
|
2 |
from climateqa.engine.llm import get_llm
|
|
|
3 |
from climateqa.logging import log_drias_interaction_to_huggingface
|
4 |
from climateqa.logging import log_drias_interaction_to_huggingface
|
5 |
import ast
|
6 |
|
7 |
-
async def ask_drias(query: str, index_state: int = 0, user_id: str = None) -> tuple:
|
8 |
"""Main function to process a DRIAS query and return results.
|
9 |
|
10 |
This function orchestrates the DRIAS workflow, processing a user query to generate
|
@@ -31,23 +34,80 @@ async def ask_drias(query: str, index_state: int = 0, user_id: str = None) -> tu
|
|
31 |
sql_queries = []
|
32 |
result_dataframes = []
|
33 |
figures = []
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
if "error" in final_state and final_state["error"] != "":
|
50 |
-
|
|
|
51 |
|
52 |
sql_query = sql_queries[index_state]
|
53 |
dataframe = result_dataframes[index_state]
|
@@ -55,4 +115,4 @@ async def ask_drias(query: str, index_state: int = 0, user_id: str = None) -> tu
|
|
55 |
|
56 |
log_drias_interaction_to_huggingface(query, sql_query, user_id)
|
57 |
|
58 |
-
return sql_query, dataframe, figure, sql_queries, result_dataframes, figures, index_state,
|
|
|
1 |
+
from operator import index
|
2 |
+
from duckdb import sql
|
3 |
from climateqa.engine.talk_to_data.workflow.drias import drias_workflow
|
4 |
from climateqa.engine.llm import get_llm
|
5 |
+
from climateqa.engine.talk_to_data.workflow.ipcc import ipcc_workflow
|
6 |
from climateqa.logging import log_drias_interaction_to_huggingface
|
7 |
from climateqa.logging import log_drias_interaction_to_huggingface
|
8 |
import ast
|
9 |
|
10 |
+
async def ask_drias(query: str, index_state: int = 0, user_id: str | None = None) -> tuple:
|
11 |
"""Main function to process a DRIAS query and return results.
|
12 |
|
13 |
This function orchestrates the DRIAS workflow, processing a user query to generate
|
|
|
34 |
sql_queries = []
|
35 |
result_dataframes = []
|
36 |
figures = []
|
37 |
+
plot_title_list = []
|
38 |
+
|
39 |
+
|
40 |
+
for output_title, output in final_state['outputs'].items():
|
41 |
+
if output['status'] == 'OK':
|
42 |
+
if output['table'] is not None:
|
43 |
+
plot_title_list.append(output_title)
|
44 |
+
if output['sql_query'] is not None:
|
45 |
+
sql_queries.append(output['sql_query'])
|
46 |
+
|
47 |
+
if output['dataframe'] is not None:
|
48 |
+
result_dataframes.append(output['dataframe'])
|
49 |
+
if output['figure'] is not None:
|
50 |
+
figures.append(output['figure'])
|
51 |
+
|
52 |
+
if "error" in final_state and final_state["error"] != "":
|
53 |
+
# No Sql query, no dataframe, no figure, empty sql queries list, empty result dataframes list, empty figures list, index state = 0, empty table list, error message
|
54 |
+
return None, None, None, [], [], [], 0, [], final_state["error"]
|
55 |
+
|
56 |
+
sql_query = sql_queries[index_state]
|
57 |
+
dataframe = result_dataframes[index_state]
|
58 |
+
figure = figures[index_state](dataframe)
|
59 |
+
|
60 |
+
|
61 |
+
log_drias_interaction_to_huggingface(query, sql_query, user_id)
|
62 |
+
|
63 |
+
return sql_query, dataframe, figure, sql_queries, result_dataframes, figures, index_state, plot_title_list, ""
|
64 |
+
|
65 |
+
|
66 |
+
async def ask_ipcc(query: str, index_state: int = 0, user_id: str | None = None) -> tuple:
|
67 |
+
"""Main function to process a DRIAS query and return results.
|
68 |
+
|
69 |
+
This function orchestrates the DRIAS workflow, processing a user query to generate
|
70 |
+
SQL queries, dataframes, and visualizations. It handles multiple results and allows
|
71 |
+
pagination through them.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
query (str): The user's question about climate data
|
75 |
+
index_state (int, optional): The index of the result to return. Defaults to 0.
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
tuple: A tuple containing:
|
79 |
+
- sql_query (str): The SQL query used
|
80 |
+
- dataframe (pd.DataFrame): The resulting data
|
81 |
+
- figure (Callable): Function to generate the visualization
|
82 |
+
- sql_queries (list): All generated SQL queries
|
83 |
+
- result_dataframes (list): All resulting dataframes
|
84 |
+
- figures (list): All figure generation functions
|
85 |
+
- index_state (int): Current result index
|
86 |
+
- table_list (list): List of table names used
|
87 |
+
- error (str): Error message if any
|
88 |
+
"""
|
89 |
+
final_state = await ipcc_workflow(query)
|
90 |
+
sql_queries = []
|
91 |
+
result_dataframes = []
|
92 |
+
figures = []
|
93 |
+
plot_title_list = []
|
94 |
+
|
95 |
+
|
96 |
+
for output_title, output in final_state['outputs'].items():
|
97 |
+
if output['status'] == 'OK':
|
98 |
+
if output['table'] is not None:
|
99 |
+
plot_title_list.append(output_title)
|
100 |
+
if output['sql_query'] is not None:
|
101 |
+
sql_queries.append(output['sql_query'])
|
102 |
+
|
103 |
+
if output['dataframe'] is not None:
|
104 |
+
result_dataframes.append(output['dataframe'])
|
105 |
+
if output['figure'] is not None:
|
106 |
+
figures.append(output['figure'])
|
107 |
|
108 |
if "error" in final_state and final_state["error"] != "":
|
109 |
+
# No Sql query, no dataframe, no figure, empty sql queries list, empty result dataframes list, empty figures list, index state = 0, empty table list, error message
|
110 |
+
return None, None, None, [], [], [], 0, [], final_state["error"]
|
111 |
|
112 |
sql_query = sql_queries[index_state]
|
113 |
dataframe = result_dataframes[index_state]
|
|
|
115 |
|
116 |
log_drias_interaction_to_huggingface(query, sql_query, user_id)
|
117 |
|
118 |
+
return sql_query, dataframe, figure, sql_queries, result_dataframes, figures, index_state, plot_title_list, ""
|
climateqa/engine/talk_to_data/objects/states.py
CHANGED
@@ -1,46 +1,18 @@
|
|
1 |
from typing import Any, Callable, Optional, TypedDict
|
2 |
from plotly.graph_objects import Figure
|
3 |
import pandas as pd
|
|
|
4 |
|
5 |
-
class
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
data processing workflow, including its name, parameters, SQL query, and results.
|
10 |
-
|
11 |
-
Attributes:
|
12 |
-
table_name (str): The name of the table in the database
|
13 |
-
params (dict[str, Any]): Parameters used for querying the table
|
14 |
-
sql_query (str, optional): The SQL query used to fetch data
|
15 |
-
dataframe (pd.DataFrame | None, optional): The resulting data
|
16 |
-
figure (Callable[..., Figure], optional): Function to generate visualization
|
17 |
-
status (str): The current status of the table processing ('OK' or 'ERROR')
|
18 |
-
"""
|
19 |
-
table_name: str
|
20 |
-
params: dict[str, Any]
|
21 |
sql_query: Optional[str]
|
22 |
-
dataframe: Optional[pd.DataFrame
|
23 |
figure: Optional[Callable[..., Figure]]
|
24 |
-
status: str
|
25 |
-
|
26 |
-
class PlotState(TypedDict):
|
27 |
-
"""Represents the state of a plot in the DRIAS workflow.
|
28 |
-
|
29 |
-
This class defines the structure for tracking the state of a plot during the
|
30 |
-
data processing workflow, including its name and associated tables.
|
31 |
-
|
32 |
-
Attributes:
|
33 |
-
plot_name (str): The name of the plot
|
34 |
-
tables (list[str]): List of tables used in the plot
|
35 |
-
table_states (dict[str, TableState]): States of the tables used in the plot
|
36 |
-
"""
|
37 |
-
plot_name: str
|
38 |
-
tables: list[str]
|
39 |
-
table_states: dict[str, TableState]
|
40 |
-
|
41 |
class State(TypedDict):
|
42 |
user_input: str
|
43 |
plots: list[str]
|
44 |
-
|
45 |
error: Optional[str]
|
46 |
|
|
|
1 |
from typing import Any, Callable, Optional, TypedDict
|
2 |
from plotly.graph_objects import Figure
|
3 |
import pandas as pd
|
4 |
+
from climateqa.engine.talk_to_data.objects.plot import Plot
|
5 |
|
6 |
+
class TTDOutput(TypedDict):
|
7 |
+
status: str
|
8 |
+
plot: Plot
|
9 |
+
table: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
sql_query: Optional[str]
|
11 |
+
dataframe: Optional[pd.DataFrame]
|
12 |
figure: Optional[Callable[..., Figure]]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
class State(TypedDict):
|
14 |
user_input: str
|
15 |
plots: list[str]
|
16 |
+
outputs: dict[str, TTDOutput]
|
17 |
error: Optional[str]
|
18 |
|
climateqa/engine/talk_to_data/workflow/ipcc.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from typing import Any
|
4 |
+
import asyncio
|
5 |
+
from climateqa.engine.llm import get_llm
|
6 |
+
from climateqa.engine.talk_to_data.input_processing import find_param, find_relevant_plots, find_relevant_tables_per_plot
|
7 |
+
from climateqa.engine.talk_to_data.query import execute_sql_query, find_indicator_column
|
8 |
+
from climateqa.engine.talk_to_data.objects.plot import Plot
|
9 |
+
from climateqa.engine.talk_to_data.objects.states import State, TTDOutput
|
10 |
+
from climateqa.engine.talk_to_data.ipcc.config import IPCC_TABLES, IPCC_INDICATOR_COLUMNS_PER_TABLE, IPCC_PLOT_PARAMETERS
|
11 |
+
from climateqa.engine.talk_to_data.ipcc.plots import IPCC_PLOTS
|
12 |
+
|
13 |
+
ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))
|
14 |
+
|
15 |
+
async def process_output(
|
16 |
+
output_title: str,
|
17 |
+
table: str,
|
18 |
+
plot: Plot,
|
19 |
+
params: dict[str, Any]
|
20 |
+
) -> tuple[str, TTDOutput, dict[str, bool]]:
|
21 |
+
"""
|
22 |
+
Process a table for a given plot and parameters: builds the SQL query, executes it,
|
23 |
+
and generates the corresponding figure.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
output_title (str): Title for the output (used as key in outputs dict).
|
27 |
+
table (str): The name of the table to process.
|
28 |
+
plot (Plot): The plot object containing SQL query and visualization function.
|
29 |
+
params (dict[str, Any]): Parameters used for querying the table.
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
tuple: (output_title, results dict, errors dict)
|
33 |
+
"""
|
34 |
+
results: TTDOutput = {
|
35 |
+
'status': 'OK',
|
36 |
+
'plot': plot,
|
37 |
+
'table': table,
|
38 |
+
'sql_query': None,
|
39 |
+
'dataframe': None,
|
40 |
+
'figure': None
|
41 |
+
}
|
42 |
+
errors = {
|
43 |
+
'have_sql_query': False,
|
44 |
+
'have_dataframe': False
|
45 |
+
}
|
46 |
+
|
47 |
+
# Find the indicator column for this table
|
48 |
+
indicator_column = find_indicator_column(table, IPCC_INDICATOR_COLUMNS_PER_TABLE)
|
49 |
+
if indicator_column:
|
50 |
+
params['indicator_column'] = indicator_column
|
51 |
+
|
52 |
+
# Build the SQL query
|
53 |
+
sql_query = plot['sql_query'](table, params)
|
54 |
+
if not sql_query:
|
55 |
+
results['status'] = 'ERROR'
|
56 |
+
return output_title, results, errors
|
57 |
+
|
58 |
+
results['sql_query'] = sql_query
|
59 |
+
errors['have_sql_query'] = True
|
60 |
+
|
61 |
+
# Execute the SQL query
|
62 |
+
df = await execute_sql_query(sql_query)
|
63 |
+
if df is not None and not df.empty:
|
64 |
+
results['dataframe'] = df
|
65 |
+
errors['have_dataframe'] = True
|
66 |
+
else:
|
67 |
+
results['status'] = 'NO_DATA'
|
68 |
+
|
69 |
+
# Generate the figure (always, even if df is empty, for consistency)
|
70 |
+
results['figure'] = plot['plot_function'](params)
|
71 |
+
|
72 |
+
return output_title, results, errors
|
73 |
+
|
74 |
+
async def ipcc_workflow(user_input: str) -> State:
|
75 |
+
"""
|
76 |
+
Performs the complete workflow of Talk To IPCC: from user input to SQL queries, dataframes, and figures.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
user_input (str): The user's question.
|
80 |
+
|
81 |
+
Returns:
|
82 |
+
State: Final state with all the results and error messages if any.
|
83 |
+
"""
|
84 |
+
state: State = {
|
85 |
+
'user_input': user_input,
|
86 |
+
'plots': [],
|
87 |
+
'outputs': {},
|
88 |
+
'error': ''
|
89 |
+
}
|
90 |
+
|
91 |
+
llm = get_llm(provider="openai")
|
92 |
+
plots = await find_relevant_plots(state, llm, IPCC_PLOTS)
|
93 |
+
state['plots'] = plots
|
94 |
+
|
95 |
+
if not plots:
|
96 |
+
state['error'] = 'There is no plot to answer to the question'
|
97 |
+
return state
|
98 |
+
|
99 |
+
errors = {
|
100 |
+
'have_relevant_table': False,
|
101 |
+
'have_sql_query': False,
|
102 |
+
'have_dataframe': False
|
103 |
+
}
|
104 |
+
outputs = {}
|
105 |
+
|
106 |
+
# Find relevant tables for each plot and prepare outputs
|
107 |
+
for plot_name in plots:
|
108 |
+
plot = next((p for p in IPCC_PLOTS if p['name'] == plot_name), None)
|
109 |
+
if plot is None:
|
110 |
+
continue
|
111 |
+
|
112 |
+
relevant_tables = await find_relevant_tables_per_plot(state, plot, llm, IPCC_TABLES)
|
113 |
+
if relevant_tables:
|
114 |
+
errors['have_relevant_table'] = True
|
115 |
+
|
116 |
+
for table in relevant_tables:
|
117 |
+
output_title = f"{plot['short_name']} - {' '.join(table.capitalize().split('_'))}"
|
118 |
+
outputs[output_title] = {
|
119 |
+
'table': table,
|
120 |
+
'plot': plot,
|
121 |
+
'status': 'OK'
|
122 |
+
}
|
123 |
+
|
124 |
+
# Gather all required parameters
|
125 |
+
params = {}
|
126 |
+
for param_name in IPCC_PLOT_PARAMETERS:
|
127 |
+
param = await find_param(state, param_name, mode='IPCC')
|
128 |
+
if param:
|
129 |
+
params.update(param)
|
130 |
+
|
131 |
+
# Process all outputs in parallel using process_output
|
132 |
+
tasks = [
|
133 |
+
process_output(output_title, output['table'], output['plot'], params.copy())
|
134 |
+
for output_title, output in outputs.items()
|
135 |
+
]
|
136 |
+
results = await asyncio.gather(*tasks)
|
137 |
+
|
138 |
+
# Update outputs with results and error flags
|
139 |
+
for output_title, task_results, task_errors in results:
|
140 |
+
outputs[output_title]['sql_query'] = task_results['sql_query']
|
141 |
+
outputs[output_title]['dataframe'] = task_results['dataframe']
|
142 |
+
outputs[output_title]['figure'] = task_results['figure']
|
143 |
+
outputs[output_title]['status'] = task_results['status']
|
144 |
+
errors['have_sql_query'] |= task_errors['have_sql_query']
|
145 |
+
errors['have_dataframe'] |= task_errors['have_dataframe']
|
146 |
+
|
147 |
+
state['outputs'] = outputs
|
148 |
+
|
149 |
+
# Set error messages if needed
|
150 |
+
if not errors['have_relevant_table']:
|
151 |
+
state['error'] = "There is no relevant table in our database to answer your question"
|
152 |
+
elif not errors['have_sql_query']:
|
153 |
+
state['error'] = "There is no relevant sql query on our database that can help to answer your question"
|
154 |
+
elif not errors['have_dataframe']:
|
155 |
+
state['error'] = "There is no data in our table that can answer to your question"
|
156 |
+
|
157 |
+
return state
|
front/tabs/tab_ipcc.py
ADDED
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from operator import index
|
2 |
+
from random import choices
|
3 |
+
import gradio as gr
|
4 |
+
from typing import TypedDict, List, Optional
|
5 |
+
import pandas as pd
|
6 |
+
import os
|
7 |
+
from climateqa.engine.talk_to_data.main import ask_ipcc
|
8 |
+
from climateqa.engine.talk_to_data.ipcc.config import IPCC_MODELS, IPCC_UI_TEXT
|
9 |
+
|
10 |
+
class ipccUIElements(TypedDict):
|
11 |
+
tab: gr.Tab
|
12 |
+
details_accordion: gr.Accordion
|
13 |
+
examples_hidden: gr.Textbox
|
14 |
+
examples: gr.Examples
|
15 |
+
image_examples: gr.Row
|
16 |
+
ipcc_direct_question: gr.Textbox
|
17 |
+
result_text: gr.Textbox
|
18 |
+
table_names_display: gr.Radio
|
19 |
+
query_accordion: gr.Accordion
|
20 |
+
ipcc_sql_query: gr.Textbox
|
21 |
+
chart_accordion: gr.Accordion
|
22 |
+
scenario_selection: gr.Dropdown
|
23 |
+
ipcc_display: gr.Plot
|
24 |
+
table_accordion: gr.Accordion
|
25 |
+
ipcc_table: gr.DataFrame
|
26 |
+
|
27 |
+
|
28 |
+
async def ask_ipcc_query(query: str, index_state: int, user_id: str):
|
29 |
+
result = await ask_ipcc(query, index_state, user_id)
|
30 |
+
return result
|
31 |
+
|
32 |
+
|
33 |
+
def show_results(sql_queries_state, dataframes_state, plots_state, table_names):
|
34 |
+
if not sql_queries_state or not dataframes_state or not plots_state:
|
35 |
+
# If all results are empty, show "No result"
|
36 |
+
return (
|
37 |
+
gr.update(visible=True),
|
38 |
+
gr.update(visible=False),
|
39 |
+
gr.update(visible=False),
|
40 |
+
gr.update(visible=False),
|
41 |
+
gr.update(visible=False),
|
42 |
+
)
|
43 |
+
else:
|
44 |
+
# Show the appropriate components with their data
|
45 |
+
return (
|
46 |
+
gr.update(visible=False),
|
47 |
+
gr.update(visible=True),
|
48 |
+
gr.update(visible=True),
|
49 |
+
gr.update(visible=True),
|
50 |
+
gr.update(choices=table_names, value=table_names[0], visible=True),
|
51 |
+
)
|
52 |
+
|
53 |
+
|
54 |
+
def show_filter_by_scenario(table_names, index_state, dataframes):
|
55 |
+
if table_names[index_state].startswith("Choropleth Map"):
|
56 |
+
df = dataframes[index_state]
|
57 |
+
return gr.update(visible=True, choices=sorted(df["scenario"].unique()), value=df["scenario"].unique()[0])
|
58 |
+
else:
|
59 |
+
return gr.update(visible=False)
|
60 |
+
|
61 |
+
def filter_by_scenario(dataframes, figures, index_state, scenario):
|
62 |
+
df = dataframes[index_state]
|
63 |
+
if df.empty:
|
64 |
+
return df, None
|
65 |
+
if "scenario" not in df.columns:
|
66 |
+
return df, figures[index_state](df)
|
67 |
+
else:
|
68 |
+
df = df[df["scenario"] == scenario]
|
69 |
+
if df.empty:
|
70 |
+
return df, None
|
71 |
+
figure = figures[index_state](df)
|
72 |
+
return df, figure
|
73 |
+
|
74 |
+
|
75 |
+
def display_table_names(table_names, index_state):
|
76 |
+
return [
|
77 |
+
[name]
|
78 |
+
for name in table_names
|
79 |
+
]
|
80 |
+
|
81 |
+
def on_table_click(selected_label, table_names, sql_queries, dataframes, plots):
|
82 |
+
index = table_names.index(selected_label)
|
83 |
+
figure = plots[index](dataframes[index])
|
84 |
+
return (
|
85 |
+
sql_queries[index],
|
86 |
+
dataframes[index],
|
87 |
+
figure,
|
88 |
+
index,
|
89 |
+
)
|
90 |
+
|
91 |
+
|
92 |
+
def create_ipcc_ui() -> ipccUIElements:
|
93 |
+
|
94 |
+
"""Create and return all UI elements for the ipcc tab."""
|
95 |
+
with gr.Tab("(Beta) Talk to IPCC", elem_id="tab-vanna", id=7) as tab:
|
96 |
+
with gr.Accordion(label="❓ How to use?", elem_id="details") as details_accordion:
|
97 |
+
gr.Markdown(IPCC_UI_TEXT)
|
98 |
+
|
99 |
+
# Add examples for common questions
|
100 |
+
examples_hidden = gr.Textbox(visible=False, elem_id="ipcc-examples-hidden")
|
101 |
+
examples = gr.Examples(
|
102 |
+
examples=[
|
103 |
+
["What will the temperature be like in Paris?"],
|
104 |
+
["What will be the total rainfall in the USA in 2030?"],
|
105 |
+
["How will the average temperature evolve in China?"],
|
106 |
+
["What will be the average total precipitation in London ?"]
|
107 |
+
],
|
108 |
+
label="Example Questions",
|
109 |
+
inputs=[examples_hidden],
|
110 |
+
outputs=[examples_hidden],
|
111 |
+
)
|
112 |
+
|
113 |
+
with gr.Row():
|
114 |
+
ipcc_direct_question = gr.Textbox(
|
115 |
+
label="Direct Question",
|
116 |
+
placeholder="You can write direct question here",
|
117 |
+
elem_id="direct-question",
|
118 |
+
interactive=True,
|
119 |
+
)
|
120 |
+
|
121 |
+
with gr.Row(visible=True, elem_id="example-img-container") as image_examples:
|
122 |
+
gr.Markdown("### Examples of possible visualizations")
|
123 |
+
|
124 |
+
with gr.Row():
|
125 |
+
gr.Image("./front/assets/talk_to_ipcc_france_example.png", label="Total Precipitation in 2030 in France", elem_classes=["example-img"])
|
126 |
+
gr.Image("./front/assets/talk_to_ipcc_new_york_example.png", label="Yearly Evolution of Mean Temperature in New York (Historical + SSP Scenarios)", elem_classes=["example-img"])
|
127 |
+
gr.Image("./front/assets/talk_to_ipcc_china_example.png", label="Mean Temperature in 2050 in China", elem_classes=["example-img"])
|
128 |
+
|
129 |
+
result_text = gr.Textbox(
|
130 |
+
label="", elem_id="no-result-label", interactive=False, visible=True
|
131 |
+
)
|
132 |
+
with gr.Row():
|
133 |
+
table_names_display = gr.Radio(
|
134 |
+
choices=[],
|
135 |
+
label="Relevant figures created",
|
136 |
+
interactive=True,
|
137 |
+
elem_id="table-names",
|
138 |
+
visible=False
|
139 |
+
)
|
140 |
+
|
141 |
+
with gr.Accordion(label="SQL Query Used", visible=False) as query_accordion:
|
142 |
+
ipcc_sql_query = gr.Textbox(
|
143 |
+
label="", elem_id="sql-query", interactive=False
|
144 |
+
)
|
145 |
+
|
146 |
+
with gr.Accordion(label="Chart", visible=False) as chart_accordion:
|
147 |
+
scenario_selection = gr.Dropdown(
|
148 |
+
label="Scenario", choices=IPCC_MODELS, value="ALL", interactive=True, visible=False
|
149 |
+
)
|
150 |
+
ipcc_display = gr.Plot(elem_id="vanna-plot")
|
151 |
+
|
152 |
+
with gr.Accordion(
|
153 |
+
label="Data used", open=False, visible=False
|
154 |
+
) as table_accordion:
|
155 |
+
ipcc_table = gr.DataFrame([], elem_id="vanna-table")
|
156 |
+
|
157 |
+
|
158 |
+
return ipccUIElements(
|
159 |
+
tab=tab,
|
160 |
+
details_accordion=details_accordion,
|
161 |
+
examples_hidden=examples_hidden,
|
162 |
+
examples=examples,
|
163 |
+
image_examples=image_examples,
|
164 |
+
ipcc_direct_question=ipcc_direct_question,
|
165 |
+
result_text=result_text,
|
166 |
+
table_names_display=table_names_display,
|
167 |
+
query_accordion=query_accordion,
|
168 |
+
ipcc_sql_query=ipcc_sql_query,
|
169 |
+
chart_accordion=chart_accordion,
|
170 |
+
scenario_selection=scenario_selection,
|
171 |
+
ipcc_display=ipcc_display,
|
172 |
+
table_accordion=table_accordion,
|
173 |
+
ipcc_table=ipcc_table,
|
174 |
+
)
|
175 |
+
|
176 |
+
|
177 |
+
|
178 |
+
def setup_ipcc_events(ui_elements: ipccUIElements, share_client=None, user_id=None) -> None:
|
179 |
+
"""Set up all event handlers for the ipcc tab."""
|
180 |
+
# Create state variables
|
181 |
+
sql_queries_state = gr.State([])
|
182 |
+
dataframes_state = gr.State([])
|
183 |
+
plots_state = gr.State([])
|
184 |
+
index_state = gr.State(0)
|
185 |
+
table_names_list = gr.State([])
|
186 |
+
user_id = gr.State(user_id)
|
187 |
+
|
188 |
+
# Handle example selection
|
189 |
+
ui_elements["examples_hidden"].change(
|
190 |
+
lambda x: (gr.Accordion(open=False), gr.Textbox(value=x)),
|
191 |
+
inputs=[ui_elements["examples_hidden"]],
|
192 |
+
outputs=[ui_elements["details_accordion"], ui_elements["ipcc_direct_question"]]
|
193 |
+
).then(
|
194 |
+
lambda : gr.update(visible=False),
|
195 |
+
inputs=None,
|
196 |
+
outputs=ui_elements["image_examples"]
|
197 |
+
).then(
|
198 |
+
ask_ipcc_query,
|
199 |
+
inputs=[ui_elements["examples_hidden"], index_state, user_id],
|
200 |
+
outputs=[
|
201 |
+
ui_elements["ipcc_sql_query"],
|
202 |
+
ui_elements["ipcc_table"],
|
203 |
+
ui_elements["ipcc_display"],
|
204 |
+
sql_queries_state,
|
205 |
+
dataframes_state,
|
206 |
+
plots_state,
|
207 |
+
index_state,
|
208 |
+
table_names_list,
|
209 |
+
ui_elements["result_text"],
|
210 |
+
],
|
211 |
+
).then(
|
212 |
+
show_results,
|
213 |
+
inputs=[sql_queries_state, dataframes_state, plots_state, table_names_list],
|
214 |
+
outputs=[
|
215 |
+
ui_elements["result_text"],
|
216 |
+
ui_elements["query_accordion"],
|
217 |
+
ui_elements["table_accordion"],
|
218 |
+
ui_elements["chart_accordion"],
|
219 |
+
ui_elements["table_names_display"],
|
220 |
+
],
|
221 |
+
).then(
|
222 |
+
show_filter_by_scenario,
|
223 |
+
inputs=[table_names_list, index_state, dataframes_state],
|
224 |
+
outputs=[ui_elements["scenario_selection"]],
|
225 |
+
)
|
226 |
+
|
227 |
+
# Handle direct question submission
|
228 |
+
ui_elements["ipcc_direct_question"].submit(
|
229 |
+
lambda: gr.Accordion(open=False),
|
230 |
+
inputs=None,
|
231 |
+
outputs=[ui_elements["details_accordion"]]
|
232 |
+
).then(
|
233 |
+
lambda: gr.update(visible=False),
|
234 |
+
inputs=None,
|
235 |
+
outputs=ui_elements["image_examples"]
|
236 |
+
).then(
|
237 |
+
ask_ipcc_query,
|
238 |
+
inputs=[ui_elements["ipcc_direct_question"], index_state, user_id],
|
239 |
+
outputs=[
|
240 |
+
ui_elements["ipcc_sql_query"],
|
241 |
+
ui_elements["ipcc_table"],
|
242 |
+
ui_elements["ipcc_display"],
|
243 |
+
sql_queries_state,
|
244 |
+
dataframes_state,
|
245 |
+
plots_state,
|
246 |
+
index_state,
|
247 |
+
table_names_list,
|
248 |
+
ui_elements["result_text"],
|
249 |
+
],
|
250 |
+
).then(
|
251 |
+
show_results,
|
252 |
+
inputs=[sql_queries_state, dataframes_state, plots_state, table_names_list],
|
253 |
+
outputs=[
|
254 |
+
ui_elements["result_text"],
|
255 |
+
ui_elements["query_accordion"],
|
256 |
+
ui_elements["table_accordion"],
|
257 |
+
ui_elements["chart_accordion"],
|
258 |
+
ui_elements["table_names_display"],
|
259 |
+
],
|
260 |
+
).then(
|
261 |
+
show_filter_by_scenario,
|
262 |
+
inputs=[table_names_list, index_state, dataframes_state],
|
263 |
+
outputs=[ui_elements["scenario_selection"]],
|
264 |
+
)
|
265 |
+
|
266 |
+
# Handle model selection change
|
267 |
+
ui_elements["scenario_selection"].change(
|
268 |
+
filter_by_scenario,
|
269 |
+
inputs=[dataframes_state, plots_state, index_state, ui_elements["scenario_selection"]],
|
270 |
+
outputs=[ui_elements["ipcc_table"], ui_elements["ipcc_display"]],
|
271 |
+
)
|
272 |
+
|
273 |
+
# Handle table selection
|
274 |
+
ui_elements["table_names_display"].change(
|
275 |
+
fn=on_table_click,
|
276 |
+
inputs=[ui_elements["table_names_display"], table_names_list, sql_queries_state, dataframes_state, plots_state],
|
277 |
+
outputs=[ui_elements["ipcc_sql_query"], ui_elements["ipcc_table"], ui_elements["ipcc_display"], index_state],
|
278 |
+
).then(
|
279 |
+
show_filter_by_scenario,
|
280 |
+
inputs=[table_names_list, index_state, dataframes_state],
|
281 |
+
outputs=[ui_elements["scenario_selection"]],
|
282 |
+
)
|
283 |
+
|
284 |
+
def create_ipcc_tab(share_client=None, user_id=None):
|
285 |
+
"""Create the ipcc tab with all its components and event handlers."""
|
286 |
+
ui_elements = create_ipcc_ui()
|
287 |
+
setup_ipcc_events(ui_elements, share_client=share_client, user_id=user_id)
|
288 |
+
|
289 |
+
|