Commit
·
45e1dba
1
Parent(s):
c3398f4
implemented talk to ipcc workflow and updated talk to data state object
Browse files- app.py +3 -2
- climateqa/engine/talk_to_data/main.py +77 -17
- climateqa/engine/talk_to_data/objects/states.py +7 -35
- climateqa/engine/talk_to_data/workflow/ipcc.py +157 -0
- front/tabs/tab_ipcc.py +289 -0
app.py
CHANGED
@@ -16,6 +16,7 @@ from climateqa.chat import start_chat, chat_stream, finish_chat
|
|
16 |
from front.tabs import create_config_modal, cqa_tab, create_about_tab
|
17 |
from front.tabs import MainTabPanel, ConfigPanel
|
18 |
from front.tabs.tab_drias import create_drias_tab
|
|
|
19 |
from front.utils import process_figures
|
20 |
from gradio_modal import Modal
|
21 |
|
@@ -532,8 +533,8 @@ def main_ui():
|
|
532 |
with gr.Tabs():
|
533 |
cqa_components = cqa_tab(tab_name="ClimateQ&A")
|
534 |
local_cqa_components = cqa_tab(tab_name="France - Local Q&A")
|
535 |
-
create_drias_tab(share_client=share_client, user_id=user_id)
|
536 |
-
|
537 |
create_about_tab()
|
538 |
|
539 |
event_handling(cqa_components, config_components, tab_name="ClimateQ&A")
|
|
|
16 |
from front.tabs import create_config_modal, cqa_tab, create_about_tab
|
17 |
from front.tabs import MainTabPanel, ConfigPanel
|
18 |
from front.tabs.tab_drias import create_drias_tab
|
19 |
+
from front.tabs.tab_ipcc import create_ipcc_tab
|
20 |
from front.utils import process_figures
|
21 |
from gradio_modal import Modal
|
22 |
|
|
|
533 |
with gr.Tabs():
|
534 |
cqa_components = cqa_tab(tab_name="ClimateQ&A")
|
535 |
local_cqa_components = cqa_tab(tab_name="France - Local Q&A")
|
536 |
+
drias_components = create_drias_tab(share_client=share_client, user_id=user_id)
|
537 |
+
ipcc_components = create_ipcc_tab(share_client=share_client, user_id=user_id)
|
538 |
create_about_tab()
|
539 |
|
540 |
event_handling(cqa_components, config_components, tab_name="ClimateQ&A")
|
climateqa/engine/talk_to_data/main.py
CHANGED
@@ -1,10 +1,13 @@
|
|
|
|
|
|
1 |
from climateqa.engine.talk_to_data.workflow.drias import drias_workflow
|
2 |
from climateqa.engine.llm import get_llm
|
|
|
3 |
from climateqa.logging import log_drias_interaction_to_huggingface
|
4 |
from climateqa.logging import log_drias_interaction_to_huggingface
|
5 |
import ast
|
6 |
|
7 |
-
async def ask_drias(query: str, index_state: int = 0, user_id: str = None) -> tuple:
|
8 |
"""Main function to process a DRIAS query and return results.
|
9 |
|
10 |
This function orchestrates the DRIAS workflow, processing a user query to generate
|
@@ -31,23 +34,80 @@ async def ask_drias(query: str, index_state: int = 0, user_id: str = None) -> tu
|
|
31 |
sql_queries = []
|
32 |
result_dataframes = []
|
33 |
figures = []
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
if "error" in final_state and final_state["error"] != "":
|
50 |
-
|
|
|
51 |
|
52 |
sql_query = sql_queries[index_state]
|
53 |
dataframe = result_dataframes[index_state]
|
@@ -55,4 +115,4 @@ async def ask_drias(query: str, index_state: int = 0, user_id: str = None) -> tu
|
|
55 |
|
56 |
log_drias_interaction_to_huggingface(query, sql_query, user_id)
|
57 |
|
58 |
-
return sql_query, dataframe, figure, sql_queries, result_dataframes, figures, index_state,
|
|
|
1 |
+
from operator import index
|
2 |
+
from duckdb import sql
|
3 |
from climateqa.engine.talk_to_data.workflow.drias import drias_workflow
|
4 |
from climateqa.engine.llm import get_llm
|
5 |
+
from climateqa.engine.talk_to_data.workflow.ipcc import ipcc_workflow
|
6 |
from climateqa.logging import log_drias_interaction_to_huggingface
|
7 |
from climateqa.logging import log_drias_interaction_to_huggingface
|
8 |
import ast
|
9 |
|
10 |
+
async def ask_drias(query: str, index_state: int = 0, user_id: str | None = None) -> tuple:
|
11 |
"""Main function to process a DRIAS query and return results.
|
12 |
|
13 |
This function orchestrates the DRIAS workflow, processing a user query to generate
|
|
|
34 |
sql_queries = []
|
35 |
result_dataframes = []
|
36 |
figures = []
|
37 |
+
plot_title_list = []
|
38 |
+
|
39 |
+
|
40 |
+
for output_title, output in final_state['outputs'].items():
|
41 |
+
if output['status'] == 'OK':
|
42 |
+
if output['table'] is not None:
|
43 |
+
plot_title_list.append(output_title)
|
44 |
+
if output['sql_query'] is not None:
|
45 |
+
sql_queries.append(output['sql_query'])
|
46 |
+
|
47 |
+
if output['dataframe'] is not None:
|
48 |
+
result_dataframes.append(output['dataframe'])
|
49 |
+
if output['figure'] is not None:
|
50 |
+
figures.append(output['figure'])
|
51 |
+
|
52 |
+
if "error" in final_state and final_state["error"] != "":
|
53 |
+
# No Sql query, no dataframe, no figure, empty sql queries list, empty result dataframes list, empty figures list, index state = 0, empty table list, error message
|
54 |
+
return None, None, None, [], [], [], 0, [], final_state["error"]
|
55 |
+
|
56 |
+
sql_query = sql_queries[index_state]
|
57 |
+
dataframe = result_dataframes[index_state]
|
58 |
+
figure = figures[index_state](dataframe)
|
59 |
+
|
60 |
+
|
61 |
+
log_drias_interaction_to_huggingface(query, sql_query, user_id)
|
62 |
+
|
63 |
+
return sql_query, dataframe, figure, sql_queries, result_dataframes, figures, index_state, plot_title_list, ""
|
64 |
+
|
65 |
+
|
66 |
+
async def ask_ipcc(query: str, index_state: int = 0, user_id: str | None = None) -> tuple:
|
67 |
+
"""Main function to process a DRIAS query and return results.
|
68 |
+
|
69 |
+
This function orchestrates the DRIAS workflow, processing a user query to generate
|
70 |
+
SQL queries, dataframes, and visualizations. It handles multiple results and allows
|
71 |
+
pagination through them.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
query (str): The user's question about climate data
|
75 |
+
index_state (int, optional): The index of the result to return. Defaults to 0.
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
tuple: A tuple containing:
|
79 |
+
- sql_query (str): The SQL query used
|
80 |
+
- dataframe (pd.DataFrame): The resulting data
|
81 |
+
- figure (Callable): Function to generate the visualization
|
82 |
+
- sql_queries (list): All generated SQL queries
|
83 |
+
- result_dataframes (list): All resulting dataframes
|
84 |
+
- figures (list): All figure generation functions
|
85 |
+
- index_state (int): Current result index
|
86 |
+
- table_list (list): List of table names used
|
87 |
+
- error (str): Error message if any
|
88 |
+
"""
|
89 |
+
final_state = await ipcc_workflow(query)
|
90 |
+
sql_queries = []
|
91 |
+
result_dataframes = []
|
92 |
+
figures = []
|
93 |
+
plot_title_list = []
|
94 |
+
|
95 |
+
|
96 |
+
for output_title, output in final_state['outputs'].items():
|
97 |
+
if output['status'] == 'OK':
|
98 |
+
if output['table'] is not None:
|
99 |
+
plot_title_list.append(output_title)
|
100 |
+
if output['sql_query'] is not None:
|
101 |
+
sql_queries.append(output['sql_query'])
|
102 |
+
|
103 |
+
if output['dataframe'] is not None:
|
104 |
+
result_dataframes.append(output['dataframe'])
|
105 |
+
if output['figure'] is not None:
|
106 |
+
figures.append(output['figure'])
|
107 |
|
108 |
if "error" in final_state and final_state["error"] != "":
|
109 |
+
# No Sql query, no dataframe, no figure, empty sql queries list, empty result dataframes list, empty figures list, index state = 0, empty table list, error message
|
110 |
+
return None, None, None, [], [], [], 0, [], final_state["error"]
|
111 |
|
112 |
sql_query = sql_queries[index_state]
|
113 |
dataframe = result_dataframes[index_state]
|
|
|
115 |
|
116 |
log_drias_interaction_to_huggingface(query, sql_query, user_id)
|
117 |
|
118 |
+
return sql_query, dataframe, figure, sql_queries, result_dataframes, figures, index_state, plot_title_list, ""
|
climateqa/engine/talk_to_data/objects/states.py
CHANGED
@@ -1,46 +1,18 @@
|
|
1 |
from typing import Any, Callable, Optional, TypedDict
|
2 |
from plotly.graph_objects import Figure
|
3 |
import pandas as pd
|
|
|
4 |
|
5 |
-
class
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
data processing workflow, including its name, parameters, SQL query, and results.
|
10 |
-
|
11 |
-
Attributes:
|
12 |
-
table_name (str): The name of the table in the database
|
13 |
-
params (dict[str, Any]): Parameters used for querying the table
|
14 |
-
sql_query (str, optional): The SQL query used to fetch data
|
15 |
-
dataframe (pd.DataFrame | None, optional): The resulting data
|
16 |
-
figure (Callable[..., Figure], optional): Function to generate visualization
|
17 |
-
status (str): The current status of the table processing ('OK' or 'ERROR')
|
18 |
-
"""
|
19 |
-
table_name: str
|
20 |
-
params: dict[str, Any]
|
21 |
sql_query: Optional[str]
|
22 |
-
dataframe: Optional[pd.DataFrame
|
23 |
figure: Optional[Callable[..., Figure]]
|
24 |
-
status: str
|
25 |
-
|
26 |
-
class PlotState(TypedDict):
|
27 |
-
"""Represents the state of a plot in the DRIAS workflow.
|
28 |
-
|
29 |
-
This class defines the structure for tracking the state of a plot during the
|
30 |
-
data processing workflow, including its name and associated tables.
|
31 |
-
|
32 |
-
Attributes:
|
33 |
-
plot_name (str): The name of the plot
|
34 |
-
tables (list[str]): List of tables used in the plot
|
35 |
-
table_states (dict[str, TableState]): States of the tables used in the plot
|
36 |
-
"""
|
37 |
-
plot_name: str
|
38 |
-
tables: list[str]
|
39 |
-
table_states: dict[str, TableState]
|
40 |
-
|
41 |
class State(TypedDict):
|
42 |
user_input: str
|
43 |
plots: list[str]
|
44 |
-
|
45 |
error: Optional[str]
|
46 |
|
|
|
1 |
from typing import Any, Callable, Optional, TypedDict
|
2 |
from plotly.graph_objects import Figure
|
3 |
import pandas as pd
|
4 |
+
from climateqa.engine.talk_to_data.objects.plot import Plot
|
5 |
|
6 |
+
class TTDOutput(TypedDict):
|
7 |
+
status: str
|
8 |
+
plot: Plot
|
9 |
+
table: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
sql_query: Optional[str]
|
11 |
+
dataframe: Optional[pd.DataFrame]
|
12 |
figure: Optional[Callable[..., Figure]]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
class State(TypedDict):
|
14 |
user_input: str
|
15 |
plots: list[str]
|
16 |
+
outputs: dict[str, TTDOutput]
|
17 |
error: Optional[str]
|
18 |
|
climateqa/engine/talk_to_data/workflow/ipcc.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from typing import Any
|
4 |
+
import asyncio
|
5 |
+
from climateqa.engine.llm import get_llm
|
6 |
+
from climateqa.engine.talk_to_data.input_processing import find_param, find_relevant_plots, find_relevant_tables_per_plot
|
7 |
+
from climateqa.engine.talk_to_data.query import execute_sql_query, find_indicator_column
|
8 |
+
from climateqa.engine.talk_to_data.objects.plot import Plot
|
9 |
+
from climateqa.engine.talk_to_data.objects.states import State, TTDOutput
|
10 |
+
from climateqa.engine.talk_to_data.ipcc.config import IPCC_TABLES, IPCC_INDICATOR_COLUMNS_PER_TABLE, IPCC_PLOT_PARAMETERS
|
11 |
+
from climateqa.engine.talk_to_data.ipcc.plots import IPCC_PLOTS
|
12 |
+
|
13 |
+
ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))
|
14 |
+
|
15 |
+
async def process_output(
|
16 |
+
output_title: str,
|
17 |
+
table: str,
|
18 |
+
plot: Plot,
|
19 |
+
params: dict[str, Any]
|
20 |
+
) -> tuple[str, TTDOutput, dict[str, bool]]:
|
21 |
+
"""
|
22 |
+
Process a table for a given plot and parameters: builds the SQL query, executes it,
|
23 |
+
and generates the corresponding figure.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
output_title (str): Title for the output (used as key in outputs dict).
|
27 |
+
table (str): The name of the table to process.
|
28 |
+
plot (Plot): The plot object containing SQL query and visualization function.
|
29 |
+
params (dict[str, Any]): Parameters used for querying the table.
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
tuple: (output_title, results dict, errors dict)
|
33 |
+
"""
|
34 |
+
results: TTDOutput = {
|
35 |
+
'status': 'OK',
|
36 |
+
'plot': plot,
|
37 |
+
'table': table,
|
38 |
+
'sql_query': None,
|
39 |
+
'dataframe': None,
|
40 |
+
'figure': None
|
41 |
+
}
|
42 |
+
errors = {
|
43 |
+
'have_sql_query': False,
|
44 |
+
'have_dataframe': False
|
45 |
+
}
|
46 |
+
|
47 |
+
# Find the indicator column for this table
|
48 |
+
indicator_column = find_indicator_column(table, IPCC_INDICATOR_COLUMNS_PER_TABLE)
|
49 |
+
if indicator_column:
|
50 |
+
params['indicator_column'] = indicator_column
|
51 |
+
|
52 |
+
# Build the SQL query
|
53 |
+
sql_query = plot['sql_query'](table, params)
|
54 |
+
if not sql_query:
|
55 |
+
results['status'] = 'ERROR'
|
56 |
+
return output_title, results, errors
|
57 |
+
|
58 |
+
results['sql_query'] = sql_query
|
59 |
+
errors['have_sql_query'] = True
|
60 |
+
|
61 |
+
# Execute the SQL query
|
62 |
+
df = await execute_sql_query(sql_query)
|
63 |
+
if df is not None and not df.empty:
|
64 |
+
results['dataframe'] = df
|
65 |
+
errors['have_dataframe'] = True
|
66 |
+
else:
|
67 |
+
results['status'] = 'NO_DATA'
|
68 |
+
|
69 |
+
# Generate the figure (always, even if df is empty, for consistency)
|
70 |
+
results['figure'] = plot['plot_function'](params)
|
71 |
+
|
72 |
+
return output_title, results, errors
|
73 |
+
|
74 |
+
async def ipcc_workflow(user_input: str) -> State:
|
75 |
+
"""
|
76 |
+
Performs the complete workflow of Talk To IPCC: from user input to SQL queries, dataframes, and figures.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
user_input (str): The user's question.
|
80 |
+
|
81 |
+
Returns:
|
82 |
+
State: Final state with all the results and error messages if any.
|
83 |
+
"""
|
84 |
+
state: State = {
|
85 |
+
'user_input': user_input,
|
86 |
+
'plots': [],
|
87 |
+
'outputs': {},
|
88 |
+
'error': ''
|
89 |
+
}
|
90 |
+
|
91 |
+
llm = get_llm(provider="openai")
|
92 |
+
plots = await find_relevant_plots(state, llm, IPCC_PLOTS)
|
93 |
+
state['plots'] = plots
|
94 |
+
|
95 |
+
if not plots:
|
96 |
+
state['error'] = 'There is no plot to answer to the question'
|
97 |
+
return state
|
98 |
+
|
99 |
+
errors = {
|
100 |
+
'have_relevant_table': False,
|
101 |
+
'have_sql_query': False,
|
102 |
+
'have_dataframe': False
|
103 |
+
}
|
104 |
+
outputs = {}
|
105 |
+
|
106 |
+
# Find relevant tables for each plot and prepare outputs
|
107 |
+
for plot_name in plots:
|
108 |
+
plot = next((p for p in IPCC_PLOTS if p['name'] == plot_name), None)
|
109 |
+
if plot is None:
|
110 |
+
continue
|
111 |
+
|
112 |
+
relevant_tables = await find_relevant_tables_per_plot(state, plot, llm, IPCC_TABLES)
|
113 |
+
if relevant_tables:
|
114 |
+
errors['have_relevant_table'] = True
|
115 |
+
|
116 |
+
for table in relevant_tables:
|
117 |
+
output_title = f"{plot['short_name']} - {' '.join(table.capitalize().split('_'))}"
|
118 |
+
outputs[output_title] = {
|
119 |
+
'table': table,
|
120 |
+
'plot': plot,
|
121 |
+
'status': 'OK'
|
122 |
+
}
|
123 |
+
|
124 |
+
# Gather all required parameters
|
125 |
+
params = {}
|
126 |
+
for param_name in IPCC_PLOT_PARAMETERS:
|
127 |
+
param = await find_param(state, param_name, mode='IPCC')
|
128 |
+
if param:
|
129 |
+
params.update(param)
|
130 |
+
|
131 |
+
# Process all outputs in parallel using process_output
|
132 |
+
tasks = [
|
133 |
+
process_output(output_title, output['table'], output['plot'], params.copy())
|
134 |
+
for output_title, output in outputs.items()
|
135 |
+
]
|
136 |
+
results = await asyncio.gather(*tasks)
|
137 |
+
|
138 |
+
# Update outputs with results and error flags
|
139 |
+
for output_title, task_results, task_errors in results:
|
140 |
+
outputs[output_title]['sql_query'] = task_results['sql_query']
|
141 |
+
outputs[output_title]['dataframe'] = task_results['dataframe']
|
142 |
+
outputs[output_title]['figure'] = task_results['figure']
|
143 |
+
outputs[output_title]['status'] = task_results['status']
|
144 |
+
errors['have_sql_query'] |= task_errors['have_sql_query']
|
145 |
+
errors['have_dataframe'] |= task_errors['have_dataframe']
|
146 |
+
|
147 |
+
state['outputs'] = outputs
|
148 |
+
|
149 |
+
# Set error messages if needed
|
150 |
+
if not errors['have_relevant_table']:
|
151 |
+
state['error'] = "There is no relevant table in our database to answer your question"
|
152 |
+
elif not errors['have_sql_query']:
|
153 |
+
state['error'] = "There is no relevant sql query on our database that can help to answer your question"
|
154 |
+
elif not errors['have_dataframe']:
|
155 |
+
state['error'] = "There is no data in our table that can answer to your question"
|
156 |
+
|
157 |
+
return state
|
front/tabs/tab_ipcc.py
ADDED
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from operator import index
|
2 |
+
from random import choices
|
3 |
+
import gradio as gr
|
4 |
+
from typing import TypedDict, List, Optional
|
5 |
+
import pandas as pd
|
6 |
+
import os
|
7 |
+
from climateqa.engine.talk_to_data.main import ask_ipcc
|
8 |
+
from climateqa.engine.talk_to_data.ipcc.config import IPCC_MODELS, IPCC_UI_TEXT
|
9 |
+
|
10 |
+
class ipccUIElements(TypedDict):
|
11 |
+
tab: gr.Tab
|
12 |
+
details_accordion: gr.Accordion
|
13 |
+
examples_hidden: gr.Textbox
|
14 |
+
examples: gr.Examples
|
15 |
+
image_examples: gr.Row
|
16 |
+
ipcc_direct_question: gr.Textbox
|
17 |
+
result_text: gr.Textbox
|
18 |
+
table_names_display: gr.Radio
|
19 |
+
query_accordion: gr.Accordion
|
20 |
+
ipcc_sql_query: gr.Textbox
|
21 |
+
chart_accordion: gr.Accordion
|
22 |
+
scenario_selection: gr.Dropdown
|
23 |
+
ipcc_display: gr.Plot
|
24 |
+
table_accordion: gr.Accordion
|
25 |
+
ipcc_table: gr.DataFrame
|
26 |
+
|
27 |
+
|
28 |
+
async def ask_ipcc_query(query: str, index_state: int, user_id: str):
|
29 |
+
result = await ask_ipcc(query, index_state, user_id)
|
30 |
+
return result
|
31 |
+
|
32 |
+
|
33 |
+
def show_results(sql_queries_state, dataframes_state, plots_state, table_names):
|
34 |
+
if not sql_queries_state or not dataframes_state or not plots_state:
|
35 |
+
# If all results are empty, show "No result"
|
36 |
+
return (
|
37 |
+
gr.update(visible=True),
|
38 |
+
gr.update(visible=False),
|
39 |
+
gr.update(visible=False),
|
40 |
+
gr.update(visible=False),
|
41 |
+
gr.update(visible=False),
|
42 |
+
)
|
43 |
+
else:
|
44 |
+
# Show the appropriate components with their data
|
45 |
+
return (
|
46 |
+
gr.update(visible=False),
|
47 |
+
gr.update(visible=True),
|
48 |
+
gr.update(visible=True),
|
49 |
+
gr.update(visible=True),
|
50 |
+
gr.update(choices=table_names, value=table_names[0], visible=True),
|
51 |
+
)
|
52 |
+
|
53 |
+
|
54 |
+
def show_filter_by_scenario(table_names, index_state, dataframes):
|
55 |
+
if table_names[index_state].startswith("Choropleth Map"):
|
56 |
+
df = dataframes[index_state]
|
57 |
+
return gr.update(visible=True, choices=sorted(df["scenario"].unique()), value=df["scenario"].unique()[0])
|
58 |
+
else:
|
59 |
+
return gr.update(visible=False)
|
60 |
+
|
61 |
+
def filter_by_scenario(dataframes, figures, index_state, scenario):
|
62 |
+
df = dataframes[index_state]
|
63 |
+
if df.empty:
|
64 |
+
return df, None
|
65 |
+
if "scenario" not in df.columns:
|
66 |
+
return df, figures[index_state](df)
|
67 |
+
else:
|
68 |
+
df = df[df["scenario"] == scenario]
|
69 |
+
if df.empty:
|
70 |
+
return df, None
|
71 |
+
figure = figures[index_state](df)
|
72 |
+
return df, figure
|
73 |
+
|
74 |
+
|
75 |
+
def display_table_names(table_names, index_state):
|
76 |
+
return [
|
77 |
+
[name]
|
78 |
+
for name in table_names
|
79 |
+
]
|
80 |
+
|
81 |
+
def on_table_click(selected_label, table_names, sql_queries, dataframes, plots):
|
82 |
+
index = table_names.index(selected_label)
|
83 |
+
figure = plots[index](dataframes[index])
|
84 |
+
return (
|
85 |
+
sql_queries[index],
|
86 |
+
dataframes[index],
|
87 |
+
figure,
|
88 |
+
index,
|
89 |
+
)
|
90 |
+
|
91 |
+
|
92 |
+
def create_ipcc_ui() -> ipccUIElements:
|
93 |
+
|
94 |
+
"""Create and return all UI elements for the ipcc tab."""
|
95 |
+
with gr.Tab("(Beta) Talk to IPCC", elem_id="tab-vanna", id=7) as tab:
|
96 |
+
with gr.Accordion(label="❓ How to use?", elem_id="details") as details_accordion:
|
97 |
+
gr.Markdown(IPCC_UI_TEXT)
|
98 |
+
|
99 |
+
# Add examples for common questions
|
100 |
+
examples_hidden = gr.Textbox(visible=False, elem_id="ipcc-examples-hidden")
|
101 |
+
examples = gr.Examples(
|
102 |
+
examples=[
|
103 |
+
["What will the temperature be like in Paris?"],
|
104 |
+
["What will be the total rainfall in the USA in 2030?"],
|
105 |
+
["How will the average temperature evolve in China?"],
|
106 |
+
["What will be the average total precipitation in London ?"]
|
107 |
+
],
|
108 |
+
label="Example Questions",
|
109 |
+
inputs=[examples_hidden],
|
110 |
+
outputs=[examples_hidden],
|
111 |
+
)
|
112 |
+
|
113 |
+
with gr.Row():
|
114 |
+
ipcc_direct_question = gr.Textbox(
|
115 |
+
label="Direct Question",
|
116 |
+
placeholder="You can write direct question here",
|
117 |
+
elem_id="direct-question",
|
118 |
+
interactive=True,
|
119 |
+
)
|
120 |
+
|
121 |
+
with gr.Row(visible=True, elem_id="example-img-container") as image_examples:
|
122 |
+
gr.Markdown("### Examples of possible visualizations")
|
123 |
+
|
124 |
+
with gr.Row():
|
125 |
+
gr.Image("./front/assets/talk_to_ipcc_france_example.png", label="Total Precipitation in 2030 in France", elem_classes=["example-img"])
|
126 |
+
gr.Image("./front/assets/talk_to_ipcc_new_york_example.png", label="Yearly Evolution of Mean Temperature in New York (Historical + SSP Scenarios)", elem_classes=["example-img"])
|
127 |
+
gr.Image("./front/assets/talk_to_ipcc_china_example.png", label="Mean Temperature in 2050 in China", elem_classes=["example-img"])
|
128 |
+
|
129 |
+
result_text = gr.Textbox(
|
130 |
+
label="", elem_id="no-result-label", interactive=False, visible=True
|
131 |
+
)
|
132 |
+
with gr.Row():
|
133 |
+
table_names_display = gr.Radio(
|
134 |
+
choices=[],
|
135 |
+
label="Relevant figures created",
|
136 |
+
interactive=True,
|
137 |
+
elem_id="table-names",
|
138 |
+
visible=False
|
139 |
+
)
|
140 |
+
|
141 |
+
with gr.Accordion(label="SQL Query Used", visible=False) as query_accordion:
|
142 |
+
ipcc_sql_query = gr.Textbox(
|
143 |
+
label="", elem_id="sql-query", interactive=False
|
144 |
+
)
|
145 |
+
|
146 |
+
with gr.Accordion(label="Chart", visible=False) as chart_accordion:
|
147 |
+
scenario_selection = gr.Dropdown(
|
148 |
+
label="Scenario", choices=IPCC_MODELS, value="ALL", interactive=True, visible=False
|
149 |
+
)
|
150 |
+
ipcc_display = gr.Plot(elem_id="vanna-plot")
|
151 |
+
|
152 |
+
with gr.Accordion(
|
153 |
+
label="Data used", open=False, visible=False
|
154 |
+
) as table_accordion:
|
155 |
+
ipcc_table = gr.DataFrame([], elem_id="vanna-table")
|
156 |
+
|
157 |
+
|
158 |
+
return ipccUIElements(
|
159 |
+
tab=tab,
|
160 |
+
details_accordion=details_accordion,
|
161 |
+
examples_hidden=examples_hidden,
|
162 |
+
examples=examples,
|
163 |
+
image_examples=image_examples,
|
164 |
+
ipcc_direct_question=ipcc_direct_question,
|
165 |
+
result_text=result_text,
|
166 |
+
table_names_display=table_names_display,
|
167 |
+
query_accordion=query_accordion,
|
168 |
+
ipcc_sql_query=ipcc_sql_query,
|
169 |
+
chart_accordion=chart_accordion,
|
170 |
+
scenario_selection=scenario_selection,
|
171 |
+
ipcc_display=ipcc_display,
|
172 |
+
table_accordion=table_accordion,
|
173 |
+
ipcc_table=ipcc_table,
|
174 |
+
)
|
175 |
+
|
176 |
+
|
177 |
+
|
178 |
+
def setup_ipcc_events(ui_elements: ipccUIElements, share_client=None, user_id=None) -> None:
|
179 |
+
"""Set up all event handlers for the ipcc tab."""
|
180 |
+
# Create state variables
|
181 |
+
sql_queries_state = gr.State([])
|
182 |
+
dataframes_state = gr.State([])
|
183 |
+
plots_state = gr.State([])
|
184 |
+
index_state = gr.State(0)
|
185 |
+
table_names_list = gr.State([])
|
186 |
+
user_id = gr.State(user_id)
|
187 |
+
|
188 |
+
# Handle example selection
|
189 |
+
ui_elements["examples_hidden"].change(
|
190 |
+
lambda x: (gr.Accordion(open=False), gr.Textbox(value=x)),
|
191 |
+
inputs=[ui_elements["examples_hidden"]],
|
192 |
+
outputs=[ui_elements["details_accordion"], ui_elements["ipcc_direct_question"]]
|
193 |
+
).then(
|
194 |
+
lambda : gr.update(visible=False),
|
195 |
+
inputs=None,
|
196 |
+
outputs=ui_elements["image_examples"]
|
197 |
+
).then(
|
198 |
+
ask_ipcc_query,
|
199 |
+
inputs=[ui_elements["examples_hidden"], index_state, user_id],
|
200 |
+
outputs=[
|
201 |
+
ui_elements["ipcc_sql_query"],
|
202 |
+
ui_elements["ipcc_table"],
|
203 |
+
ui_elements["ipcc_display"],
|
204 |
+
sql_queries_state,
|
205 |
+
dataframes_state,
|
206 |
+
plots_state,
|
207 |
+
index_state,
|
208 |
+
table_names_list,
|
209 |
+
ui_elements["result_text"],
|
210 |
+
],
|
211 |
+
).then(
|
212 |
+
show_results,
|
213 |
+
inputs=[sql_queries_state, dataframes_state, plots_state, table_names_list],
|
214 |
+
outputs=[
|
215 |
+
ui_elements["result_text"],
|
216 |
+
ui_elements["query_accordion"],
|
217 |
+
ui_elements["table_accordion"],
|
218 |
+
ui_elements["chart_accordion"],
|
219 |
+
ui_elements["table_names_display"],
|
220 |
+
],
|
221 |
+
).then(
|
222 |
+
show_filter_by_scenario,
|
223 |
+
inputs=[table_names_list, index_state, dataframes_state],
|
224 |
+
outputs=[ui_elements["scenario_selection"]],
|
225 |
+
)
|
226 |
+
|
227 |
+
# Handle direct question submission
|
228 |
+
ui_elements["ipcc_direct_question"].submit(
|
229 |
+
lambda: gr.Accordion(open=False),
|
230 |
+
inputs=None,
|
231 |
+
outputs=[ui_elements["details_accordion"]]
|
232 |
+
).then(
|
233 |
+
lambda: gr.update(visible=False),
|
234 |
+
inputs=None,
|
235 |
+
outputs=ui_elements["image_examples"]
|
236 |
+
).then(
|
237 |
+
ask_ipcc_query,
|
238 |
+
inputs=[ui_elements["ipcc_direct_question"], index_state, user_id],
|
239 |
+
outputs=[
|
240 |
+
ui_elements["ipcc_sql_query"],
|
241 |
+
ui_elements["ipcc_table"],
|
242 |
+
ui_elements["ipcc_display"],
|
243 |
+
sql_queries_state,
|
244 |
+
dataframes_state,
|
245 |
+
plots_state,
|
246 |
+
index_state,
|
247 |
+
table_names_list,
|
248 |
+
ui_elements["result_text"],
|
249 |
+
],
|
250 |
+
).then(
|
251 |
+
show_results,
|
252 |
+
inputs=[sql_queries_state, dataframes_state, plots_state, table_names_list],
|
253 |
+
outputs=[
|
254 |
+
ui_elements["result_text"],
|
255 |
+
ui_elements["query_accordion"],
|
256 |
+
ui_elements["table_accordion"],
|
257 |
+
ui_elements["chart_accordion"],
|
258 |
+
ui_elements["table_names_display"],
|
259 |
+
],
|
260 |
+
).then(
|
261 |
+
show_filter_by_scenario,
|
262 |
+
inputs=[table_names_list, index_state, dataframes_state],
|
263 |
+
outputs=[ui_elements["scenario_selection"]],
|
264 |
+
)
|
265 |
+
|
266 |
+
# Handle model selection change
|
267 |
+
ui_elements["scenario_selection"].change(
|
268 |
+
filter_by_scenario,
|
269 |
+
inputs=[dataframes_state, plots_state, index_state, ui_elements["scenario_selection"]],
|
270 |
+
outputs=[ui_elements["ipcc_table"], ui_elements["ipcc_display"]],
|
271 |
+
)
|
272 |
+
|
273 |
+
# Handle table selection
|
274 |
+
ui_elements["table_names_display"].change(
|
275 |
+
fn=on_table_click,
|
276 |
+
inputs=[ui_elements["table_names_display"], table_names_list, sql_queries_state, dataframes_state, plots_state],
|
277 |
+
outputs=[ui_elements["ipcc_sql_query"], ui_elements["ipcc_table"], ui_elements["ipcc_display"], index_state],
|
278 |
+
).then(
|
279 |
+
show_filter_by_scenario,
|
280 |
+
inputs=[table_names_list, index_state, dataframes_state],
|
281 |
+
outputs=[ui_elements["scenario_selection"]],
|
282 |
+
)
|
283 |
+
|
284 |
+
def create_ipcc_tab(share_client=None, user_id=None):
|
285 |
+
"""Create the ipcc tab with all its components and event handlers."""
|
286 |
+
ui_elements = create_ipcc_ui()
|
287 |
+
setup_ipcc_events(ui_elements, share_client=share_client, user_id=user_id)
|
288 |
+
|
289 |
+
|