Merged in feature/talk_to_data (pull request #19)
Browse files- app.py +3 -2
- climateqa/engine/chains/retrieve_documents.py +6 -4
- climateqa/engine/talk_to_data/config.py +8 -96
- climateqa/engine/talk_to_data/drias/config.py +124 -0
- climateqa/engine/talk_to_data/drias/plot_informations.py +88 -0
- climateqa/engine/talk_to_data/{plot.py → drias/plots.py} +72 -56
- climateqa/engine/talk_to_data/{sql_query.py → drias/queries.py} +5 -36
- climateqa/engine/talk_to_data/{utils.py → input_processing.py} +144 -168
- climateqa/engine/talk_to_data/ipcc/config.py +98 -0
- climateqa/engine/talk_to_data/ipcc/plot_informations.py +50 -0
- climateqa/engine/talk_to_data/ipcc/plots.py +189 -0
- climateqa/engine/talk_to_data/ipcc/queries.py +143 -0
- climateqa/engine/talk_to_data/main.py +77 -71
- climateqa/engine/talk_to_data/objects/llm_outputs.py +13 -0
- climateqa/engine/talk_to_data/objects/location.py +12 -0
- climateqa/engine/talk_to_data/objects/plot.py +23 -0
- climateqa/engine/talk_to_data/objects/states.py +19 -0
- climateqa/engine/talk_to_data/prompt.py +44 -0
- climateqa/engine/talk_to_data/query.py +57 -0
- climateqa/engine/talk_to_data/talk_to_drias.py +0 -317
- climateqa/engine/talk_to_data/ui_config.py +27 -0
- climateqa/engine/talk_to_data/{myVanna.py → vanna/myVanna.py} +0 -0
- climateqa/engine/talk_to_data/{vanna_class.py → vanna/vanna_class.py} +0 -0
- climateqa/engine/talk_to_data/workflow/drias.py +163 -0
- climateqa/engine/talk_to_data/workflow/ipcc.py +161 -0
- front/tabs/tab_drias.py +60 -149
- front/tabs/tab_ipcc.py +300 -0
- requirements.txt +2 -1
- style.css +39 -7
app.py
CHANGED
@@ -16,6 +16,7 @@ from climateqa.chat import start_chat, chat_stream, finish_chat
|
|
16 |
from front.tabs import create_config_modal, cqa_tab, create_about_tab
|
17 |
from front.tabs import MainTabPanel, ConfigPanel
|
18 |
from front.tabs.tab_drias import create_drias_tab
|
|
|
19 |
from front.utils import process_figures
|
20 |
from gradio_modal import Modal
|
21 |
|
@@ -532,8 +533,8 @@ def main_ui():
|
|
532 |
with gr.Tabs():
|
533 |
cqa_components = cqa_tab(tab_name="ClimateQ&A")
|
534 |
local_cqa_components = cqa_tab(tab_name="France - Local Q&A")
|
535 |
-
create_drias_tab(share_client=share_client, user_id=user_id)
|
536 |
-
|
537 |
create_about_tab()
|
538 |
|
539 |
event_handling(cqa_components, config_components, tab_name="ClimateQ&A")
|
|
|
16 |
from front.tabs import create_config_modal, cqa_tab, create_about_tab
|
17 |
from front.tabs import MainTabPanel, ConfigPanel
|
18 |
from front.tabs.tab_drias import create_drias_tab
|
19 |
+
from front.tabs.tab_ipcc import create_ipcc_tab
|
20 |
from front.utils import process_figures
|
21 |
from gradio_modal import Modal
|
22 |
|
|
|
533 |
with gr.Tabs():
|
534 |
cqa_components = cqa_tab(tab_name="ClimateQ&A")
|
535 |
local_cqa_components = cqa_tab(tab_name="France - Local Q&A")
|
536 |
+
drias_components = create_drias_tab(share_client=share_client, user_id=user_id)
|
537 |
+
ipcc_components = create_ipcc_tab(share_client=share_client, user_id=user_id)
|
538 |
create_about_tab()
|
539 |
|
540 |
event_handling(cqa_components, config_components, tab_name="ClimateQ&A")
|
climateqa/engine/chains/retrieve_documents.py
CHANGED
@@ -21,7 +21,7 @@ from langchain_core.prompts import ChatPromptTemplate
|
|
21 |
from langchain_core.output_parsers import StrOutputParser
|
22 |
from ..vectorstore import get_pinecone_vectorstore
|
23 |
from ..embeddings import get_embeddings_function
|
24 |
-
|
25 |
|
26 |
import asyncio
|
27 |
|
@@ -477,8 +477,10 @@ async def retrieve_documents(
|
|
477 |
docs_question_dict[key] = rerank_and_sort_docs(reranker,docs_question_dict[key],question)
|
478 |
else:
|
479 |
# Add a default reranking score
|
480 |
-
for
|
481 |
-
|
|
|
|
|
482 |
|
483 |
# Keep the right number of documents
|
484 |
docs_question, images_question = concatenate_documents(index, source_type, docs_question_dict, k_by_question, k_summary_by_question, k_images_by_question)
|
@@ -580,7 +582,7 @@ async def get_relevant_toc_level_for_query(
|
|
580 |
response = chain.invoke({"query": query, "doc_list": doc_list})
|
581 |
|
582 |
try:
|
583 |
-
relevant_tocs =
|
584 |
except Exception as e:
|
585 |
print(f" Failed to parse the result because of : {e}")
|
586 |
|
|
|
21 |
from langchain_core.output_parsers import StrOutputParser
|
22 |
from ..vectorstore import get_pinecone_vectorstore
|
23 |
from ..embeddings import get_embeddings_function
|
24 |
+
import ast
|
25 |
|
26 |
import asyncio
|
27 |
|
|
|
477 |
docs_question_dict[key] = rerank_and_sort_docs(reranker,docs_question_dict[key],question)
|
478 |
else:
|
479 |
# Add a default reranking score
|
480 |
+
for key in docs_question_dict.keys():
|
481 |
+
if isinstance(docs_question_dict[key], list) and len(docs_question_dict[key]) > 0:
|
482 |
+
for doc in docs_question_dict[key]:
|
483 |
+
doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
|
484 |
|
485 |
# Keep the right number of documents
|
486 |
docs_question, images_question = concatenate_documents(index, source_type, docs_question_dict, k_by_question, k_summary_by_question, k_images_by_question)
|
|
|
582 |
response = chain.invoke({"query": query, "doc_list": doc_list})
|
583 |
|
584 |
try:
|
585 |
+
relevant_tocs = ast.literal_eval(response)
|
586 |
except Exception as e:
|
587 |
print(f" Failed to parse the result because of : {e}")
|
588 |
|
climateqa/engine/talk_to_data/config.py
CHANGED
@@ -1,99 +1,11 @@
|
|
1 |
-
|
2 |
-
"total_winter_precipitation",
|
3 |
-
"total_summer_precipiation",
|
4 |
-
"total_annual_precipitation",
|
5 |
-
"total_remarkable_daily_precipitation",
|
6 |
-
"frequency_of_remarkable_daily_precipitation",
|
7 |
-
"extreme_precipitation_intensity",
|
8 |
-
"mean_winter_temperature",
|
9 |
-
"mean_summer_temperature",
|
10 |
-
"mean_annual_temperature",
|
11 |
-
"number_of_tropical_nights",
|
12 |
-
"maximum_summer_temperature",
|
13 |
-
"number_of_days_with_tx_above_30",
|
14 |
-
"number_of_days_with_tx_above_35",
|
15 |
-
"number_of_days_with_a_dry_ground",
|
16 |
-
]
|
17 |
|
18 |
-
|
19 |
-
|
20 |
-
"total_summer_precipiation": "total_summer_precipitation",
|
21 |
-
"total_annual_precipitation": "total_annual_precipitation",
|
22 |
-
"total_remarkable_daily_precipitation": "total_remarkable_daily_precipitation",
|
23 |
-
"frequency_of_remarkable_daily_precipitation": "frequency_of_remarkable_daily_precipitation",
|
24 |
-
"extreme_precipitation_intensity": "extreme_precipitation_intensity",
|
25 |
-
"mean_winter_temperature": "mean_winter_temperature",
|
26 |
-
"mean_summer_temperature": "mean_summer_temperature",
|
27 |
-
"mean_annual_temperature": "mean_annual_temperature",
|
28 |
-
"number_of_tropical_nights": "number_tropical_nights",
|
29 |
-
"maximum_summer_temperature": "maximum_summer_temperature",
|
30 |
-
"number_of_days_with_tx_above_30": "number_of_days_with_tx_above_30",
|
31 |
-
"number_of_days_with_tx_above_35": "number_of_days_with_tx_above_35",
|
32 |
-
"number_of_days_with_a_dry_ground": "number_of_days_with_dry_ground"
|
33 |
-
}
|
34 |
|
35 |
-
|
36 |
-
|
37 |
-
'RegCM4-6_MPI-ESM-LR',
|
38 |
-
'RACMO22E_EC-EARTH',
|
39 |
-
'RegCM4-6_HadGEM2-ES',
|
40 |
-
'HadREM3-GA7_EC-EARTH',
|
41 |
-
'HadREM3-GA7_CNRM-CM5',
|
42 |
-
'REMO2015_NorESM1-M',
|
43 |
-
'SMHI-RCA4_EC-EARTH',
|
44 |
-
'WRF381P_NorESM1-M',
|
45 |
-
'ALADIN63_CNRM-CM5',
|
46 |
-
'CCLM4-8-17_MPI-ESM-LR',
|
47 |
-
'HIRHAM5_IPSL-CM5A-MR',
|
48 |
-
'HadREM3-GA7_HadGEM2-ES',
|
49 |
-
'SMHI-RCA4_IPSL-CM5A-MR',
|
50 |
-
'HIRHAM5_NorESM1-M',
|
51 |
-
'REMO2009_MPI-ESM-LR',
|
52 |
-
'CCLM4-8-17_HadGEM2-ES'
|
53 |
-
]
|
54 |
-
# Mapping between indicator columns and their units
|
55 |
-
INDICATOR_TO_UNIT = {
|
56 |
-
"total_winter_precipitation": "mm",
|
57 |
-
"total_summer_precipitation": "mm",
|
58 |
-
"total_annual_precipitation": "mm",
|
59 |
-
"total_remarkable_daily_precipitation": "mm",
|
60 |
-
"frequency_of_remarkable_daily_precipitation": "days",
|
61 |
-
"extreme_precipitation_intensity": "mm",
|
62 |
-
"mean_winter_temperature": "°C",
|
63 |
-
"mean_summer_temperature": "°C",
|
64 |
-
"mean_annual_temperature": "°C",
|
65 |
-
"number_tropical_nights": "days",
|
66 |
-
"maximum_summer_temperature": "°C",
|
67 |
-
"number_of_days_with_tx_above_30": "days",
|
68 |
-
"number_of_days_with_tx_above_35": "days",
|
69 |
-
"number_of_days_with_dry_ground": "days"
|
70 |
-
}
|
71 |
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
❓ **How to use?**
|
77 |
-
You can ask me anything about these climate indicators: **temperature**, **precipitation** or **drought**.
|
78 |
-
You can specify **location** and/or **year**.
|
79 |
-
You can choose from a list of climate models. By default, we take the **average of each model**.
|
80 |
-
|
81 |
-
For example, you can ask:
|
82 |
-
- What will the temperature be like in Paris?
|
83 |
-
- What will be the total rainfall in France in 2030?
|
84 |
-
- How frequent will extreme events be in Lyon?
|
85 |
-
|
86 |
-
**Example of indicators in the data**:
|
87 |
-
- Mean temperature (annual, winter, summer)
|
88 |
-
- Total precipitation (annual, winter, summer)
|
89 |
-
- Number of days with remarkable precipitations, with dry ground, with temperature above 30°C
|
90 |
-
|
91 |
-
⚠️ **Limitations**:
|
92 |
-
- You can't ask anything that isn't related to **DRIAS - TRACC 2023** data.
|
93 |
-
- You can only ask about **locations in France**.
|
94 |
-
- If you specify a year, there may be **no data for that year for some models**.
|
95 |
-
- You **cannot compare two models**.
|
96 |
-
|
97 |
-
🛈 **Information**
|
98 |
-
Please note that we **log your questions for meta-analysis purposes**, so avoid sharing any sensitive or personal information.
|
99 |
-
"""
|
|
|
1 |
+
# Path configuration for climateqa project
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
+
# IPCC dataset path
|
4 |
+
IPCC_DATASET_URL = "hf://datasets/ekimetrics/ipcc-atlas"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
+
# DRIAS dataset paths
|
7 |
+
DRIAS_DATASET_URL = "hf://datasets/timeki/drias_db"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
+
# Table paths
|
10 |
+
DRIAS_MEAN_ANNUAL_TEMPERATURE_PATH = f"{DRIAS_DATASET_URL}/mean_annual_temperature.parquet"
|
11 |
+
IPCC_COORDINATES_PATH = f"{IPCC_DATASET_URL}/coordinates.parquet"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/talk_to_data/drias/config.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from climateqa.engine.talk_to_data.ui_config import PRECIPITATION_COLORSCALE, TEMPERATURE_COLORSCALE
|
3 |
+
|
4 |
+
|
5 |
+
DRIAS_TABLES = [
|
6 |
+
"total_winter_precipitation",
|
7 |
+
"total_summer_precipitation",
|
8 |
+
"total_annual_precipitation",
|
9 |
+
"total_remarkable_daily_precipitation",
|
10 |
+
"frequency_of_remarkable_daily_precipitation",
|
11 |
+
"extreme_precipitation_intensity",
|
12 |
+
"mean_winter_temperature",
|
13 |
+
"mean_summer_temperature",
|
14 |
+
"mean_annual_temperature",
|
15 |
+
"number_of_tropical_nights",
|
16 |
+
"maximum_summer_temperature",
|
17 |
+
"number_of_days_with_tx_above_30",
|
18 |
+
"number_of_days_with_tx_above_35",
|
19 |
+
"number_of_days_with_a_dry_ground",
|
20 |
+
]
|
21 |
+
|
22 |
+
DRIAS_INDICATOR_COLUMNS_PER_TABLE = {
|
23 |
+
"total_winter_precipitation": "total_winter_precipitation",
|
24 |
+
"total_summer_precipitation": "total_summer_precipitation",
|
25 |
+
"total_annual_precipitation": "total_annual_precipitation",
|
26 |
+
"total_remarkable_daily_precipitation": "total_remarkable_daily_precipitation",
|
27 |
+
"frequency_of_remarkable_daily_precipitation": "frequency_of_remarkable_daily_precipitation",
|
28 |
+
"extreme_precipitation_intensity": "extreme_precipitation_intensity",
|
29 |
+
"mean_winter_temperature": "mean_winter_temperature",
|
30 |
+
"mean_summer_temperature": "mean_summer_temperature",
|
31 |
+
"mean_annual_temperature": "mean_annual_temperature",
|
32 |
+
"number_of_tropical_nights": "number_tropical_nights",
|
33 |
+
"maximum_summer_temperature": "maximum_summer_temperature",
|
34 |
+
"number_of_days_with_tx_above_30": "number_of_days_with_tx_above_30",
|
35 |
+
"number_of_days_with_tx_above_35": "number_of_days_with_tx_above_35",
|
36 |
+
"number_of_days_with_a_dry_ground": "number_of_days_with_dry_ground"
|
37 |
+
}
|
38 |
+
|
39 |
+
DRIAS_MODELS = [
|
40 |
+
'ALL',
|
41 |
+
'RegCM4-6_MPI-ESM-LR',
|
42 |
+
'RACMO22E_EC-EARTH',
|
43 |
+
'RegCM4-6_HadGEM2-ES',
|
44 |
+
'HadREM3-GA7_EC-EARTH',
|
45 |
+
'HadREM3-GA7_CNRM-CM5',
|
46 |
+
'REMO2015_NorESM1-M',
|
47 |
+
'SMHI-RCA4_EC-EARTH',
|
48 |
+
'WRF381P_NorESM1-M',
|
49 |
+
'ALADIN63_CNRM-CM5',
|
50 |
+
'CCLM4-8-17_MPI-ESM-LR',
|
51 |
+
'HIRHAM5_IPSL-CM5A-MR',
|
52 |
+
'HadREM3-GA7_HadGEM2-ES',
|
53 |
+
'SMHI-RCA4_IPSL-CM5A-MR',
|
54 |
+
'HIRHAM5_NorESM1-M',
|
55 |
+
'REMO2009_MPI-ESM-LR',
|
56 |
+
'CCLM4-8-17_HadGEM2-ES'
|
57 |
+
]
|
58 |
+
# Mapping between indicator columns and their units
|
59 |
+
DRIAS_INDICATOR_TO_UNIT = {
|
60 |
+
"total_winter_precipitation": "mm",
|
61 |
+
"total_summer_precipitation": "mm",
|
62 |
+
"total_annual_precipitation": "mm",
|
63 |
+
"total_remarkable_daily_precipitation": "mm",
|
64 |
+
"frequency_of_remarkable_daily_precipitation": "days",
|
65 |
+
"extreme_precipitation_intensity": "mm",
|
66 |
+
"mean_winter_temperature": "°C",
|
67 |
+
"mean_summer_temperature": "°C",
|
68 |
+
"mean_annual_temperature": "°C",
|
69 |
+
"number_tropical_nights": "days",
|
70 |
+
"maximum_summer_temperature": "°C",
|
71 |
+
"number_of_days_with_tx_above_30": "days",
|
72 |
+
"number_of_days_with_tx_above_35": "days",
|
73 |
+
"number_of_days_with_dry_ground": "days"
|
74 |
+
}
|
75 |
+
|
76 |
+
DRIAS_PLOT_PARAMETERS = [
|
77 |
+
'year',
|
78 |
+
'location'
|
79 |
+
]
|
80 |
+
|
81 |
+
DRIAS_INDICATOR_TO_COLORSCALE = {
|
82 |
+
"total_winter_precipitation": PRECIPITATION_COLORSCALE,
|
83 |
+
"total_summer_precipitation": PRECIPITATION_COLORSCALE,
|
84 |
+
"total_annual_precipitation": PRECIPITATION_COLORSCALE,
|
85 |
+
"total_remarkable_daily_precipitation": PRECIPITATION_COLORSCALE,
|
86 |
+
"frequency_of_remarkable_daily_precipitation": PRECIPITATION_COLORSCALE,
|
87 |
+
"extreme_precipitation_intensity": PRECIPITATION_COLORSCALE,
|
88 |
+
"mean_winter_temperature":TEMPERATURE_COLORSCALE,
|
89 |
+
"mean_summer_temperature":TEMPERATURE_COLORSCALE,
|
90 |
+
"mean_annual_temperature":TEMPERATURE_COLORSCALE,
|
91 |
+
"number_tropical_nights": TEMPERATURE_COLORSCALE,
|
92 |
+
"maximum_summer_temperature":TEMPERATURE_COLORSCALE,
|
93 |
+
"number_of_days_with_tx_above_30": TEMPERATURE_COLORSCALE,
|
94 |
+
"number_of_days_with_tx_above_35": TEMPERATURE_COLORSCALE,
|
95 |
+
"number_of_days_with_dry_ground": TEMPERATURE_COLORSCALE
|
96 |
+
}
|
97 |
+
|
98 |
+
DRIAS_UI_TEXT = """
|
99 |
+
Hi, I'm **Talk to Drias**, designed to answer your questions using [**DRIAS - TRACC 2023**](https://www.drias-climat.fr/accompagnement/sections/401) data.
|
100 |
+
I'll answer by displaying a list of SQL queries, graphs and data most relevant to your question.
|
101 |
+
|
102 |
+
You can ask me anything about these climate indicators: **temperature**, **precipitation** or **drought**.
|
103 |
+
You can specify **location** and/or **year**.
|
104 |
+
You can choose from a list of climate models. By default, we take the **average of each model**.
|
105 |
+
|
106 |
+
For example, you can ask:
|
107 |
+
- What will the temperature be like in Paris?
|
108 |
+
- What will be the total rainfall in France in 2030?
|
109 |
+
- How frequent will extreme events be in Lyon?
|
110 |
+
|
111 |
+
**Example of indicators in the data**:
|
112 |
+
- Mean temperature (annual, winter, summer)
|
113 |
+
- Total precipitation (annual, winter, summer)
|
114 |
+
- Number of days with remarkable precipitations, with dry ground, with temperature above 30°C
|
115 |
+
|
116 |
+
⚠️ **Limitations**:
|
117 |
+
- You can't ask anything that isn't related to **DRIAS - TRACC 2023** data.
|
118 |
+
- You can only ask about **locations in France**.
|
119 |
+
- If you specify a year, there may be **no data for that year for some models**.
|
120 |
+
- You **cannot compare two models**.
|
121 |
+
|
122 |
+
🛈 **Information**
|
123 |
+
Please note that we **log your questions for meta-analysis purposes**, so avoid sharing any sensitive or personal information.
|
124 |
+
"""
|
climateqa/engine/talk_to_data/drias/plot_informations.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from climateqa.engine.talk_to_data.drias.config import DRIAS_INDICATOR_TO_UNIT
|
2 |
+
|
3 |
+
def indicator_evolution_informations(
|
4 |
+
indicator: str,
|
5 |
+
params: dict[str, str]
|
6 |
+
) -> str:
|
7 |
+
unit = DRIAS_INDICATOR_TO_UNIT[indicator]
|
8 |
+
if "location" not in params:
|
9 |
+
raise ValueError('"location" must be provided in params')
|
10 |
+
location = params["location"]
|
11 |
+
return f"""
|
12 |
+
This plot shows how the climate indicator **{indicator}** evolves over time in **{location}**.
|
13 |
+
|
14 |
+
It combines both historical observations and future projections according to the climate scenario RCP8.5.
|
15 |
+
|
16 |
+
The x-axis represents the years, and the y-axis shows the value of the indicator ({unit}).
|
17 |
+
|
18 |
+
A 10-year rolling average curve is displayed to give a better idea of the overall trend.
|
19 |
+
|
20 |
+
**Data source:**
|
21 |
+
- The data come from the DRIAS TRACC data. The data were initially extracted from [the DRIAS website](https://www.drias-climat.fr/drias_prod/accueil/okapiWebDrias/index.jsp?iddrias=climat) and then preprocessed to a tabular format and uploaded as parquet in this [Hugging Face dataset](https://huggingface.co/datasets/timeki/drias_db).
|
22 |
+
- For each year and climate model, the value of {indicator} in {location} is collected, to build the time series.
|
23 |
+
- The coordinates used for {location} correspond to the closest available point in the DRIAS database, which uses a regular grid with a spatial resolution of 8 km.
|
24 |
+
- The indicator values shown are those for the selected climate model.
|
25 |
+
- If ALL climate model is selected, the average value of the indicator between all the climate models is used.
|
26 |
+
"""
|
27 |
+
|
28 |
+
def indicator_number_of_days_per_year_informations(
|
29 |
+
indicator: str,
|
30 |
+
params: dict[str, str]
|
31 |
+
) -> str:
|
32 |
+
unit = DRIAS_INDICATOR_TO_UNIT[indicator]
|
33 |
+
if "location" not in params:
|
34 |
+
raise ValueError('"location" must be provided in params')
|
35 |
+
location = params["location"]
|
36 |
+
return f"""
|
37 |
+
This plot displays a bar chart showing the yearly frequency of the climate indicator **{indicator}** in **{location}**.
|
38 |
+
|
39 |
+
The x-axis represents the years, and the y-axis shows the frequency of {indicator} ({unit}) per year.
|
40 |
+
|
41 |
+
**Data source:**
|
42 |
+
- The data come from the DRIAS TRACC data. The data were initially extracted from [the DRIAS website](https://www.drias-climat.fr/drias_prod/accueil/okapiWebDrias/index.jsp?iddrias=climat) and then preprocessed to a tabular format and uploaded as parquet in this [Hugging Face dataset](https://huggingface.co/datasets/timeki/drias_db).
|
43 |
+
- For each year and climate model, the value of {indicator} in {location} is collected, to build the time series.
|
44 |
+
- The coordinates used for {location} correspond to the closest available point in the DRIAS database, which uses a regular grid with a spatial resolution of 8 km.
|
45 |
+
- The indicator values shown are those for the selected climate model.
|
46 |
+
- If ALL climate model is selected, the average value of the indicator between all the climate models is used.
|
47 |
+
"""
|
48 |
+
|
49 |
+
def distribution_of_indicator_for_given_year_informations(
|
50 |
+
indicator: str,
|
51 |
+
params: dict[str, str]
|
52 |
+
) -> str:
|
53 |
+
unit = DRIAS_INDICATOR_TO_UNIT[indicator]
|
54 |
+
year = params["year"]
|
55 |
+
if year is None:
|
56 |
+
year = 2030
|
57 |
+
return f"""
|
58 |
+
This plot shows a histogram of the distribution of the climate indicator **{indicator}** across all locations for the year **{year}**.
|
59 |
+
|
60 |
+
It allows you to visualize how the values of {indicator} ({unit}) are spread for a given year.
|
61 |
+
|
62 |
+
**Data source:**
|
63 |
+
- The data come from the DRIAS TRACC data. The data were initially extracted from [the DRIAS website](https://www.drias-climat.fr/drias_prod/accueil/okapiWebDrias/index.jsp?iddrias=climat) and then preprocessed to a tabular format and uploaded as parquet in this [Hugging Face dataset](https://huggingface.co/datasets/timeki/drias_db).
|
64 |
+
- For each grid point in the dataset and climate model, the value of {indicator} for the year {year} is extracted.
|
65 |
+
- The indicator values shown are those for the selected climate model.
|
66 |
+
- If ALL climate model is selected, the average value of the indicator between all the climate models is used.
|
67 |
+
"""
|
68 |
+
|
69 |
+
def map_of_france_of_indicator_for_given_year_informations(
|
70 |
+
indicator: str,
|
71 |
+
params: dict[str, str]
|
72 |
+
) -> str:
|
73 |
+
unit = DRIAS_INDICATOR_TO_UNIT[indicator]
|
74 |
+
year = params["year"]
|
75 |
+
if year is None:
|
76 |
+
year = 2030
|
77 |
+
return f"""
|
78 |
+
This plot displays a choropleth map showing the spatial distribution of **{indicator}** across all regions of France for the year **{year}**.
|
79 |
+
|
80 |
+
Each region is colored according to the value of the indicator ({unit}), allowing you to visually compare how {indicator} varies geographically within France for the selected year and climate model.
|
81 |
+
|
82 |
+
**Data source:**
|
83 |
+
- The data come from the DRIAS TRACC data. The data were initially extracted from [the DRIAS website](https://www.drias-climat.fr/drias_prod/accueil/okapiWebDrias/index.jsp?iddrias=climat) and then preprocessed to a tabular format and uploaded as parquet in this [Hugging Face dataset](https://huggingface.co/datasets/timeki/drias_db).
|
84 |
+
- For each region of France, the value of {indicator} in {year} and for the selected climate model is extracted and mapped to its geographic coordinates.
|
85 |
+
- The regions correspond to 8 km squares centered on the grid points of the DRIAS dataset.
|
86 |
+
- The indicator values shown are those for the selected climate model.
|
87 |
+
- If ALL climate model is selected, the average value of the indicator between all the climate models is used.
|
88 |
+
"""
|
climateqa/engine/talk_to_data/{plot.py → drias/plots.py}
RENAMED
@@ -1,38 +1,39 @@
|
|
1 |
-
|
2 |
-
|
|
|
|
|
3 |
import pandas as pd
|
4 |
from plotly.graph_objects import Figure
|
5 |
import plotly.graph_objects as go
|
6 |
-
|
7 |
-
|
8 |
-
from climateqa.engine.talk_to_data.
|
9 |
indicator_for_given_year_query,
|
10 |
indicator_per_year_at_location_query,
|
11 |
)
|
12 |
-
from climateqa.engine.talk_to_data.config import
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
|
37 |
def plot_indicator_evolution_at_location(params: dict) -> Callable[..., Figure]:
|
38 |
"""Generates a function to plot indicator evolution over time at a location.
|
@@ -61,7 +62,7 @@ def plot_indicator_evolution_at_location(params: dict) -> Callable[..., Figure]:
|
|
61 |
indicator = params["indicator_column"]
|
62 |
location = params["location"]
|
63 |
indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
|
64 |
-
unit =
|
65 |
|
66 |
def plot_data(df: pd.DataFrame) -> Figure:
|
67 |
"""Generates the actual plot from the data.
|
@@ -145,10 +146,11 @@ def plot_indicator_evolution_at_location(params: dict) -> Callable[..., Figure]:
|
|
145 |
hovertemplate=f"{indicator_label}: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
|
146 |
)
|
147 |
fig.update_layout(
|
148 |
-
title=f"
|
149 |
xaxis_title="Year",
|
150 |
yaxis_title=f"{indicator_label} ({unit})",
|
151 |
template="plotly_white",
|
|
|
152 |
)
|
153 |
return fig
|
154 |
|
@@ -161,6 +163,8 @@ indicator_evolution_at_location: Plot = {
|
|
161 |
"params": ["indicator_column", "location", "model"],
|
162 |
"plot_function": plot_indicator_evolution_at_location,
|
163 |
"sql_query": indicator_per_year_at_location_query,
|
|
|
|
|
164 |
}
|
165 |
|
166 |
|
@@ -184,7 +188,7 @@ def plot_indicator_number_of_days_per_year_at_location(
|
|
184 |
indicator = params["indicator_column"]
|
185 |
location = params["location"]
|
186 |
indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
|
187 |
-
unit =
|
188 |
|
189 |
def plot_data(df: pd.DataFrame) -> Figure:
|
190 |
"""Generate the figure thanks to the dataframe
|
@@ -229,6 +233,7 @@ def plot_indicator_number_of_days_per_year_at_location(
|
|
229 |
yaxis_title=f"{indicator_label} ({unit})",
|
230 |
yaxis=dict(range=[0, max(indicators)]),
|
231 |
bargap=0.5,
|
|
|
232 |
template="plotly_white",
|
233 |
)
|
234 |
|
@@ -243,6 +248,8 @@ indicator_number_of_days_per_year_at_location: Plot = {
|
|
243 |
"params": ["indicator_column", "location", "model"],
|
244 |
"plot_function": plot_indicator_number_of_days_per_year_at_location,
|
245 |
"sql_query": indicator_per_year_at_location_query,
|
|
|
|
|
246 |
}
|
247 |
|
248 |
|
@@ -265,8 +272,10 @@ def plot_distribution_of_indicator_for_given_year(
|
|
265 |
"""
|
266 |
indicator = params["indicator_column"]
|
267 |
year = params["year"]
|
|
|
|
|
268 |
indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
|
269 |
-
unit =
|
270 |
|
271 |
def plot_data(df: pd.DataFrame) -> Figure:
|
272 |
"""Generate the figure thanks to the dataframe
|
@@ -311,6 +320,7 @@ def plot_distribution_of_indicator_for_given_year(
|
|
311 |
yaxis_title="Frequency (%)",
|
312 |
plot_bgcolor="rgba(0, 0, 0, 0)",
|
313 |
showlegend=False,
|
|
|
314 |
)
|
315 |
|
316 |
return fig
|
@@ -324,6 +334,8 @@ distribution_of_indicator_for_given_year: Plot = {
|
|
324 |
"params": ["indicator_column", "model", "year"],
|
325 |
"plot_function": plot_distribution_of_indicator_for_given_year,
|
326 |
"sql_query": indicator_for_given_year_query,
|
|
|
|
|
327 |
}
|
328 |
|
329 |
|
@@ -346,8 +358,10 @@ def plot_map_of_france_of_indicator_for_given_year(
|
|
346 |
"""
|
347 |
indicator = params["indicator_column"]
|
348 |
year = params["year"]
|
|
|
|
|
349 |
indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
|
350 |
-
unit =
|
351 |
|
352 |
def plot_data(df: pd.DataFrame) -> Figure:
|
353 |
fig = go.Figure()
|
@@ -371,27 +385,28 @@ def plot_map_of_france_of_indicator_for_given_year(
|
|
371 |
model_label = f"Model : {df['model'].unique()[0]}"
|
372 |
|
373 |
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
)
|
391 |
|
392 |
fig.update_layout(
|
393 |
mapbox_style="open-street-map", # Use OpenStreetMap
|
394 |
-
mapbox_zoom=
|
|
|
395 |
mapbox_center={"lat": 46.6, "lon": 2.0},
|
396 |
coloraxis_colorbar=dict(title=f"{indicator_label} ({unit})"), # Add legend
|
397 |
title=f"{indicator_label} in {year} in France ({model_label}) " # Title
|
@@ -403,16 +418,17 @@ def plot_map_of_france_of_indicator_for_given_year(
|
|
403 |
|
404 |
map_of_france_of_indicator_for_given_year: Plot = {
|
405 |
"name": "Map of France of an indicator for a given year",
|
406 |
-
"description": "Heatmap on the map of France of the values of an
|
407 |
"params": ["indicator_column", "year", "model"],
|
408 |
"plot_function": plot_map_of_france_of_indicator_for_given_year,
|
409 |
"sql_query": indicator_for_given_year_query,
|
|
|
|
|
410 |
}
|
411 |
|
412 |
-
|
413 |
-
PLOTS = [
|
414 |
indicator_evolution_at_location,
|
415 |
indicator_number_of_days_per_year_at_location,
|
416 |
distribution_of_indicator_for_given_year,
|
417 |
map_of_france_of_indicator_for_given_year,
|
418 |
-
]
|
|
|
1 |
+
import os
|
2 |
+
import geojson
|
3 |
+
from math import cos, radians
|
4 |
+
from typing import Callable
|
5 |
import pandas as pd
|
6 |
from plotly.graph_objects import Figure
|
7 |
import plotly.graph_objects as go
|
8 |
+
from climateqa.engine.talk_to_data.drias.plot_informations import distribution_of_indicator_for_given_year_informations, indicator_evolution_informations, indicator_number_of_days_per_year_informations, map_of_france_of_indicator_for_given_year_informations
|
9 |
+
from climateqa.engine.talk_to_data.objects.plot import Plot
|
10 |
+
from climateqa.engine.talk_to_data.drias.queries import (
|
11 |
indicator_for_given_year_query,
|
12 |
indicator_per_year_at_location_query,
|
13 |
)
|
14 |
+
from climateqa.engine.talk_to_data.drias.config import DRIAS_INDICATOR_TO_COLORSCALE, DRIAS_INDICATOR_TO_UNIT
|
15 |
+
|
16 |
+
def generate_geojson_polygons(latitudes: list[float], longitudes: list[float], indicators: list[float]) -> geojson.FeatureCollection:
|
17 |
+
side_km = 8
|
18 |
+
delta_lat = side_km / 111
|
19 |
+
features = []
|
20 |
+
for idx, (lat, lon, val) in enumerate(zip(latitudes, longitudes, indicators)):
|
21 |
+
delta_lon = side_km / (111 * cos(radians(lat)))
|
22 |
+
half_lat = delta_lat / 2
|
23 |
+
half_lon = delta_lon / 2
|
24 |
+
features.append(geojson.Feature(
|
25 |
+
geometry=geojson.Polygon([[
|
26 |
+
[lon - half_lon, lat - half_lat],
|
27 |
+
[lon + half_lon, lat - half_lat],
|
28 |
+
[lon + half_lon, lat + half_lat],
|
29 |
+
[lon - half_lon, lat + half_lat],
|
30 |
+
[lon - half_lon, lat - half_lat]
|
31 |
+
]]),
|
32 |
+
properties={"value": val},
|
33 |
+
id=str(idx)
|
34 |
+
))
|
35 |
+
|
36 |
+
return geojson.FeatureCollection(features)
|
|
|
37 |
|
38 |
def plot_indicator_evolution_at_location(params: dict) -> Callable[..., Figure]:
|
39 |
"""Generates a function to plot indicator evolution over time at a location.
|
|
|
62 |
indicator = params["indicator_column"]
|
63 |
location = params["location"]
|
64 |
indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
|
65 |
+
unit = DRIAS_INDICATOR_TO_UNIT.get(indicator, "")
|
66 |
|
67 |
def plot_data(df: pd.DataFrame) -> Figure:
|
68 |
"""Generates the actual plot from the data.
|
|
|
146 |
hovertemplate=f"{indicator_label}: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
|
147 |
)
|
148 |
fig.update_layout(
|
149 |
+
title=f"Evolution of {indicator_label} in {location} ({model_label})",
|
150 |
xaxis_title="Year",
|
151 |
yaxis_title=f"{indicator_label} ({unit})",
|
152 |
template="plotly_white",
|
153 |
+
height=900,
|
154 |
)
|
155 |
return fig
|
156 |
|
|
|
163 |
"params": ["indicator_column", "location", "model"],
|
164 |
"plot_function": plot_indicator_evolution_at_location,
|
165 |
"sql_query": indicator_per_year_at_location_query,
|
166 |
+
"plot_information": indicator_evolution_informations,
|
167 |
+
'short_name': 'Evolution'
|
168 |
}
|
169 |
|
170 |
|
|
|
188 |
indicator = params["indicator_column"]
|
189 |
location = params["location"]
|
190 |
indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
|
191 |
+
unit = DRIAS_INDICATOR_TO_UNIT.get(indicator, "")
|
192 |
|
193 |
def plot_data(df: pd.DataFrame) -> Figure:
|
194 |
"""Generate the figure thanks to the dataframe
|
|
|
233 |
yaxis_title=f"{indicator_label} ({unit})",
|
234 |
yaxis=dict(range=[0, max(indicators)]),
|
235 |
bargap=0.5,
|
236 |
+
height=900,
|
237 |
template="plotly_white",
|
238 |
)
|
239 |
|
|
|
248 |
"params": ["indicator_column", "location", "model"],
|
249 |
"plot_function": plot_indicator_number_of_days_per_year_at_location,
|
250 |
"sql_query": indicator_per_year_at_location_query,
|
251 |
+
"plot_information": indicator_number_of_days_per_year_informations,
|
252 |
+
"short_name": "Yearly Frequency",
|
253 |
}
|
254 |
|
255 |
|
|
|
272 |
"""
|
273 |
indicator = params["indicator_column"]
|
274 |
year = params["year"]
|
275 |
+
if year is None:
|
276 |
+
year = 2030
|
277 |
indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
|
278 |
+
unit = DRIAS_INDICATOR_TO_UNIT.get(indicator, "")
|
279 |
|
280 |
def plot_data(df: pd.DataFrame) -> Figure:
|
281 |
"""Generate the figure thanks to the dataframe
|
|
|
320 |
yaxis_title="Frequency (%)",
|
321 |
plot_bgcolor="rgba(0, 0, 0, 0)",
|
322 |
showlegend=False,
|
323 |
+
height=900,
|
324 |
)
|
325 |
|
326 |
return fig
|
|
|
334 |
"params": ["indicator_column", "model", "year"],
|
335 |
"plot_function": plot_distribution_of_indicator_for_given_year,
|
336 |
"sql_query": indicator_for_given_year_query,
|
337 |
+
"plot_information": distribution_of_indicator_for_given_year_informations,
|
338 |
+
'short_name': 'Distribution'
|
339 |
}
|
340 |
|
341 |
|
|
|
358 |
"""
|
359 |
indicator = params["indicator_column"]
|
360 |
year = params["year"]
|
361 |
+
if year is None:
|
362 |
+
year = 2030
|
363 |
indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
|
364 |
+
unit = DRIAS_INDICATOR_TO_UNIT.get(indicator, "")
|
365 |
|
366 |
def plot_data(df: pd.DataFrame) -> Figure:
|
367 |
fig = go.Figure()
|
|
|
385 |
model_label = f"Model : {df['model'].unique()[0]}"
|
386 |
|
387 |
|
388 |
+
|
389 |
+
geojson_data = generate_geojson_polygons(latitudes, longitudes, indicators)
|
390 |
+
|
391 |
+
fig = go.Figure(go.Choroplethmapbox(
|
392 |
+
geojson=geojson_data,
|
393 |
+
locations=[str(i) for i in range(len(indicators))],
|
394 |
+
featureidkey="id",
|
395 |
+
z=indicators,
|
396 |
+
colorscale=DRIAS_INDICATOR_TO_COLORSCALE[indicator],
|
397 |
+
zmin=min(indicators),
|
398 |
+
zmax=max(indicators),
|
399 |
+
marker_opacity=0.7,
|
400 |
+
marker_line_width=0,
|
401 |
+
colorbar_title=f"{indicator_label} ({unit})",
|
402 |
+
text=[f"{indicator_label}: {value:.2f} {unit}" for value in indicators], # Add hover text showing the indicator value
|
403 |
+
hoverinfo="text"
|
404 |
+
))
|
405 |
|
406 |
fig.update_layout(
|
407 |
mapbox_style="open-street-map", # Use OpenStreetMap
|
408 |
+
mapbox_zoom=5,
|
409 |
+
height=900,
|
410 |
mapbox_center={"lat": 46.6, "lon": 2.0},
|
411 |
coloraxis_colorbar=dict(title=f"{indicator_label} ({unit})"), # Add legend
|
412 |
title=f"{indicator_label} in {year} in France ({model_label}) " # Title
|
|
|
418 |
|
419 |
map_of_france_of_indicator_for_given_year: Plot = {
|
420 |
"name": "Map of France of an indicator for a given year",
|
421 |
+
"description": "Heatmap on the map of France of the values of an indicator for a given year",
|
422 |
"params": ["indicator_column", "year", "model"],
|
423 |
"plot_function": plot_map_of_france_of_indicator_for_given_year,
|
424 |
"sql_query": indicator_for_given_year_query,
|
425 |
+
"plot_information": map_of_france_of_indicator_for_given_year_informations,
|
426 |
+
'short_name': 'Map of France'
|
427 |
}
|
428 |
|
429 |
+
DRIAS_PLOTS = [
|
|
|
430 |
indicator_evolution_at_location,
|
431 |
indicator_number_of_days_per_year_at_location,
|
432 |
distribution_of_indicator_for_given_year,
|
433 |
map_of_france_of_indicator_for_given_year,
|
434 |
+
]
|
climateqa/engine/talk_to_data/{sql_query.py → drias/queries.py}
RENAMED
@@ -1,37 +1,5 @@
|
|
1 |
-
import asyncio
|
2 |
-
from concurrent.futures import ThreadPoolExecutor
|
3 |
from typing import TypedDict
|
4 |
-
import
|
5 |
-
import pandas as pd
|
6 |
-
|
7 |
-
async def execute_sql_query(sql_query: str) -> pd.DataFrame:
|
8 |
-
"""Executes a SQL query on the DRIAS database and returns the results.
|
9 |
-
|
10 |
-
This function connects to the DuckDB database containing DRIAS climate data
|
11 |
-
and executes the provided SQL query. It handles the database connection and
|
12 |
-
returns the results as a pandas DataFrame.
|
13 |
-
|
14 |
-
Args:
|
15 |
-
sql_query (str): The SQL query to execute
|
16 |
-
|
17 |
-
Returns:
|
18 |
-
pd.DataFrame: A DataFrame containing the query results
|
19 |
-
|
20 |
-
Raises:
|
21 |
-
duckdb.Error: If there is an error executing the SQL query
|
22 |
-
"""
|
23 |
-
def _execute_query():
|
24 |
-
# Execute the query
|
25 |
-
con = duckdb.connect()
|
26 |
-
results = con.sql(sql_query).fetchdf()
|
27 |
-
# return fetched data
|
28 |
-
return results
|
29 |
-
|
30 |
-
# Run the query in a thread pool to avoid blocking
|
31 |
-
loop = asyncio.get_event_loop()
|
32 |
-
with ThreadPoolExecutor() as executor:
|
33 |
-
return await loop.run_in_executor(executor, _execute_query)
|
34 |
-
|
35 |
|
36 |
class IndicatorPerYearAtLocationQueryParams(TypedDict, total=False):
|
37 |
"""Parameters for querying an indicator's values over time at a location.
|
@@ -50,7 +18,6 @@ class IndicatorPerYearAtLocationQueryParams(TypedDict, total=False):
|
|
50 |
longitude: str
|
51 |
model: str
|
52 |
|
53 |
-
|
54 |
def indicator_per_year_at_location_query(
|
55 |
table: str, params: IndicatorPerYearAtLocationQueryParams
|
56 |
) -> str:
|
@@ -70,7 +37,7 @@ def indicator_per_year_at_location_query(
|
|
70 |
if indicator_column is None or latitude is None or longitude is None: # If one parameter is missing, returns an empty query
|
71 |
return ""
|
72 |
|
73 |
-
table = f"'
|
74 |
|
75 |
sql_query = f"SELECT year, {indicator_column}, model\nFROM {table}\nWHERE latitude = {latitude} \nAnd longitude = {longitude} \nOrder by Year"
|
76 |
|
@@ -105,10 +72,12 @@ def indicator_for_given_year_query(
|
|
105 |
"""
|
106 |
indicator_column = params.get("indicator_column")
|
107 |
year = params.get('year')
|
|
|
|
|
108 |
if year is None or indicator_column is None: # If one parameter is missing, returns an empty query
|
109 |
return ""
|
110 |
|
111 |
-
table = f"'
|
112 |
|
113 |
sql_query = f"Select {indicator_column}, latitude, longitude, model\nFrom {table}\nWhere year = {year}"
|
114 |
return sql_query
|
|
|
|
|
|
|
1 |
from typing import TypedDict
|
2 |
+
from climateqa.engine.talk_to_data.config import DRIAS_DATASET_URL
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
class IndicatorPerYearAtLocationQueryParams(TypedDict, total=False):
|
5 |
"""Parameters for querying an indicator's values over time at a location.
|
|
|
18 |
longitude: str
|
19 |
model: str
|
20 |
|
|
|
21 |
def indicator_per_year_at_location_query(
|
22 |
table: str, params: IndicatorPerYearAtLocationQueryParams
|
23 |
) -> str:
|
|
|
37 |
if indicator_column is None or latitude is None or longitude is None: # If one parameter is missing, returns an empty query
|
38 |
return ""
|
39 |
|
40 |
+
table = f"'{DRIAS_DATASET_URL}/{table.lower()}.parquet'"
|
41 |
|
42 |
sql_query = f"SELECT year, {indicator_column}, model\nFROM {table}\nWHERE latitude = {latitude} \nAnd longitude = {longitude} \nOrder by Year"
|
43 |
|
|
|
72 |
"""
|
73 |
indicator_column = params.get("indicator_column")
|
74 |
year = params.get('year')
|
75 |
+
if year is None:
|
76 |
+
year = 2050
|
77 |
if year is None or indicator_column is None: # If one parameter is missing, returns an empty query
|
78 |
return ""
|
79 |
|
80 |
+
table = f"'{DRIAS_DATASET_URL}/{table.lower()}.parquet'"
|
81 |
|
82 |
sql_query = f"Select {indicator_column}, latitude, longitude, model\nFrom {table}\nWhere year = {year}"
|
83 |
return sql_query
|
climateqa/engine/talk_to_data/{utils.py → input_processing.py}
RENAMED
@@ -1,15 +1,17 @@
|
|
1 |
-
import
|
2 |
-
from typing import Annotated, TypedDict
|
3 |
-
import duckdb
|
4 |
-
from geopy.geocoders import Nominatim
|
5 |
import ast
|
6 |
-
from climateqa.engine.llm import get_llm
|
7 |
-
from climateqa.engine.talk_to_data.config import DRIAS_TABLES
|
8 |
-
from climateqa.engine.talk_to_data.plot import PLOTS, Plot
|
9 |
from langchain_core.prompts import ChatPromptTemplate
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
"""
|
14 |
Detects locations in a sentence using OpenAI's API via LangChain.
|
15 |
"""
|
@@ -29,63 +31,7 @@ async def detect_location_with_openai(sentence):
|
|
29 |
else:
|
30 |
return ""
|
31 |
|
32 |
-
|
33 |
-
"""Represents the output of a function that returns an array.
|
34 |
-
|
35 |
-
This class is used to type-hint functions that return arrays,
|
36 |
-
ensuring consistent return types across the codebase.
|
37 |
-
|
38 |
-
Attributes:
|
39 |
-
array (str): A syntactically valid Python array string
|
40 |
-
"""
|
41 |
-
array: Annotated[str, "Syntactically valid python array."]
|
42 |
-
|
43 |
-
async def detect_year_with_openai(sentence: str) -> str:
|
44 |
-
"""
|
45 |
-
Detects years in a sentence using OpenAI's API via LangChain.
|
46 |
-
"""
|
47 |
-
llm = get_llm()
|
48 |
-
|
49 |
-
prompt = """
|
50 |
-
Extract all years mentioned in the following sentence.
|
51 |
-
Return the result as a Python list. If no year are mentioned, return an empty list.
|
52 |
-
|
53 |
-
Sentence: "{sentence}"
|
54 |
-
"""
|
55 |
-
|
56 |
-
prompt = ChatPromptTemplate.from_template(prompt)
|
57 |
-
structured_llm = llm.with_structured_output(ArrayOutput)
|
58 |
-
chain = prompt | structured_llm
|
59 |
-
response: ArrayOutput = await chain.ainvoke({"sentence": sentence})
|
60 |
-
years_list = eval(response['array'])
|
61 |
-
if len(years_list) > 0:
|
62 |
-
return years_list[0]
|
63 |
-
else:
|
64 |
-
return ""
|
65 |
-
|
66 |
-
|
67 |
-
def detectTable(sql_query: str) -> list[str]:
|
68 |
-
"""Extracts table names from a SQL query.
|
69 |
-
|
70 |
-
This function uses regular expressions to find all table names
|
71 |
-
referenced in a SQL query's FROM clause.
|
72 |
-
|
73 |
-
Args:
|
74 |
-
sql_query (str): The SQL query to analyze
|
75 |
-
|
76 |
-
Returns:
|
77 |
-
list[str]: A list of table names found in the query
|
78 |
-
|
79 |
-
Example:
|
80 |
-
>>> detectTable("SELECT * FROM temperature_data WHERE year > 2000")
|
81 |
-
['temperature_data']
|
82 |
-
"""
|
83 |
-
pattern = r'(?i)\bFROM\s+((?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+)(?:\.(?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+))*)'
|
84 |
-
matches = re.findall(pattern, sql_query)
|
85 |
-
return matches
|
86 |
-
|
87 |
-
|
88 |
-
def loc2coords(location: str) -> tuple[float, float]:
|
89 |
"""Converts a location name to geographic coordinates.
|
90 |
|
91 |
This function uses the Nominatim geocoding service to convert
|
@@ -104,49 +50,77 @@ def loc2coords(location: str) -> tuple[float, float]:
|
|
104 |
coords = geolocator.geocode(location)
|
105 |
return (coords.latitude, coords.longitude)
|
106 |
|
107 |
-
|
108 |
-
|
109 |
-
"""Converts geographic coordinates to a location name.
|
110 |
|
111 |
This function uses the Nominatim reverse geocoding service to convert
|
112 |
-
latitude and longitude coordinates to a
|
113 |
|
114 |
Args:
|
115 |
coords (tuple[float, float]): A tuple containing (latitude, longitude)
|
116 |
|
117 |
Returns:
|
118 |
-
str:
|
119 |
|
120 |
-
|
121 |
-
|
122 |
-
'Paris, France'
|
123 |
"""
|
124 |
-
geolocator = Nominatim(user_agent="
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
except Exception as e:
|
129 |
-
print(f"Error: {e}")
|
130 |
-
return "Unknown Location"
|
131 |
-
|
132 |
|
133 |
-
def
|
134 |
long = round(location[1], 3)
|
135 |
lat = round(location[0], 3)
|
|
|
136 |
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
|
143 |
if len(results) == 0:
|
144 |
-
return "", ""
|
145 |
-
|
146 |
-
|
|
|
|
|
|
|
|
|
147 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
|
149 |
-
async def detect_relevant_tables(user_question: str, plot: Plot, llm) -> list[str]:
|
150 |
"""Identifies relevant tables for a plot based on user input.
|
151 |
|
152 |
This function uses an LLM to analyze the user's question and the plot
|
@@ -170,7 +144,6 @@ async def detect_relevant_tables(user_question: str, plot: Plot, llm) -> list[st
|
|
170 |
['mean_annual_temperature', 'mean_summer_temperature']
|
171 |
"""
|
172 |
# Get all table names
|
173 |
-
table_names_list = DRIAS_TABLES
|
174 |
|
175 |
prompt = (
|
176 |
f"You are helping to build a plot following this description : {plot['description']}."
|
@@ -187,95 +160,98 @@ async def detect_relevant_tables(user_question: str, plot: Plot, llm) -> list[st
|
|
187 |
)
|
188 |
return table_names
|
189 |
|
190 |
-
|
191 |
-
def replace_coordonates(coords, query, coords_tables):
|
192 |
-
n = query.count(str(coords[0]))
|
193 |
-
|
194 |
-
for i in range(n):
|
195 |
-
query = query.replace(str(coords[0]), str(coords_tables[i][0]), 1)
|
196 |
-
query = query.replace(str(coords[1]), str(coords_tables[i][1]), 1)
|
197 |
-
return query
|
198 |
-
|
199 |
-
|
200 |
-
async def detect_relevant_plots(user_question: str, llm):
|
201 |
plots_description = ""
|
202 |
-
for plot in
|
203 |
plots_description += "Name: " + plot["name"]
|
204 |
plots_description += " - Description: " + plot["description"] + "\n"
|
205 |
|
206 |
prompt = (
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
|
|
|
|
211 |
f"### Descriptions of the plots : {plots_description}"
|
212 |
-
f"### User question : {user_question}"
|
213 |
-
f"###
|
214 |
)
|
215 |
-
# prompt = (
|
216 |
-
# f"You are helping to answer a question with insightful visualizations. "
|
217 |
-
# f"Given a list of plots with their name and description: "
|
218 |
-
# f"{plots_description} "
|
219 |
-
# f"The user question is: {user_question}. "
|
220 |
-
# f"Choose the most relevant plots to answer the question. "
|
221 |
-
# f"The answer must be a Python list with the names of the relevant plots, and nothing else. "
|
222 |
-
# f"Ensure the response is in the exact format: ['PlotName1', 'PlotName2']."
|
223 |
-
# )
|
224 |
|
225 |
plot_names = ast.literal_eval(
|
226 |
(await llm.ainvoke(prompt)).content.strip("```python\n").strip()
|
227 |
)
|
228 |
return plot_names
|
229 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
# query: Annotated[str, ..., "Syntactically valid SQL query."]
|
236 |
-
|
237 |
-
|
238 |
-
# class PlotlyCodeOutput(TypedDict):
|
239 |
-
# """Generated Plotly code"""
|
240 |
-
|
241 |
-
# code: Annotated[str, ..., "Synatically valid Plotly python code."]
|
242 |
-
# def write_sql_query(user_input: str, db: SQLDatabase, relevant_tables: list[str], llm):
|
243 |
-
# """Generate SQL query to fetch information."""
|
244 |
-
# prompt_params = {
|
245 |
-
# "dialect": db.dialect,
|
246 |
-
# "table_info": db.get_table_info(),
|
247 |
-
# "input": user_input,
|
248 |
-
# "relevant_tables": relevant_tables,
|
249 |
-
# "model": "ALADIN63_CNRM-CM5",
|
250 |
-
# }
|
251 |
-
|
252 |
-
# prompt = ChatPromptTemplate.from_template(query_prompt_template)
|
253 |
-
# structured_llm = llm.with_structured_output(QueryOutput)
|
254 |
-
# chain = prompt | structured_llm
|
255 |
-
# result = chain.invoke(prompt_params)
|
256 |
-
|
257 |
-
# return result["query"]
|
258 |
-
|
259 |
-
|
260 |
-
# def fetch_data_from_sql_query(db: str, sql_query: str):
|
261 |
-
# conn = sqlite3.connect(db)
|
262 |
-
# cursor = conn.cursor()
|
263 |
-
# cursor.execute(sql_query)
|
264 |
-
# column_names = [desc[0] for desc in cursor.description]
|
265 |
-
# values = cursor.fetchall()
|
266 |
-
# return {"column_names": column_names, "data": values}
|
267 |
-
|
268 |
|
269 |
-
|
270 |
-
|
|
|
|
|
271 |
|
272 |
-
|
273 |
-
|
274 |
|
275 |
-
|
|
|
|
|
|
|
276 |
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Literal, Optional, cast
|
|
|
|
|
|
|
2 |
import ast
|
|
|
|
|
|
|
3 |
from langchain_core.prompts import ChatPromptTemplate
|
4 |
+
from geopy.geocoders import Nominatim
|
5 |
+
from climateqa.engine.llm import get_llm
|
6 |
+
import duckdb
|
7 |
+
import os
|
8 |
+
from climateqa.engine.talk_to_data.config import DRIAS_MEAN_ANNUAL_TEMPERATURE_PATH, IPCC_COORDINATES_PATH
|
9 |
+
from climateqa.engine.talk_to_data.objects.llm_outputs import ArrayOutput
|
10 |
+
from climateqa.engine.talk_to_data.objects.location import Location
|
11 |
+
from climateqa.engine.talk_to_data.objects.plot import Plot
|
12 |
+
from climateqa.engine.talk_to_data.objects.states import State
|
13 |
+
|
14 |
+
async def detect_location_with_openai(sentence: str) -> str:
|
15 |
"""
|
16 |
Detects locations in a sentence using OpenAI's API via LangChain.
|
17 |
"""
|
|
|
31 |
else:
|
32 |
return ""
|
33 |
|
34 |
+
def loc_to_coords(location: str) -> tuple[float, float]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
"""Converts a location name to geographic coordinates.
|
36 |
|
37 |
This function uses the Nominatim geocoding service to convert
|
|
|
50 |
coords = geolocator.geocode(location)
|
51 |
return (coords.latitude, coords.longitude)
|
52 |
|
53 |
+
def coords_to_country(coords: tuple[float, float]) -> tuple[str,str]:
|
54 |
+
"""Converts geographic coordinates to a country name.
|
|
|
55 |
|
56 |
This function uses the Nominatim reverse geocoding service to convert
|
57 |
+
latitude and longitude coordinates to a country name.
|
58 |
|
59 |
Args:
|
60 |
coords (tuple[float, float]): A tuple containing (latitude, longitude)
|
61 |
|
62 |
Returns:
|
63 |
+
tuple[str,str]: A tuple containg (country_code, country_name, admin1)
|
64 |
|
65 |
+
Raises:
|
66 |
+
AttributeError: If the coordinates cannot be found
|
|
|
67 |
"""
|
68 |
+
geolocator = Nominatim(user_agent="latlong_to_country")
|
69 |
+
location = geolocator.reverse(coords)
|
70 |
+
address = location.raw['address']
|
71 |
+
return address['country_code'].upper(), address['country']
|
|
|
|
|
|
|
|
|
72 |
|
73 |
+
def nearest_neighbour_sql(location: tuple, mode: Literal['DRIAS', 'IPCC']) -> tuple[str, str, Optional[str]]:
|
74 |
long = round(location[1], 3)
|
75 |
lat = round(location[0], 3)
|
76 |
+
conn = duckdb.connect()
|
77 |
|
78 |
+
if mode == 'DRIAS':
|
79 |
+
table_path = f"'{DRIAS_MEAN_ANNUAL_TEMPERATURE_PATH}'"
|
80 |
+
results = conn.sql(
|
81 |
+
f"SELECT latitude, longitude FROM {table_path} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}"
|
82 |
+
).fetchdf()
|
83 |
+
else:
|
84 |
+
table_path = f"'{IPCC_COORDINATES_PATH}'"
|
85 |
+
results = conn.sql(
|
86 |
+
f"SELECT latitude, longitude, admin1 FROM {table_path} WHERE latitude BETWEEN {lat - 0.5} AND {lat + 0.5} AND longitude BETWEEN {long - 0.5} AND {long + 0.5}"
|
87 |
+
).fetchdf()
|
88 |
+
|
89 |
|
90 |
if len(results) == 0:
|
91 |
+
return "", "", ""
|
92 |
+
|
93 |
+
if 'admin1' in results.columns:
|
94 |
+
admin1 = results['admin1'].iloc[0]
|
95 |
+
else:
|
96 |
+
admin1 = None
|
97 |
+
return results['latitude'].iloc[0], results['longitude'].iloc[0], admin1
|
98 |
|
99 |
+
async def detect_year_with_openai(sentence: str) -> str:
|
100 |
+
"""
|
101 |
+
Detects years in a sentence using OpenAI's API via LangChain.
|
102 |
+
"""
|
103 |
+
llm = get_llm()
|
104 |
+
|
105 |
+
prompt = """
|
106 |
+
Extract all years mentioned in the following sentence.
|
107 |
+
Return the result as a Python list. If no year are mentioned, return an empty list.
|
108 |
+
|
109 |
+
Sentence: "{sentence}"
|
110 |
+
"""
|
111 |
+
|
112 |
+
prompt = ChatPromptTemplate.from_template(prompt)
|
113 |
+
structured_llm = llm.with_structured_output(ArrayOutput)
|
114 |
+
chain = prompt | structured_llm
|
115 |
+
response: ArrayOutput = await chain.ainvoke({"sentence": sentence})
|
116 |
+
years_list = ast.literal_eval(response['array'])
|
117 |
+
if len(years_list) > 0:
|
118 |
+
return years_list[0]
|
119 |
+
else:
|
120 |
+
return ""
|
121 |
+
|
122 |
|
123 |
+
async def detect_relevant_tables(user_question: str, plot: Plot, llm, table_names_list: list[str]) -> list[str]:
|
124 |
"""Identifies relevant tables for a plot based on user input.
|
125 |
|
126 |
This function uses an LLM to analyze the user's question and the plot
|
|
|
144 |
['mean_annual_temperature', 'mean_summer_temperature']
|
145 |
"""
|
146 |
# Get all table names
|
|
|
147 |
|
148 |
prompt = (
|
149 |
f"You are helping to build a plot following this description : {plot['description']}."
|
|
|
160 |
)
|
161 |
return table_names
|
162 |
|
163 |
+
async def detect_relevant_plots(user_question: str, llm, plot_list: list[Plot]) -> list[str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
plots_description = ""
|
165 |
+
for plot in plot_list:
|
166 |
plots_description += "Name: " + plot["name"]
|
167 |
plots_description += " - Description: " + plot["description"] + "\n"
|
168 |
|
169 |
prompt = (
|
170 |
+
"You are helping to answer a question with insightful visualizations.\n"
|
171 |
+
"You are given a user question and a list of plots with their name and description.\n"
|
172 |
+
"Based on the descriptions of the plots, select ALL plots that could provide a useful answer to this question. "
|
173 |
+
"Include any plot that could show relevant information, even if their perspectives (such as time series or spatial distribution) are different.\n"
|
174 |
+
"For example, for a question like 'What will be the total rainfall in China in 2050?', both a time series plot and a spatial map plot could be relevant.\n"
|
175 |
+
"Return only a Python list of plot names sorted from the most relevant one to the less relevant one.\n"
|
176 |
f"### Descriptions of the plots : {plots_description}"
|
177 |
+
f"### User question : {user_question}\n"
|
178 |
+
f"### Names of the plots : "
|
179 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
|
181 |
plot_names = ast.literal_eval(
|
182 |
(await llm.ainvoke(prompt)).content.strip("```python\n").strip()
|
183 |
)
|
184 |
return plot_names
|
185 |
|
186 |
+
async def find_location(user_input: str, mode: Literal['DRIAS', 'IPCC'] = 'DRIAS') -> Location:
|
187 |
+
print(f"---- Find location in user input ----")
|
188 |
+
location = await detect_location_with_openai(user_input)
|
189 |
+
output: Location = {
|
190 |
+
'location' : location,
|
191 |
+
'longitude' : None,
|
192 |
+
'latitude' : None,
|
193 |
+
'country_code' : None,
|
194 |
+
'country_name' : None,
|
195 |
+
'admin1' : None
|
196 |
+
}
|
197 |
+
|
198 |
+
if location:
|
199 |
+
coords = loc_to_coords(location)
|
200 |
+
country_code, country_name = coords_to_country(coords)
|
201 |
+
neighbour = nearest_neighbour_sql(coords, mode)
|
202 |
+
output.update({
|
203 |
+
"latitude": neighbour[0],
|
204 |
+
"longitude": neighbour[1],
|
205 |
+
"country_code": country_code,
|
206 |
+
"country_name": country_name,
|
207 |
+
"admin1": neighbour[2]
|
208 |
+
})
|
209 |
+
output = cast(Location, output)
|
210 |
+
return output
|
211 |
+
|
212 |
+
async def find_year(user_input: str) -> str| None:
|
213 |
+
"""Extracts year information from user input using LLM.
|
214 |
+
|
215 |
+
This function uses an LLM to identify and extract year information from the
|
216 |
+
user's query, which is used to filter data in subsequent queries.
|
217 |
+
|
218 |
+
Args:
|
219 |
+
user_input (str): The user's query text
|
220 |
+
|
221 |
+
Returns:
|
222 |
+
str: The extracted year, or empty string if no year found
|
223 |
+
"""
|
224 |
+
print(f"---- Find year ---")
|
225 |
+
year = await detect_year_with_openai(user_input)
|
226 |
+
if year == "":
|
227 |
+
return None
|
228 |
+
return year
|
229 |
|
230 |
+
async def find_relevant_plots(state: State, llm, plots: list[Plot]) -> list[str]:
|
231 |
+
print("---- Find relevant plots ----")
|
232 |
+
relevant_plots = await detect_relevant_plots(state['user_input'], llm, plots)
|
233 |
+
return relevant_plots
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
234 |
|
235 |
+
async def find_relevant_tables_per_plot(state: State, plot: Plot, llm, tables: list[str]) -> list[str]:
|
236 |
+
print(f"---- Find relevant tables for {plot['name']} ----")
|
237 |
+
relevant_tables = await detect_relevant_tables(state['user_input'], plot, llm, tables)
|
238 |
+
return relevant_tables
|
239 |
|
240 |
+
async def find_param(state: State, param_name:str, mode: Literal['DRIAS', 'IPCC'] = 'DRIAS') -> dict[str, Optional[str]] | Location | None:
|
241 |
+
"""Perform the good method to retrieve the desired parameter
|
242 |
|
243 |
+
Args:
|
244 |
+
state (State): state of the workflow
|
245 |
+
param_name (str): name of the desired parameter
|
246 |
+
table (str): name of the table
|
247 |
|
248 |
+
Returns:
|
249 |
+
dict[str, Any] | None:
|
250 |
+
"""
|
251 |
+
if param_name == 'location':
|
252 |
+
location = await find_location(state['user_input'], mode)
|
253 |
+
return location
|
254 |
+
if param_name == 'year':
|
255 |
+
year = await find_year(state['user_input'])
|
256 |
+
return {'year': year}
|
257 |
+
return None
|
climateqa/engine/talk_to_data/ipcc/config.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from climateqa.engine.talk_to_data.ui_config import PRECIPITATION_COLORSCALE, TEMPERATURE_COLORSCALE
|
2 |
+
from climateqa.engine.talk_to_data.config import IPCC_DATASET_URL
|
3 |
+
|
4 |
+
|
5 |
+
# IPCC_DATASET_URL = "hf://datasets/ekimetrics/ipcc-atlas"
|
6 |
+
IPCC_TABLES = [
|
7 |
+
"mean_temperature",
|
8 |
+
"total_precipitation",
|
9 |
+
]
|
10 |
+
|
11 |
+
IPCC_INDICATOR_COLUMNS_PER_TABLE = {
|
12 |
+
"mean_temperature": "mean_temperature",
|
13 |
+
"total_precipitation": "total_precipitation"
|
14 |
+
}
|
15 |
+
|
16 |
+
IPCC_INDICATOR_TO_UNIT = {
|
17 |
+
"mean_temperature": "°C",
|
18 |
+
"total_precipitation": "mm/day"
|
19 |
+
}
|
20 |
+
|
21 |
+
IPCC_SCENARIO = [
|
22 |
+
"historical",
|
23 |
+
"ssp126",
|
24 |
+
"ssp245",
|
25 |
+
"ssp370",
|
26 |
+
"ssp585",
|
27 |
+
]
|
28 |
+
|
29 |
+
IPCC_MODELS = []
|
30 |
+
|
31 |
+
IPCC_PLOT_PARAMETERS = [
|
32 |
+
'year',
|
33 |
+
'location'
|
34 |
+
]
|
35 |
+
|
36 |
+
MACRO_COUNTRIES = ['JP',
|
37 |
+
'IN',
|
38 |
+
'MH',
|
39 |
+
'PT',
|
40 |
+
'ID',
|
41 |
+
'SJ',
|
42 |
+
'MX',
|
43 |
+
'CN',
|
44 |
+
'GL',
|
45 |
+
'PN',
|
46 |
+
'AR',
|
47 |
+
'AQ',
|
48 |
+
'PF',
|
49 |
+
'BR',
|
50 |
+
'SH',
|
51 |
+
'GS',
|
52 |
+
'ZA',
|
53 |
+
'NZ',
|
54 |
+
'TF',
|
55 |
+
]
|
56 |
+
|
57 |
+
HUGE_MACRO_COUNTRIES = ['CL',
|
58 |
+
'CA',
|
59 |
+
'AU',
|
60 |
+
'US',
|
61 |
+
'RU'
|
62 |
+
]
|
63 |
+
|
64 |
+
IPCC_INDICATOR_TO_COLORSCALE = {
|
65 |
+
"mean_temperature": TEMPERATURE_COLORSCALE,
|
66 |
+
"total_precipitation": PRECIPITATION_COLORSCALE
|
67 |
+
}
|
68 |
+
|
69 |
+
IPCC_UI_TEXT = """
|
70 |
+
Hi, I'm **Talk to IPCC**, designed to answer your questions using [**IPCC - ATLAS**](https://interactive-atlas.ipcc.ch/regional-information#eyJ0eXBlIjoiQVRMQVMiLCJjb21tb25zIjp7ImxhdCI6OTc3MiwibG5nIjo0MDA2OTIsInpvb20iOjQsInByb2oiOiJFUFNHOjU0MDMwIiwibW9kZSI6ImNvbXBsZXRlX2F0bGFzIn0sInByaW1hcnkiOnsic2NlbmFyaW8iOiJzc3A1ODUiLCJwZXJpb2QiOiIyIiwic2Vhc29uIjoieWVhciIsImRhdGFzZXQiOiJDTUlQNiIsInZhcmlhYmxlIjoidGFzIiwidmFsdWVUeXBlIjoiQU5PTUFMWSIsImhhdGNoaW5nIjoiU0lNUExFIiwicmVnaW9uU2V0IjoiYXI2IiwiYmFzZWxpbmUiOiJwcmVJbmR1c3RyaWFsIiwicmVnaW9uc1NlbGVjdGVkIjpbXX0sInBsb3QiOnsiYWN0aXZlVGFiIjoicGx1bWUiLCJtYXNrIjoibm9uZSIsInNjYXR0ZXJZTWFnIjpudWxsLCJzY2F0dGVyWVZhciI6bnVsbCwic2hvd2luZyI6ZmFsc2V9fQ==) data.
|
71 |
+
I'll answer by displaying a list of SQL queries, graphs and data most relevant to your question.
|
72 |
+
|
73 |
+
You can ask me anything about these climate indicators: **temperature** or **precipitation**.
|
74 |
+
You can specify **location** and/or **year**.
|
75 |
+
By default, we take the **mediane of each climate model**.
|
76 |
+
|
77 |
+
Current available charts :
|
78 |
+
- Yearly evolution of an indicator at a specific location (historical + SSP Projections)
|
79 |
+
- Yearly spatial distribution of an indicator in a specific country
|
80 |
+
|
81 |
+
Current available indicators :
|
82 |
+
- Mean temperature
|
83 |
+
- Total precipitation
|
84 |
+
|
85 |
+
For example, you can ask:
|
86 |
+
- What will the temperature be like in Paris?
|
87 |
+
- What will be the total rainfall in the USA in 2030?
|
88 |
+
- How will the average temperature evolve in China ?
|
89 |
+
|
90 |
+
⚠️ **Limitations**:
|
91 |
+
- You can't ask anything that isn't related to **IPCC - ATLAS** data.
|
92 |
+
- You can not ask about **several locations at the same time**.
|
93 |
+
- If you specify a year **before 1850 or over 2100**, there will be **no data**.
|
94 |
+
- You **cannot compare two models**.
|
95 |
+
|
96 |
+
🛈 **Information**
|
97 |
+
Please note that we **log your questions for meta-analysis purposes**, so avoid sharing any sensitive or personal information.
|
98 |
+
"""
|
climateqa/engine/talk_to_data/ipcc/plot_informations.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from climateqa.engine.talk_to_data.ipcc.config import IPCC_INDICATOR_TO_UNIT
|
2 |
+
|
3 |
+
def indicator_evolution_informations(
|
4 |
+
indicator: str,
|
5 |
+
params: dict[str,str],
|
6 |
+
) -> str:
|
7 |
+
if "location" not in params:
|
8 |
+
raise ValueError('"location" must be provided in params')
|
9 |
+
location = params["location"]
|
10 |
+
|
11 |
+
unit = IPCC_INDICATOR_TO_UNIT[indicator]
|
12 |
+
return f"""
|
13 |
+
This plot shows how the climate indicator **{indicator}** evolves over time in **{location}**.
|
14 |
+
|
15 |
+
It combines both historical (from 1950 to 2015) observations and future (from 2016 to 2100) projections for the different SSP climate scenarios (SSP126, SSP245, SSP370 and SSP585).
|
16 |
+
|
17 |
+
The x-axis represents the years (from 1950 to 2100), and the y-axis shows the value of the {indicator} ({unit}).
|
18 |
+
|
19 |
+
Each line corresponds to a different scenario, allowing you to compare how {indicator} might change under various future conditions.
|
20 |
+
|
21 |
+
**Data source:**
|
22 |
+
- The data comes from the CMIP6 IPCC ATLAS data. The data were initially extracted from [this referenced website](https://digital.csic.es/handle/10261/332744) and then preprocessed to a tabular format and uploaded as parquet in this [Hugging Face dataset](https://huggingface.co/datasets/Ekimetrics/ipcc-atlas).
|
23 |
+
- The underlying data is retrieved by aggregating yearly values of {indicator} for the selected location, across all available scenarios. This means the system collects, for each year, the value of {indicator} in {location}, both for the historical period and for each scenario, to build the time series.
|
24 |
+
- The coordinates used for {location} correspond to the closest available point in the IPCC database, which uses a regular grid with a spatial resolution of 1 degree.
|
25 |
+
"""
|
26 |
+
|
27 |
+
def choropleth_map_informations(
|
28 |
+
indicator: str,
|
29 |
+
params: dict[str, str],
|
30 |
+
) -> str:
|
31 |
+
unit = IPCC_INDICATOR_TO_UNIT[indicator]
|
32 |
+
if "location" not in params:
|
33 |
+
raise ValueError('"location" must be provided in params')
|
34 |
+
location = params["location"]
|
35 |
+
country_name = params["country_name"]
|
36 |
+
year = params["year"]
|
37 |
+
if year is None:
|
38 |
+
year = 2050
|
39 |
+
|
40 |
+
return f"""
|
41 |
+
This plot displays a choropleth map showing the spatial distribution of **{indicator}** across all regions of **{location}** country ({country_name}) for the year **{year}** and the chosen scenario.
|
42 |
+
|
43 |
+
Each grid point is colored according to the value of the indicator ({unit}), allowing you to visually compare how {indicator} varies geographically within the country for the selected year and scenario.
|
44 |
+
|
45 |
+
**Data source:**
|
46 |
+
- The data come from the CMIP6 IPCC ATLAS data. The data were initially extracted from [this referenced website](https://digital.csic.es/handle/10261/332744) and then preprocessed to a tabular format and uploaded as parquet in this [Hugging Face dataset](https://huggingface.co/datasets/Ekimetrics/ipcc-atlas).
|
47 |
+
- For each grid point of {location} country ({country_name}), the value of {indicator} in {year} and for the selected scenario is extracted and mapped to its geographic coordinates.
|
48 |
+
- The grid points correspond to 1-degree squares centered on the grid points of the IPCC dataset. Each grid point has been mapped to a country using [**reverse_geocoder**](https://github.com/thampiman/reverse-geocoder).
|
49 |
+
- The coordinates used for each region are those of the closest available grid point in the IPCC database, which uses a regular grid with a spatial resolution of 1 degree.
|
50 |
+
"""
|
climateqa/engine/talk_to_data/ipcc/plots.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable
|
2 |
+
from plotly.graph_objects import Figure
|
3 |
+
import plotly.graph_objects as go
|
4 |
+
import pandas as pd
|
5 |
+
import geojson
|
6 |
+
|
7 |
+
from climateqa.engine.talk_to_data.ipcc.config import IPCC_INDICATOR_TO_COLORSCALE, IPCC_INDICATOR_TO_UNIT, IPCC_SCENARIO
|
8 |
+
from climateqa.engine.talk_to_data.ipcc.plot_informations import choropleth_map_informations, indicator_evolution_informations
|
9 |
+
from climateqa.engine.talk_to_data.ipcc.queries import indicator_for_given_year_query, indicator_per_year_at_location_query
|
10 |
+
from climateqa.engine.talk_to_data.objects.plot import Plot
|
11 |
+
|
12 |
+
def generate_geojson_polygons(latitudes: list[float], longitudes: list[float], indicators: list[float]) -> geojson.FeatureCollection:
|
13 |
+
features = [
|
14 |
+
geojson.Feature(
|
15 |
+
geometry=geojson.Polygon([[
|
16 |
+
[lon - 0.5, lat - 0.5],
|
17 |
+
[lon + 0.5, lat - 0.5],
|
18 |
+
[lon + 0.5, lat + 0.5],
|
19 |
+
[lon - 0.5, lat + 0.5],
|
20 |
+
[lon - 0.5, lat - 0.5]
|
21 |
+
]]),
|
22 |
+
properties={"value": val},
|
23 |
+
id=str(idx)
|
24 |
+
)
|
25 |
+
for idx, (lat, lon, val) in enumerate(zip(latitudes, longitudes, indicators))
|
26 |
+
]
|
27 |
+
|
28 |
+
geojson_data = geojson.FeatureCollection(features)
|
29 |
+
return geojson_data
|
30 |
+
|
31 |
+
def plot_indicator_evolution_at_location_historical_and_projections(
|
32 |
+
params: dict,
|
33 |
+
) -> Callable[[pd.DataFrame], Figure]:
|
34 |
+
"""
|
35 |
+
Returns a function that generates a line plot showing the evolution of a climate indicator
|
36 |
+
(e.g., temperature, rainfall) over time at a specific location, including both historical data
|
37 |
+
and future projections for different climate scenarios.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
params (dict): Dictionary with:
|
41 |
+
- indicator_column (str): Name of the climate indicator column to plot.
|
42 |
+
- location (str): Location (e.g., country, city) for which to plot the indicator.
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
Callable[[pd.DataFrame], Figure]: Function that takes a DataFrame and returns a Plotly Figure
|
46 |
+
showing the indicator's evolution over time, with scenario lines and historical data.
|
47 |
+
"""
|
48 |
+
indicator = params["indicator_column"]
|
49 |
+
location = params["location"]
|
50 |
+
indicator_label = " ".join(word.capitalize() for word in indicator.split("_"))
|
51 |
+
unit = IPCC_INDICATOR_TO_UNIT.get(indicator, "")
|
52 |
+
|
53 |
+
def plot_data(df: pd.DataFrame) -> Figure:
|
54 |
+
df = df.sort_values(by='year')
|
55 |
+
years = df['year'].astype(int).tolist()
|
56 |
+
indicators = df[indicator].astype(float).tolist()
|
57 |
+
scenarios = df['scenario'].astype(str).tolist()
|
58 |
+
|
59 |
+
# Find last historical value for continuity
|
60 |
+
last_historical = [(y, v) for y, v, s in zip(years, indicators, scenarios) if s == 'historical']
|
61 |
+
last_historical_year, last_historical_indicator = last_historical[-1] if last_historical else (None, None)
|
62 |
+
|
63 |
+
fig = go.Figure()
|
64 |
+
for scenario in IPCC_SCENARIO:
|
65 |
+
x = [y for y, s in zip(years, scenarios) if s == scenario]
|
66 |
+
y = [v for v, s in zip(indicators, scenarios) if s == scenario]
|
67 |
+
# Connect historical to scenario
|
68 |
+
if scenario != 'historical' and last_historical_indicator is not None:
|
69 |
+
x = [last_historical_year] + x
|
70 |
+
y = [last_historical_indicator] + y
|
71 |
+
fig.add_trace(go.Scatter(
|
72 |
+
x=x,
|
73 |
+
y=y,
|
74 |
+
mode='lines',
|
75 |
+
name=scenario
|
76 |
+
))
|
77 |
+
|
78 |
+
fig.update_layout(
|
79 |
+
title=f'Yearly Evolution of {indicator_label} in {location} (Historical + SSP Scenarios)',
|
80 |
+
xaxis_title='Year',
|
81 |
+
yaxis_title=f'{indicator_label} ({unit})',
|
82 |
+
legend_title='Scenario',
|
83 |
+
height=800,
|
84 |
+
)
|
85 |
+
return fig
|
86 |
+
|
87 |
+
return plot_data
|
88 |
+
|
89 |
+
indicator_evolution_at_location_historical_and_projections: Plot = {
|
90 |
+
"name": "Indicator Evolution at Location (Historical + Projections)",
|
91 |
+
"description": (
|
92 |
+
"Shows how a climate indicator (e.g., rainfall, temperature) changes over time at a specific location, "
|
93 |
+
"including historical data and future projections. "
|
94 |
+
"Useful for questions about the value or trend of an indicator at a location for any year, "
|
95 |
+
"such as 'What will be the total rainfall in China in 2050?' or 'How does rainfall evolve in China over time?'. "
|
96 |
+
"Parameters: indicator_column (the climate variable), location (e.g., country, city)."
|
97 |
+
),
|
98 |
+
"params": ["indicator_column", "location"],
|
99 |
+
"plot_function": plot_indicator_evolution_at_location_historical_and_projections,
|
100 |
+
"sql_query": indicator_per_year_at_location_query,
|
101 |
+
"plot_information": indicator_evolution_informations,
|
102 |
+
"short_name": "Evolution"
|
103 |
+
}
|
104 |
+
|
105 |
+
def plot_choropleth_map_of_country_indicator_for_specific_year(
|
106 |
+
params: dict,
|
107 |
+
) -> Callable[[pd.DataFrame], Figure]:
|
108 |
+
"""
|
109 |
+
Returns a function that generates a choropleth map (heatmap) showing the spatial distribution
|
110 |
+
of a climate indicator (e.g., temperature, rainfall) across all regions of a country for a specific year.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
params (dict): Dictionary with:
|
114 |
+
- indicator_column (str): Name of the climate indicator column to plot.
|
115 |
+
- year (str or int, optional): Year for which to plot the indicator (default: 2050).
|
116 |
+
- country_name (str): Name of the country.
|
117 |
+
- location (str): Location (country or region) for the map.
|
118 |
+
|
119 |
+
Returns:
|
120 |
+
Callable[[pd.DataFrame], Figure]: Function that takes a DataFrame and returns a Plotly Figure
|
121 |
+
showing the indicator's spatial distribution as a choropleth map for the specified year.
|
122 |
+
"""
|
123 |
+
indicator = params["indicator_column"]
|
124 |
+
year = params.get('year')
|
125 |
+
if year is None:
|
126 |
+
year = 2050
|
127 |
+
country_name = params['country_name']
|
128 |
+
location = params['location']
|
129 |
+
indicator_label = " ".join(word.capitalize() for word in indicator.split("_"))
|
130 |
+
unit = IPCC_INDICATOR_TO_UNIT.get(indicator, "")
|
131 |
+
|
132 |
+
def plot_data(df: pd.DataFrame) -> Figure:
|
133 |
+
|
134 |
+
indicators = df[indicator].astype(float).tolist()
|
135 |
+
latitudes = df["latitude"].astype(float).tolist()
|
136 |
+
longitudes = df["longitude"].astype(float).tolist()
|
137 |
+
|
138 |
+
geojson_data = generate_geojson_polygons(latitudes, longitudes, indicators)
|
139 |
+
|
140 |
+
fig = go.Figure(go.Choroplethmapbox(
|
141 |
+
geojson=geojson_data,
|
142 |
+
locations=[str(i) for i in range(len(indicators))],
|
143 |
+
featureidkey="id",
|
144 |
+
z=indicators,
|
145 |
+
colorscale=IPCC_INDICATOR_TO_COLORSCALE[indicator],
|
146 |
+
zmin=min(indicators),
|
147 |
+
zmax=max(indicators),
|
148 |
+
marker_opacity=0.7,
|
149 |
+
marker_line_width=0,
|
150 |
+
colorbar_title=f"{indicator_label} ({unit})",
|
151 |
+
text=[f"{indicator_label}: {value:.2f} {unit}" for value in indicators], # Add hover text showing the indicator value
|
152 |
+
hoverinfo="text"
|
153 |
+
))
|
154 |
+
|
155 |
+
fig.update_layout(
|
156 |
+
mapbox_style="open-street-map",
|
157 |
+
mapbox_zoom=2,
|
158 |
+
height=800,
|
159 |
+
mapbox_center={
|
160 |
+
"lat": latitudes[len(latitudes)//2] if latitudes else 0,
|
161 |
+
"lon": longitudes[len(longitudes)//2] if longitudes else 0
|
162 |
+
},
|
163 |
+
coloraxis_colorbar=dict(title=f"{indicator_label} ({unit})"),
|
164 |
+
title=f"{indicator_label} in {year} in {location} ({country_name})"
|
165 |
+
)
|
166 |
+
return fig
|
167 |
+
|
168 |
+
return plot_data
|
169 |
+
|
170 |
+
choropleth_map_of_country_indicator_for_specific_year: Plot = {
|
171 |
+
"name": "Choropleth Map of a Country's Indicator Distribution for a Specific Year",
|
172 |
+
"description": (
|
173 |
+
"Displays a map showing the spatial distribution of a climate indicator (e.g., rainfall, temperature) "
|
174 |
+
"across all regions of a country for a specific year. "
|
175 |
+
"Can answer questions about the value of an indicator in a country or region for a given year, "
|
176 |
+
"such as 'What will be the total rainfall in China in 2050?' or 'How is rainfall distributed across China in 2050?'. "
|
177 |
+
"Parameters: indicator_column (the climate variable), year, location (country name)."
|
178 |
+
),
|
179 |
+
"params": ["indicator_column", "year", "location"],
|
180 |
+
"plot_function": plot_choropleth_map_of_country_indicator_for_specific_year,
|
181 |
+
"sql_query": indicator_for_given_year_query,
|
182 |
+
"plot_information": choropleth_map_informations,
|
183 |
+
"short_name": "Map",
|
184 |
+
}
|
185 |
+
|
186 |
+
IPCC_PLOTS = [
|
187 |
+
indicator_evolution_at_location_historical_and_projections,
|
188 |
+
choropleth_map_of_country_indicator_for_specific_year
|
189 |
+
]
|
climateqa/engine/talk_to_data/ipcc/queries.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import TypedDict, Optional
|
2 |
+
|
3 |
+
from climateqa.engine.talk_to_data.ipcc.config import HUGE_MACRO_COUNTRIES, MACRO_COUNTRIES
|
4 |
+
from climateqa.engine.talk_to_data.config import IPCC_DATASET_URL
|
5 |
+
class IndicatorPerYearAtLocationQueryParams(TypedDict, total=False):
|
6 |
+
"""
|
7 |
+
Parameters for querying the evolution of an indicator per year at a specific location.
|
8 |
+
|
9 |
+
Attributes:
|
10 |
+
indicator_column (str): Name of the climate indicator column.
|
11 |
+
latitude (str): Latitude of the location.
|
12 |
+
longitude (str): Longitude of the location.
|
13 |
+
country_code (str): Country code.
|
14 |
+
admin1 (str): Administrative region (optional).
|
15 |
+
"""
|
16 |
+
indicator_column: str
|
17 |
+
latitude: str
|
18 |
+
longitude: str
|
19 |
+
country_code: str
|
20 |
+
admin1: Optional[str]
|
21 |
+
|
22 |
+
def indicator_per_year_at_location_query(
|
23 |
+
table: str, params: IndicatorPerYearAtLocationQueryParams
|
24 |
+
) -> str:
|
25 |
+
"""
|
26 |
+
Builds an SQL query to get the evolution of an indicator per year at a specific location.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
table (str): SQL table of the indicator.
|
30 |
+
params (IndicatorPerYearAtLocationQueryParams): Dictionary with the required params for the query.
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
str: The SQL query string, or an empty string if required parameters are missing.
|
34 |
+
"""
|
35 |
+
indicator_column = params.get("indicator_column")
|
36 |
+
latitude = params.get("latitude")
|
37 |
+
longitude = params.get("longitude")
|
38 |
+
country_code = params.get("country_code")
|
39 |
+
admin1 = params.get("admin1")
|
40 |
+
|
41 |
+
if not all([indicator_column, latitude, longitude, country_code]):
|
42 |
+
return ""
|
43 |
+
|
44 |
+
if country_code in MACRO_COUNTRIES:
|
45 |
+
table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}_macro.parquet'"
|
46 |
+
sql_query = f"""
|
47 |
+
SELECT year, scenario, AVG({indicator_column}) as {indicator_column}
|
48 |
+
FROM {table_path}
|
49 |
+
WHERE latitude = {latitude} AND longitude = {longitude} AND year >= 1950
|
50 |
+
GROUP BY scenario, year
|
51 |
+
ORDER BY year, scenario
|
52 |
+
"""
|
53 |
+
elif country_code in HUGE_MACRO_COUNTRIES:
|
54 |
+
table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}_macro.parquet'"
|
55 |
+
sql_query = f"""
|
56 |
+
SELECT year, scenario, {indicator_column},
|
57 |
+
FROM {table_path}
|
58 |
+
WHERE latitude = {latitude} AND longitude = {longitude} AND year >= 1950
|
59 |
+
ORDER year, scenario
|
60 |
+
"""
|
61 |
+
else:
|
62 |
+
table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}.parquet'"
|
63 |
+
sql_query = f"""
|
64 |
+
WITH medians_per_month AS (
|
65 |
+
SELECT year, scenario, month, MEDIAN({indicator_column}) AS median_value
|
66 |
+
FROM {table_path}
|
67 |
+
WHERE latitude = {latitude} AND longitude = {longitude} AND year >= 1950
|
68 |
+
GROUP BY scenario, year, month
|
69 |
+
)
|
70 |
+
SELECT year, scenario, AVG(median_value) AS {indicator_column}
|
71 |
+
FROM medians_per_month
|
72 |
+
GROUP BY scenario, year
|
73 |
+
ORDER BY year, scenario
|
74 |
+
"""
|
75 |
+
return sql_query.strip()
|
76 |
+
|
77 |
+
class IndicatorForGivenYearQueryParams(TypedDict, total=False):
|
78 |
+
"""
|
79 |
+
Parameters for querying an indicator's values across locations for a specific year.
|
80 |
+
|
81 |
+
Attributes:
|
82 |
+
indicator_column (str): The column name for the climate indicator.
|
83 |
+
year (str): The year to query.
|
84 |
+
country_code (str): The country code.
|
85 |
+
"""
|
86 |
+
indicator_column: str
|
87 |
+
year: str
|
88 |
+
country_code: str
|
89 |
+
|
90 |
+
def indicator_for_given_year_query(
|
91 |
+
table: str, params: IndicatorForGivenYearQueryParams
|
92 |
+
) -> str:
|
93 |
+
"""
|
94 |
+
Builds an SQL query to get the values of an indicator with their latitudes, longitudes,
|
95 |
+
and scenarios for a given year.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
table (str): SQL table of the indicator.
|
99 |
+
params (IndicatorForGivenYearQueryParams): Dictionary with the required params for the query.
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
str: The SQL query string, or an empty string if required parameters are missing.
|
103 |
+
"""
|
104 |
+
indicator_column = params.get("indicator_column")
|
105 |
+
year = params.get("year") or 2050
|
106 |
+
country_code = params.get("country_code")
|
107 |
+
|
108 |
+
if not all([indicator_column, year, country_code]):
|
109 |
+
return ""
|
110 |
+
|
111 |
+
if country_code in MACRO_COUNTRIES:
|
112 |
+
table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}_macro.parquet'"
|
113 |
+
sql_query = f"""
|
114 |
+
SELECT latitude, longitude, scenario, AVG({indicator_column}) as {indicator_column}
|
115 |
+
FROM {table_path}
|
116 |
+
WHERE year = {year}
|
117 |
+
GROUP BY latitude, longitude, scenario
|
118 |
+
ORDER BY latitude, longitude, scenario
|
119 |
+
"""
|
120 |
+
elif country_code in HUGE_MACRO_COUNTRIES:
|
121 |
+
table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}_macro.parquet'"
|
122 |
+
sql_query = f"""
|
123 |
+
SELECT latitude, longitude, scenario, {indicator_column},
|
124 |
+
FROM {table_path}
|
125 |
+
WHERE year = {year}
|
126 |
+
ORDER BY latitude, longitude, scenario
|
127 |
+
"""
|
128 |
+
else:
|
129 |
+
table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}.parquet'"
|
130 |
+
sql_query = f"""
|
131 |
+
WITH medians_per_month AS (
|
132 |
+
SELECT latitude, longitude, scenario, month, MEDIAN({indicator_column}) AS median_value
|
133 |
+
FROM {table_path}
|
134 |
+
WHERE year = {year}
|
135 |
+
GROUP BY latitude, longitude, scenario, month
|
136 |
+
)
|
137 |
+
SELECT latitude, longitude, scenario, AVG(median_value) AS {indicator_column}
|
138 |
+
FROM medians_per_month
|
139 |
+
GROUP BY latitude, longitude, scenario
|
140 |
+
ORDER BY latitude, longitude, scenario
|
141 |
+
"""
|
142 |
+
|
143 |
+
return sql_query.strip()
|
climateqa/engine/talk_to_data/main.py
CHANGED
@@ -1,44 +1,70 @@
|
|
1 |
-
from climateqa.engine.talk_to_data.
|
2 |
-
from climateqa.engine.
|
3 |
from climateqa.logging import log_drias_interaction_to_huggingface
|
4 |
-
import ast
|
5 |
|
6 |
-
|
7 |
-
|
8 |
-
def ask_llm_to_add_table_names(sql_query: str, llm) -> str:
|
9 |
-
"""Adds table names to the SQL query result rows using LLM.
|
10 |
|
11 |
-
This function
|
12 |
-
|
|
|
13 |
|
14 |
Args:
|
15 |
-
|
16 |
-
|
17 |
|
18 |
Returns:
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
"""
|
21 |
-
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
-
|
25 |
-
|
|
|
|
|
26 |
|
27 |
-
This function analyzes a SQL query to identify which columns are being selected
|
28 |
-
in the result set.
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
Returns:
|
35 |
-
list[str]: A list of column names being selected in the query
|
36 |
-
"""
|
37 |
-
columns = llm.invoke(f"From the given sql query, list the columns that are being selected. The answer should only be a python list. Just answer the list. The SQL query : {sql_query}").content
|
38 |
-
columns_list = ast.literal_eval(columns.strip("```python\n").strip())
|
39 |
-
return columns_list
|
40 |
|
41 |
-
|
|
|
|
|
42 |
"""Main function to process a DRIAS query and return results.
|
43 |
|
44 |
This function orchestrates the DRIAS workflow, processing a user query to generate
|
@@ -61,58 +87,38 @@ async def ask_drias(query: str, index_state: int = 0, user_id: str = None) -> tu
|
|
61 |
- table_list (list): List of table names used
|
62 |
- error (str): Error message if any
|
63 |
"""
|
64 |
-
final_state = await
|
65 |
sql_queries = []
|
66 |
result_dataframes = []
|
67 |
figures = []
|
68 |
-
|
|
|
69 |
|
70 |
-
for
|
71 |
-
|
72 |
-
if
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
|
|
|
|
|
|
82 |
|
83 |
if "error" in final_state and final_state["error"] != "":
|
84 |
-
|
|
|
85 |
|
86 |
sql_query = sql_queries[index_state]
|
87 |
dataframe = result_dataframes[index_state]
|
88 |
figure = figures[index_state](dataframe)
|
|
|
89 |
|
90 |
log_drias_interaction_to_huggingface(query, sql_query, user_id)
|
91 |
|
92 |
-
return sql_query, dataframe, figure, sql_queries, result_dataframes, figures, index_state,
|
93 |
-
|
94 |
-
# def ask_vanna(vn,db_vanna_path, query):
|
95 |
-
|
96 |
-
# try :
|
97 |
-
# location = detect_location_with_openai(query)
|
98 |
-
# if location:
|
99 |
-
|
100 |
-
# coords = loc2coords(location)
|
101 |
-
# user_input = query.lower().replace(location.lower(), f"lat, long : {coords}")
|
102 |
-
|
103 |
-
# relevant_tables = detect_relevant_tables(db_vanna_path, user_input, llm)
|
104 |
-
# coords_tables = [nearestNeighbourSQL(db_vanna_path, coords, relevant_tables[i]) for i in range(len(relevant_tables))]
|
105 |
-
# user_input_with_coords = replace_coordonates(coords, user_input, coords_tables)
|
106 |
-
|
107 |
-
# sql_query, result_dataframe, figure = vn.ask(user_input_with_coords, print_results=False, allow_llm_to_see_data=True, auto_train=False)
|
108 |
-
|
109 |
-
# return sql_query, result_dataframe, figure
|
110 |
-
# else :
|
111 |
-
# empty_df = pd.DataFrame()
|
112 |
-
# empty_fig = None
|
113 |
-
# return "", empty_df, empty_fig
|
114 |
-
# except Exception as e:
|
115 |
-
# print(f"Error: {e}")
|
116 |
-
# empty_df = pd.DataFrame()
|
117 |
-
# empty_fig = None
|
118 |
-
# return "", empty_df, empty_fig
|
|
|
1 |
+
from climateqa.engine.talk_to_data.workflow.drias import drias_workflow
|
2 |
+
from climateqa.engine.talk_to_data.workflow.ipcc import ipcc_workflow
|
3 |
from climateqa.logging import log_drias_interaction_to_huggingface
|
|
|
4 |
|
5 |
+
async def ask_drias(query: str, index_state: int = 0, user_id: str | None = None) -> tuple:
|
6 |
+
"""Main function to process a DRIAS query and return results.
|
|
|
|
|
7 |
|
8 |
+
This function orchestrates the DRIAS workflow, processing a user query to generate
|
9 |
+
SQL queries, dataframes, and visualizations. It handles multiple results and allows
|
10 |
+
pagination through them.
|
11 |
|
12 |
Args:
|
13 |
+
query (str): The user's question about climate data
|
14 |
+
index_state (int, optional): The index of the result to return. Defaults to 0.
|
15 |
|
16 |
Returns:
|
17 |
+
tuple: A tuple containing:
|
18 |
+
- sql_query (str): The SQL query used
|
19 |
+
- dataframe (pd.DataFrame): The resulting data
|
20 |
+
- figure (Callable): Function to generate the visualization
|
21 |
+
- sql_queries (list): All generated SQL queries
|
22 |
+
- result_dataframes (list): All resulting dataframes
|
23 |
+
- figures (list): All figure generation functions
|
24 |
+
- index_state (int): Current result index
|
25 |
+
- table_list (list): List of table names used
|
26 |
+
- error (str): Error message if any
|
27 |
"""
|
28 |
+
final_state = await drias_workflow(query)
|
29 |
+
sql_queries = []
|
30 |
+
result_dataframes = []
|
31 |
+
figures = []
|
32 |
+
plot_title_list = []
|
33 |
+
plot_informations = []
|
34 |
+
|
35 |
+
for output_title, output in final_state['outputs'].items():
|
36 |
+
if output['status'] == 'OK':
|
37 |
+
if output['table'] is not None:
|
38 |
+
plot_title_list.append(output_title)
|
39 |
+
|
40 |
+
if output['plot_information'] is not None:
|
41 |
+
plot_informations.append(output['plot_information'])
|
42 |
+
|
43 |
+
if output['sql_query'] is not None:
|
44 |
+
sql_queries.append(output['sql_query'])
|
45 |
+
|
46 |
+
if output['dataframe'] is not None:
|
47 |
+
result_dataframes.append(output['dataframe'])
|
48 |
+
if output['figure'] is not None:
|
49 |
+
figures.append(output['figure'])
|
50 |
+
|
51 |
+
if "error" in final_state and final_state["error"] != "":
|
52 |
+
# No Sql query, no dataframe, no figure, no plot information, empty sql queries list, empty result dataframes list, empty figures list, empty plot information list, index state = 0, empty table list, error message
|
53 |
+
return None, None, None, None, [], [], [], 0, [], final_state["error"]
|
54 |
|
55 |
+
sql_query = sql_queries[index_state]
|
56 |
+
dataframe = result_dataframes[index_state]
|
57 |
+
figure = figures[index_state](dataframe)
|
58 |
+
plot_information = plot_informations[index_state]
|
59 |
|
|
|
|
|
60 |
|
61 |
+
log_drias_interaction_to_huggingface(query, sql_query, user_id)
|
62 |
+
|
63 |
+
return sql_query, dataframe, figure, plot_information, sql_queries, result_dataframes, figures, plot_informations, index_state, plot_title_list, ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
+
|
66 |
+
|
67 |
+
async def ask_ipcc(query: str, index_state: int = 0, user_id: str | None = None) -> tuple:
|
68 |
"""Main function to process a DRIAS query and return results.
|
69 |
|
70 |
This function orchestrates the DRIAS workflow, processing a user query to generate
|
|
|
87 |
- table_list (list): List of table names used
|
88 |
- error (str): Error message if any
|
89 |
"""
|
90 |
+
final_state = await ipcc_workflow(query)
|
91 |
sql_queries = []
|
92 |
result_dataframes = []
|
93 |
figures = []
|
94 |
+
plot_title_list = []
|
95 |
+
plot_informations = []
|
96 |
|
97 |
+
for output_title, output in final_state['outputs'].items():
|
98 |
+
if output['status'] == 'OK':
|
99 |
+
if output['table'] is not None:
|
100 |
+
plot_title_list.append(output_title)
|
101 |
+
|
102 |
+
if output['plot_information'] is not None:
|
103 |
+
plot_informations.append(output['plot_information'])
|
104 |
+
|
105 |
+
if output['sql_query'] is not None:
|
106 |
+
sql_queries.append(output['sql_query'])
|
107 |
+
|
108 |
+
if output['dataframe'] is not None:
|
109 |
+
result_dataframes.append(output['dataframe'])
|
110 |
+
if output['figure'] is not None:
|
111 |
+
figures.append(output['figure'])
|
112 |
|
113 |
if "error" in final_state and final_state["error"] != "":
|
114 |
+
# No Sql query, no dataframe, no figure, no plot information, empty sql queries list, empty result dataframes list, empty figures list, empty plot information list, index state = 0, empty table list, error message
|
115 |
+
return None, None, None, None, [], [], [], 0, [], final_state["error"]
|
116 |
|
117 |
sql_query = sql_queries[index_state]
|
118 |
dataframe = result_dataframes[index_state]
|
119 |
figure = figures[index_state](dataframe)
|
120 |
+
plot_information = plot_informations[index_state]
|
121 |
|
122 |
log_drias_interaction_to_huggingface(query, sql_query, user_id)
|
123 |
|
124 |
+
return sql_query, dataframe, figure, plot_information, sql_queries, result_dataframes, figures, plot_informations, index_state, plot_title_list, ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/talk_to_data/objects/llm_outputs.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Annotated, TypedDict
|
2 |
+
|
3 |
+
|
4 |
+
class ArrayOutput(TypedDict):
|
5 |
+
"""Represents the output of a function that returns an array.
|
6 |
+
|
7 |
+
This class is used to type-hint functions that return arrays,
|
8 |
+
ensuring consistent return types across the codebase.
|
9 |
+
|
10 |
+
Attributes:
|
11 |
+
array (str): A syntactically valid Python array string
|
12 |
+
"""
|
13 |
+
array: Annotated[str, "Syntactically valid python array."]
|
climateqa/engine/talk_to_data/objects/location.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from token import OP
|
2 |
+
from typing import Optional, TypedDict
|
3 |
+
|
4 |
+
|
5 |
+
|
6 |
+
class Location(TypedDict):
|
7 |
+
location: str
|
8 |
+
latitude: Optional[str]
|
9 |
+
longitude: Optional[str]
|
10 |
+
country_code: Optional[str]
|
11 |
+
country_name: Optional[str]
|
12 |
+
admin1: Optional[str]
|
climateqa/engine/talk_to_data/objects/plot.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, TypedDict, Optional
|
2 |
+
from plotly.graph_objects import Figure
|
3 |
+
|
4 |
+
class Plot(TypedDict):
|
5 |
+
"""Represents a plot configuration in the DRIAS system.
|
6 |
+
|
7 |
+
This class defines the structure for configuring different types of plots
|
8 |
+
that can be generated from climate data.
|
9 |
+
|
10 |
+
Attributes:
|
11 |
+
name (str): The name of the plot type
|
12 |
+
description (str): A description of what the plot shows
|
13 |
+
params (list[str]): List of required parameters for the plot
|
14 |
+
plot_function (Callable[..., Callable[..., Figure]]): Function to generate the plot
|
15 |
+
sql_query (Callable[..., str]): Function to generate the SQL query for the plot
|
16 |
+
"""
|
17 |
+
name: str
|
18 |
+
description: str
|
19 |
+
params: list[str]
|
20 |
+
plot_function: Callable[..., Callable[..., Figure]]
|
21 |
+
sql_query: Callable[..., str]
|
22 |
+
plot_information: Callable[..., str]
|
23 |
+
short_name: str
|
climateqa/engine/talk_to_data/objects/states.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Callable, Optional, TypedDict
|
2 |
+
from plotly.graph_objects import Figure
|
3 |
+
import pandas as pd
|
4 |
+
from climateqa.engine.talk_to_data.objects.plot import Plot
|
5 |
+
|
6 |
+
class TTDOutput(TypedDict):
|
7 |
+
status: str
|
8 |
+
plot: Plot
|
9 |
+
table: str
|
10 |
+
sql_query: Optional[str]
|
11 |
+
dataframe: Optional[pd.DataFrame]
|
12 |
+
figure: Optional[Callable[..., Figure]]
|
13 |
+
plot_information: Optional[str]
|
14 |
+
class State(TypedDict):
|
15 |
+
user_input: str
|
16 |
+
plots: list[str]
|
17 |
+
outputs: dict[str, TTDOutput]
|
18 |
+
error: Optional[str]
|
19 |
+
|
climateqa/engine/talk_to_data/prompt.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
query_prompt_template = """You are an expert SQL query generator. Given an input question, database schema, SQL dialect and relevant tables to answer the question, generate an optimized and syntactically correct SQL query which can provide useful insights to the question.
|
2 |
+
|
3 |
+
### Instructions:
|
4 |
+
1. **Use only relevant tables**: The following tables are relevant to answering the question: {relevant_tables}. Do not use any other tables.
|
5 |
+
2. **Relevant columns only**: Never select `*`. Only include necessary columns based on the input question.
|
6 |
+
3. **Schema Awareness**:
|
7 |
+
- Use only columns present in the given schema.
|
8 |
+
- **If a column name appears in multiple tables, always use the format `table_name.column_name` to avoid ambiguity.**
|
9 |
+
- Select only the column which are insightful for the question.
|
10 |
+
4. **Dialect Compliance**: Follow `{dialect}` syntax rules.
|
11 |
+
5. **Ordering**: Order the results by a relevant column if applicable (e.g., timestamp for recent records).
|
12 |
+
6. **Valid query**: Make sure the query is syntactically and functionally correct.
|
13 |
+
7. **Conditions** : For the common columns, the same condition should be applied to all the tables (e.g. latitude, longitude, model, year...)
|
14 |
+
9. **Join tables** : If you need to join table, you should join them with year feature.
|
15 |
+
10. **Model** : For each table, you need to add a condition on the model to be equal to {model}
|
16 |
+
|
17 |
+
### Provided Database Schema:
|
18 |
+
{table_info}
|
19 |
+
|
20 |
+
### Relevant Tables:
|
21 |
+
{relevant_tables}
|
22 |
+
|
23 |
+
**Question:** {input}
|
24 |
+
|
25 |
+
**SQL Query:**"""
|
26 |
+
|
27 |
+
plot_prompt_template = """You are a data visualization expert. Given an input question and an SQL Query, generate an insightful plot according to the question.
|
28 |
+
|
29 |
+
### Instructions
|
30 |
+
1. **Use only the column names provided**. The data will be provided as a Pandas DataFrame `df` with the columns present in the SELECT.
|
31 |
+
2. Generate the Python Plotly code to chart the results using `df` and the column names.
|
32 |
+
3. Make as complete a graph as possible to answer the question, and make it as easy to understand as possible.
|
33 |
+
4. **Response with only Python code**. Do not answer with any explanations -- just the code.
|
34 |
+
5. **Specific cases** :
|
35 |
+
- For a question about the evolution of something, it is also relevant to plot the data with also the sliding average for a period of 20 years for example.
|
36 |
+
|
37 |
+
### SQL Query:
|
38 |
+
{sql_query}
|
39 |
+
|
40 |
+
**Question:** {input}
|
41 |
+
|
42 |
+
**Python code:**
|
43 |
+
"""
|
44 |
+
|
climateqa/engine/talk_to_data/query.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
from concurrent.futures import ThreadPoolExecutor
|
3 |
+
import duckdb
|
4 |
+
import pandas as pd
|
5 |
+
import os
|
6 |
+
|
7 |
+
def find_indicator_column(table: str, indicator_columns_per_table: dict[str,str]) -> str:
|
8 |
+
"""Retrieves the name of the indicator column within a table.
|
9 |
+
|
10 |
+
This function maps table names to their corresponding indicator columns
|
11 |
+
using the predefined mapping in INDICATOR_COLUMNS_PER_TABLE.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
table (str): Name of the table in the database
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
str: Name of the indicator column for the specified table
|
18 |
+
|
19 |
+
Raises:
|
20 |
+
KeyError: If the table name is not found in the mapping
|
21 |
+
"""
|
22 |
+
print(f"---- Find indicator column in table {table} ----")
|
23 |
+
return indicator_columns_per_table[table]
|
24 |
+
|
25 |
+
async def execute_sql_query(sql_query: str) -> pd.DataFrame:
|
26 |
+
"""Executes a SQL query on the DRIAS database and returns the results.
|
27 |
+
|
28 |
+
This function connects to the DuckDB database containing DRIAS climate data
|
29 |
+
and executes the provided SQL query. It handles the database connection and
|
30 |
+
returns the results as a pandas DataFrame.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
sql_query (str): The SQL query to execute
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
pd.DataFrame: A DataFrame containing the query results
|
37 |
+
|
38 |
+
Raises:
|
39 |
+
duckdb.Error: If there is an error executing the SQL query
|
40 |
+
"""
|
41 |
+
def _execute_query():
|
42 |
+
# Execute the query
|
43 |
+
con = duckdb.connect()
|
44 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
45 |
+
con.execute(f"""CREATE SECRET hf_token (
|
46 |
+
TYPE huggingface,
|
47 |
+
TOKEN '{HF_TOKEN}'
|
48 |
+
);""")
|
49 |
+
results = con.execute(sql_query).fetchdf()
|
50 |
+
# return fetched data
|
51 |
+
return results
|
52 |
+
|
53 |
+
# Run the query in a thread pool to avoid blocking
|
54 |
+
loop = asyncio.get_event_loop()
|
55 |
+
with ThreadPoolExecutor() as executor:
|
56 |
+
return await loop.run_in_executor(executor, _execute_query)
|
57 |
+
|
climateqa/engine/talk_to_data/talk_to_drias.py
DELETED
@@ -1,317 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
|
3 |
-
from typing import Any, Callable, TypedDict, Optional
|
4 |
-
from numpy import sort
|
5 |
-
import pandas as pd
|
6 |
-
import asyncio
|
7 |
-
from plotly.graph_objects import Figure
|
8 |
-
from climateqa.engine.llm import get_llm
|
9 |
-
from climateqa.engine.talk_to_data import sql_query
|
10 |
-
from climateqa.engine.talk_to_data.config import INDICATOR_COLUMNS_PER_TABLE
|
11 |
-
from climateqa.engine.talk_to_data.plot import PLOTS, Plot
|
12 |
-
from climateqa.engine.talk_to_data.sql_query import execute_sql_query
|
13 |
-
from climateqa.engine.talk_to_data.utils import (
|
14 |
-
detect_relevant_plots,
|
15 |
-
detect_year_with_openai,
|
16 |
-
loc2coords,
|
17 |
-
detect_location_with_openai,
|
18 |
-
nearestNeighbourSQL,
|
19 |
-
detect_relevant_tables,
|
20 |
-
)
|
21 |
-
|
22 |
-
|
23 |
-
ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))
|
24 |
-
|
25 |
-
class TableState(TypedDict):
|
26 |
-
"""Represents the state of a table in the DRIAS workflow.
|
27 |
-
|
28 |
-
This class defines the structure for tracking the state of a table during the
|
29 |
-
data processing workflow, including its name, parameters, SQL query, and results.
|
30 |
-
|
31 |
-
Attributes:
|
32 |
-
table_name (str): The name of the table in the database
|
33 |
-
params (dict[str, Any]): Parameters used for querying the table
|
34 |
-
sql_query (str, optional): The SQL query used to fetch data
|
35 |
-
dataframe (pd.DataFrame | None, optional): The resulting data
|
36 |
-
figure (Callable[..., Figure], optional): Function to generate visualization
|
37 |
-
status (str): The current status of the table processing ('OK' or 'ERROR')
|
38 |
-
"""
|
39 |
-
table_name: str
|
40 |
-
params: dict[str, Any]
|
41 |
-
sql_query: Optional[str]
|
42 |
-
dataframe: Optional[pd.DataFrame | None]
|
43 |
-
figure: Optional[Callable[..., Figure]]
|
44 |
-
status: str
|
45 |
-
|
46 |
-
class PlotState(TypedDict):
|
47 |
-
"""Represents the state of a plot in the DRIAS workflow.
|
48 |
-
|
49 |
-
This class defines the structure for tracking the state of a plot during the
|
50 |
-
data processing workflow, including its name and associated tables.
|
51 |
-
|
52 |
-
Attributes:
|
53 |
-
plot_name (str): The name of the plot
|
54 |
-
tables (list[str]): List of tables used in the plot
|
55 |
-
table_states (dict[str, TableState]): States of the tables used in the plot
|
56 |
-
"""
|
57 |
-
plot_name: str
|
58 |
-
tables: list[str]
|
59 |
-
table_states: dict[str, TableState]
|
60 |
-
|
61 |
-
class State(TypedDict):
|
62 |
-
user_input: str
|
63 |
-
plots: list[str]
|
64 |
-
plot_states: dict[str, PlotState]
|
65 |
-
error: Optional[str]
|
66 |
-
|
67 |
-
async def find_relevant_plots(state: State, llm) -> list[str]:
|
68 |
-
print("---- Find relevant plots ----")
|
69 |
-
relevant_plots = await detect_relevant_plots(state['user_input'], llm)
|
70 |
-
return relevant_plots
|
71 |
-
|
72 |
-
async def find_relevant_tables_per_plot(state: State, plot: Plot, llm) -> list[str]:
|
73 |
-
print(f"---- Find relevant tables for {plot['name']} ----")
|
74 |
-
relevant_tables = await detect_relevant_tables(state['user_input'], plot, llm)
|
75 |
-
return relevant_tables
|
76 |
-
|
77 |
-
async def find_param(state: State, param_name:str, table: str) -> dict[str, Any] | None:
|
78 |
-
"""Perform the good method to retrieve the desired parameter
|
79 |
-
|
80 |
-
Args:
|
81 |
-
state (State): state of the workflow
|
82 |
-
param_name (str): name of the desired parameter
|
83 |
-
table (str): name of the table
|
84 |
-
|
85 |
-
Returns:
|
86 |
-
dict[str, Any] | None:
|
87 |
-
"""
|
88 |
-
if param_name == 'location':
|
89 |
-
location = await find_location(state['user_input'], table)
|
90 |
-
return location
|
91 |
-
if param_name == 'year':
|
92 |
-
year = await find_year(state['user_input'])
|
93 |
-
return {'year': year}
|
94 |
-
return None
|
95 |
-
|
96 |
-
class Location(TypedDict):
|
97 |
-
location: str
|
98 |
-
latitude: Optional[str]
|
99 |
-
longitude: Optional[str]
|
100 |
-
|
101 |
-
async def find_location(user_input: str, table: str) -> Location:
|
102 |
-
print(f"---- Find location in table {table} ----")
|
103 |
-
location = await detect_location_with_openai(user_input)
|
104 |
-
output: Location = {'location' : location}
|
105 |
-
if location:
|
106 |
-
coords = loc2coords(location)
|
107 |
-
neighbour = nearestNeighbourSQL(coords, table)
|
108 |
-
output.update({
|
109 |
-
"latitude": neighbour[0],
|
110 |
-
"longitude": neighbour[1],
|
111 |
-
})
|
112 |
-
return output
|
113 |
-
|
114 |
-
async def find_year(user_input: str) -> str:
|
115 |
-
"""Extracts year information from user input using LLM.
|
116 |
-
|
117 |
-
This function uses an LLM to identify and extract year information from the
|
118 |
-
user's query, which is used to filter data in subsequent queries.
|
119 |
-
|
120 |
-
Args:
|
121 |
-
user_input (str): The user's query text
|
122 |
-
|
123 |
-
Returns:
|
124 |
-
str: The extracted year, or empty string if no year found
|
125 |
-
"""
|
126 |
-
print(f"---- Find year ---")
|
127 |
-
year = await detect_year_with_openai(user_input)
|
128 |
-
return year
|
129 |
-
|
130 |
-
def find_indicator_column(table: str) -> str:
|
131 |
-
"""Retrieves the name of the indicator column within a table.
|
132 |
-
|
133 |
-
This function maps table names to their corresponding indicator columns
|
134 |
-
using the predefined mapping in INDICATOR_COLUMNS_PER_TABLE.
|
135 |
-
|
136 |
-
Args:
|
137 |
-
table (str): Name of the table in the database
|
138 |
-
|
139 |
-
Returns:
|
140 |
-
str: Name of the indicator column for the specified table
|
141 |
-
|
142 |
-
Raises:
|
143 |
-
KeyError: If the table name is not found in the mapping
|
144 |
-
"""
|
145 |
-
print(f"---- Find indicator column in table {table} ----")
|
146 |
-
return INDICATOR_COLUMNS_PER_TABLE[table]
|
147 |
-
|
148 |
-
|
149 |
-
async def process_table(
|
150 |
-
table: str,
|
151 |
-
params: dict[str, Any],
|
152 |
-
plot: Plot,
|
153 |
-
) -> TableState:
|
154 |
-
"""Processes a table to extract relevant data and generate visualizations.
|
155 |
-
|
156 |
-
This function retrieves the SQL query for the specified table, executes it,
|
157 |
-
and generates a visualization based on the results.
|
158 |
-
|
159 |
-
Args:
|
160 |
-
table (str): The name of the table to process
|
161 |
-
params (dict[str, Any]): Parameters used for querying the table
|
162 |
-
plot (Plot): The plot object containing SQL query and visualization function
|
163 |
-
|
164 |
-
Returns:
|
165 |
-
TableState: The state of the processed table
|
166 |
-
"""
|
167 |
-
table_state: TableState = {
|
168 |
-
'table_name': table,
|
169 |
-
'params': params.copy(),
|
170 |
-
'status': 'OK',
|
171 |
-
'dataframe': None,
|
172 |
-
'sql_query': None,
|
173 |
-
'figure': None
|
174 |
-
}
|
175 |
-
|
176 |
-
table_state['params']['indicator_column'] = find_indicator_column(table)
|
177 |
-
sql_query = plot['sql_query'](table, table_state['params'])
|
178 |
-
|
179 |
-
if sql_query == "":
|
180 |
-
table_state['status'] = 'ERROR'
|
181 |
-
return table_state
|
182 |
-
table_state['sql_query'] = sql_query
|
183 |
-
df = await execute_sql_query(sql_query)
|
184 |
-
|
185 |
-
table_state['dataframe'] = df
|
186 |
-
table_state['figure'] = plot['plot_function'](table_state['params'])
|
187 |
-
|
188 |
-
return table_state
|
189 |
-
|
190 |
-
async def drias_workflow(user_input: str) -> State:
|
191 |
-
"""Performs the complete workflow of Talk To Drias : from user input to sql queries, dataframes and figures generated
|
192 |
-
|
193 |
-
Args:
|
194 |
-
user_input (str): initial user input
|
195 |
-
|
196 |
-
Returns:
|
197 |
-
State: Final state with all the results
|
198 |
-
"""
|
199 |
-
state: State = {
|
200 |
-
'user_input': user_input,
|
201 |
-
'plots': [],
|
202 |
-
'plot_states': {},
|
203 |
-
'error': ''
|
204 |
-
}
|
205 |
-
|
206 |
-
llm = get_llm(provider="openai")
|
207 |
-
|
208 |
-
plots = await find_relevant_plots(state, llm)
|
209 |
-
|
210 |
-
state['plots'] = plots
|
211 |
-
|
212 |
-
if len(state['plots']) < 1:
|
213 |
-
state['error'] = 'There is no plot to answer to the question'
|
214 |
-
return state
|
215 |
-
|
216 |
-
have_relevant_table = False
|
217 |
-
have_sql_query = False
|
218 |
-
have_dataframe = False
|
219 |
-
|
220 |
-
for plot_name in state['plots']:
|
221 |
-
|
222 |
-
plot = next((p for p in PLOTS if p['name'] == plot_name), None) # Find the associated plot object
|
223 |
-
if plot is None:
|
224 |
-
continue
|
225 |
-
|
226 |
-
plot_state: PlotState = {
|
227 |
-
'plot_name': plot_name,
|
228 |
-
'tables': [],
|
229 |
-
'table_states': {}
|
230 |
-
}
|
231 |
-
|
232 |
-
plot_state['plot_name'] = plot_name
|
233 |
-
|
234 |
-
relevant_tables = await find_relevant_tables_per_plot(state, plot, llm)
|
235 |
-
|
236 |
-
if len(relevant_tables) > 0 :
|
237 |
-
have_relevant_table = True
|
238 |
-
|
239 |
-
plot_state['tables'] = relevant_tables
|
240 |
-
|
241 |
-
params = {}
|
242 |
-
for param_name in plot['params']:
|
243 |
-
param = await find_param(state, param_name, relevant_tables[0])
|
244 |
-
if param:
|
245 |
-
params.update(param)
|
246 |
-
|
247 |
-
tasks = [process_table(table, params, plot) for table in plot_state['tables'][:3]]
|
248 |
-
results = await asyncio.gather(*tasks)
|
249 |
-
|
250 |
-
# Store results back in plot_state
|
251 |
-
have_dataframe = False
|
252 |
-
have_sql_query = False
|
253 |
-
for table_state in results:
|
254 |
-
if table_state['sql_query']:
|
255 |
-
have_sql_query = True
|
256 |
-
if table_state['dataframe'] is not None and len(table_state['dataframe']) > 0:
|
257 |
-
have_dataframe = True
|
258 |
-
plot_state['table_states'][table_state['table_name']] = table_state
|
259 |
-
|
260 |
-
state['plot_states'][plot_name] = plot_state
|
261 |
-
|
262 |
-
if not have_relevant_table:
|
263 |
-
state['error'] = "There is no relevant table in our database to answer your question"
|
264 |
-
elif not have_sql_query:
|
265 |
-
state['error'] = "There is no relevant sql query on our database that can help to answer your question"
|
266 |
-
elif not have_dataframe:
|
267 |
-
state['error'] = "There is no data in our table that can answer to your question"
|
268 |
-
|
269 |
-
return state
|
270 |
-
|
271 |
-
# def make_write_query_node():
|
272 |
-
|
273 |
-
# def write_query(state):
|
274 |
-
# print("---- Write query ----")
|
275 |
-
# for table in state["tables"]:
|
276 |
-
# sql_query = QUERIES[state[table]['query_type']](
|
277 |
-
# table=table,
|
278 |
-
# indicator_column=state[table]["columns"],
|
279 |
-
# longitude=state[table]["longitude"],
|
280 |
-
# latitude=state[table]["latitude"],
|
281 |
-
# )
|
282 |
-
# state[table].update({"sql_query": sql_query})
|
283 |
-
|
284 |
-
# return state
|
285 |
-
|
286 |
-
# return write_query
|
287 |
-
|
288 |
-
# def make_fetch_data_node(db_path):
|
289 |
-
|
290 |
-
# def fetch_data(state):
|
291 |
-
# print("---- Fetch data ----")
|
292 |
-
# for table in state["tables"]:
|
293 |
-
# results = execute_sql_query(db_path, state[table]['sql_query'])
|
294 |
-
# state[table].update(results)
|
295 |
-
|
296 |
-
# return state
|
297 |
-
|
298 |
-
# return fetch_data
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
## V2
|
303 |
-
|
304 |
-
|
305 |
-
# def make_fetch_data_node(db_path: str, llm):
|
306 |
-
# def fetch_data(state):
|
307 |
-
# print("---- Fetch data ----")
|
308 |
-
# db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
|
309 |
-
# output = {}
|
310 |
-
# sql_query = write_sql_query(state["query"], db, state["tables"], llm)
|
311 |
-
# # TO DO : Add query checker
|
312 |
-
# print(f"SQL query : {sql_query}")
|
313 |
-
# output["sql_query"] = sql_query
|
314 |
-
# output.update(fetch_data_from_sql_query(db_path, sql_query))
|
315 |
-
# return output
|
316 |
-
|
317 |
-
# return fetch_data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/talk_to_data/ui_config.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
TEMPERATURE_COLORSCALE = [
|
2 |
+
[0.0, "rgb(5, 48, 97)"],
|
3 |
+
[0.10, "rgb(33, 102, 172)"],
|
4 |
+
[0.20, "rgb(67, 147, 195)"],
|
5 |
+
[0.30, "rgb(146, 197, 222)"],
|
6 |
+
[0.40, "rgb(209, 229, 240)"],
|
7 |
+
[0.50, "rgb(247, 247, 247)"],
|
8 |
+
[0.60, "rgb(253, 219, 199)"],
|
9 |
+
[0.75, "rgb(244, 165, 130)"],
|
10 |
+
[0.85, "rgb(214, 96, 77)"],
|
11 |
+
[0.90, "rgb(178, 24, 43)"],
|
12 |
+
[1.0, "rgb(103, 0, 31)"]
|
13 |
+
]
|
14 |
+
|
15 |
+
PRECIPITATION_COLORSCALE = [
|
16 |
+
[0.0, "rgb(84, 48, 5)"],
|
17 |
+
[0.10, "rgb(140, 81, 10)"],
|
18 |
+
[0.20, "rgb(191, 129, 45)"],
|
19 |
+
[0.30, "rgb(223, 194, 125)"],
|
20 |
+
[0.40, "rgb(246, 232, 195)"],
|
21 |
+
[0.50, "rgb(245, 245, 245)"],
|
22 |
+
[0.60, "rgb(199, 234, 229)"],
|
23 |
+
[0.75, "rgb(128, 205, 193)"],
|
24 |
+
[0.85, "rgb(53, 151, 143)"],
|
25 |
+
[0.90, "rgb(1, 102, 94)"],
|
26 |
+
[1.0, "rgb(0, 60, 48)"]
|
27 |
+
]
|
climateqa/engine/talk_to_data/{myVanna.py → vanna/myVanna.py}
RENAMED
File without changes
|
climateqa/engine/talk_to_data/{vanna_class.py → vanna/vanna_class.py}
RENAMED
File without changes
|
climateqa/engine/talk_to_data/workflow/drias.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from typing import Any
|
4 |
+
import asyncio
|
5 |
+
from climateqa.engine.llm import get_llm
|
6 |
+
from climateqa.engine.talk_to_data.input_processing import find_param, find_relevant_plots, find_relevant_tables_per_plot
|
7 |
+
from climateqa.engine.talk_to_data.query import execute_sql_query, find_indicator_column
|
8 |
+
from climateqa.engine.talk_to_data.objects.plot import Plot
|
9 |
+
from climateqa.engine.talk_to_data.objects.states import State, TTDOutput
|
10 |
+
from climateqa.engine.talk_to_data.drias.config import DRIAS_TABLES, DRIAS_INDICATOR_COLUMNS_PER_TABLE, DRIAS_PLOT_PARAMETERS
|
11 |
+
from climateqa.engine.talk_to_data.drias.plots import DRIAS_PLOTS
|
12 |
+
|
13 |
+
ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))
|
14 |
+
|
15 |
+
async def process_output(
|
16 |
+
output_title: str,
|
17 |
+
table: str,
|
18 |
+
plot: Plot,
|
19 |
+
params: dict[str, Any]
|
20 |
+
) -> tuple[str, TTDOutput, dict[str, bool]]:
|
21 |
+
"""
|
22 |
+
Processes a table for a given plot and parameters: builds the SQL query, executes it,
|
23 |
+
and generates the corresponding figure.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
output_title (str): Title for the output (used as key in outputs dict).
|
27 |
+
table (str): The name of the table to process.
|
28 |
+
plot (Plot): The plot object containing SQL query and visualization function.
|
29 |
+
params (dict[str, Any]): Parameters used for querying the table.
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
tuple: (output_title, results dict, errors dict)
|
33 |
+
"""
|
34 |
+
results: TTDOutput = {
|
35 |
+
'status': 'OK',
|
36 |
+
'plot': plot,
|
37 |
+
'table': table,
|
38 |
+
'sql_query': None,
|
39 |
+
'dataframe': None,
|
40 |
+
'figure': None,
|
41 |
+
'plot_information': None
|
42 |
+
}
|
43 |
+
errors = {
|
44 |
+
'have_sql_query': False,
|
45 |
+
'have_dataframe': False
|
46 |
+
}
|
47 |
+
|
48 |
+
# Find the indicator column for this table
|
49 |
+
indicator_column = find_indicator_column(table, DRIAS_INDICATOR_COLUMNS_PER_TABLE)
|
50 |
+
if indicator_column:
|
51 |
+
params['indicator_column'] = indicator_column
|
52 |
+
|
53 |
+
# Build the SQL query
|
54 |
+
sql_query = plot['sql_query'](table, params)
|
55 |
+
if not sql_query:
|
56 |
+
results['status'] = 'ERROR'
|
57 |
+
return output_title, results, errors
|
58 |
+
|
59 |
+
results['plot_information'] = plot['plot_information'](table, params)
|
60 |
+
|
61 |
+
results['sql_query'] = sql_query
|
62 |
+
errors['have_sql_query'] = True
|
63 |
+
|
64 |
+
# Execute the SQL query
|
65 |
+
df = await execute_sql_query(sql_query)
|
66 |
+
if df is not None and len(df) > 0:
|
67 |
+
results['dataframe'] = df
|
68 |
+
errors['have_dataframe'] = True
|
69 |
+
else:
|
70 |
+
results['status'] = 'NO_DATA'
|
71 |
+
|
72 |
+
# Generate the figure (always, even if df is empty, for consistency)
|
73 |
+
results['figure'] = plot['plot_function'](params)
|
74 |
+
|
75 |
+
return output_title, results, errors
|
76 |
+
|
77 |
+
async def drias_workflow(user_input: str) -> State:
|
78 |
+
"""
|
79 |
+
Orchestrates the DRIAS workflow: from user input to SQL queries, dataframes, and figures.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
user_input (str): The user's question.
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
State: Final state with all results and error messages if any.
|
86 |
+
"""
|
87 |
+
state: State = {
|
88 |
+
'user_input': user_input,
|
89 |
+
'plots': [],
|
90 |
+
'outputs': {},
|
91 |
+
'error': ''
|
92 |
+
}
|
93 |
+
|
94 |
+
llm = get_llm(provider="openai")
|
95 |
+
plots = await find_relevant_plots(state, llm, DRIAS_PLOTS)
|
96 |
+
|
97 |
+
if not plots:
|
98 |
+
state['error'] = 'There is no plot to answer to the question'
|
99 |
+
return state
|
100 |
+
|
101 |
+
plots = plots[:2] # limit to 2 types of plots
|
102 |
+
state['plots'] = plots
|
103 |
+
|
104 |
+
errors = {
|
105 |
+
'have_relevant_table': False,
|
106 |
+
'have_sql_query': False,
|
107 |
+
'have_dataframe': False
|
108 |
+
}
|
109 |
+
outputs = {}
|
110 |
+
|
111 |
+
# Find relevant tables for each plot and prepare outputs
|
112 |
+
for plot_name in plots:
|
113 |
+
plot = next((p for p in DRIAS_PLOTS if p['name'] == plot_name), None)
|
114 |
+
if plot is None:
|
115 |
+
continue
|
116 |
+
|
117 |
+
relevant_tables = await find_relevant_tables_per_plot(state, plot, llm, DRIAS_TABLES)
|
118 |
+
if relevant_tables:
|
119 |
+
errors['have_relevant_table'] = True
|
120 |
+
|
121 |
+
for table in relevant_tables:
|
122 |
+
output_title = f"{plot['short_name']} - {' '.join(table.capitalize().split('_'))}"
|
123 |
+
outputs[output_title] = {
|
124 |
+
'table': table,
|
125 |
+
'plot': plot,
|
126 |
+
'status': 'OK'
|
127 |
+
}
|
128 |
+
|
129 |
+
# Gather all required parameters
|
130 |
+
params = {}
|
131 |
+
for param_name in DRIAS_PLOT_PARAMETERS:
|
132 |
+
param = await find_param(state, param_name, mode='DRIAS')
|
133 |
+
if param:
|
134 |
+
params.update(param)
|
135 |
+
|
136 |
+
# Process all outputs in parallel using process_output
|
137 |
+
tasks = [
|
138 |
+
process_output(output_title, output['table'], output['plot'], params.copy())
|
139 |
+
for output_title, output in outputs.items()
|
140 |
+
]
|
141 |
+
results = await asyncio.gather(*tasks)
|
142 |
+
|
143 |
+
# Update outputs with results and error flags
|
144 |
+
for output_title, task_results, task_errors in results:
|
145 |
+
outputs[output_title]['sql_query'] = task_results['sql_query']
|
146 |
+
outputs[output_title]['dataframe'] = task_results['dataframe']
|
147 |
+
outputs[output_title]['figure'] = task_results['figure']
|
148 |
+
outputs[output_title]['plot_information'] = task_results['plot_information']
|
149 |
+
outputs[output_title]['status'] = task_results['status']
|
150 |
+
errors['have_sql_query'] |= task_errors['have_sql_query']
|
151 |
+
errors['have_dataframe'] |= task_errors['have_dataframe']
|
152 |
+
|
153 |
+
state['outputs'] = outputs
|
154 |
+
|
155 |
+
# Set error messages if needed
|
156 |
+
if not errors['have_relevant_table']:
|
157 |
+
state['error'] = "There is no relevant table in our database to answer your question"
|
158 |
+
elif not errors['have_sql_query']:
|
159 |
+
state['error'] = "There is no relevant sql query on our database that can help to answer your question"
|
160 |
+
elif not errors['have_dataframe']:
|
161 |
+
state['error'] = "There is no data in our table that can answer to your question"
|
162 |
+
|
163 |
+
return state
|
climateqa/engine/talk_to_data/workflow/ipcc.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from typing import Any
|
4 |
+
import asyncio
|
5 |
+
from climateqa.engine.llm import get_llm
|
6 |
+
from climateqa.engine.talk_to_data.input_processing import find_param, find_relevant_plots, find_relevant_tables_per_plot
|
7 |
+
from climateqa.engine.talk_to_data.query import execute_sql_query, find_indicator_column
|
8 |
+
from climateqa.engine.talk_to_data.objects.plot import Plot
|
9 |
+
from climateqa.engine.talk_to_data.objects.states import State, TTDOutput
|
10 |
+
from climateqa.engine.talk_to_data.ipcc.config import IPCC_TABLES, IPCC_INDICATOR_COLUMNS_PER_TABLE, IPCC_PLOT_PARAMETERS
|
11 |
+
from climateqa.engine.talk_to_data.ipcc.plots import IPCC_PLOTS
|
12 |
+
|
13 |
+
ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))
|
14 |
+
|
15 |
+
async def process_output(
|
16 |
+
output_title: str,
|
17 |
+
table: str,
|
18 |
+
plot: Plot,
|
19 |
+
params: dict[str, Any]
|
20 |
+
) -> tuple[str, TTDOutput, dict[str, bool]]:
|
21 |
+
"""
|
22 |
+
Process a table for a given plot and parameters: builds the SQL query, executes it,
|
23 |
+
and generates the corresponding figure.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
output_title (str): Title for the output (used as key in outputs dict).
|
27 |
+
table (str): The name of the table to process.
|
28 |
+
plot (Plot): The plot object containing SQL query and visualization function.
|
29 |
+
params (dict[str, Any]): Parameters used for querying the table.
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
tuple: (output_title, results dict, errors dict)
|
33 |
+
"""
|
34 |
+
results: TTDOutput = {
|
35 |
+
'status': 'OK',
|
36 |
+
'plot': plot,
|
37 |
+
'table': table,
|
38 |
+
'sql_query': None,
|
39 |
+
'dataframe': None,
|
40 |
+
'figure': None,
|
41 |
+
'plot_information': None,
|
42 |
+
}
|
43 |
+
errors = {
|
44 |
+
'have_sql_query': False,
|
45 |
+
'have_dataframe': False
|
46 |
+
}
|
47 |
+
|
48 |
+
# Find the indicator column for this table
|
49 |
+
indicator_column = find_indicator_column(table, IPCC_INDICATOR_COLUMNS_PER_TABLE)
|
50 |
+
if indicator_column:
|
51 |
+
params['indicator_column'] = indicator_column
|
52 |
+
|
53 |
+
# Build the SQL query
|
54 |
+
sql_query = plot['sql_query'](table, params)
|
55 |
+
if not sql_query:
|
56 |
+
results['status'] = 'ERROR'
|
57 |
+
return output_title, results, errors
|
58 |
+
|
59 |
+
results['plot_information'] = plot['plot_information'](table, params)
|
60 |
+
|
61 |
+
results['sql_query'] = sql_query
|
62 |
+
errors['have_sql_query'] = True
|
63 |
+
|
64 |
+
# Execute the SQL query
|
65 |
+
df = await execute_sql_query(sql_query)
|
66 |
+
if df is not None and not df.empty:
|
67 |
+
results['dataframe'] = df
|
68 |
+
errors['have_dataframe'] = True
|
69 |
+
else:
|
70 |
+
results['status'] = 'NO_DATA'
|
71 |
+
|
72 |
+
# Generate the figure (always, even if df is empty, for consistency)
|
73 |
+
results['figure'] = plot['plot_function'](params)
|
74 |
+
|
75 |
+
return output_title, results, errors
|
76 |
+
|
77 |
+
async def ipcc_workflow(user_input: str) -> State:
|
78 |
+
"""
|
79 |
+
Performs the complete workflow of Talk To IPCC: from user input to SQL queries, dataframes, and figures.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
user_input (str): The user's question.
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
State: Final state with all the results and error messages if any.
|
86 |
+
"""
|
87 |
+
state: State = {
|
88 |
+
'user_input': user_input,
|
89 |
+
'plots': [],
|
90 |
+
'outputs': {},
|
91 |
+
'error': ''
|
92 |
+
}
|
93 |
+
|
94 |
+
llm = get_llm(provider="openai")
|
95 |
+
plots = await find_relevant_plots(state, llm, IPCC_PLOTS)
|
96 |
+
state['plots'] = plots
|
97 |
+
|
98 |
+
if not plots:
|
99 |
+
state['error'] = 'There is no plot to answer to the question'
|
100 |
+
return state
|
101 |
+
|
102 |
+
errors = {
|
103 |
+
'have_relevant_table': False,
|
104 |
+
'have_sql_query': False,
|
105 |
+
'have_dataframe': False
|
106 |
+
}
|
107 |
+
outputs = {}
|
108 |
+
|
109 |
+
# Find relevant tables for each plot and prepare outputs
|
110 |
+
for plot_name in plots:
|
111 |
+
plot = next((p for p in IPCC_PLOTS if p['name'] == plot_name), None)
|
112 |
+
if plot is None:
|
113 |
+
continue
|
114 |
+
|
115 |
+
relevant_tables = await find_relevant_tables_per_plot(state, plot, llm, IPCC_TABLES)
|
116 |
+
if relevant_tables:
|
117 |
+
errors['have_relevant_table'] = True
|
118 |
+
|
119 |
+
for table in relevant_tables:
|
120 |
+
output_title = f"{plot['short_name']} - {' '.join(table.capitalize().split('_'))}"
|
121 |
+
outputs[output_title] = {
|
122 |
+
'table': table,
|
123 |
+
'plot': plot,
|
124 |
+
'status': 'OK'
|
125 |
+
}
|
126 |
+
|
127 |
+
# Gather all required parameters
|
128 |
+
params = {}
|
129 |
+
for param_name in IPCC_PLOT_PARAMETERS:
|
130 |
+
param = await find_param(state, param_name, mode='IPCC')
|
131 |
+
if param:
|
132 |
+
params.update(param)
|
133 |
+
|
134 |
+
# Process all outputs in parallel using process_output
|
135 |
+
tasks = [
|
136 |
+
process_output(output_title, output['table'], output['plot'], params.copy())
|
137 |
+
for output_title, output in outputs.items()
|
138 |
+
]
|
139 |
+
results = await asyncio.gather(*tasks)
|
140 |
+
|
141 |
+
# Update outputs with results and error flags
|
142 |
+
for output_title, task_results, task_errors in results:
|
143 |
+
outputs[output_title]['sql_query'] = task_results['sql_query']
|
144 |
+
outputs[output_title]['dataframe'] = task_results['dataframe']
|
145 |
+
outputs[output_title]['figure'] = task_results['figure']
|
146 |
+
outputs[output_title]['plot_information'] = task_results['plot_information']
|
147 |
+
outputs[output_title]['status'] = task_results['status']
|
148 |
+
errors['have_sql_query'] |= task_errors['have_sql_query']
|
149 |
+
errors['have_dataframe'] |= task_errors['have_dataframe']
|
150 |
+
|
151 |
+
state['outputs'] = outputs
|
152 |
+
|
153 |
+
# Set error messages if needed
|
154 |
+
if not errors['have_relevant_table']:
|
155 |
+
state['error'] = "There is no relevant table in our database to answer your question"
|
156 |
+
elif not errors['have_sql_query']:
|
157 |
+
state['error'] = "There is no relevant sql query on our database that can help to answer your question"
|
158 |
+
elif not errors['have_dataframe']:
|
159 |
+
state['error'] = "There is no data in our table that can answer to your question"
|
160 |
+
|
161 |
+
return state
|
front/tabs/tab_drias.py
CHANGED
@@ -4,26 +4,25 @@ import os
|
|
4 |
import pandas as pd
|
5 |
|
6 |
from climateqa.engine.talk_to_data.main import ask_drias
|
7 |
-
from climateqa.engine.talk_to_data.config import DRIAS_MODELS, DRIAS_UI_TEXT
|
8 |
|
9 |
class DriasUIElements(TypedDict):
|
10 |
tab: gr.Tab
|
11 |
details_accordion: gr.Accordion
|
12 |
examples_hidden: gr.Textbox
|
13 |
examples: gr.Examples
|
|
|
14 |
drias_direct_question: gr.Textbox
|
15 |
result_text: gr.Textbox
|
16 |
-
table_names_display: gr.
|
17 |
query_accordion: gr.Accordion
|
18 |
drias_sql_query: gr.Textbox
|
19 |
chart_accordion: gr.Accordion
|
|
|
20 |
model_selection: gr.Dropdown
|
21 |
drias_display: gr.Plot
|
22 |
table_accordion: gr.Accordion
|
23 |
drias_table: gr.DataFrame
|
24 |
-
pagination_display: gr.Markdown
|
25 |
-
prev_button: gr.Button
|
26 |
-
next_button: gr.Button
|
27 |
|
28 |
|
29 |
async def ask_drias_query(query: str, index_state: int, user_id: str):
|
@@ -31,7 +30,7 @@ async def ask_drias_query(query: str, index_state: int, user_id: str):
|
|
31 |
return result
|
32 |
|
33 |
|
34 |
-
def show_results(sql_queries_state, dataframes_state, plots_state):
|
35 |
if not sql_queries_state or not dataframes_state or not plots_state:
|
36 |
# If all results are empty, show "No result"
|
37 |
return (
|
@@ -40,9 +39,6 @@ def show_results(sql_queries_state, dataframes_state, plots_state):
|
|
40 |
gr.update(visible=False),
|
41 |
gr.update(visible=False),
|
42 |
gr.update(visible=False),
|
43 |
-
gr.update(visible=False),
|
44 |
-
gr.update(visible=False),
|
45 |
-
gr.update(visible=False),
|
46 |
)
|
47 |
else:
|
48 |
# Show the appropriate components with their data
|
@@ -51,10 +47,7 @@ def show_results(sql_queries_state, dataframes_state, plots_state):
|
|
51 |
gr.update(visible=True),
|
52 |
gr.update(visible=True),
|
53 |
gr.update(visible=True),
|
54 |
-
gr.update(visible=True),
|
55 |
-
gr.update(visible=True),
|
56 |
-
gr.update(visible=True),
|
57 |
-
gr.update(visible=True),
|
58 |
)
|
59 |
|
60 |
|
@@ -72,44 +65,14 @@ def filter_by_model(dataframes, figures, index_state, model_selection):
|
|
72 |
return df, figure
|
73 |
|
74 |
|
75 |
-
def
|
76 |
-
|
77 |
-
return pagination
|
78 |
-
|
79 |
-
|
80 |
-
def show_previous(index, sql_queries, dataframes, plots):
|
81 |
-
if index > 0:
|
82 |
-
index -= 1
|
83 |
-
return (
|
84 |
-
sql_queries[index],
|
85 |
-
dataframes[index],
|
86 |
-
plots[index](dataframes[index]),
|
87 |
-
index,
|
88 |
-
)
|
89 |
-
|
90 |
-
|
91 |
-
def show_next(index, sql_queries, dataframes, plots):
|
92 |
-
if index < len(sql_queries) - 1:
|
93 |
-
index += 1
|
94 |
-
return (
|
95 |
-
sql_queries[index],
|
96 |
-
dataframes[index],
|
97 |
-
plots[index](dataframes[index]),
|
98 |
-
index,
|
99 |
-
)
|
100 |
-
|
101 |
-
|
102 |
-
def display_table_names(table_names):
|
103 |
-
return [table_names]
|
104 |
-
|
105 |
-
|
106 |
-
def on_table_click(evt: gr.SelectData, table_names, sql_queries, dataframes, plots):
|
107 |
-
index = evt.index[1]
|
108 |
figure = plots[index](dataframes[index])
|
109 |
return (
|
110 |
sql_queries[index],
|
111 |
dataframes[index],
|
112 |
figure,
|
|
|
113 |
index,
|
114 |
)
|
115 |
|
@@ -117,7 +80,7 @@ def on_table_click(evt: gr.SelectData, table_names, sql_queries, dataframes, plo
|
|
117 |
def create_drias_ui() -> DriasUIElements:
|
118 |
"""Create and return all UI elements for the DRIAS tab."""
|
119 |
with gr.Tab("France - Talk to DRIAS", elem_id="tab-vanna", id=6) as tab:
|
120 |
-
with gr.Accordion(label="
|
121 |
gr.Markdown(DRIAS_UI_TEXT)
|
122 |
|
123 |
# Add examples for common questions
|
@@ -141,24 +104,43 @@ def create_drias_ui() -> DriasUIElements:
|
|
141 |
elem_id="direct-question",
|
142 |
interactive=True,
|
143 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
|
145 |
result_text = gr.Textbox(
|
146 |
label="", elem_id="no-result-label", interactive=False, visible=True
|
147 |
)
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
)
|
157 |
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
with gr.Accordion(label="Chart", visible=False) as chart_accordion:
|
159 |
-
|
160 |
-
|
161 |
-
|
|
|
|
|
|
|
|
|
162 |
drias_display = gr.Plot(elem_id="vanna-plot")
|
163 |
|
164 |
with gr.Accordion(
|
@@ -166,32 +148,23 @@ def create_drias_ui() -> DriasUIElements:
|
|
166 |
) as table_accordion:
|
167 |
drias_table = gr.DataFrame([], elem_id="vanna-table")
|
168 |
|
169 |
-
pagination_display = gr.Markdown(
|
170 |
-
value="", visible=False, elem_id="pagination-display"
|
171 |
-
)
|
172 |
-
|
173 |
-
with gr.Row():
|
174 |
-
prev_button = gr.Button("Previous", visible=False)
|
175 |
-
next_button = gr.Button("Next", visible=False)
|
176 |
-
|
177 |
return DriasUIElements(
|
178 |
tab=tab,
|
179 |
details_accordion=details_accordion,
|
180 |
examples_hidden=examples_hidden,
|
181 |
examples=examples,
|
|
|
182 |
drias_direct_question=drias_direct_question,
|
183 |
result_text=result_text,
|
184 |
table_names_display=table_names_display,
|
185 |
query_accordion=query_accordion,
|
186 |
drias_sql_query=drias_sql_query,
|
187 |
chart_accordion=chart_accordion,
|
|
|
188 |
model_selection=model_selection,
|
189 |
drias_display=drias_display,
|
190 |
table_accordion=table_accordion,
|
191 |
drias_table=drias_table,
|
192 |
-
pagination_display=pagination_display,
|
193 |
-
prev_button=prev_button,
|
194 |
-
next_button=next_button
|
195 |
)
|
196 |
|
197 |
|
@@ -202,94 +175,56 @@ def setup_drias_events(ui_elements: DriasUIElements, share_client=None, user_id=
|
|
202 |
sql_queries_state = gr.State([])
|
203 |
dataframes_state = gr.State([])
|
204 |
plots_state = gr.State([])
|
|
|
205 |
index_state = gr.State(0)
|
206 |
table_names_list = gr.State([])
|
207 |
user_id = gr.State(user_id)
|
208 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
209 |
# Handle example selection
|
210 |
ui_elements["examples_hidden"].change(
|
211 |
lambda x: (gr.Accordion(open=False), gr.Textbox(value=x)),
|
212 |
inputs=[ui_elements["examples_hidden"]],
|
213 |
outputs=[ui_elements["details_accordion"], ui_elements["drias_direct_question"]]
|
214 |
).then(
|
215 |
-
|
216 |
-
inputs=[ui_elements["examples_hidden"], index_state, user_id],
|
217 |
-
outputs=[
|
218 |
-
ui_elements["drias_sql_query"],
|
219 |
-
ui_elements["drias_table"],
|
220 |
-
ui_elements["drias_display"],
|
221 |
-
sql_queries_state,
|
222 |
-
dataframes_state,
|
223 |
-
plots_state,
|
224 |
-
index_state,
|
225 |
-
table_names_list,
|
226 |
-
ui_elements["result_text"],
|
227 |
-
],
|
228 |
-
).then(
|
229 |
-
show_results,
|
230 |
-
inputs=[sql_queries_state, dataframes_state, plots_state],
|
231 |
-
outputs=[
|
232 |
-
ui_elements["result_text"],
|
233 |
-
ui_elements["query_accordion"],
|
234 |
-
ui_elements["table_accordion"],
|
235 |
-
ui_elements["chart_accordion"],
|
236 |
-
ui_elements["prev_button"],
|
237 |
-
ui_elements["next_button"],
|
238 |
-
ui_elements["pagination_display"],
|
239 |
-
ui_elements["table_names_display"],
|
240 |
-
],
|
241 |
-
).then(
|
242 |
-
update_pagination,
|
243 |
-
inputs=[index_state, sql_queries_state],
|
244 |
-
outputs=[ui_elements["pagination_display"]],
|
245 |
-
).then(
|
246 |
-
display_table_names,
|
247 |
-
inputs=[table_names_list],
|
248 |
-
outputs=[ui_elements["table_names_display"]],
|
249 |
-
)
|
250 |
-
|
251 |
-
# Handle direct question submission
|
252 |
-
ui_elements["drias_direct_question"].submit(
|
253 |
-
lambda: gr.Accordion(open=False),
|
254 |
inputs=None,
|
255 |
-
outputs=
|
256 |
).then(
|
257 |
ask_drias_query,
|
258 |
-
inputs=[ui_elements["
|
259 |
outputs=[
|
260 |
ui_elements["drias_sql_query"],
|
261 |
ui_elements["drias_table"],
|
262 |
ui_elements["drias_display"],
|
|
|
263 |
sql_queries_state,
|
264 |
dataframes_state,
|
265 |
plots_state,
|
|
|
266 |
index_state,
|
267 |
table_names_list,
|
268 |
ui_elements["result_text"],
|
269 |
],
|
270 |
).then(
|
271 |
show_results,
|
272 |
-
inputs=[sql_queries_state, dataframes_state, plots_state],
|
273 |
outputs=[
|
274 |
ui_elements["result_text"],
|
275 |
ui_elements["query_accordion"],
|
276 |
ui_elements["table_accordion"],
|
277 |
ui_elements["chart_accordion"],
|
278 |
-
ui_elements["prev_button"],
|
279 |
-
ui_elements["next_button"],
|
280 |
-
ui_elements["pagination_display"],
|
281 |
ui_elements["table_names_display"],
|
282 |
],
|
283 |
-
).then(
|
284 |
-
update_pagination,
|
285 |
-
inputs=[index_state, sql_queries_state],
|
286 |
-
outputs=[ui_elements["pagination_display"]],
|
287 |
-
).then(
|
288 |
-
display_table_names,
|
289 |
-
inputs=[table_names_list],
|
290 |
-
outputs=[ui_elements["table_names_display"]],
|
291 |
)
|
292 |
|
|
|
293 |
# Handle model selection change
|
294 |
ui_elements["model_selection"].change(
|
295 |
filter_by_model,
|
@@ -297,36 +232,12 @@ def setup_drias_events(ui_elements: DriasUIElements, share_client=None, user_id=
|
|
297 |
outputs=[ui_elements["drias_table"], ui_elements["drias_display"]],
|
298 |
)
|
299 |
|
300 |
-
# Handle pagination buttons
|
301 |
-
ui_elements["prev_button"].click(
|
302 |
-
show_previous,
|
303 |
-
inputs=[index_state, sql_queries_state, dataframes_state, plots_state],
|
304 |
-
outputs=[ui_elements["drias_sql_query"], ui_elements["drias_table"], ui_elements["drias_display"], index_state],
|
305 |
-
).then(
|
306 |
-
update_pagination,
|
307 |
-
inputs=[index_state, sql_queries_state],
|
308 |
-
outputs=[ui_elements["pagination_display"]],
|
309 |
-
)
|
310 |
-
|
311 |
-
ui_elements["next_button"].click(
|
312 |
-
show_next,
|
313 |
-
inputs=[index_state, sql_queries_state, dataframes_state, plots_state],
|
314 |
-
outputs=[ui_elements["drias_sql_query"], ui_elements["drias_table"], ui_elements["drias_display"], index_state],
|
315 |
-
).then(
|
316 |
-
update_pagination,
|
317 |
-
inputs=[index_state, sql_queries_state],
|
318 |
-
outputs=[ui_elements["pagination_display"]],
|
319 |
-
)
|
320 |
|
321 |
# Handle table selection
|
322 |
-
ui_elements["table_names_display"].
|
323 |
fn=on_table_click,
|
324 |
-
inputs=[table_names_list, sql_queries_state, dataframes_state, plots_state],
|
325 |
-
outputs=[ui_elements["drias_sql_query"], ui_elements["drias_table"], ui_elements["drias_display"], index_state],
|
326 |
-
).then(
|
327 |
-
update_pagination,
|
328 |
-
inputs=[index_state, sql_queries_state],
|
329 |
-
outputs=[ui_elements["pagination_display"]],
|
330 |
)
|
331 |
|
332 |
def create_drias_tab(share_client=None, user_id=None):
|
|
|
4 |
import pandas as pd
|
5 |
|
6 |
from climateqa.engine.talk_to_data.main import ask_drias
|
7 |
+
from climateqa.engine.talk_to_data.drias.config import DRIAS_MODELS, DRIAS_UI_TEXT
|
8 |
|
9 |
class DriasUIElements(TypedDict):
|
10 |
tab: gr.Tab
|
11 |
details_accordion: gr.Accordion
|
12 |
examples_hidden: gr.Textbox
|
13 |
examples: gr.Examples
|
14 |
+
image_examples: gr.Row
|
15 |
drias_direct_question: gr.Textbox
|
16 |
result_text: gr.Textbox
|
17 |
+
table_names_display: gr.Radio
|
18 |
query_accordion: gr.Accordion
|
19 |
drias_sql_query: gr.Textbox
|
20 |
chart_accordion: gr.Accordion
|
21 |
+
plot_information: gr.Markdown
|
22 |
model_selection: gr.Dropdown
|
23 |
drias_display: gr.Plot
|
24 |
table_accordion: gr.Accordion
|
25 |
drias_table: gr.DataFrame
|
|
|
|
|
|
|
26 |
|
27 |
|
28 |
async def ask_drias_query(query: str, index_state: int, user_id: str):
|
|
|
30 |
return result
|
31 |
|
32 |
|
33 |
+
def show_results(sql_queries_state, dataframes_state, plots_state, table_names):
|
34 |
if not sql_queries_state or not dataframes_state or not plots_state:
|
35 |
# If all results are empty, show "No result"
|
36 |
return (
|
|
|
39 |
gr.update(visible=False),
|
40 |
gr.update(visible=False),
|
41 |
gr.update(visible=False),
|
|
|
|
|
|
|
42 |
)
|
43 |
else:
|
44 |
# Show the appropriate components with their data
|
|
|
47 |
gr.update(visible=True),
|
48 |
gr.update(visible=True),
|
49 |
gr.update(visible=True),
|
50 |
+
gr.update(choices=table_names, value=table_names[0], visible=True),
|
|
|
|
|
|
|
51 |
)
|
52 |
|
53 |
|
|
|
65 |
return df, figure
|
66 |
|
67 |
|
68 |
+
def on_table_click(selected_label, table_names, sql_queries, dataframes, plot_informations, plots):
|
69 |
+
index = table_names.index(selected_label)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
figure = plots[index](dataframes[index])
|
71 |
return (
|
72 |
sql_queries[index],
|
73 |
dataframes[index],
|
74 |
figure,
|
75 |
+
plot_informations[index],
|
76 |
index,
|
77 |
)
|
78 |
|
|
|
80 |
def create_drias_ui() -> DriasUIElements:
|
81 |
"""Create and return all UI elements for the DRIAS tab."""
|
82 |
with gr.Tab("France - Talk to DRIAS", elem_id="tab-vanna", id=6) as tab:
|
83 |
+
with gr.Accordion(label="❓ How to use?", elem_id="details") as details_accordion:
|
84 |
gr.Markdown(DRIAS_UI_TEXT)
|
85 |
|
86 |
# Add examples for common questions
|
|
|
104 |
elem_id="direct-question",
|
105 |
interactive=True,
|
106 |
)
|
107 |
+
|
108 |
+
|
109 |
+
with gr.Row(visible=True, elem_id="example-img-container") as image_examples:
|
110 |
+
gr.Markdown("### Examples of possible visualizations")
|
111 |
+
|
112 |
+
with gr.Row():
|
113 |
+
gr.Image("./front/assets/talk_to_drias_winter_temp_paris_example.png", label="Evolution of Mean Winter Temperature in Paris", elem_classes=["example-img"])
|
114 |
+
gr.Image("./front/assets/talk_to_drias_annual_temperature_france_example.png", label="Mean Annual Temperature in 2030 in France", elem_classes=["example-img"])
|
115 |
+
gr.Image("./front/assets/talk_to_drias_frequency_remarkable_precipitation_lyon_example.png", label="Frequency of Remarkable Daily Precipitation in Lyon", elem_classes=["example-img"])
|
116 |
|
117 |
result_text = gr.Textbox(
|
118 |
label="", elem_id="no-result-label", interactive=False, visible=True
|
119 |
)
|
120 |
+
|
121 |
+
with gr.Row():
|
122 |
+
table_names_display = gr.Radio(
|
123 |
+
choices=[],
|
124 |
+
label="Relevant figures created",
|
125 |
+
interactive=True,
|
126 |
+
elem_id="table-names",
|
127 |
+
visible=False
|
128 |
)
|
129 |
|
130 |
+
with gr.Accordion(label="SQL Query Used", visible=False) as query_accordion:
|
131 |
+
drias_sql_query = gr.Textbox(
|
132 |
+
label="", elem_id="sql-query", interactive=False
|
133 |
+
)
|
134 |
+
|
135 |
+
|
136 |
with gr.Accordion(label="Chart", visible=False) as chart_accordion:
|
137 |
+
with gr.Row():
|
138 |
+
model_selection = gr.Dropdown(
|
139 |
+
label="Model", choices=DRIAS_MODELS, value="ALL", interactive=True
|
140 |
+
)
|
141 |
+
with gr.Accordion(label="Informations about the plot", open=False):
|
142 |
+
plot_information = gr.Markdown(value = "")
|
143 |
+
|
144 |
drias_display = gr.Plot(elem_id="vanna-plot")
|
145 |
|
146 |
with gr.Accordion(
|
|
|
148 |
) as table_accordion:
|
149 |
drias_table = gr.DataFrame([], elem_id="vanna-table")
|
150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
return DriasUIElements(
|
152 |
tab=tab,
|
153 |
details_accordion=details_accordion,
|
154 |
examples_hidden=examples_hidden,
|
155 |
examples=examples,
|
156 |
+
image_examples=image_examples,
|
157 |
drias_direct_question=drias_direct_question,
|
158 |
result_text=result_text,
|
159 |
table_names_display=table_names_display,
|
160 |
query_accordion=query_accordion,
|
161 |
drias_sql_query=drias_sql_query,
|
162 |
chart_accordion=chart_accordion,
|
163 |
+
plot_information=plot_information,
|
164 |
model_selection=model_selection,
|
165 |
drias_display=drias_display,
|
166 |
table_accordion=table_accordion,
|
167 |
drias_table=drias_table,
|
|
|
|
|
|
|
168 |
)
|
169 |
|
170 |
|
|
|
175 |
sql_queries_state = gr.State([])
|
176 |
dataframes_state = gr.State([])
|
177 |
plots_state = gr.State([])
|
178 |
+
plot_informations_state = gr.State([])
|
179 |
index_state = gr.State(0)
|
180 |
table_names_list = gr.State([])
|
181 |
user_id = gr.State(user_id)
|
182 |
|
183 |
+
# Handle direct question submission - trigger the same workflow by setting examples_hidden
|
184 |
+
ui_elements["drias_direct_question"].submit(
|
185 |
+
lambda x: gr.update(value=x),
|
186 |
+
inputs=[ui_elements["drias_direct_question"]],
|
187 |
+
outputs=[ui_elements["examples_hidden"]],
|
188 |
+
)
|
189 |
+
|
190 |
# Handle example selection
|
191 |
ui_elements["examples_hidden"].change(
|
192 |
lambda x: (gr.Accordion(open=False), gr.Textbox(value=x)),
|
193 |
inputs=[ui_elements["examples_hidden"]],
|
194 |
outputs=[ui_elements["details_accordion"], ui_elements["drias_direct_question"]]
|
195 |
).then(
|
196 |
+
lambda : gr.update(visible=False),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
inputs=None,
|
198 |
+
outputs=ui_elements["image_examples"]
|
199 |
).then(
|
200 |
ask_drias_query,
|
201 |
+
inputs=[ui_elements["examples_hidden"], index_state, user_id],
|
202 |
outputs=[
|
203 |
ui_elements["drias_sql_query"],
|
204 |
ui_elements["drias_table"],
|
205 |
ui_elements["drias_display"],
|
206 |
+
ui_elements["plot_information"],
|
207 |
sql_queries_state,
|
208 |
dataframes_state,
|
209 |
plots_state,
|
210 |
+
plot_informations_state,
|
211 |
index_state,
|
212 |
table_names_list,
|
213 |
ui_elements["result_text"],
|
214 |
],
|
215 |
).then(
|
216 |
show_results,
|
217 |
+
inputs=[sql_queries_state, dataframes_state, plots_state, table_names_list],
|
218 |
outputs=[
|
219 |
ui_elements["result_text"],
|
220 |
ui_elements["query_accordion"],
|
221 |
ui_elements["table_accordion"],
|
222 |
ui_elements["chart_accordion"],
|
|
|
|
|
|
|
223 |
ui_elements["table_names_display"],
|
224 |
],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
225 |
)
|
226 |
|
227 |
+
|
228 |
# Handle model selection change
|
229 |
ui_elements["model_selection"].change(
|
230 |
filter_by_model,
|
|
|
232 |
outputs=[ui_elements["drias_table"], ui_elements["drias_display"]],
|
233 |
)
|
234 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
235 |
|
236 |
# Handle table selection
|
237 |
+
ui_elements["table_names_display"].change(
|
238 |
fn=on_table_click,
|
239 |
+
inputs=[ui_elements["table_names_display"], table_names_list, sql_queries_state, dataframes_state, plot_informations_state, plots_state],
|
240 |
+
outputs=[ui_elements["drias_sql_query"], ui_elements["drias_table"], ui_elements["drias_display"], ui_elements["plot_information"], index_state],
|
|
|
|
|
|
|
|
|
241 |
)
|
242 |
|
243 |
def create_drias_tab(share_client=None, user_id=None):
|
front/tabs/tab_ipcc.py
ADDED
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from random import choices
|
2 |
+
import gradio as gr
|
3 |
+
from typing import TypedDict
|
4 |
+
from climateqa.engine.talk_to_data.main import ask_ipcc
|
5 |
+
from climateqa.engine.talk_to_data.ipcc.config import IPCC_MODELS, IPCC_SCENARIO, IPCC_UI_TEXT
|
6 |
+
import uuid
|
7 |
+
|
8 |
+
class ipccUIElements(TypedDict):
|
9 |
+
tab: gr.Tab
|
10 |
+
details_accordion: gr.Accordion
|
11 |
+
examples_hidden: gr.Textbox
|
12 |
+
examples: gr.Examples
|
13 |
+
image_examples: gr.Row
|
14 |
+
ipcc_direct_question: gr.Textbox
|
15 |
+
result_text: gr.Textbox
|
16 |
+
table_names_display: gr.Radio
|
17 |
+
query_accordion: gr.Accordion
|
18 |
+
ipcc_sql_query: gr.Textbox
|
19 |
+
chart_accordion: gr.Accordion
|
20 |
+
plot_information: gr.Markdown
|
21 |
+
scenario_selection: gr.Dropdown
|
22 |
+
ipcc_display: gr.Plot
|
23 |
+
table_accordion: gr.Accordion
|
24 |
+
ipcc_table: gr.DataFrame
|
25 |
+
|
26 |
+
|
27 |
+
async def ask_ipcc_query(query: str, index_state: int, user_id: str):
|
28 |
+
result = await ask_ipcc(query, index_state, user_id)
|
29 |
+
return result
|
30 |
+
|
31 |
+
def hide_outputs():
|
32 |
+
"""Hide all outputs initially."""
|
33 |
+
return (
|
34 |
+
gr.update(visible=True), # Show the result text
|
35 |
+
gr.update(visible=False), # Hide the query accordion
|
36 |
+
gr.update(visible=False), # Hide the table accordion
|
37 |
+
gr.update(visible=False), # Hide the chart accordion
|
38 |
+
gr.update(visible=False), # Hide table names
|
39 |
+
)
|
40 |
+
|
41 |
+
def show_results(sql_queries_state, dataframes_state, plots_state, table_names):
|
42 |
+
if not sql_queries_state or not dataframes_state or not plots_state:
|
43 |
+
# If all results are empty, show "No result"
|
44 |
+
return (
|
45 |
+
gr.update(visible=True),
|
46 |
+
gr.update(visible=False),
|
47 |
+
gr.update(visible=False),
|
48 |
+
gr.update(visible=False),
|
49 |
+
gr.update(visible=False),
|
50 |
+
)
|
51 |
+
else:
|
52 |
+
# Show the appropriate components with their data
|
53 |
+
return (
|
54 |
+
gr.update(visible=False),
|
55 |
+
gr.update(visible=True),
|
56 |
+
gr.update(visible=True),
|
57 |
+
gr.update(visible=True),
|
58 |
+
gr.update(choices=table_names, value=table_names[0], visible=True),
|
59 |
+
)
|
60 |
+
|
61 |
+
|
62 |
+
def show_filter_by_scenario(table_names, index_state, dataframes):
|
63 |
+
if len(table_names) > 0 and table_names[index_state].startswith("Map"):
|
64 |
+
df = dataframes[index_state]
|
65 |
+
scenarios = sorted(df["scenario"].unique())
|
66 |
+
return gr.update(visible=True, choices=scenarios, value=scenarios[0])
|
67 |
+
else:
|
68 |
+
return gr.update(visible=False)
|
69 |
+
|
70 |
+
def filter_by_scenario(dataframes, figures, table_names, index_state, scenario):
|
71 |
+
df = dataframes[index_state]
|
72 |
+
if not table_names[index_state].startswith("Map"):
|
73 |
+
return df, figures[index_state](df)
|
74 |
+
if df.empty:
|
75 |
+
return df, None
|
76 |
+
if "scenario" not in df.columns:
|
77 |
+
return df, figures[index_state](df)
|
78 |
+
else:
|
79 |
+
df = df[df["scenario"] == scenario]
|
80 |
+
if df.empty:
|
81 |
+
return df, None
|
82 |
+
figure = figures[index_state](df)
|
83 |
+
return df, figure
|
84 |
+
|
85 |
+
|
86 |
+
def display_table_names(table_names, index_state):
|
87 |
+
return [
|
88 |
+
[name]
|
89 |
+
for name in table_names
|
90 |
+
]
|
91 |
+
|
92 |
+
def on_table_click(selected_label, table_names, sql_queries, dataframes, plot_informations, plots):
|
93 |
+
index = table_names.index(selected_label)
|
94 |
+
figure = plots[index](dataframes[index])
|
95 |
+
|
96 |
+
return (
|
97 |
+
sql_queries[index],
|
98 |
+
dataframes[index],
|
99 |
+
figure,
|
100 |
+
plot_informations[index],
|
101 |
+
index,
|
102 |
+
)
|
103 |
+
|
104 |
+
|
105 |
+
def create_ipcc_ui() -> ipccUIElements:
|
106 |
+
|
107 |
+
"""Create and return all UI elements for the ipcc tab."""
|
108 |
+
with gr.Tab("(Beta) Talk to IPCC", elem_id="tab-vanna", id=7) as tab:
|
109 |
+
with gr.Accordion(label="❓ How to use?", elem_id="details") as details_accordion:
|
110 |
+
gr.Markdown(IPCC_UI_TEXT)
|
111 |
+
|
112 |
+
# Add examples for common questions
|
113 |
+
examples_hidden = gr.Textbox(visible=False, elem_id="ipcc-examples-hidden")
|
114 |
+
examples = gr.Examples(
|
115 |
+
examples=[
|
116 |
+
["What will the temperature be like in Paris?"],
|
117 |
+
["What will be the total rainfall in the USA in 2030?"],
|
118 |
+
["How will the average temperature evolve in China?"],
|
119 |
+
["What will be the average total precipitation in London ?"]
|
120 |
+
],
|
121 |
+
label="Example Questions",
|
122 |
+
inputs=[examples_hidden],
|
123 |
+
outputs=[examples_hidden],
|
124 |
+
)
|
125 |
+
|
126 |
+
with gr.Row():
|
127 |
+
ipcc_direct_question = gr.Textbox(
|
128 |
+
label="Direct Question",
|
129 |
+
placeholder="You can write direct question here",
|
130 |
+
elem_id="direct-question",
|
131 |
+
interactive=True,
|
132 |
+
)
|
133 |
+
|
134 |
+
with gr.Row(visible=True, elem_id="example-img-container") as image_examples:
|
135 |
+
gr.Markdown("### Examples of possible visualizations")
|
136 |
+
|
137 |
+
with gr.Row():
|
138 |
+
gr.Image("./front/assets/talk_to_ipcc_france_example.png", label="Total Precipitation in 2030 in France", elem_classes=["example-img"])
|
139 |
+
gr.Image("./front/assets/talk_to_ipcc_new_york_example.png", label="Yearly Evolution of Mean Temperature in New York (Historical + SSP Scenarios)", elem_classes=["example-img"])
|
140 |
+
gr.Image("./front/assets/talk_to_ipcc_china_example.png", label="Mean Temperature in 2050 in China", elem_classes=["example-img"])
|
141 |
+
|
142 |
+
result_text = gr.Textbox(
|
143 |
+
label="", elem_id="no-result-label", interactive=False, visible=True
|
144 |
+
)
|
145 |
+
with gr.Row():
|
146 |
+
table_names_display = gr.Radio(
|
147 |
+
choices=[],
|
148 |
+
label="Relevant figures created",
|
149 |
+
interactive=True,
|
150 |
+
elem_id="table-names",
|
151 |
+
visible=False
|
152 |
+
)
|
153 |
+
|
154 |
+
with gr.Accordion(label="SQL Query Used", visible=False) as query_accordion:
|
155 |
+
ipcc_sql_query = gr.Textbox(
|
156 |
+
label="", elem_id="sql-query", interactive=False
|
157 |
+
)
|
158 |
+
|
159 |
+
with gr.Accordion(label="Chart", visible=False) as chart_accordion:
|
160 |
+
|
161 |
+
with gr.Row():
|
162 |
+
scenario_selection = gr.Dropdown(
|
163 |
+
label="Scenario", choices=IPCC_SCENARIO, value=IPCC_SCENARIO[0], interactive=True, visible=False
|
164 |
+
)
|
165 |
+
|
166 |
+
with gr.Accordion(label="Informations about the plot", open=False):
|
167 |
+
plot_information = gr.Markdown(value = "")
|
168 |
+
|
169 |
+
ipcc_display = gr.Plot(elem_id="vanna-plot")
|
170 |
+
|
171 |
+
with gr.Accordion(
|
172 |
+
label="Data used", open=False, visible=False
|
173 |
+
) as table_accordion:
|
174 |
+
ipcc_table = gr.DataFrame([], elem_id="vanna-table")
|
175 |
+
|
176 |
+
|
177 |
+
return ipccUIElements(
|
178 |
+
tab=tab,
|
179 |
+
details_accordion=details_accordion,
|
180 |
+
examples_hidden=examples_hidden,
|
181 |
+
examples=examples,
|
182 |
+
image_examples=image_examples,
|
183 |
+
ipcc_direct_question=ipcc_direct_question,
|
184 |
+
result_text=result_text,
|
185 |
+
table_names_display=table_names_display,
|
186 |
+
query_accordion=query_accordion,
|
187 |
+
ipcc_sql_query=ipcc_sql_query,
|
188 |
+
chart_accordion=chart_accordion,
|
189 |
+
plot_information=plot_information,
|
190 |
+
scenario_selection=scenario_selection,
|
191 |
+
ipcc_display=ipcc_display,
|
192 |
+
table_accordion=table_accordion,
|
193 |
+
ipcc_table=ipcc_table,
|
194 |
+
)
|
195 |
+
|
196 |
+
|
197 |
+
|
198 |
+
def setup_ipcc_events(ui_elements: ipccUIElements, share_client=None, user_id=None) -> None:
|
199 |
+
"""Set up all event handlers for the ipcc tab."""
|
200 |
+
# Create state variables
|
201 |
+
sql_queries_state = gr.State([])
|
202 |
+
dataframes_state = gr.State([])
|
203 |
+
plots_state = gr.State([])
|
204 |
+
plot_informations_state = gr.State([])
|
205 |
+
index_state = gr.State(0)
|
206 |
+
table_names_list = gr.State([])
|
207 |
+
user_id = gr.State(user_id)
|
208 |
+
|
209 |
+
# Handle direct question submission - trigger the same workflow by setting examples_hidden
|
210 |
+
ui_elements["ipcc_direct_question"].submit(
|
211 |
+
lambda x: gr.update(value=x),
|
212 |
+
inputs=[ui_elements["ipcc_direct_question"]],
|
213 |
+
outputs=[ui_elements["examples_hidden"]],
|
214 |
+
)
|
215 |
+
|
216 |
+
# Handle example selection
|
217 |
+
ui_elements["examples_hidden"].change(
|
218 |
+
lambda x: (gr.Accordion(open=False), gr.Textbox(value=x)),
|
219 |
+
inputs=[ui_elements["examples_hidden"]],
|
220 |
+
outputs=[ui_elements["details_accordion"], ui_elements["ipcc_direct_question"]]
|
221 |
+
).then(
|
222 |
+
lambda : gr.update(visible=False),
|
223 |
+
inputs=None,
|
224 |
+
outputs=ui_elements["image_examples"]
|
225 |
+
).then(
|
226 |
+
hide_outputs,
|
227 |
+
inputs=None,
|
228 |
+
outputs=[
|
229 |
+
ui_elements["result_text"],
|
230 |
+
ui_elements["query_accordion"],
|
231 |
+
ui_elements["table_accordion"],
|
232 |
+
ui_elements["chart_accordion"],
|
233 |
+
ui_elements["table_names_display"],
|
234 |
+
]
|
235 |
+
).then(
|
236 |
+
ask_ipcc_query,
|
237 |
+
inputs=[ui_elements["examples_hidden"], index_state, user_id],
|
238 |
+
outputs=[
|
239 |
+
ui_elements["ipcc_sql_query"],
|
240 |
+
ui_elements["ipcc_table"],
|
241 |
+
ui_elements["ipcc_display"],
|
242 |
+
ui_elements["plot_information"],
|
243 |
+
sql_queries_state,
|
244 |
+
dataframes_state,
|
245 |
+
plots_state,
|
246 |
+
plot_informations_state,
|
247 |
+
index_state,
|
248 |
+
table_names_list,
|
249 |
+
ui_elements["result_text"],
|
250 |
+
],
|
251 |
+
).then(
|
252 |
+
show_results,
|
253 |
+
inputs=[sql_queries_state, dataframes_state, plots_state, table_names_list],
|
254 |
+
outputs=[
|
255 |
+
ui_elements["result_text"],
|
256 |
+
ui_elements["query_accordion"],
|
257 |
+
ui_elements["table_accordion"],
|
258 |
+
ui_elements["chart_accordion"],
|
259 |
+
ui_elements["table_names_display"],
|
260 |
+
],
|
261 |
+
).then(
|
262 |
+
show_filter_by_scenario,
|
263 |
+
inputs=[table_names_list, index_state, dataframes_state],
|
264 |
+
outputs=[ui_elements["scenario_selection"]],
|
265 |
+
).then(
|
266 |
+
filter_by_scenario,
|
267 |
+
inputs=[dataframes_state, plots_state, table_names_list, index_state, ui_elements["scenario_selection"]],
|
268 |
+
outputs=[ui_elements["ipcc_table"], ui_elements["ipcc_display"]],
|
269 |
+
)
|
270 |
+
|
271 |
+
|
272 |
+
# Handle model selection change
|
273 |
+
ui_elements["scenario_selection"].change(
|
274 |
+
filter_by_scenario,
|
275 |
+
inputs=[dataframes_state, plots_state, table_names_list, index_state, ui_elements["scenario_selection"]],
|
276 |
+
outputs=[ui_elements["ipcc_table"], ui_elements["ipcc_display"]],
|
277 |
+
)
|
278 |
+
|
279 |
+
# Handle table selection
|
280 |
+
ui_elements["table_names_display"].change(
|
281 |
+
fn=on_table_click,
|
282 |
+
inputs=[ui_elements["table_names_display"], table_names_list, sql_queries_state, dataframes_state, plot_informations_state, plots_state],
|
283 |
+
outputs=[ui_elements["ipcc_sql_query"], ui_elements["ipcc_table"], ui_elements["ipcc_display"], ui_elements["plot_information"], index_state],
|
284 |
+
).then(
|
285 |
+
show_filter_by_scenario,
|
286 |
+
inputs=[table_names_list, index_state, dataframes_state],
|
287 |
+
outputs=[ui_elements["scenario_selection"]],
|
288 |
+
).then(
|
289 |
+
filter_by_scenario,
|
290 |
+
inputs=[dataframes_state, plots_state, table_names_list, index_state, ui_elements["scenario_selection"]],
|
291 |
+
outputs=[ui_elements["ipcc_table"], ui_elements["ipcc_display"]],
|
292 |
+
)
|
293 |
+
|
294 |
+
|
295 |
+
def create_ipcc_tab(share_client=None, user_id=None):
|
296 |
+
"""Create the ipcc tab with all its components and event handlers."""
|
297 |
+
ui_elements = create_ipcc_ui()
|
298 |
+
setup_ipcc_events(ui_elements, share_client=share_client, user_id=user_id)
|
299 |
+
|
300 |
+
|
requirements.txt
CHANGED
@@ -25,4 +25,5 @@ geopy==2.4.1
|
|
25 |
duckdb==1.2.1
|
26 |
openai==1.61.1
|
27 |
pydantic==2.9.2
|
28 |
-
pydantic-settings==2.2.1
|
|
|
|
25 |
duckdb==1.2.1
|
26 |
openai==1.61.1
|
27 |
pydantic==2.9.2
|
28 |
+
pydantic-settings==2.2.1
|
29 |
+
geojson==3.2.0
|
style.css
CHANGED
@@ -656,12 +656,11 @@ a {
|
|
656 |
/* overflow-y: scroll; */
|
657 |
}
|
658 |
#sql-query{
|
659 |
-
max-height:
|
660 |
-
overflow-y:scroll;
|
661 |
}
|
662 |
|
663 |
#sql-query textarea{
|
664 |
-
min-height:
|
665 |
}
|
666 |
|
667 |
#sql-query span{
|
@@ -671,8 +670,11 @@ div#tab-vanna{
|
|
671 |
max-height: 100¨vh;
|
672 |
overflow-y: hidden;
|
673 |
}
|
|
|
|
|
|
|
674 |
#vanna-plot{
|
675 |
-
max-height:
|
676 |
}
|
677 |
|
678 |
#pagination-display{
|
@@ -681,13 +683,33 @@ div#tab-vanna{
|
|
681 |
font-size: 16px;
|
682 |
}
|
683 |
|
684 |
-
|
685 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
686 |
}
|
687 |
-
|
|
|
|
|
|
|
|
|
|
|
688 |
display: none;
|
689 |
}
|
690 |
|
|
|
|
|
|
|
|
|
|
|
691 |
/* DRIAS Data Table Styles */
|
692 |
#vanna-table {
|
693 |
height: 400px !important;
|
@@ -710,3 +732,13 @@ div#tab-vanna{
|
|
710 |
background: white;
|
711 |
z-index: 1;
|
712 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
656 |
/* overflow-y: scroll; */
|
657 |
}
|
658 |
#sql-query{
|
659 |
+
max-height: 100%;
|
|
|
660 |
}
|
661 |
|
662 |
#sql-query textarea{
|
663 |
+
min-height: 200px !important;
|
664 |
}
|
665 |
|
666 |
#sql-query span{
|
|
|
670 |
max-height: 100¨vh;
|
671 |
overflow-y: hidden;
|
672 |
}
|
673 |
+
#details button span{
|
674 |
+
font-weight: bold;
|
675 |
+
}
|
676 |
#vanna-plot{
|
677 |
+
max-height:1000px
|
678 |
}
|
679 |
|
680 |
#pagination-display{
|
|
|
683 |
font-size: 16px;
|
684 |
}
|
685 |
|
686 |
+
|
687 |
+
#table-names label {
|
688 |
+
display: block;
|
689 |
+
width: 100%;
|
690 |
+
box-sizing: border-box;
|
691 |
+
padding: 8px 12px;
|
692 |
+
margin-bottom: 4px;
|
693 |
+
border: 1px solid #ccc;
|
694 |
+
border-radius: 6px;
|
695 |
+
background-color: white;
|
696 |
+
cursor: pointer;
|
697 |
+
text-align: center;
|
698 |
}
|
699 |
+
|
700 |
+
#table-names label:hover {
|
701 |
+
background-color: #f0f8ff;
|
702 |
+
}
|
703 |
+
|
704 |
+
#table-names input[type="radio"] {
|
705 |
display: none;
|
706 |
}
|
707 |
|
708 |
+
#table-names input[type="radio"]:checked + label {
|
709 |
+
background-color: #d0eaff;
|
710 |
+
border-color: #2196f3;
|
711 |
+
}
|
712 |
+
|
713 |
/* DRIAS Data Table Styles */
|
714 |
#vanna-table {
|
715 |
height: 400px !important;
|
|
|
732 |
background: white;
|
733 |
z-index: 1;
|
734 |
}
|
735 |
+
|
736 |
+
.example-img{
|
737 |
+
height: 250px;
|
738 |
+
object-fit: contain;
|
739 |
+
}
|
740 |
+
|
741 |
+
#example-img-container {
|
742 |
+
flex-direction: column;
|
743 |
+
align-items: left;
|
744 |
+
}
|