armanddemasson commited on
Commit
11ab5fb
·
2 Parent(s): e92e8dc 819e3c0

Merged in feature/talk_to_data (pull request #19)

Browse files
Files changed (29) hide show
  1. app.py +3 -2
  2. climateqa/engine/chains/retrieve_documents.py +6 -4
  3. climateqa/engine/talk_to_data/config.py +8 -96
  4. climateqa/engine/talk_to_data/drias/config.py +124 -0
  5. climateqa/engine/talk_to_data/drias/plot_informations.py +88 -0
  6. climateqa/engine/talk_to_data/{plot.py → drias/plots.py} +72 -56
  7. climateqa/engine/talk_to_data/{sql_query.py → drias/queries.py} +5 -36
  8. climateqa/engine/talk_to_data/{utils.py → input_processing.py} +144 -168
  9. climateqa/engine/talk_to_data/ipcc/config.py +98 -0
  10. climateqa/engine/talk_to_data/ipcc/plot_informations.py +50 -0
  11. climateqa/engine/talk_to_data/ipcc/plots.py +189 -0
  12. climateqa/engine/talk_to_data/ipcc/queries.py +143 -0
  13. climateqa/engine/talk_to_data/main.py +77 -71
  14. climateqa/engine/talk_to_data/objects/llm_outputs.py +13 -0
  15. climateqa/engine/talk_to_data/objects/location.py +12 -0
  16. climateqa/engine/talk_to_data/objects/plot.py +23 -0
  17. climateqa/engine/talk_to_data/objects/states.py +19 -0
  18. climateqa/engine/talk_to_data/prompt.py +44 -0
  19. climateqa/engine/talk_to_data/query.py +57 -0
  20. climateqa/engine/talk_to_data/talk_to_drias.py +0 -317
  21. climateqa/engine/talk_to_data/ui_config.py +27 -0
  22. climateqa/engine/talk_to_data/{myVanna.py → vanna/myVanna.py} +0 -0
  23. climateqa/engine/talk_to_data/{vanna_class.py → vanna/vanna_class.py} +0 -0
  24. climateqa/engine/talk_to_data/workflow/drias.py +163 -0
  25. climateqa/engine/talk_to_data/workflow/ipcc.py +161 -0
  26. front/tabs/tab_drias.py +60 -149
  27. front/tabs/tab_ipcc.py +300 -0
  28. requirements.txt +2 -1
  29. style.css +39 -7
app.py CHANGED
@@ -16,6 +16,7 @@ from climateqa.chat import start_chat, chat_stream, finish_chat
16
  from front.tabs import create_config_modal, cqa_tab, create_about_tab
17
  from front.tabs import MainTabPanel, ConfigPanel
18
  from front.tabs.tab_drias import create_drias_tab
 
19
  from front.utils import process_figures
20
  from gradio_modal import Modal
21
 
@@ -532,8 +533,8 @@ def main_ui():
532
  with gr.Tabs():
533
  cqa_components = cqa_tab(tab_name="ClimateQ&A")
534
  local_cqa_components = cqa_tab(tab_name="France - Local Q&A")
535
- create_drias_tab(share_client=share_client, user_id=user_id)
536
-
537
  create_about_tab()
538
 
539
  event_handling(cqa_components, config_components, tab_name="ClimateQ&A")
 
16
  from front.tabs import create_config_modal, cqa_tab, create_about_tab
17
  from front.tabs import MainTabPanel, ConfigPanel
18
  from front.tabs.tab_drias import create_drias_tab
19
+ from front.tabs.tab_ipcc import create_ipcc_tab
20
  from front.utils import process_figures
21
  from gradio_modal import Modal
22
 
 
533
  with gr.Tabs():
534
  cqa_components = cqa_tab(tab_name="ClimateQ&A")
535
  local_cqa_components = cqa_tab(tab_name="France - Local Q&A")
536
+ drias_components = create_drias_tab(share_client=share_client, user_id=user_id)
537
+ ipcc_components = create_ipcc_tab(share_client=share_client, user_id=user_id)
538
  create_about_tab()
539
 
540
  event_handling(cqa_components, config_components, tab_name="ClimateQ&A")
climateqa/engine/chains/retrieve_documents.py CHANGED
@@ -21,7 +21,7 @@ from langchain_core.prompts import ChatPromptTemplate
21
  from langchain_core.output_parsers import StrOutputParser
22
  from ..vectorstore import get_pinecone_vectorstore
23
  from ..embeddings import get_embeddings_function
24
-
25
 
26
  import asyncio
27
 
@@ -477,8 +477,10 @@ async def retrieve_documents(
477
  docs_question_dict[key] = rerank_and_sort_docs(reranker,docs_question_dict[key],question)
478
  else:
479
  # Add a default reranking score
480
- for doc in docs_question:
481
- doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
 
 
482
 
483
  # Keep the right number of documents
484
  docs_question, images_question = concatenate_documents(index, source_type, docs_question_dict, k_by_question, k_summary_by_question, k_images_by_question)
@@ -580,7 +582,7 @@ async def get_relevant_toc_level_for_query(
580
  response = chain.invoke({"query": query, "doc_list": doc_list})
581
 
582
  try:
583
- relevant_tocs = eval(response)
584
  except Exception as e:
585
  print(f" Failed to parse the result because of : {e}")
586
 
 
21
  from langchain_core.output_parsers import StrOutputParser
22
  from ..vectorstore import get_pinecone_vectorstore
23
  from ..embeddings import get_embeddings_function
24
+ import ast
25
 
26
  import asyncio
27
 
 
477
  docs_question_dict[key] = rerank_and_sort_docs(reranker,docs_question_dict[key],question)
478
  else:
479
  # Add a default reranking score
480
+ for key in docs_question_dict.keys():
481
+ if isinstance(docs_question_dict[key], list) and len(docs_question_dict[key]) > 0:
482
+ for doc in docs_question_dict[key]:
483
+ doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
484
 
485
  # Keep the right number of documents
486
  docs_question, images_question = concatenate_documents(index, source_type, docs_question_dict, k_by_question, k_summary_by_question, k_images_by_question)
 
582
  response = chain.invoke({"query": query, "doc_list": doc_list})
583
 
584
  try:
585
+ relevant_tocs = ast.literal_eval(response)
586
  except Exception as e:
587
  print(f" Failed to parse the result because of : {e}")
588
 
climateqa/engine/talk_to_data/config.py CHANGED
@@ -1,99 +1,11 @@
1
- DRIAS_TABLES = [
2
- "total_winter_precipitation",
3
- "total_summer_precipiation",
4
- "total_annual_precipitation",
5
- "total_remarkable_daily_precipitation",
6
- "frequency_of_remarkable_daily_precipitation",
7
- "extreme_precipitation_intensity",
8
- "mean_winter_temperature",
9
- "mean_summer_temperature",
10
- "mean_annual_temperature",
11
- "number_of_tropical_nights",
12
- "maximum_summer_temperature",
13
- "number_of_days_with_tx_above_30",
14
- "number_of_days_with_tx_above_35",
15
- "number_of_days_with_a_dry_ground",
16
- ]
17
 
18
- INDICATOR_COLUMNS_PER_TABLE = {
19
- "total_winter_precipitation": "total_winter_precipitation",
20
- "total_summer_precipiation": "total_summer_precipitation",
21
- "total_annual_precipitation": "total_annual_precipitation",
22
- "total_remarkable_daily_precipitation": "total_remarkable_daily_precipitation",
23
- "frequency_of_remarkable_daily_precipitation": "frequency_of_remarkable_daily_precipitation",
24
- "extreme_precipitation_intensity": "extreme_precipitation_intensity",
25
- "mean_winter_temperature": "mean_winter_temperature",
26
- "mean_summer_temperature": "mean_summer_temperature",
27
- "mean_annual_temperature": "mean_annual_temperature",
28
- "number_of_tropical_nights": "number_tropical_nights",
29
- "maximum_summer_temperature": "maximum_summer_temperature",
30
- "number_of_days_with_tx_above_30": "number_of_days_with_tx_above_30",
31
- "number_of_days_with_tx_above_35": "number_of_days_with_tx_above_35",
32
- "number_of_days_with_a_dry_ground": "number_of_days_with_dry_ground"
33
- }
34
 
35
- DRIAS_MODELS = [
36
- 'ALL',
37
- 'RegCM4-6_MPI-ESM-LR',
38
- 'RACMO22E_EC-EARTH',
39
- 'RegCM4-6_HadGEM2-ES',
40
- 'HadREM3-GA7_EC-EARTH',
41
- 'HadREM3-GA7_CNRM-CM5',
42
- 'REMO2015_NorESM1-M',
43
- 'SMHI-RCA4_EC-EARTH',
44
- 'WRF381P_NorESM1-M',
45
- 'ALADIN63_CNRM-CM5',
46
- 'CCLM4-8-17_MPI-ESM-LR',
47
- 'HIRHAM5_IPSL-CM5A-MR',
48
- 'HadREM3-GA7_HadGEM2-ES',
49
- 'SMHI-RCA4_IPSL-CM5A-MR',
50
- 'HIRHAM5_NorESM1-M',
51
- 'REMO2009_MPI-ESM-LR',
52
- 'CCLM4-8-17_HadGEM2-ES'
53
- ]
54
- # Mapping between indicator columns and their units
55
- INDICATOR_TO_UNIT = {
56
- "total_winter_precipitation": "mm",
57
- "total_summer_precipitation": "mm",
58
- "total_annual_precipitation": "mm",
59
- "total_remarkable_daily_precipitation": "mm",
60
- "frequency_of_remarkable_daily_precipitation": "days",
61
- "extreme_precipitation_intensity": "mm",
62
- "mean_winter_temperature": "°C",
63
- "mean_summer_temperature": "°C",
64
- "mean_annual_temperature": "°C",
65
- "number_tropical_nights": "days",
66
- "maximum_summer_temperature": "°C",
67
- "number_of_days_with_tx_above_30": "days",
68
- "number_of_days_with_tx_above_35": "days",
69
- "number_of_days_with_dry_ground": "days"
70
- }
71
 
72
- DRIAS_UI_TEXT = """
73
- Hi, I'm **Talk to Drias**, designed to answer your questions using [**DRIAS - TRACC 2023**](https://www.drias-climat.fr/accompagnement/sections/401) data.
74
- I'll answer by displaying a list of SQL queries, graphs and data most relevant to your question.
75
-
76
- ❓ **How to use?**
77
- You can ask me anything about these climate indicators: **temperature**, **precipitation** or **drought**.
78
- You can specify **location** and/or **year**.
79
- You can choose from a list of climate models. By default, we take the **average of each model**.
80
-
81
- For example, you can ask:
82
- - What will the temperature be like in Paris?
83
- - What will be the total rainfall in France in 2030?
84
- - How frequent will extreme events be in Lyon?
85
-
86
- **Example of indicators in the data**:
87
- - Mean temperature (annual, winter, summer)
88
- - Total precipitation (annual, winter, summer)
89
- - Number of days with remarkable precipitations, with dry ground, with temperature above 30°C
90
-
91
- ⚠️ **Limitations**:
92
- - You can't ask anything that isn't related to **DRIAS - TRACC 2023** data.
93
- - You can only ask about **locations in France**.
94
- - If you specify a year, there may be **no data for that year for some models**.
95
- - You **cannot compare two models**.
96
-
97
- 🛈 **Information**
98
- Please note that we **log your questions for meta-analysis purposes**, so avoid sharing any sensitive or personal information.
99
- """
 
1
+ # Path configuration for climateqa project
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
+ # IPCC dataset path
4
+ IPCC_DATASET_URL = "hf://datasets/ekimetrics/ipcc-atlas"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
+ # DRIAS dataset paths
7
+ DRIAS_DATASET_URL = "hf://datasets/timeki/drias_db"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ # Table paths
10
+ DRIAS_MEAN_ANNUAL_TEMPERATURE_PATH = f"{DRIAS_DATASET_URL}/mean_annual_temperature.parquet"
11
+ IPCC_COORDINATES_PATH = f"{IPCC_DATASET_URL}/coordinates.parquet"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/engine/talk_to_data/drias/config.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from climateqa.engine.talk_to_data.ui_config import PRECIPITATION_COLORSCALE, TEMPERATURE_COLORSCALE
3
+
4
+
5
+ DRIAS_TABLES = [
6
+ "total_winter_precipitation",
7
+ "total_summer_precipitation",
8
+ "total_annual_precipitation",
9
+ "total_remarkable_daily_precipitation",
10
+ "frequency_of_remarkable_daily_precipitation",
11
+ "extreme_precipitation_intensity",
12
+ "mean_winter_temperature",
13
+ "mean_summer_temperature",
14
+ "mean_annual_temperature",
15
+ "number_of_tropical_nights",
16
+ "maximum_summer_temperature",
17
+ "number_of_days_with_tx_above_30",
18
+ "number_of_days_with_tx_above_35",
19
+ "number_of_days_with_a_dry_ground",
20
+ ]
21
+
22
+ DRIAS_INDICATOR_COLUMNS_PER_TABLE = {
23
+ "total_winter_precipitation": "total_winter_precipitation",
24
+ "total_summer_precipitation": "total_summer_precipitation",
25
+ "total_annual_precipitation": "total_annual_precipitation",
26
+ "total_remarkable_daily_precipitation": "total_remarkable_daily_precipitation",
27
+ "frequency_of_remarkable_daily_precipitation": "frequency_of_remarkable_daily_precipitation",
28
+ "extreme_precipitation_intensity": "extreme_precipitation_intensity",
29
+ "mean_winter_temperature": "mean_winter_temperature",
30
+ "mean_summer_temperature": "mean_summer_temperature",
31
+ "mean_annual_temperature": "mean_annual_temperature",
32
+ "number_of_tropical_nights": "number_tropical_nights",
33
+ "maximum_summer_temperature": "maximum_summer_temperature",
34
+ "number_of_days_with_tx_above_30": "number_of_days_with_tx_above_30",
35
+ "number_of_days_with_tx_above_35": "number_of_days_with_tx_above_35",
36
+ "number_of_days_with_a_dry_ground": "number_of_days_with_dry_ground"
37
+ }
38
+
39
+ DRIAS_MODELS = [
40
+ 'ALL',
41
+ 'RegCM4-6_MPI-ESM-LR',
42
+ 'RACMO22E_EC-EARTH',
43
+ 'RegCM4-6_HadGEM2-ES',
44
+ 'HadREM3-GA7_EC-EARTH',
45
+ 'HadREM3-GA7_CNRM-CM5',
46
+ 'REMO2015_NorESM1-M',
47
+ 'SMHI-RCA4_EC-EARTH',
48
+ 'WRF381P_NorESM1-M',
49
+ 'ALADIN63_CNRM-CM5',
50
+ 'CCLM4-8-17_MPI-ESM-LR',
51
+ 'HIRHAM5_IPSL-CM5A-MR',
52
+ 'HadREM3-GA7_HadGEM2-ES',
53
+ 'SMHI-RCA4_IPSL-CM5A-MR',
54
+ 'HIRHAM5_NorESM1-M',
55
+ 'REMO2009_MPI-ESM-LR',
56
+ 'CCLM4-8-17_HadGEM2-ES'
57
+ ]
58
+ # Mapping between indicator columns and their units
59
+ DRIAS_INDICATOR_TO_UNIT = {
60
+ "total_winter_precipitation": "mm",
61
+ "total_summer_precipitation": "mm",
62
+ "total_annual_precipitation": "mm",
63
+ "total_remarkable_daily_precipitation": "mm",
64
+ "frequency_of_remarkable_daily_precipitation": "days",
65
+ "extreme_precipitation_intensity": "mm",
66
+ "mean_winter_temperature": "°C",
67
+ "mean_summer_temperature": "°C",
68
+ "mean_annual_temperature": "°C",
69
+ "number_tropical_nights": "days",
70
+ "maximum_summer_temperature": "°C",
71
+ "number_of_days_with_tx_above_30": "days",
72
+ "number_of_days_with_tx_above_35": "days",
73
+ "number_of_days_with_dry_ground": "days"
74
+ }
75
+
76
+ DRIAS_PLOT_PARAMETERS = [
77
+ 'year',
78
+ 'location'
79
+ ]
80
+
81
+ DRIAS_INDICATOR_TO_COLORSCALE = {
82
+ "total_winter_precipitation": PRECIPITATION_COLORSCALE,
83
+ "total_summer_precipitation": PRECIPITATION_COLORSCALE,
84
+ "total_annual_precipitation": PRECIPITATION_COLORSCALE,
85
+ "total_remarkable_daily_precipitation": PRECIPITATION_COLORSCALE,
86
+ "frequency_of_remarkable_daily_precipitation": PRECIPITATION_COLORSCALE,
87
+ "extreme_precipitation_intensity": PRECIPITATION_COLORSCALE,
88
+ "mean_winter_temperature":TEMPERATURE_COLORSCALE,
89
+ "mean_summer_temperature":TEMPERATURE_COLORSCALE,
90
+ "mean_annual_temperature":TEMPERATURE_COLORSCALE,
91
+ "number_tropical_nights": TEMPERATURE_COLORSCALE,
92
+ "maximum_summer_temperature":TEMPERATURE_COLORSCALE,
93
+ "number_of_days_with_tx_above_30": TEMPERATURE_COLORSCALE,
94
+ "number_of_days_with_tx_above_35": TEMPERATURE_COLORSCALE,
95
+ "number_of_days_with_dry_ground": TEMPERATURE_COLORSCALE
96
+ }
97
+
98
+ DRIAS_UI_TEXT = """
99
+ Hi, I'm **Talk to Drias**, designed to answer your questions using [**DRIAS - TRACC 2023**](https://www.drias-climat.fr/accompagnement/sections/401) data.
100
+ I'll answer by displaying a list of SQL queries, graphs and data most relevant to your question.
101
+
102
+ You can ask me anything about these climate indicators: **temperature**, **precipitation** or **drought**.
103
+ You can specify **location** and/or **year**.
104
+ You can choose from a list of climate models. By default, we take the **average of each model**.
105
+
106
+ For example, you can ask:
107
+ - What will the temperature be like in Paris?
108
+ - What will be the total rainfall in France in 2030?
109
+ - How frequent will extreme events be in Lyon?
110
+
111
+ **Example of indicators in the data**:
112
+ - Mean temperature (annual, winter, summer)
113
+ - Total precipitation (annual, winter, summer)
114
+ - Number of days with remarkable precipitations, with dry ground, with temperature above 30°C
115
+
116
+ ⚠️ **Limitations**:
117
+ - You can't ask anything that isn't related to **DRIAS - TRACC 2023** data.
118
+ - You can only ask about **locations in France**.
119
+ - If you specify a year, there may be **no data for that year for some models**.
120
+ - You **cannot compare two models**.
121
+
122
+ 🛈 **Information**
123
+ Please note that we **log your questions for meta-analysis purposes**, so avoid sharing any sensitive or personal information.
124
+ """
climateqa/engine/talk_to_data/drias/plot_informations.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from climateqa.engine.talk_to_data.drias.config import DRIAS_INDICATOR_TO_UNIT
2
+
3
+ def indicator_evolution_informations(
4
+ indicator: str,
5
+ params: dict[str, str]
6
+ ) -> str:
7
+ unit = DRIAS_INDICATOR_TO_UNIT[indicator]
8
+ if "location" not in params:
9
+ raise ValueError('"location" must be provided in params')
10
+ location = params["location"]
11
+ return f"""
12
+ This plot shows how the climate indicator **{indicator}** evolves over time in **{location}**.
13
+
14
+ It combines both historical observations and future projections according to the climate scenario RCP8.5.
15
+
16
+ The x-axis represents the years, and the y-axis shows the value of the indicator ({unit}).
17
+
18
+ A 10-year rolling average curve is displayed to give a better idea of the overall trend.
19
+
20
+ **Data source:**
21
+ - The data come from the DRIAS TRACC data. The data were initially extracted from [the DRIAS website](https://www.drias-climat.fr/drias_prod/accueil/okapiWebDrias/index.jsp?iddrias=climat) and then preprocessed to a tabular format and uploaded as parquet in this [Hugging Face dataset](https://huggingface.co/datasets/timeki/drias_db).
22
+ - For each year and climate model, the value of {indicator} in {location} is collected, to build the time series.
23
+ - The coordinates used for {location} correspond to the closest available point in the DRIAS database, which uses a regular grid with a spatial resolution of 8 km.
24
+ - The indicator values shown are those for the selected climate model.
25
+ - If ALL climate model is selected, the average value of the indicator between all the climate models is used.
26
+ """
27
+
28
+ def indicator_number_of_days_per_year_informations(
29
+ indicator: str,
30
+ params: dict[str, str]
31
+ ) -> str:
32
+ unit = DRIAS_INDICATOR_TO_UNIT[indicator]
33
+ if "location" not in params:
34
+ raise ValueError('"location" must be provided in params')
35
+ location = params["location"]
36
+ return f"""
37
+ This plot displays a bar chart showing the yearly frequency of the climate indicator **{indicator}** in **{location}**.
38
+
39
+ The x-axis represents the years, and the y-axis shows the frequency of {indicator} ({unit}) per year.
40
+
41
+ **Data source:**
42
+ - The data come from the DRIAS TRACC data. The data were initially extracted from [the DRIAS website](https://www.drias-climat.fr/drias_prod/accueil/okapiWebDrias/index.jsp?iddrias=climat) and then preprocessed to a tabular format and uploaded as parquet in this [Hugging Face dataset](https://huggingface.co/datasets/timeki/drias_db).
43
+ - For each year and climate model, the value of {indicator} in {location} is collected, to build the time series.
44
+ - The coordinates used for {location} correspond to the closest available point in the DRIAS database, which uses a regular grid with a spatial resolution of 8 km.
45
+ - The indicator values shown are those for the selected climate model.
46
+ - If ALL climate model is selected, the average value of the indicator between all the climate models is used.
47
+ """
48
+
49
+ def distribution_of_indicator_for_given_year_informations(
50
+ indicator: str,
51
+ params: dict[str, str]
52
+ ) -> str:
53
+ unit = DRIAS_INDICATOR_TO_UNIT[indicator]
54
+ year = params["year"]
55
+ if year is None:
56
+ year = 2030
57
+ return f"""
58
+ This plot shows a histogram of the distribution of the climate indicator **{indicator}** across all locations for the year **{year}**.
59
+
60
+ It allows you to visualize how the values of {indicator} ({unit}) are spread for a given year.
61
+
62
+ **Data source:**
63
+ - The data come from the DRIAS TRACC data. The data were initially extracted from [the DRIAS website](https://www.drias-climat.fr/drias_prod/accueil/okapiWebDrias/index.jsp?iddrias=climat) and then preprocessed to a tabular format and uploaded as parquet in this [Hugging Face dataset](https://huggingface.co/datasets/timeki/drias_db).
64
+ - For each grid point in the dataset and climate model, the value of {indicator} for the year {year} is extracted.
65
+ - The indicator values shown are those for the selected climate model.
66
+ - If ALL climate model is selected, the average value of the indicator between all the climate models is used.
67
+ """
68
+
69
+ def map_of_france_of_indicator_for_given_year_informations(
70
+ indicator: str,
71
+ params: dict[str, str]
72
+ ) -> str:
73
+ unit = DRIAS_INDICATOR_TO_UNIT[indicator]
74
+ year = params["year"]
75
+ if year is None:
76
+ year = 2030
77
+ return f"""
78
+ This plot displays a choropleth map showing the spatial distribution of **{indicator}** across all regions of France for the year **{year}**.
79
+
80
+ Each region is colored according to the value of the indicator ({unit}), allowing you to visually compare how {indicator} varies geographically within France for the selected year and climate model.
81
+
82
+ **Data source:**
83
+ - The data come from the DRIAS TRACC data. The data were initially extracted from [the DRIAS website](https://www.drias-climat.fr/drias_prod/accueil/okapiWebDrias/index.jsp?iddrias=climat) and then preprocessed to a tabular format and uploaded as parquet in this [Hugging Face dataset](https://huggingface.co/datasets/timeki/drias_db).
84
+ - For each region of France, the value of {indicator} in {year} and for the selected climate model is extracted and mapped to its geographic coordinates.
85
+ - The regions correspond to 8 km squares centered on the grid points of the DRIAS dataset.
86
+ - The indicator values shown are those for the selected climate model.
87
+ - If ALL climate model is selected, the average value of the indicator between all the climate models is used.
88
+ """
climateqa/engine/talk_to_data/{plot.py → drias/plots.py} RENAMED
@@ -1,38 +1,39 @@
1
- from typing import Callable, TypedDict
2
- from matplotlib.figure import figaspect
 
 
3
  import pandas as pd
4
  from plotly.graph_objects import Figure
5
  import plotly.graph_objects as go
6
- import plotly.express as px
7
-
8
- from climateqa.engine.talk_to_data.sql_query import (
9
  indicator_for_given_year_query,
10
  indicator_per_year_at_location_query,
11
  )
12
- from climateqa.engine.talk_to_data.config import INDICATOR_TO_UNIT
13
-
14
-
15
-
16
-
17
- class Plot(TypedDict):
18
- """Represents a plot configuration in the DRIAS system.
19
-
20
- This class defines the structure for configuring different types of plots
21
- that can be generated from climate data.
22
-
23
- Attributes:
24
- name (str): The name of the plot type
25
- description (str): A description of what the plot shows
26
- params (list[str]): List of required parameters for the plot
27
- plot_function (Callable[..., Callable[..., Figure]]): Function to generate the plot
28
- sql_query (Callable[..., str]): Function to generate the SQL query for the plot
29
- """
30
- name: str
31
- description: str
32
- params: list[str]
33
- plot_function: Callable[..., Callable[..., Figure]]
34
- sql_query: Callable[..., str]
35
-
36
 
37
  def plot_indicator_evolution_at_location(params: dict) -> Callable[..., Figure]:
38
  """Generates a function to plot indicator evolution over time at a location.
@@ -61,7 +62,7 @@ def plot_indicator_evolution_at_location(params: dict) -> Callable[..., Figure]:
61
  indicator = params["indicator_column"]
62
  location = params["location"]
63
  indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
64
- unit = INDICATOR_TO_UNIT.get(indicator, "")
65
 
66
  def plot_data(df: pd.DataFrame) -> Figure:
67
  """Generates the actual plot from the data.
@@ -145,10 +146,11 @@ def plot_indicator_evolution_at_location(params: dict) -> Callable[..., Figure]:
145
  hovertemplate=f"{indicator_label}: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
146
  )
147
  fig.update_layout(
148
- title=f"Plot of {indicator_label} in {location} ({model_label})",
149
  xaxis_title="Year",
150
  yaxis_title=f"{indicator_label} ({unit})",
151
  template="plotly_white",
 
152
  )
153
  return fig
154
 
@@ -161,6 +163,8 @@ indicator_evolution_at_location: Plot = {
161
  "params": ["indicator_column", "location", "model"],
162
  "plot_function": plot_indicator_evolution_at_location,
163
  "sql_query": indicator_per_year_at_location_query,
 
 
164
  }
165
 
166
 
@@ -184,7 +188,7 @@ def plot_indicator_number_of_days_per_year_at_location(
184
  indicator = params["indicator_column"]
185
  location = params["location"]
186
  indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
187
- unit = INDICATOR_TO_UNIT.get(indicator, "")
188
 
189
  def plot_data(df: pd.DataFrame) -> Figure:
190
  """Generate the figure thanks to the dataframe
@@ -229,6 +233,7 @@ def plot_indicator_number_of_days_per_year_at_location(
229
  yaxis_title=f"{indicator_label} ({unit})",
230
  yaxis=dict(range=[0, max(indicators)]),
231
  bargap=0.5,
 
232
  template="plotly_white",
233
  )
234
 
@@ -243,6 +248,8 @@ indicator_number_of_days_per_year_at_location: Plot = {
243
  "params": ["indicator_column", "location", "model"],
244
  "plot_function": plot_indicator_number_of_days_per_year_at_location,
245
  "sql_query": indicator_per_year_at_location_query,
 
 
246
  }
247
 
248
 
@@ -265,8 +272,10 @@ def plot_distribution_of_indicator_for_given_year(
265
  """
266
  indicator = params["indicator_column"]
267
  year = params["year"]
 
 
268
  indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
269
- unit = INDICATOR_TO_UNIT.get(indicator, "")
270
 
271
  def plot_data(df: pd.DataFrame) -> Figure:
272
  """Generate the figure thanks to the dataframe
@@ -311,6 +320,7 @@ def plot_distribution_of_indicator_for_given_year(
311
  yaxis_title="Frequency (%)",
312
  plot_bgcolor="rgba(0, 0, 0, 0)",
313
  showlegend=False,
 
314
  )
315
 
316
  return fig
@@ -324,6 +334,8 @@ distribution_of_indicator_for_given_year: Plot = {
324
  "params": ["indicator_column", "model", "year"],
325
  "plot_function": plot_distribution_of_indicator_for_given_year,
326
  "sql_query": indicator_for_given_year_query,
 
 
327
  }
328
 
329
 
@@ -346,8 +358,10 @@ def plot_map_of_france_of_indicator_for_given_year(
346
  """
347
  indicator = params["indicator_column"]
348
  year = params["year"]
 
 
349
  indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
350
- unit = INDICATOR_TO_UNIT.get(indicator, "")
351
 
352
  def plot_data(df: pd.DataFrame) -> Figure:
353
  fig = go.Figure()
@@ -371,27 +385,28 @@ def plot_map_of_france_of_indicator_for_given_year(
371
  model_label = f"Model : {df['model'].unique()[0]}"
372
 
373
 
374
- fig.add_trace(
375
- go.Scattermapbox(
376
- lat=latitudes,
377
- lon=longitudes,
378
- mode="markers",
379
- marker=dict(
380
- size=10,
381
- color=indicators, # Color mapped to values
382
- colorscale="Turbo", # Color scale (can be 'Plasma', 'Jet', etc.)
383
- cmin=min(indicators), # Minimum color range
384
- cmax=max(indicators), # Maximum color range
385
- showscale=True, # Show colorbar
386
- ),
387
- text=[f"{indicator_label}: {value:.2f} {unit}" for value in indicators], # Add hover text showing the indicator value
388
- hoverinfo="text" # Only show the custom text on hover
389
- )
390
- )
391
 
392
  fig.update_layout(
393
  mapbox_style="open-street-map", # Use OpenStreetMap
394
- mapbox_zoom=3,
 
395
  mapbox_center={"lat": 46.6, "lon": 2.0},
396
  coloraxis_colorbar=dict(title=f"{indicator_label} ({unit})"), # Add legend
397
  title=f"{indicator_label} in {year} in France ({model_label}) " # Title
@@ -403,16 +418,17 @@ def plot_map_of_france_of_indicator_for_given_year(
403
 
404
  map_of_france_of_indicator_for_given_year: Plot = {
405
  "name": "Map of France of an indicator for a given year",
406
- "description": "Heatmap on the map of France of the values of an in indicator for a given year",
407
  "params": ["indicator_column", "year", "model"],
408
  "plot_function": plot_map_of_france_of_indicator_for_given_year,
409
  "sql_query": indicator_for_given_year_query,
 
 
410
  }
411
 
412
-
413
- PLOTS = [
414
  indicator_evolution_at_location,
415
  indicator_number_of_days_per_year_at_location,
416
  distribution_of_indicator_for_given_year,
417
  map_of_france_of_indicator_for_given_year,
418
- ]
 
1
+ import os
2
+ import geojson
3
+ from math import cos, radians
4
+ from typing import Callable
5
  import pandas as pd
6
  from plotly.graph_objects import Figure
7
  import plotly.graph_objects as go
8
+ from climateqa.engine.talk_to_data.drias.plot_informations import distribution_of_indicator_for_given_year_informations, indicator_evolution_informations, indicator_number_of_days_per_year_informations, map_of_france_of_indicator_for_given_year_informations
9
+ from climateqa.engine.talk_to_data.objects.plot import Plot
10
+ from climateqa.engine.talk_to_data.drias.queries import (
11
  indicator_for_given_year_query,
12
  indicator_per_year_at_location_query,
13
  )
14
+ from climateqa.engine.talk_to_data.drias.config import DRIAS_INDICATOR_TO_COLORSCALE, DRIAS_INDICATOR_TO_UNIT
15
+
16
+ def generate_geojson_polygons(latitudes: list[float], longitudes: list[float], indicators: list[float]) -> geojson.FeatureCollection:
17
+ side_km = 8
18
+ delta_lat = side_km / 111
19
+ features = []
20
+ for idx, (lat, lon, val) in enumerate(zip(latitudes, longitudes, indicators)):
21
+ delta_lon = side_km / (111 * cos(radians(lat)))
22
+ half_lat = delta_lat / 2
23
+ half_lon = delta_lon / 2
24
+ features.append(geojson.Feature(
25
+ geometry=geojson.Polygon([[
26
+ [lon - half_lon, lat - half_lat],
27
+ [lon + half_lon, lat - half_lat],
28
+ [lon + half_lon, lat + half_lat],
29
+ [lon - half_lon, lat + half_lat],
30
+ [lon - half_lon, lat - half_lat]
31
+ ]]),
32
+ properties={"value": val},
33
+ id=str(idx)
34
+ ))
35
+
36
+ return geojson.FeatureCollection(features)
 
37
 
38
  def plot_indicator_evolution_at_location(params: dict) -> Callable[..., Figure]:
39
  """Generates a function to plot indicator evolution over time at a location.
 
62
  indicator = params["indicator_column"]
63
  location = params["location"]
64
  indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
65
+ unit = DRIAS_INDICATOR_TO_UNIT.get(indicator, "")
66
 
67
  def plot_data(df: pd.DataFrame) -> Figure:
68
  """Generates the actual plot from the data.
 
146
  hovertemplate=f"{indicator_label}: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
147
  )
148
  fig.update_layout(
149
+ title=f"Evolution of {indicator_label} in {location} ({model_label})",
150
  xaxis_title="Year",
151
  yaxis_title=f"{indicator_label} ({unit})",
152
  template="plotly_white",
153
+ height=900,
154
  )
155
  return fig
156
 
 
163
  "params": ["indicator_column", "location", "model"],
164
  "plot_function": plot_indicator_evolution_at_location,
165
  "sql_query": indicator_per_year_at_location_query,
166
+ "plot_information": indicator_evolution_informations,
167
+ 'short_name': 'Evolution'
168
  }
169
 
170
 
 
188
  indicator = params["indicator_column"]
189
  location = params["location"]
190
  indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
191
+ unit = DRIAS_INDICATOR_TO_UNIT.get(indicator, "")
192
 
193
  def plot_data(df: pd.DataFrame) -> Figure:
194
  """Generate the figure thanks to the dataframe
 
233
  yaxis_title=f"{indicator_label} ({unit})",
234
  yaxis=dict(range=[0, max(indicators)]),
235
  bargap=0.5,
236
+ height=900,
237
  template="plotly_white",
238
  )
239
 
 
248
  "params": ["indicator_column", "location", "model"],
249
  "plot_function": plot_indicator_number_of_days_per_year_at_location,
250
  "sql_query": indicator_per_year_at_location_query,
251
+ "plot_information": indicator_number_of_days_per_year_informations,
252
+ "short_name": "Yearly Frequency",
253
  }
254
 
255
 
 
272
  """
273
  indicator = params["indicator_column"]
274
  year = params["year"]
275
+ if year is None:
276
+ year = 2030
277
  indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
278
+ unit = DRIAS_INDICATOR_TO_UNIT.get(indicator, "")
279
 
280
  def plot_data(df: pd.DataFrame) -> Figure:
281
  """Generate the figure thanks to the dataframe
 
320
  yaxis_title="Frequency (%)",
321
  plot_bgcolor="rgba(0, 0, 0, 0)",
322
  showlegend=False,
323
+ height=900,
324
  )
325
 
326
  return fig
 
334
  "params": ["indicator_column", "model", "year"],
335
  "plot_function": plot_distribution_of_indicator_for_given_year,
336
  "sql_query": indicator_for_given_year_query,
337
+ "plot_information": distribution_of_indicator_for_given_year_informations,
338
+ 'short_name': 'Distribution'
339
  }
340
 
341
 
 
358
  """
359
  indicator = params["indicator_column"]
360
  year = params["year"]
361
+ if year is None:
362
+ year = 2030
363
  indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
364
+ unit = DRIAS_INDICATOR_TO_UNIT.get(indicator, "")
365
 
366
  def plot_data(df: pd.DataFrame) -> Figure:
367
  fig = go.Figure()
 
385
  model_label = f"Model : {df['model'].unique()[0]}"
386
 
387
 
388
+
389
+ geojson_data = generate_geojson_polygons(latitudes, longitudes, indicators)
390
+
391
+ fig = go.Figure(go.Choroplethmapbox(
392
+ geojson=geojson_data,
393
+ locations=[str(i) for i in range(len(indicators))],
394
+ featureidkey="id",
395
+ z=indicators,
396
+ colorscale=DRIAS_INDICATOR_TO_COLORSCALE[indicator],
397
+ zmin=min(indicators),
398
+ zmax=max(indicators),
399
+ marker_opacity=0.7,
400
+ marker_line_width=0,
401
+ colorbar_title=f"{indicator_label} ({unit})",
402
+ text=[f"{indicator_label}: {value:.2f} {unit}" for value in indicators], # Add hover text showing the indicator value
403
+ hoverinfo="text"
404
+ ))
405
 
406
  fig.update_layout(
407
  mapbox_style="open-street-map", # Use OpenStreetMap
408
+ mapbox_zoom=5,
409
+ height=900,
410
  mapbox_center={"lat": 46.6, "lon": 2.0},
411
  coloraxis_colorbar=dict(title=f"{indicator_label} ({unit})"), # Add legend
412
  title=f"{indicator_label} in {year} in France ({model_label}) " # Title
 
418
 
419
  map_of_france_of_indicator_for_given_year: Plot = {
420
  "name": "Map of France of an indicator for a given year",
421
+ "description": "Heatmap on the map of France of the values of an indicator for a given year",
422
  "params": ["indicator_column", "year", "model"],
423
  "plot_function": plot_map_of_france_of_indicator_for_given_year,
424
  "sql_query": indicator_for_given_year_query,
425
+ "plot_information": map_of_france_of_indicator_for_given_year_informations,
426
+ 'short_name': 'Map of France'
427
  }
428
 
429
+ DRIAS_PLOTS = [
 
430
  indicator_evolution_at_location,
431
  indicator_number_of_days_per_year_at_location,
432
  distribution_of_indicator_for_given_year,
433
  map_of_france_of_indicator_for_given_year,
434
+ ]
climateqa/engine/talk_to_data/{sql_query.py → drias/queries.py} RENAMED
@@ -1,37 +1,5 @@
1
- import asyncio
2
- from concurrent.futures import ThreadPoolExecutor
3
  from typing import TypedDict
4
- import duckdb
5
- import pandas as pd
6
-
7
- async def execute_sql_query(sql_query: str) -> pd.DataFrame:
8
- """Executes a SQL query on the DRIAS database and returns the results.
9
-
10
- This function connects to the DuckDB database containing DRIAS climate data
11
- and executes the provided SQL query. It handles the database connection and
12
- returns the results as a pandas DataFrame.
13
-
14
- Args:
15
- sql_query (str): The SQL query to execute
16
-
17
- Returns:
18
- pd.DataFrame: A DataFrame containing the query results
19
-
20
- Raises:
21
- duckdb.Error: If there is an error executing the SQL query
22
- """
23
- def _execute_query():
24
- # Execute the query
25
- con = duckdb.connect()
26
- results = con.sql(sql_query).fetchdf()
27
- # return fetched data
28
- return results
29
-
30
- # Run the query in a thread pool to avoid blocking
31
- loop = asyncio.get_event_loop()
32
- with ThreadPoolExecutor() as executor:
33
- return await loop.run_in_executor(executor, _execute_query)
34
-
35
 
36
  class IndicatorPerYearAtLocationQueryParams(TypedDict, total=False):
37
  """Parameters for querying an indicator's values over time at a location.
@@ -50,7 +18,6 @@ class IndicatorPerYearAtLocationQueryParams(TypedDict, total=False):
50
  longitude: str
51
  model: str
52
 
53
-
54
  def indicator_per_year_at_location_query(
55
  table: str, params: IndicatorPerYearAtLocationQueryParams
56
  ) -> str:
@@ -70,7 +37,7 @@ def indicator_per_year_at_location_query(
70
  if indicator_column is None or latitude is None or longitude is None: # If one parameter is missing, returns an empty query
71
  return ""
72
 
73
- table = f"'hf://datasets/timeki/drias_db/{table.lower()}.parquet'"
74
 
75
  sql_query = f"SELECT year, {indicator_column}, model\nFROM {table}\nWHERE latitude = {latitude} \nAnd longitude = {longitude} \nOrder by Year"
76
 
@@ -105,10 +72,12 @@ def indicator_for_given_year_query(
105
  """
106
  indicator_column = params.get("indicator_column")
107
  year = params.get('year')
 
 
108
  if year is None or indicator_column is None: # If one parameter is missing, returns an empty query
109
  return ""
110
 
111
- table = f"'hf://datasets/timeki/drias_db/{table.lower()}.parquet'"
112
 
113
  sql_query = f"Select {indicator_column}, latitude, longitude, model\nFrom {table}\nWhere year = {year}"
114
  return sql_query
 
 
 
1
  from typing import TypedDict
2
+ from climateqa.engine.talk_to_data.config import DRIAS_DATASET_URL
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  class IndicatorPerYearAtLocationQueryParams(TypedDict, total=False):
5
  """Parameters for querying an indicator's values over time at a location.
 
18
  longitude: str
19
  model: str
20
 
 
21
  def indicator_per_year_at_location_query(
22
  table: str, params: IndicatorPerYearAtLocationQueryParams
23
  ) -> str:
 
37
  if indicator_column is None or latitude is None or longitude is None: # If one parameter is missing, returns an empty query
38
  return ""
39
 
40
+ table = f"'{DRIAS_DATASET_URL}/{table.lower()}.parquet'"
41
 
42
  sql_query = f"SELECT year, {indicator_column}, model\nFROM {table}\nWHERE latitude = {latitude} \nAnd longitude = {longitude} \nOrder by Year"
43
 
 
72
  """
73
  indicator_column = params.get("indicator_column")
74
  year = params.get('year')
75
+ if year is None:
76
+ year = 2050
77
  if year is None or indicator_column is None: # If one parameter is missing, returns an empty query
78
  return ""
79
 
80
+ table = f"'{DRIAS_DATASET_URL}/{table.lower()}.parquet'"
81
 
82
  sql_query = f"Select {indicator_column}, latitude, longitude, model\nFrom {table}\nWhere year = {year}"
83
  return sql_query
climateqa/engine/talk_to_data/{utils.py → input_processing.py} RENAMED
@@ -1,15 +1,17 @@
1
- import re
2
- from typing import Annotated, TypedDict
3
- import duckdb
4
- from geopy.geocoders import Nominatim
5
  import ast
6
- from climateqa.engine.llm import get_llm
7
- from climateqa.engine.talk_to_data.config import DRIAS_TABLES
8
- from climateqa.engine.talk_to_data.plot import PLOTS, Plot
9
  from langchain_core.prompts import ChatPromptTemplate
10
-
11
-
12
- async def detect_location_with_openai(sentence):
 
 
 
 
 
 
 
 
13
  """
14
  Detects locations in a sentence using OpenAI's API via LangChain.
15
  """
@@ -29,63 +31,7 @@ async def detect_location_with_openai(sentence):
29
  else:
30
  return ""
31
 
32
- class ArrayOutput(TypedDict):
33
- """Represents the output of a function that returns an array.
34
-
35
- This class is used to type-hint functions that return arrays,
36
- ensuring consistent return types across the codebase.
37
-
38
- Attributes:
39
- array (str): A syntactically valid Python array string
40
- """
41
- array: Annotated[str, "Syntactically valid python array."]
42
-
43
- async def detect_year_with_openai(sentence: str) -> str:
44
- """
45
- Detects years in a sentence using OpenAI's API via LangChain.
46
- """
47
- llm = get_llm()
48
-
49
- prompt = """
50
- Extract all years mentioned in the following sentence.
51
- Return the result as a Python list. If no year are mentioned, return an empty list.
52
-
53
- Sentence: "{sentence}"
54
- """
55
-
56
- prompt = ChatPromptTemplate.from_template(prompt)
57
- structured_llm = llm.with_structured_output(ArrayOutput)
58
- chain = prompt | structured_llm
59
- response: ArrayOutput = await chain.ainvoke({"sentence": sentence})
60
- years_list = eval(response['array'])
61
- if len(years_list) > 0:
62
- return years_list[0]
63
- else:
64
- return ""
65
-
66
-
67
- def detectTable(sql_query: str) -> list[str]:
68
- """Extracts table names from a SQL query.
69
-
70
- This function uses regular expressions to find all table names
71
- referenced in a SQL query's FROM clause.
72
-
73
- Args:
74
- sql_query (str): The SQL query to analyze
75
-
76
- Returns:
77
- list[str]: A list of table names found in the query
78
-
79
- Example:
80
- >>> detectTable("SELECT * FROM temperature_data WHERE year > 2000")
81
- ['temperature_data']
82
- """
83
- pattern = r'(?i)\bFROM\s+((?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+)(?:\.(?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+))*)'
84
- matches = re.findall(pattern, sql_query)
85
- return matches
86
-
87
-
88
- def loc2coords(location: str) -> tuple[float, float]:
89
  """Converts a location name to geographic coordinates.
90
 
91
  This function uses the Nominatim geocoding service to convert
@@ -104,49 +50,77 @@ def loc2coords(location: str) -> tuple[float, float]:
104
  coords = geolocator.geocode(location)
105
  return (coords.latitude, coords.longitude)
106
 
107
-
108
- def coords2loc(coords: tuple[float, float]) -> str:
109
- """Converts geographic coordinates to a location name.
110
 
111
  This function uses the Nominatim reverse geocoding service to convert
112
- latitude and longitude coordinates to a human-readable location name.
113
 
114
  Args:
115
  coords (tuple[float, float]): A tuple containing (latitude, longitude)
116
 
117
  Returns:
118
- str: The address of the location, or "Unknown Location" if not found
119
 
120
- Example:
121
- >>> coords2loc((48.8566, 2.3522))
122
- 'Paris, France'
123
  """
124
- geolocator = Nominatim(user_agent="coords_to_city")
125
- try:
126
- location = geolocator.reverse(coords)
127
- return location.address
128
- except Exception as e:
129
- print(f"Error: {e}")
130
- return "Unknown Location"
131
-
132
 
133
- def nearestNeighbourSQL(location: tuple, table: str) -> tuple[str, str]:
134
  long = round(location[1], 3)
135
  lat = round(location[0], 3)
 
136
 
137
- table = f"'hf://datasets/timeki/drias_db/{table.lower()}.parquet'"
138
-
139
- results = duckdb.sql(
140
- f"SELECT latitude, longitude FROM {table} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}"
141
- ).fetchdf()
 
 
 
 
 
 
142
 
143
  if len(results) == 0:
144
- return "", ""
145
- # cursor.execute(f"SELECT latitude, longitude FROM {table} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}")
146
- return results['latitude'].iloc[0], results['longitude'].iloc[0]
 
 
 
 
147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
- async def detect_relevant_tables(user_question: str, plot: Plot, llm) -> list[str]:
150
  """Identifies relevant tables for a plot based on user input.
151
 
152
  This function uses an LLM to analyze the user's question and the plot
@@ -170,7 +144,6 @@ async def detect_relevant_tables(user_question: str, plot: Plot, llm) -> list[st
170
  ['mean_annual_temperature', 'mean_summer_temperature']
171
  """
172
  # Get all table names
173
- table_names_list = DRIAS_TABLES
174
 
175
  prompt = (
176
  f"You are helping to build a plot following this description : {plot['description']}."
@@ -187,95 +160,98 @@ async def detect_relevant_tables(user_question: str, plot: Plot, llm) -> list[st
187
  )
188
  return table_names
189
 
190
-
191
- def replace_coordonates(coords, query, coords_tables):
192
- n = query.count(str(coords[0]))
193
-
194
- for i in range(n):
195
- query = query.replace(str(coords[0]), str(coords_tables[i][0]), 1)
196
- query = query.replace(str(coords[1]), str(coords_tables[i][1]), 1)
197
- return query
198
-
199
-
200
- async def detect_relevant_plots(user_question: str, llm):
201
  plots_description = ""
202
- for plot in PLOTS:
203
  plots_description += "Name: " + plot["name"]
204
  plots_description += " - Description: " + plot["description"] + "\n"
205
 
206
  prompt = (
207
- f"You are helping to answer a quesiton with insightful visualizations."
208
- f"You are given an user question and a list of plots with their name and description."
209
- f"Based on the descriptions of the plots, which plot is appropriate to answer to this question."
210
- f"Write the most relevant tables to use. Answer only a python list of plot name."
 
 
211
  f"### Descriptions of the plots : {plots_description}"
212
- f"### User question : {user_question}"
213
- f"### Name of the plot : "
214
  )
215
- # prompt = (
216
- # f"You are helping to answer a question with insightful visualizations. "
217
- # f"Given a list of plots with their name and description: "
218
- # f"{plots_description} "
219
- # f"The user question is: {user_question}. "
220
- # f"Choose the most relevant plots to answer the question. "
221
- # f"The answer must be a Python list with the names of the relevant plots, and nothing else. "
222
- # f"Ensure the response is in the exact format: ['PlotName1', 'PlotName2']."
223
- # )
224
 
225
  plot_names = ast.literal_eval(
226
  (await llm.ainvoke(prompt)).content.strip("```python\n").strip()
227
  )
228
  return plot_names
229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
- # Next Version
232
- # class QueryOutput(TypedDict):
233
- # """Generated SQL query."""
234
-
235
- # query: Annotated[str, ..., "Syntactically valid SQL query."]
236
-
237
-
238
- # class PlotlyCodeOutput(TypedDict):
239
- # """Generated Plotly code"""
240
-
241
- # code: Annotated[str, ..., "Synatically valid Plotly python code."]
242
- # def write_sql_query(user_input: str, db: SQLDatabase, relevant_tables: list[str], llm):
243
- # """Generate SQL query to fetch information."""
244
- # prompt_params = {
245
- # "dialect": db.dialect,
246
- # "table_info": db.get_table_info(),
247
- # "input": user_input,
248
- # "relevant_tables": relevant_tables,
249
- # "model": "ALADIN63_CNRM-CM5",
250
- # }
251
-
252
- # prompt = ChatPromptTemplate.from_template(query_prompt_template)
253
- # structured_llm = llm.with_structured_output(QueryOutput)
254
- # chain = prompt | structured_llm
255
- # result = chain.invoke(prompt_params)
256
-
257
- # return result["query"]
258
-
259
-
260
- # def fetch_data_from_sql_query(db: str, sql_query: str):
261
- # conn = sqlite3.connect(db)
262
- # cursor = conn.cursor()
263
- # cursor.execute(sql_query)
264
- # column_names = [desc[0] for desc in cursor.description]
265
- # values = cursor.fetchall()
266
- # return {"column_names": column_names, "data": values}
267
-
268
 
269
- # def generate_chart_code(user_input: str, sql_query: list[str], llm):
270
- # """ "Generate plotly python code for the chart based on the sql query and the user question"""
 
 
271
 
272
- # class PlotlyCodeOutput(TypedDict):
273
- # """Generated Plotly code"""
274
 
275
- # code: Annotated[str, ..., "Synatically valid Plotly python code."]
 
 
 
276
 
277
- # prompt = ChatPromptTemplate.from_template(plot_prompt_template)
278
- # structured_llm = llm.with_structured_output(PlotlyCodeOutput)
279
- # chain = prompt | structured_llm
280
- # result = chain.invoke({"input": user_input, "sql_query": sql_query})
281
- # return result["code"]
 
 
 
 
 
 
1
+ from typing import Any, Literal, Optional, cast
 
 
 
2
  import ast
 
 
 
3
  from langchain_core.prompts import ChatPromptTemplate
4
+ from geopy.geocoders import Nominatim
5
+ from climateqa.engine.llm import get_llm
6
+ import duckdb
7
+ import os
8
+ from climateqa.engine.talk_to_data.config import DRIAS_MEAN_ANNUAL_TEMPERATURE_PATH, IPCC_COORDINATES_PATH
9
+ from climateqa.engine.talk_to_data.objects.llm_outputs import ArrayOutput
10
+ from climateqa.engine.talk_to_data.objects.location import Location
11
+ from climateqa.engine.talk_to_data.objects.plot import Plot
12
+ from climateqa.engine.talk_to_data.objects.states import State
13
+
14
+ async def detect_location_with_openai(sentence: str) -> str:
15
  """
16
  Detects locations in a sentence using OpenAI's API via LangChain.
17
  """
 
31
  else:
32
  return ""
33
 
34
+ def loc_to_coords(location: str) -> tuple[float, float]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  """Converts a location name to geographic coordinates.
36
 
37
  This function uses the Nominatim geocoding service to convert
 
50
  coords = geolocator.geocode(location)
51
  return (coords.latitude, coords.longitude)
52
 
53
+ def coords_to_country(coords: tuple[float, float]) -> tuple[str,str]:
54
+ """Converts geographic coordinates to a country name.
 
55
 
56
  This function uses the Nominatim reverse geocoding service to convert
57
+ latitude and longitude coordinates to a country name.
58
 
59
  Args:
60
  coords (tuple[float, float]): A tuple containing (latitude, longitude)
61
 
62
  Returns:
63
+ tuple[str,str]: A tuple containg (country_code, country_name, admin1)
64
 
65
+ Raises:
66
+ AttributeError: If the coordinates cannot be found
 
67
  """
68
+ geolocator = Nominatim(user_agent="latlong_to_country")
69
+ location = geolocator.reverse(coords)
70
+ address = location.raw['address']
71
+ return address['country_code'].upper(), address['country']
 
 
 
 
72
 
73
+ def nearest_neighbour_sql(location: tuple, mode: Literal['DRIAS', 'IPCC']) -> tuple[str, str, Optional[str]]:
74
  long = round(location[1], 3)
75
  lat = round(location[0], 3)
76
+ conn = duckdb.connect()
77
 
78
+ if mode == 'DRIAS':
79
+ table_path = f"'{DRIAS_MEAN_ANNUAL_TEMPERATURE_PATH}'"
80
+ results = conn.sql(
81
+ f"SELECT latitude, longitude FROM {table_path} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}"
82
+ ).fetchdf()
83
+ else:
84
+ table_path = f"'{IPCC_COORDINATES_PATH}'"
85
+ results = conn.sql(
86
+ f"SELECT latitude, longitude, admin1 FROM {table_path} WHERE latitude BETWEEN {lat - 0.5} AND {lat + 0.5} AND longitude BETWEEN {long - 0.5} AND {long + 0.5}"
87
+ ).fetchdf()
88
+
89
 
90
  if len(results) == 0:
91
+ return "", "", ""
92
+
93
+ if 'admin1' in results.columns:
94
+ admin1 = results['admin1'].iloc[0]
95
+ else:
96
+ admin1 = None
97
+ return results['latitude'].iloc[0], results['longitude'].iloc[0], admin1
98
 
99
+ async def detect_year_with_openai(sentence: str) -> str:
100
+ """
101
+ Detects years in a sentence using OpenAI's API via LangChain.
102
+ """
103
+ llm = get_llm()
104
+
105
+ prompt = """
106
+ Extract all years mentioned in the following sentence.
107
+ Return the result as a Python list. If no year are mentioned, return an empty list.
108
+
109
+ Sentence: "{sentence}"
110
+ """
111
+
112
+ prompt = ChatPromptTemplate.from_template(prompt)
113
+ structured_llm = llm.with_structured_output(ArrayOutput)
114
+ chain = prompt | structured_llm
115
+ response: ArrayOutput = await chain.ainvoke({"sentence": sentence})
116
+ years_list = ast.literal_eval(response['array'])
117
+ if len(years_list) > 0:
118
+ return years_list[0]
119
+ else:
120
+ return ""
121
+
122
 
123
+ async def detect_relevant_tables(user_question: str, plot: Plot, llm, table_names_list: list[str]) -> list[str]:
124
  """Identifies relevant tables for a plot based on user input.
125
 
126
  This function uses an LLM to analyze the user's question and the plot
 
144
  ['mean_annual_temperature', 'mean_summer_temperature']
145
  """
146
  # Get all table names
 
147
 
148
  prompt = (
149
  f"You are helping to build a plot following this description : {plot['description']}."
 
160
  )
161
  return table_names
162
 
163
+ async def detect_relevant_plots(user_question: str, llm, plot_list: list[Plot]) -> list[str]:
 
 
 
 
 
 
 
 
 
 
164
  plots_description = ""
165
+ for plot in plot_list:
166
  plots_description += "Name: " + plot["name"]
167
  plots_description += " - Description: " + plot["description"] + "\n"
168
 
169
  prompt = (
170
+ "You are helping to answer a question with insightful visualizations.\n"
171
+ "You are given a user question and a list of plots with their name and description.\n"
172
+ "Based on the descriptions of the plots, select ALL plots that could provide a useful answer to this question. "
173
+ "Include any plot that could show relevant information, even if their perspectives (such as time series or spatial distribution) are different.\n"
174
+ "For example, for a question like 'What will be the total rainfall in China in 2050?', both a time series plot and a spatial map plot could be relevant.\n"
175
+ "Return only a Python list of plot names sorted from the most relevant one to the less relevant one.\n"
176
  f"### Descriptions of the plots : {plots_description}"
177
+ f"### User question : {user_question}\n"
178
+ f"### Names of the plots : "
179
  )
 
 
 
 
 
 
 
 
 
180
 
181
  plot_names = ast.literal_eval(
182
  (await llm.ainvoke(prompt)).content.strip("```python\n").strip()
183
  )
184
  return plot_names
185
 
186
+ async def find_location(user_input: str, mode: Literal['DRIAS', 'IPCC'] = 'DRIAS') -> Location:
187
+ print(f"---- Find location in user input ----")
188
+ location = await detect_location_with_openai(user_input)
189
+ output: Location = {
190
+ 'location' : location,
191
+ 'longitude' : None,
192
+ 'latitude' : None,
193
+ 'country_code' : None,
194
+ 'country_name' : None,
195
+ 'admin1' : None
196
+ }
197
+
198
+ if location:
199
+ coords = loc_to_coords(location)
200
+ country_code, country_name = coords_to_country(coords)
201
+ neighbour = nearest_neighbour_sql(coords, mode)
202
+ output.update({
203
+ "latitude": neighbour[0],
204
+ "longitude": neighbour[1],
205
+ "country_code": country_code,
206
+ "country_name": country_name,
207
+ "admin1": neighbour[2]
208
+ })
209
+ output = cast(Location, output)
210
+ return output
211
+
212
+ async def find_year(user_input: str) -> str| None:
213
+ """Extracts year information from user input using LLM.
214
+
215
+ This function uses an LLM to identify and extract year information from the
216
+ user's query, which is used to filter data in subsequent queries.
217
+
218
+ Args:
219
+ user_input (str): The user's query text
220
+
221
+ Returns:
222
+ str: The extracted year, or empty string if no year found
223
+ """
224
+ print(f"---- Find year ---")
225
+ year = await detect_year_with_openai(user_input)
226
+ if year == "":
227
+ return None
228
+ return year
229
 
230
+ async def find_relevant_plots(state: State, llm, plots: list[Plot]) -> list[str]:
231
+ print("---- Find relevant plots ----")
232
+ relevant_plots = await detect_relevant_plots(state['user_input'], llm, plots)
233
+ return relevant_plots
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
+ async def find_relevant_tables_per_plot(state: State, plot: Plot, llm, tables: list[str]) -> list[str]:
236
+ print(f"---- Find relevant tables for {plot['name']} ----")
237
+ relevant_tables = await detect_relevant_tables(state['user_input'], plot, llm, tables)
238
+ return relevant_tables
239
 
240
+ async def find_param(state: State, param_name:str, mode: Literal['DRIAS', 'IPCC'] = 'DRIAS') -> dict[str, Optional[str]] | Location | None:
241
+ """Perform the good method to retrieve the desired parameter
242
 
243
+ Args:
244
+ state (State): state of the workflow
245
+ param_name (str): name of the desired parameter
246
+ table (str): name of the table
247
 
248
+ Returns:
249
+ dict[str, Any] | None:
250
+ """
251
+ if param_name == 'location':
252
+ location = await find_location(state['user_input'], mode)
253
+ return location
254
+ if param_name == 'year':
255
+ year = await find_year(state['user_input'])
256
+ return {'year': year}
257
+ return None
climateqa/engine/talk_to_data/ipcc/config.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from climateqa.engine.talk_to_data.ui_config import PRECIPITATION_COLORSCALE, TEMPERATURE_COLORSCALE
2
+ from climateqa.engine.talk_to_data.config import IPCC_DATASET_URL
3
+
4
+
5
+ # IPCC_DATASET_URL = "hf://datasets/ekimetrics/ipcc-atlas"
6
+ IPCC_TABLES = [
7
+ "mean_temperature",
8
+ "total_precipitation",
9
+ ]
10
+
11
+ IPCC_INDICATOR_COLUMNS_PER_TABLE = {
12
+ "mean_temperature": "mean_temperature",
13
+ "total_precipitation": "total_precipitation"
14
+ }
15
+
16
+ IPCC_INDICATOR_TO_UNIT = {
17
+ "mean_temperature": "°C",
18
+ "total_precipitation": "mm/day"
19
+ }
20
+
21
+ IPCC_SCENARIO = [
22
+ "historical",
23
+ "ssp126",
24
+ "ssp245",
25
+ "ssp370",
26
+ "ssp585",
27
+ ]
28
+
29
+ IPCC_MODELS = []
30
+
31
+ IPCC_PLOT_PARAMETERS = [
32
+ 'year',
33
+ 'location'
34
+ ]
35
+
36
+ MACRO_COUNTRIES = ['JP',
37
+ 'IN',
38
+ 'MH',
39
+ 'PT',
40
+ 'ID',
41
+ 'SJ',
42
+ 'MX',
43
+ 'CN',
44
+ 'GL',
45
+ 'PN',
46
+ 'AR',
47
+ 'AQ',
48
+ 'PF',
49
+ 'BR',
50
+ 'SH',
51
+ 'GS',
52
+ 'ZA',
53
+ 'NZ',
54
+ 'TF',
55
+ ]
56
+
57
+ HUGE_MACRO_COUNTRIES = ['CL',
58
+ 'CA',
59
+ 'AU',
60
+ 'US',
61
+ 'RU'
62
+ ]
63
+
64
+ IPCC_INDICATOR_TO_COLORSCALE = {
65
+ "mean_temperature": TEMPERATURE_COLORSCALE,
66
+ "total_precipitation": PRECIPITATION_COLORSCALE
67
+ }
68
+
69
+ IPCC_UI_TEXT = """
70
+ Hi, I'm **Talk to IPCC**, designed to answer your questions using [**IPCC - ATLAS**](https://interactive-atlas.ipcc.ch/regional-information#eyJ0eXBlIjoiQVRMQVMiLCJjb21tb25zIjp7ImxhdCI6OTc3MiwibG5nIjo0MDA2OTIsInpvb20iOjQsInByb2oiOiJFUFNHOjU0MDMwIiwibW9kZSI6ImNvbXBsZXRlX2F0bGFzIn0sInByaW1hcnkiOnsic2NlbmFyaW8iOiJzc3A1ODUiLCJwZXJpb2QiOiIyIiwic2Vhc29uIjoieWVhciIsImRhdGFzZXQiOiJDTUlQNiIsInZhcmlhYmxlIjoidGFzIiwidmFsdWVUeXBlIjoiQU5PTUFMWSIsImhhdGNoaW5nIjoiU0lNUExFIiwicmVnaW9uU2V0IjoiYXI2IiwiYmFzZWxpbmUiOiJwcmVJbmR1c3RyaWFsIiwicmVnaW9uc1NlbGVjdGVkIjpbXX0sInBsb3QiOnsiYWN0aXZlVGFiIjoicGx1bWUiLCJtYXNrIjoibm9uZSIsInNjYXR0ZXJZTWFnIjpudWxsLCJzY2F0dGVyWVZhciI6bnVsbCwic2hvd2luZyI6ZmFsc2V9fQ==) data.
71
+ I'll answer by displaying a list of SQL queries, graphs and data most relevant to your question.
72
+
73
+ You can ask me anything about these climate indicators: **temperature** or **precipitation**.
74
+ You can specify **location** and/or **year**.
75
+ By default, we take the **mediane of each climate model**.
76
+
77
+ Current available charts :
78
+ - Yearly evolution of an indicator at a specific location (historical + SSP Projections)
79
+ - Yearly spatial distribution of an indicator in a specific country
80
+
81
+ Current available indicators :
82
+ - Mean temperature
83
+ - Total precipitation
84
+
85
+ For example, you can ask:
86
+ - What will the temperature be like in Paris?
87
+ - What will be the total rainfall in the USA in 2030?
88
+ - How will the average temperature evolve in China ?
89
+
90
+ ⚠️ **Limitations**:
91
+ - You can't ask anything that isn't related to **IPCC - ATLAS** data.
92
+ - You can not ask about **several locations at the same time**.
93
+ - If you specify a year **before 1850 or over 2100**, there will be **no data**.
94
+ - You **cannot compare two models**.
95
+
96
+ 🛈 **Information**
97
+ Please note that we **log your questions for meta-analysis purposes**, so avoid sharing any sensitive or personal information.
98
+ """
climateqa/engine/talk_to_data/ipcc/plot_informations.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from climateqa.engine.talk_to_data.ipcc.config import IPCC_INDICATOR_TO_UNIT
2
+
3
+ def indicator_evolution_informations(
4
+ indicator: str,
5
+ params: dict[str,str],
6
+ ) -> str:
7
+ if "location" not in params:
8
+ raise ValueError('"location" must be provided in params')
9
+ location = params["location"]
10
+
11
+ unit = IPCC_INDICATOR_TO_UNIT[indicator]
12
+ return f"""
13
+ This plot shows how the climate indicator **{indicator}** evolves over time in **{location}**.
14
+
15
+ It combines both historical (from 1950 to 2015) observations and future (from 2016 to 2100) projections for the different SSP climate scenarios (SSP126, SSP245, SSP370 and SSP585).
16
+
17
+ The x-axis represents the years (from 1950 to 2100), and the y-axis shows the value of the {indicator} ({unit}).
18
+
19
+ Each line corresponds to a different scenario, allowing you to compare how {indicator} might change under various future conditions.
20
+
21
+ **Data source:**
22
+ - The data comes from the CMIP6 IPCC ATLAS data. The data were initially extracted from [this referenced website](https://digital.csic.es/handle/10261/332744) and then preprocessed to a tabular format and uploaded as parquet in this [Hugging Face dataset](https://huggingface.co/datasets/Ekimetrics/ipcc-atlas).
23
+ - The underlying data is retrieved by aggregating yearly values of {indicator} for the selected location, across all available scenarios. This means the system collects, for each year, the value of {indicator} in {location}, both for the historical period and for each scenario, to build the time series.
24
+ - The coordinates used for {location} correspond to the closest available point in the IPCC database, which uses a regular grid with a spatial resolution of 1 degree.
25
+ """
26
+
27
+ def choropleth_map_informations(
28
+ indicator: str,
29
+ params: dict[str, str],
30
+ ) -> str:
31
+ unit = IPCC_INDICATOR_TO_UNIT[indicator]
32
+ if "location" not in params:
33
+ raise ValueError('"location" must be provided in params')
34
+ location = params["location"]
35
+ country_name = params["country_name"]
36
+ year = params["year"]
37
+ if year is None:
38
+ year = 2050
39
+
40
+ return f"""
41
+ This plot displays a choropleth map showing the spatial distribution of **{indicator}** across all regions of **{location}** country ({country_name}) for the year **{year}** and the chosen scenario.
42
+
43
+ Each grid point is colored according to the value of the indicator ({unit}), allowing you to visually compare how {indicator} varies geographically within the country for the selected year and scenario.
44
+
45
+ **Data source:**
46
+ - The data come from the CMIP6 IPCC ATLAS data. The data were initially extracted from [this referenced website](https://digital.csic.es/handle/10261/332744) and then preprocessed to a tabular format and uploaded as parquet in this [Hugging Face dataset](https://huggingface.co/datasets/Ekimetrics/ipcc-atlas).
47
+ - For each grid point of {location} country ({country_name}), the value of {indicator} in {year} and for the selected scenario is extracted and mapped to its geographic coordinates.
48
+ - The grid points correspond to 1-degree squares centered on the grid points of the IPCC dataset. Each grid point has been mapped to a country using [**reverse_geocoder**](https://github.com/thampiman/reverse-geocoder).
49
+ - The coordinates used for each region are those of the closest available grid point in the IPCC database, which uses a regular grid with a spatial resolution of 1 degree.
50
+ """
climateqa/engine/talk_to_data/ipcc/plots.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable
2
+ from plotly.graph_objects import Figure
3
+ import plotly.graph_objects as go
4
+ import pandas as pd
5
+ import geojson
6
+
7
+ from climateqa.engine.talk_to_data.ipcc.config import IPCC_INDICATOR_TO_COLORSCALE, IPCC_INDICATOR_TO_UNIT, IPCC_SCENARIO
8
+ from climateqa.engine.talk_to_data.ipcc.plot_informations import choropleth_map_informations, indicator_evolution_informations
9
+ from climateqa.engine.talk_to_data.ipcc.queries import indicator_for_given_year_query, indicator_per_year_at_location_query
10
+ from climateqa.engine.talk_to_data.objects.plot import Plot
11
+
12
+ def generate_geojson_polygons(latitudes: list[float], longitudes: list[float], indicators: list[float]) -> geojson.FeatureCollection:
13
+ features = [
14
+ geojson.Feature(
15
+ geometry=geojson.Polygon([[
16
+ [lon - 0.5, lat - 0.5],
17
+ [lon + 0.5, lat - 0.5],
18
+ [lon + 0.5, lat + 0.5],
19
+ [lon - 0.5, lat + 0.5],
20
+ [lon - 0.5, lat - 0.5]
21
+ ]]),
22
+ properties={"value": val},
23
+ id=str(idx)
24
+ )
25
+ for idx, (lat, lon, val) in enumerate(zip(latitudes, longitudes, indicators))
26
+ ]
27
+
28
+ geojson_data = geojson.FeatureCollection(features)
29
+ return geojson_data
30
+
31
+ def plot_indicator_evolution_at_location_historical_and_projections(
32
+ params: dict,
33
+ ) -> Callable[[pd.DataFrame], Figure]:
34
+ """
35
+ Returns a function that generates a line plot showing the evolution of a climate indicator
36
+ (e.g., temperature, rainfall) over time at a specific location, including both historical data
37
+ and future projections for different climate scenarios.
38
+
39
+ Args:
40
+ params (dict): Dictionary with:
41
+ - indicator_column (str): Name of the climate indicator column to plot.
42
+ - location (str): Location (e.g., country, city) for which to plot the indicator.
43
+
44
+ Returns:
45
+ Callable[[pd.DataFrame], Figure]: Function that takes a DataFrame and returns a Plotly Figure
46
+ showing the indicator's evolution over time, with scenario lines and historical data.
47
+ """
48
+ indicator = params["indicator_column"]
49
+ location = params["location"]
50
+ indicator_label = " ".join(word.capitalize() for word in indicator.split("_"))
51
+ unit = IPCC_INDICATOR_TO_UNIT.get(indicator, "")
52
+
53
+ def plot_data(df: pd.DataFrame) -> Figure:
54
+ df = df.sort_values(by='year')
55
+ years = df['year'].astype(int).tolist()
56
+ indicators = df[indicator].astype(float).tolist()
57
+ scenarios = df['scenario'].astype(str).tolist()
58
+
59
+ # Find last historical value for continuity
60
+ last_historical = [(y, v) for y, v, s in zip(years, indicators, scenarios) if s == 'historical']
61
+ last_historical_year, last_historical_indicator = last_historical[-1] if last_historical else (None, None)
62
+
63
+ fig = go.Figure()
64
+ for scenario in IPCC_SCENARIO:
65
+ x = [y for y, s in zip(years, scenarios) if s == scenario]
66
+ y = [v for v, s in zip(indicators, scenarios) if s == scenario]
67
+ # Connect historical to scenario
68
+ if scenario != 'historical' and last_historical_indicator is not None:
69
+ x = [last_historical_year] + x
70
+ y = [last_historical_indicator] + y
71
+ fig.add_trace(go.Scatter(
72
+ x=x,
73
+ y=y,
74
+ mode='lines',
75
+ name=scenario
76
+ ))
77
+
78
+ fig.update_layout(
79
+ title=f'Yearly Evolution of {indicator_label} in {location} (Historical + SSP Scenarios)',
80
+ xaxis_title='Year',
81
+ yaxis_title=f'{indicator_label} ({unit})',
82
+ legend_title='Scenario',
83
+ height=800,
84
+ )
85
+ return fig
86
+
87
+ return plot_data
88
+
89
+ indicator_evolution_at_location_historical_and_projections: Plot = {
90
+ "name": "Indicator Evolution at Location (Historical + Projections)",
91
+ "description": (
92
+ "Shows how a climate indicator (e.g., rainfall, temperature) changes over time at a specific location, "
93
+ "including historical data and future projections. "
94
+ "Useful for questions about the value or trend of an indicator at a location for any year, "
95
+ "such as 'What will be the total rainfall in China in 2050?' or 'How does rainfall evolve in China over time?'. "
96
+ "Parameters: indicator_column (the climate variable), location (e.g., country, city)."
97
+ ),
98
+ "params": ["indicator_column", "location"],
99
+ "plot_function": plot_indicator_evolution_at_location_historical_and_projections,
100
+ "sql_query": indicator_per_year_at_location_query,
101
+ "plot_information": indicator_evolution_informations,
102
+ "short_name": "Evolution"
103
+ }
104
+
105
+ def plot_choropleth_map_of_country_indicator_for_specific_year(
106
+ params: dict,
107
+ ) -> Callable[[pd.DataFrame], Figure]:
108
+ """
109
+ Returns a function that generates a choropleth map (heatmap) showing the spatial distribution
110
+ of a climate indicator (e.g., temperature, rainfall) across all regions of a country for a specific year.
111
+
112
+ Args:
113
+ params (dict): Dictionary with:
114
+ - indicator_column (str): Name of the climate indicator column to plot.
115
+ - year (str or int, optional): Year for which to plot the indicator (default: 2050).
116
+ - country_name (str): Name of the country.
117
+ - location (str): Location (country or region) for the map.
118
+
119
+ Returns:
120
+ Callable[[pd.DataFrame], Figure]: Function that takes a DataFrame and returns a Plotly Figure
121
+ showing the indicator's spatial distribution as a choropleth map for the specified year.
122
+ """
123
+ indicator = params["indicator_column"]
124
+ year = params.get('year')
125
+ if year is None:
126
+ year = 2050
127
+ country_name = params['country_name']
128
+ location = params['location']
129
+ indicator_label = " ".join(word.capitalize() for word in indicator.split("_"))
130
+ unit = IPCC_INDICATOR_TO_UNIT.get(indicator, "")
131
+
132
+ def plot_data(df: pd.DataFrame) -> Figure:
133
+
134
+ indicators = df[indicator].astype(float).tolist()
135
+ latitudes = df["latitude"].astype(float).tolist()
136
+ longitudes = df["longitude"].astype(float).tolist()
137
+
138
+ geojson_data = generate_geojson_polygons(latitudes, longitudes, indicators)
139
+
140
+ fig = go.Figure(go.Choroplethmapbox(
141
+ geojson=geojson_data,
142
+ locations=[str(i) for i in range(len(indicators))],
143
+ featureidkey="id",
144
+ z=indicators,
145
+ colorscale=IPCC_INDICATOR_TO_COLORSCALE[indicator],
146
+ zmin=min(indicators),
147
+ zmax=max(indicators),
148
+ marker_opacity=0.7,
149
+ marker_line_width=0,
150
+ colorbar_title=f"{indicator_label} ({unit})",
151
+ text=[f"{indicator_label}: {value:.2f} {unit}" for value in indicators], # Add hover text showing the indicator value
152
+ hoverinfo="text"
153
+ ))
154
+
155
+ fig.update_layout(
156
+ mapbox_style="open-street-map",
157
+ mapbox_zoom=2,
158
+ height=800,
159
+ mapbox_center={
160
+ "lat": latitudes[len(latitudes)//2] if latitudes else 0,
161
+ "lon": longitudes[len(longitudes)//2] if longitudes else 0
162
+ },
163
+ coloraxis_colorbar=dict(title=f"{indicator_label} ({unit})"),
164
+ title=f"{indicator_label} in {year} in {location} ({country_name})"
165
+ )
166
+ return fig
167
+
168
+ return plot_data
169
+
170
+ choropleth_map_of_country_indicator_for_specific_year: Plot = {
171
+ "name": "Choropleth Map of a Country's Indicator Distribution for a Specific Year",
172
+ "description": (
173
+ "Displays a map showing the spatial distribution of a climate indicator (e.g., rainfall, temperature) "
174
+ "across all regions of a country for a specific year. "
175
+ "Can answer questions about the value of an indicator in a country or region for a given year, "
176
+ "such as 'What will be the total rainfall in China in 2050?' or 'How is rainfall distributed across China in 2050?'. "
177
+ "Parameters: indicator_column (the climate variable), year, location (country name)."
178
+ ),
179
+ "params": ["indicator_column", "year", "location"],
180
+ "plot_function": plot_choropleth_map_of_country_indicator_for_specific_year,
181
+ "sql_query": indicator_for_given_year_query,
182
+ "plot_information": choropleth_map_informations,
183
+ "short_name": "Map",
184
+ }
185
+
186
+ IPCC_PLOTS = [
187
+ indicator_evolution_at_location_historical_and_projections,
188
+ choropleth_map_of_country_indicator_for_specific_year
189
+ ]
climateqa/engine/talk_to_data/ipcc/queries.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TypedDict, Optional
2
+
3
+ from climateqa.engine.talk_to_data.ipcc.config import HUGE_MACRO_COUNTRIES, MACRO_COUNTRIES
4
+ from climateqa.engine.talk_to_data.config import IPCC_DATASET_URL
5
+ class IndicatorPerYearAtLocationQueryParams(TypedDict, total=False):
6
+ """
7
+ Parameters for querying the evolution of an indicator per year at a specific location.
8
+
9
+ Attributes:
10
+ indicator_column (str): Name of the climate indicator column.
11
+ latitude (str): Latitude of the location.
12
+ longitude (str): Longitude of the location.
13
+ country_code (str): Country code.
14
+ admin1 (str): Administrative region (optional).
15
+ """
16
+ indicator_column: str
17
+ latitude: str
18
+ longitude: str
19
+ country_code: str
20
+ admin1: Optional[str]
21
+
22
+ def indicator_per_year_at_location_query(
23
+ table: str, params: IndicatorPerYearAtLocationQueryParams
24
+ ) -> str:
25
+ """
26
+ Builds an SQL query to get the evolution of an indicator per year at a specific location.
27
+
28
+ Args:
29
+ table (str): SQL table of the indicator.
30
+ params (IndicatorPerYearAtLocationQueryParams): Dictionary with the required params for the query.
31
+
32
+ Returns:
33
+ str: The SQL query string, or an empty string if required parameters are missing.
34
+ """
35
+ indicator_column = params.get("indicator_column")
36
+ latitude = params.get("latitude")
37
+ longitude = params.get("longitude")
38
+ country_code = params.get("country_code")
39
+ admin1 = params.get("admin1")
40
+
41
+ if not all([indicator_column, latitude, longitude, country_code]):
42
+ return ""
43
+
44
+ if country_code in MACRO_COUNTRIES:
45
+ table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}_macro.parquet'"
46
+ sql_query = f"""
47
+ SELECT year, scenario, AVG({indicator_column}) as {indicator_column}
48
+ FROM {table_path}
49
+ WHERE latitude = {latitude} AND longitude = {longitude} AND year >= 1950
50
+ GROUP BY scenario, year
51
+ ORDER BY year, scenario
52
+ """
53
+ elif country_code in HUGE_MACRO_COUNTRIES:
54
+ table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}_macro.parquet'"
55
+ sql_query = f"""
56
+ SELECT year, scenario, {indicator_column},
57
+ FROM {table_path}
58
+ WHERE latitude = {latitude} AND longitude = {longitude} AND year >= 1950
59
+ ORDER year, scenario
60
+ """
61
+ else:
62
+ table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}.parquet'"
63
+ sql_query = f"""
64
+ WITH medians_per_month AS (
65
+ SELECT year, scenario, month, MEDIAN({indicator_column}) AS median_value
66
+ FROM {table_path}
67
+ WHERE latitude = {latitude} AND longitude = {longitude} AND year >= 1950
68
+ GROUP BY scenario, year, month
69
+ )
70
+ SELECT year, scenario, AVG(median_value) AS {indicator_column}
71
+ FROM medians_per_month
72
+ GROUP BY scenario, year
73
+ ORDER BY year, scenario
74
+ """
75
+ return sql_query.strip()
76
+
77
+ class IndicatorForGivenYearQueryParams(TypedDict, total=False):
78
+ """
79
+ Parameters for querying an indicator's values across locations for a specific year.
80
+
81
+ Attributes:
82
+ indicator_column (str): The column name for the climate indicator.
83
+ year (str): The year to query.
84
+ country_code (str): The country code.
85
+ """
86
+ indicator_column: str
87
+ year: str
88
+ country_code: str
89
+
90
+ def indicator_for_given_year_query(
91
+ table: str, params: IndicatorForGivenYearQueryParams
92
+ ) -> str:
93
+ """
94
+ Builds an SQL query to get the values of an indicator with their latitudes, longitudes,
95
+ and scenarios for a given year.
96
+
97
+ Args:
98
+ table (str): SQL table of the indicator.
99
+ params (IndicatorForGivenYearQueryParams): Dictionary with the required params for the query.
100
+
101
+ Returns:
102
+ str: The SQL query string, or an empty string if required parameters are missing.
103
+ """
104
+ indicator_column = params.get("indicator_column")
105
+ year = params.get("year") or 2050
106
+ country_code = params.get("country_code")
107
+
108
+ if not all([indicator_column, year, country_code]):
109
+ return ""
110
+
111
+ if country_code in MACRO_COUNTRIES:
112
+ table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}_macro.parquet'"
113
+ sql_query = f"""
114
+ SELECT latitude, longitude, scenario, AVG({indicator_column}) as {indicator_column}
115
+ FROM {table_path}
116
+ WHERE year = {year}
117
+ GROUP BY latitude, longitude, scenario
118
+ ORDER BY latitude, longitude, scenario
119
+ """
120
+ elif country_code in HUGE_MACRO_COUNTRIES:
121
+ table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}_macro.parquet'"
122
+ sql_query = f"""
123
+ SELECT latitude, longitude, scenario, {indicator_column},
124
+ FROM {table_path}
125
+ WHERE year = {year}
126
+ ORDER BY latitude, longitude, scenario
127
+ """
128
+ else:
129
+ table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}.parquet'"
130
+ sql_query = f"""
131
+ WITH medians_per_month AS (
132
+ SELECT latitude, longitude, scenario, month, MEDIAN({indicator_column}) AS median_value
133
+ FROM {table_path}
134
+ WHERE year = {year}
135
+ GROUP BY latitude, longitude, scenario, month
136
+ )
137
+ SELECT latitude, longitude, scenario, AVG(median_value) AS {indicator_column}
138
+ FROM medians_per_month
139
+ GROUP BY latitude, longitude, scenario
140
+ ORDER BY latitude, longitude, scenario
141
+ """
142
+
143
+ return sql_query.strip()
climateqa/engine/talk_to_data/main.py CHANGED
@@ -1,44 +1,70 @@
1
- from climateqa.engine.talk_to_data.talk_to_drias import drias_workflow
2
- from climateqa.engine.llm import get_llm
3
  from climateqa.logging import log_drias_interaction_to_huggingface
4
- import ast
5
 
6
- llm = get_llm(provider="openai")
7
-
8
- def ask_llm_to_add_table_names(sql_query: str, llm) -> str:
9
- """Adds table names to the SQL query result rows using LLM.
10
 
11
- This function modifies the SQL query to include the source table name in each row
12
- of the result set, making it easier to track which data comes from which table.
 
13
 
14
  Args:
15
- sql_query (str): The original SQL query to modify
16
- llm: The language model instance to use for generating the modified query
17
 
18
  Returns:
19
- str: The modified SQL query with table names included in the result rows
 
 
 
 
 
 
 
 
 
20
  """
21
- sql_with_table_names = llm.invoke(f"Make the following sql query display the source table in the rows {sql_query}. Just answer the query. The answer should not include ```sql\n").content
22
- return sql_with_table_names
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- def ask_llm_column_names(sql_query: str, llm) -> list[str]:
25
- """Extracts column names from a SQL query using LLM.
 
 
26
 
27
- This function analyzes a SQL query to identify which columns are being selected
28
- in the result set.
29
 
30
- Args:
31
- sql_query (str): The SQL query to analyze
32
- llm: The language model instance to use for column extraction
33
-
34
- Returns:
35
- list[str]: A list of column names being selected in the query
36
- """
37
- columns = llm.invoke(f"From the given sql query, list the columns that are being selected. The answer should only be a python list. Just answer the list. The SQL query : {sql_query}").content
38
- columns_list = ast.literal_eval(columns.strip("```python\n").strip())
39
- return columns_list
40
 
41
- async def ask_drias(query: str, index_state: int = 0, user_id: str = None) -> tuple:
 
 
42
  """Main function to process a DRIAS query and return results.
43
 
44
  This function orchestrates the DRIAS workflow, processing a user query to generate
@@ -61,58 +87,38 @@ async def ask_drias(query: str, index_state: int = 0, user_id: str = None) -> tu
61
  - table_list (list): List of table names used
62
  - error (str): Error message if any
63
  """
64
- final_state = await drias_workflow(query)
65
  sql_queries = []
66
  result_dataframes = []
67
  figures = []
68
- table_list = []
 
69
 
70
- for plot_state in final_state['plot_states'].values():
71
- for table_state in plot_state['table_states'].values():
72
- if table_state['status'] == 'OK':
73
- if 'table_name' in table_state:
74
- table_list.append(' '.join(table_state['table_name'].capitalize().split('_')))
75
- if 'sql_query' in table_state and table_state['sql_query'] is not None:
76
- sql_queries.append(table_state['sql_query'])
77
-
78
- if 'dataframe' in table_state and table_state['dataframe'] is not None:
79
- result_dataframes.append(table_state['dataframe'])
80
- if 'figure' in table_state and table_state['figure'] is not None:
81
- figures.append(table_state['figure'])
 
 
 
82
 
83
  if "error" in final_state and final_state["error"] != "":
84
- return None, None, None, [], [], [], 0, final_state["error"]
 
85
 
86
  sql_query = sql_queries[index_state]
87
  dataframe = result_dataframes[index_state]
88
  figure = figures[index_state](dataframe)
 
89
 
90
  log_drias_interaction_to_huggingface(query, sql_query, user_id)
91
 
92
- return sql_query, dataframe, figure, sql_queries, result_dataframes, figures, index_state, table_list, ""
93
-
94
- # def ask_vanna(vn,db_vanna_path, query):
95
-
96
- # try :
97
- # location = detect_location_with_openai(query)
98
- # if location:
99
-
100
- # coords = loc2coords(location)
101
- # user_input = query.lower().replace(location.lower(), f"lat, long : {coords}")
102
-
103
- # relevant_tables = detect_relevant_tables(db_vanna_path, user_input, llm)
104
- # coords_tables = [nearestNeighbourSQL(db_vanna_path, coords, relevant_tables[i]) for i in range(len(relevant_tables))]
105
- # user_input_with_coords = replace_coordonates(coords, user_input, coords_tables)
106
-
107
- # sql_query, result_dataframe, figure = vn.ask(user_input_with_coords, print_results=False, allow_llm_to_see_data=True, auto_train=False)
108
-
109
- # return sql_query, result_dataframe, figure
110
- # else :
111
- # empty_df = pd.DataFrame()
112
- # empty_fig = None
113
- # return "", empty_df, empty_fig
114
- # except Exception as e:
115
- # print(f"Error: {e}")
116
- # empty_df = pd.DataFrame()
117
- # empty_fig = None
118
- # return "", empty_df, empty_fig
 
1
+ from climateqa.engine.talk_to_data.workflow.drias import drias_workflow
2
+ from climateqa.engine.talk_to_data.workflow.ipcc import ipcc_workflow
3
  from climateqa.logging import log_drias_interaction_to_huggingface
 
4
 
5
+ async def ask_drias(query: str, index_state: int = 0, user_id: str | None = None) -> tuple:
6
+ """Main function to process a DRIAS query and return results.
 
 
7
 
8
+ This function orchestrates the DRIAS workflow, processing a user query to generate
9
+ SQL queries, dataframes, and visualizations. It handles multiple results and allows
10
+ pagination through them.
11
 
12
  Args:
13
+ query (str): The user's question about climate data
14
+ index_state (int, optional): The index of the result to return. Defaults to 0.
15
 
16
  Returns:
17
+ tuple: A tuple containing:
18
+ - sql_query (str): The SQL query used
19
+ - dataframe (pd.DataFrame): The resulting data
20
+ - figure (Callable): Function to generate the visualization
21
+ - sql_queries (list): All generated SQL queries
22
+ - result_dataframes (list): All resulting dataframes
23
+ - figures (list): All figure generation functions
24
+ - index_state (int): Current result index
25
+ - table_list (list): List of table names used
26
+ - error (str): Error message if any
27
  """
28
+ final_state = await drias_workflow(query)
29
+ sql_queries = []
30
+ result_dataframes = []
31
+ figures = []
32
+ plot_title_list = []
33
+ plot_informations = []
34
+
35
+ for output_title, output in final_state['outputs'].items():
36
+ if output['status'] == 'OK':
37
+ if output['table'] is not None:
38
+ plot_title_list.append(output_title)
39
+
40
+ if output['plot_information'] is not None:
41
+ plot_informations.append(output['plot_information'])
42
+
43
+ if output['sql_query'] is not None:
44
+ sql_queries.append(output['sql_query'])
45
+
46
+ if output['dataframe'] is not None:
47
+ result_dataframes.append(output['dataframe'])
48
+ if output['figure'] is not None:
49
+ figures.append(output['figure'])
50
+
51
+ if "error" in final_state and final_state["error"] != "":
52
+ # No Sql query, no dataframe, no figure, no plot information, empty sql queries list, empty result dataframes list, empty figures list, empty plot information list, index state = 0, empty table list, error message
53
+ return None, None, None, None, [], [], [], 0, [], final_state["error"]
54
 
55
+ sql_query = sql_queries[index_state]
56
+ dataframe = result_dataframes[index_state]
57
+ figure = figures[index_state](dataframe)
58
+ plot_information = plot_informations[index_state]
59
 
 
 
60
 
61
+ log_drias_interaction_to_huggingface(query, sql_query, user_id)
62
+
63
+ return sql_query, dataframe, figure, plot_information, sql_queries, result_dataframes, figures, plot_informations, index_state, plot_title_list, ""
 
 
 
 
 
 
 
64
 
65
+
66
+
67
+ async def ask_ipcc(query: str, index_state: int = 0, user_id: str | None = None) -> tuple:
68
  """Main function to process a DRIAS query and return results.
69
 
70
  This function orchestrates the DRIAS workflow, processing a user query to generate
 
87
  - table_list (list): List of table names used
88
  - error (str): Error message if any
89
  """
90
+ final_state = await ipcc_workflow(query)
91
  sql_queries = []
92
  result_dataframes = []
93
  figures = []
94
+ plot_title_list = []
95
+ plot_informations = []
96
 
97
+ for output_title, output in final_state['outputs'].items():
98
+ if output['status'] == 'OK':
99
+ if output['table'] is not None:
100
+ plot_title_list.append(output_title)
101
+
102
+ if output['plot_information'] is not None:
103
+ plot_informations.append(output['plot_information'])
104
+
105
+ if output['sql_query'] is not None:
106
+ sql_queries.append(output['sql_query'])
107
+
108
+ if output['dataframe'] is not None:
109
+ result_dataframes.append(output['dataframe'])
110
+ if output['figure'] is not None:
111
+ figures.append(output['figure'])
112
 
113
  if "error" in final_state and final_state["error"] != "":
114
+ # No Sql query, no dataframe, no figure, no plot information, empty sql queries list, empty result dataframes list, empty figures list, empty plot information list, index state = 0, empty table list, error message
115
+ return None, None, None, None, [], [], [], 0, [], final_state["error"]
116
 
117
  sql_query = sql_queries[index_state]
118
  dataframe = result_dataframes[index_state]
119
  figure = figures[index_state](dataframe)
120
+ plot_information = plot_informations[index_state]
121
 
122
  log_drias_interaction_to_huggingface(query, sql_query, user_id)
123
 
124
+ return sql_query, dataframe, figure, plot_information, sql_queries, result_dataframes, figures, plot_informations, index_state, plot_title_list, ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/engine/talk_to_data/objects/llm_outputs.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Annotated, TypedDict
2
+
3
+
4
+ class ArrayOutput(TypedDict):
5
+ """Represents the output of a function that returns an array.
6
+
7
+ This class is used to type-hint functions that return arrays,
8
+ ensuring consistent return types across the codebase.
9
+
10
+ Attributes:
11
+ array (str): A syntactically valid Python array string
12
+ """
13
+ array: Annotated[str, "Syntactically valid python array."]
climateqa/engine/talk_to_data/objects/location.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from token import OP
2
+ from typing import Optional, TypedDict
3
+
4
+
5
+
6
+ class Location(TypedDict):
7
+ location: str
8
+ latitude: Optional[str]
9
+ longitude: Optional[str]
10
+ country_code: Optional[str]
11
+ country_name: Optional[str]
12
+ admin1: Optional[str]
climateqa/engine/talk_to_data/objects/plot.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, TypedDict, Optional
2
+ from plotly.graph_objects import Figure
3
+
4
+ class Plot(TypedDict):
5
+ """Represents a plot configuration in the DRIAS system.
6
+
7
+ This class defines the structure for configuring different types of plots
8
+ that can be generated from climate data.
9
+
10
+ Attributes:
11
+ name (str): The name of the plot type
12
+ description (str): A description of what the plot shows
13
+ params (list[str]): List of required parameters for the plot
14
+ plot_function (Callable[..., Callable[..., Figure]]): Function to generate the plot
15
+ sql_query (Callable[..., str]): Function to generate the SQL query for the plot
16
+ """
17
+ name: str
18
+ description: str
19
+ params: list[str]
20
+ plot_function: Callable[..., Callable[..., Figure]]
21
+ sql_query: Callable[..., str]
22
+ plot_information: Callable[..., str]
23
+ short_name: str
climateqa/engine/talk_to_data/objects/states.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable, Optional, TypedDict
2
+ from plotly.graph_objects import Figure
3
+ import pandas as pd
4
+ from climateqa.engine.talk_to_data.objects.plot import Plot
5
+
6
+ class TTDOutput(TypedDict):
7
+ status: str
8
+ plot: Plot
9
+ table: str
10
+ sql_query: Optional[str]
11
+ dataframe: Optional[pd.DataFrame]
12
+ figure: Optional[Callable[..., Figure]]
13
+ plot_information: Optional[str]
14
+ class State(TypedDict):
15
+ user_input: str
16
+ plots: list[str]
17
+ outputs: dict[str, TTDOutput]
18
+ error: Optional[str]
19
+
climateqa/engine/talk_to_data/prompt.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ query_prompt_template = """You are an expert SQL query generator. Given an input question, database schema, SQL dialect and relevant tables to answer the question, generate an optimized and syntactically correct SQL query which can provide useful insights to the question.
2
+
3
+ ### Instructions:
4
+ 1. **Use only relevant tables**: The following tables are relevant to answering the question: {relevant_tables}. Do not use any other tables.
5
+ 2. **Relevant columns only**: Never select `*`. Only include necessary columns based on the input question.
6
+ 3. **Schema Awareness**:
7
+ - Use only columns present in the given schema.
8
+ - **If a column name appears in multiple tables, always use the format `table_name.column_name` to avoid ambiguity.**
9
+ - Select only the column which are insightful for the question.
10
+ 4. **Dialect Compliance**: Follow `{dialect}` syntax rules.
11
+ 5. **Ordering**: Order the results by a relevant column if applicable (e.g., timestamp for recent records).
12
+ 6. **Valid query**: Make sure the query is syntactically and functionally correct.
13
+ 7. **Conditions** : For the common columns, the same condition should be applied to all the tables (e.g. latitude, longitude, model, year...)
14
+ 9. **Join tables** : If you need to join table, you should join them with year feature.
15
+ 10. **Model** : For each table, you need to add a condition on the model to be equal to {model}
16
+
17
+ ### Provided Database Schema:
18
+ {table_info}
19
+
20
+ ### Relevant Tables:
21
+ {relevant_tables}
22
+
23
+ **Question:** {input}
24
+
25
+ **SQL Query:**"""
26
+
27
+ plot_prompt_template = """You are a data visualization expert. Given an input question and an SQL Query, generate an insightful plot according to the question.
28
+
29
+ ### Instructions
30
+ 1. **Use only the column names provided**. The data will be provided as a Pandas DataFrame `df` with the columns present in the SELECT.
31
+ 2. Generate the Python Plotly code to chart the results using `df` and the column names.
32
+ 3. Make as complete a graph as possible to answer the question, and make it as easy to understand as possible.
33
+ 4. **Response with only Python code**. Do not answer with any explanations -- just the code.
34
+ 5. **Specific cases** :
35
+ - For a question about the evolution of something, it is also relevant to plot the data with also the sliding average for a period of 20 years for example.
36
+
37
+ ### SQL Query:
38
+ {sql_query}
39
+
40
+ **Question:** {input}
41
+
42
+ **Python code:**
43
+ """
44
+
climateqa/engine/talk_to_data/query.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from concurrent.futures import ThreadPoolExecutor
3
+ import duckdb
4
+ import pandas as pd
5
+ import os
6
+
7
+ def find_indicator_column(table: str, indicator_columns_per_table: dict[str,str]) -> str:
8
+ """Retrieves the name of the indicator column within a table.
9
+
10
+ This function maps table names to their corresponding indicator columns
11
+ using the predefined mapping in INDICATOR_COLUMNS_PER_TABLE.
12
+
13
+ Args:
14
+ table (str): Name of the table in the database
15
+
16
+ Returns:
17
+ str: Name of the indicator column for the specified table
18
+
19
+ Raises:
20
+ KeyError: If the table name is not found in the mapping
21
+ """
22
+ print(f"---- Find indicator column in table {table} ----")
23
+ return indicator_columns_per_table[table]
24
+
25
+ async def execute_sql_query(sql_query: str) -> pd.DataFrame:
26
+ """Executes a SQL query on the DRIAS database and returns the results.
27
+
28
+ This function connects to the DuckDB database containing DRIAS climate data
29
+ and executes the provided SQL query. It handles the database connection and
30
+ returns the results as a pandas DataFrame.
31
+
32
+ Args:
33
+ sql_query (str): The SQL query to execute
34
+
35
+ Returns:
36
+ pd.DataFrame: A DataFrame containing the query results
37
+
38
+ Raises:
39
+ duckdb.Error: If there is an error executing the SQL query
40
+ """
41
+ def _execute_query():
42
+ # Execute the query
43
+ con = duckdb.connect()
44
+ HF_TOKEN = os.getenv("HF_TOKEN")
45
+ con.execute(f"""CREATE SECRET hf_token (
46
+ TYPE huggingface,
47
+ TOKEN '{HF_TOKEN}'
48
+ );""")
49
+ results = con.execute(sql_query).fetchdf()
50
+ # return fetched data
51
+ return results
52
+
53
+ # Run the query in a thread pool to avoid blocking
54
+ loop = asyncio.get_event_loop()
55
+ with ThreadPoolExecutor() as executor:
56
+ return await loop.run_in_executor(executor, _execute_query)
57
+
climateqa/engine/talk_to_data/talk_to_drias.py DELETED
@@ -1,317 +0,0 @@
1
- import os
2
-
3
- from typing import Any, Callable, TypedDict, Optional
4
- from numpy import sort
5
- import pandas as pd
6
- import asyncio
7
- from plotly.graph_objects import Figure
8
- from climateqa.engine.llm import get_llm
9
- from climateqa.engine.talk_to_data import sql_query
10
- from climateqa.engine.talk_to_data.config import INDICATOR_COLUMNS_PER_TABLE
11
- from climateqa.engine.talk_to_data.plot import PLOTS, Plot
12
- from climateqa.engine.talk_to_data.sql_query import execute_sql_query
13
- from climateqa.engine.talk_to_data.utils import (
14
- detect_relevant_plots,
15
- detect_year_with_openai,
16
- loc2coords,
17
- detect_location_with_openai,
18
- nearestNeighbourSQL,
19
- detect_relevant_tables,
20
- )
21
-
22
-
23
- ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))
24
-
25
- class TableState(TypedDict):
26
- """Represents the state of a table in the DRIAS workflow.
27
-
28
- This class defines the structure for tracking the state of a table during the
29
- data processing workflow, including its name, parameters, SQL query, and results.
30
-
31
- Attributes:
32
- table_name (str): The name of the table in the database
33
- params (dict[str, Any]): Parameters used for querying the table
34
- sql_query (str, optional): The SQL query used to fetch data
35
- dataframe (pd.DataFrame | None, optional): The resulting data
36
- figure (Callable[..., Figure], optional): Function to generate visualization
37
- status (str): The current status of the table processing ('OK' or 'ERROR')
38
- """
39
- table_name: str
40
- params: dict[str, Any]
41
- sql_query: Optional[str]
42
- dataframe: Optional[pd.DataFrame | None]
43
- figure: Optional[Callable[..., Figure]]
44
- status: str
45
-
46
- class PlotState(TypedDict):
47
- """Represents the state of a plot in the DRIAS workflow.
48
-
49
- This class defines the structure for tracking the state of a plot during the
50
- data processing workflow, including its name and associated tables.
51
-
52
- Attributes:
53
- plot_name (str): The name of the plot
54
- tables (list[str]): List of tables used in the plot
55
- table_states (dict[str, TableState]): States of the tables used in the plot
56
- """
57
- plot_name: str
58
- tables: list[str]
59
- table_states: dict[str, TableState]
60
-
61
- class State(TypedDict):
62
- user_input: str
63
- plots: list[str]
64
- plot_states: dict[str, PlotState]
65
- error: Optional[str]
66
-
67
- async def find_relevant_plots(state: State, llm) -> list[str]:
68
- print("---- Find relevant plots ----")
69
- relevant_plots = await detect_relevant_plots(state['user_input'], llm)
70
- return relevant_plots
71
-
72
- async def find_relevant_tables_per_plot(state: State, plot: Plot, llm) -> list[str]:
73
- print(f"---- Find relevant tables for {plot['name']} ----")
74
- relevant_tables = await detect_relevant_tables(state['user_input'], plot, llm)
75
- return relevant_tables
76
-
77
- async def find_param(state: State, param_name:str, table: str) -> dict[str, Any] | None:
78
- """Perform the good method to retrieve the desired parameter
79
-
80
- Args:
81
- state (State): state of the workflow
82
- param_name (str): name of the desired parameter
83
- table (str): name of the table
84
-
85
- Returns:
86
- dict[str, Any] | None:
87
- """
88
- if param_name == 'location':
89
- location = await find_location(state['user_input'], table)
90
- return location
91
- if param_name == 'year':
92
- year = await find_year(state['user_input'])
93
- return {'year': year}
94
- return None
95
-
96
- class Location(TypedDict):
97
- location: str
98
- latitude: Optional[str]
99
- longitude: Optional[str]
100
-
101
- async def find_location(user_input: str, table: str) -> Location:
102
- print(f"---- Find location in table {table} ----")
103
- location = await detect_location_with_openai(user_input)
104
- output: Location = {'location' : location}
105
- if location:
106
- coords = loc2coords(location)
107
- neighbour = nearestNeighbourSQL(coords, table)
108
- output.update({
109
- "latitude": neighbour[0],
110
- "longitude": neighbour[1],
111
- })
112
- return output
113
-
114
- async def find_year(user_input: str) -> str:
115
- """Extracts year information from user input using LLM.
116
-
117
- This function uses an LLM to identify and extract year information from the
118
- user's query, which is used to filter data in subsequent queries.
119
-
120
- Args:
121
- user_input (str): The user's query text
122
-
123
- Returns:
124
- str: The extracted year, or empty string if no year found
125
- """
126
- print(f"---- Find year ---")
127
- year = await detect_year_with_openai(user_input)
128
- return year
129
-
130
- def find_indicator_column(table: str) -> str:
131
- """Retrieves the name of the indicator column within a table.
132
-
133
- This function maps table names to their corresponding indicator columns
134
- using the predefined mapping in INDICATOR_COLUMNS_PER_TABLE.
135
-
136
- Args:
137
- table (str): Name of the table in the database
138
-
139
- Returns:
140
- str: Name of the indicator column for the specified table
141
-
142
- Raises:
143
- KeyError: If the table name is not found in the mapping
144
- """
145
- print(f"---- Find indicator column in table {table} ----")
146
- return INDICATOR_COLUMNS_PER_TABLE[table]
147
-
148
-
149
- async def process_table(
150
- table: str,
151
- params: dict[str, Any],
152
- plot: Plot,
153
- ) -> TableState:
154
- """Processes a table to extract relevant data and generate visualizations.
155
-
156
- This function retrieves the SQL query for the specified table, executes it,
157
- and generates a visualization based on the results.
158
-
159
- Args:
160
- table (str): The name of the table to process
161
- params (dict[str, Any]): Parameters used for querying the table
162
- plot (Plot): The plot object containing SQL query and visualization function
163
-
164
- Returns:
165
- TableState: The state of the processed table
166
- """
167
- table_state: TableState = {
168
- 'table_name': table,
169
- 'params': params.copy(),
170
- 'status': 'OK',
171
- 'dataframe': None,
172
- 'sql_query': None,
173
- 'figure': None
174
- }
175
-
176
- table_state['params']['indicator_column'] = find_indicator_column(table)
177
- sql_query = plot['sql_query'](table, table_state['params'])
178
-
179
- if sql_query == "":
180
- table_state['status'] = 'ERROR'
181
- return table_state
182
- table_state['sql_query'] = sql_query
183
- df = await execute_sql_query(sql_query)
184
-
185
- table_state['dataframe'] = df
186
- table_state['figure'] = plot['plot_function'](table_state['params'])
187
-
188
- return table_state
189
-
190
- async def drias_workflow(user_input: str) -> State:
191
- """Performs the complete workflow of Talk To Drias : from user input to sql queries, dataframes and figures generated
192
-
193
- Args:
194
- user_input (str): initial user input
195
-
196
- Returns:
197
- State: Final state with all the results
198
- """
199
- state: State = {
200
- 'user_input': user_input,
201
- 'plots': [],
202
- 'plot_states': {},
203
- 'error': ''
204
- }
205
-
206
- llm = get_llm(provider="openai")
207
-
208
- plots = await find_relevant_plots(state, llm)
209
-
210
- state['plots'] = plots
211
-
212
- if len(state['plots']) < 1:
213
- state['error'] = 'There is no plot to answer to the question'
214
- return state
215
-
216
- have_relevant_table = False
217
- have_sql_query = False
218
- have_dataframe = False
219
-
220
- for plot_name in state['plots']:
221
-
222
- plot = next((p for p in PLOTS if p['name'] == plot_name), None) # Find the associated plot object
223
- if plot is None:
224
- continue
225
-
226
- plot_state: PlotState = {
227
- 'plot_name': plot_name,
228
- 'tables': [],
229
- 'table_states': {}
230
- }
231
-
232
- plot_state['plot_name'] = plot_name
233
-
234
- relevant_tables = await find_relevant_tables_per_plot(state, plot, llm)
235
-
236
- if len(relevant_tables) > 0 :
237
- have_relevant_table = True
238
-
239
- plot_state['tables'] = relevant_tables
240
-
241
- params = {}
242
- for param_name in plot['params']:
243
- param = await find_param(state, param_name, relevant_tables[0])
244
- if param:
245
- params.update(param)
246
-
247
- tasks = [process_table(table, params, plot) for table in plot_state['tables'][:3]]
248
- results = await asyncio.gather(*tasks)
249
-
250
- # Store results back in plot_state
251
- have_dataframe = False
252
- have_sql_query = False
253
- for table_state in results:
254
- if table_state['sql_query']:
255
- have_sql_query = True
256
- if table_state['dataframe'] is not None and len(table_state['dataframe']) > 0:
257
- have_dataframe = True
258
- plot_state['table_states'][table_state['table_name']] = table_state
259
-
260
- state['plot_states'][plot_name] = plot_state
261
-
262
- if not have_relevant_table:
263
- state['error'] = "There is no relevant table in our database to answer your question"
264
- elif not have_sql_query:
265
- state['error'] = "There is no relevant sql query on our database that can help to answer your question"
266
- elif not have_dataframe:
267
- state['error'] = "There is no data in our table that can answer to your question"
268
-
269
- return state
270
-
271
- # def make_write_query_node():
272
-
273
- # def write_query(state):
274
- # print("---- Write query ----")
275
- # for table in state["tables"]:
276
- # sql_query = QUERIES[state[table]['query_type']](
277
- # table=table,
278
- # indicator_column=state[table]["columns"],
279
- # longitude=state[table]["longitude"],
280
- # latitude=state[table]["latitude"],
281
- # )
282
- # state[table].update({"sql_query": sql_query})
283
-
284
- # return state
285
-
286
- # return write_query
287
-
288
- # def make_fetch_data_node(db_path):
289
-
290
- # def fetch_data(state):
291
- # print("---- Fetch data ----")
292
- # for table in state["tables"]:
293
- # results = execute_sql_query(db_path, state[table]['sql_query'])
294
- # state[table].update(results)
295
-
296
- # return state
297
-
298
- # return fetch_data
299
-
300
-
301
-
302
- ## V2
303
-
304
-
305
- # def make_fetch_data_node(db_path: str, llm):
306
- # def fetch_data(state):
307
- # print("---- Fetch data ----")
308
- # db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
309
- # output = {}
310
- # sql_query = write_sql_query(state["query"], db, state["tables"], llm)
311
- # # TO DO : Add query checker
312
- # print(f"SQL query : {sql_query}")
313
- # output["sql_query"] = sql_query
314
- # output.update(fetch_data_from_sql_query(db_path, sql_query))
315
- # return output
316
-
317
- # return fetch_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/engine/talk_to_data/ui_config.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TEMPERATURE_COLORSCALE = [
2
+ [0.0, "rgb(5, 48, 97)"],
3
+ [0.10, "rgb(33, 102, 172)"],
4
+ [0.20, "rgb(67, 147, 195)"],
5
+ [0.30, "rgb(146, 197, 222)"],
6
+ [0.40, "rgb(209, 229, 240)"],
7
+ [0.50, "rgb(247, 247, 247)"],
8
+ [0.60, "rgb(253, 219, 199)"],
9
+ [0.75, "rgb(244, 165, 130)"],
10
+ [0.85, "rgb(214, 96, 77)"],
11
+ [0.90, "rgb(178, 24, 43)"],
12
+ [1.0, "rgb(103, 0, 31)"]
13
+ ]
14
+
15
+ PRECIPITATION_COLORSCALE = [
16
+ [0.0, "rgb(84, 48, 5)"],
17
+ [0.10, "rgb(140, 81, 10)"],
18
+ [0.20, "rgb(191, 129, 45)"],
19
+ [0.30, "rgb(223, 194, 125)"],
20
+ [0.40, "rgb(246, 232, 195)"],
21
+ [0.50, "rgb(245, 245, 245)"],
22
+ [0.60, "rgb(199, 234, 229)"],
23
+ [0.75, "rgb(128, 205, 193)"],
24
+ [0.85, "rgb(53, 151, 143)"],
25
+ [0.90, "rgb(1, 102, 94)"],
26
+ [1.0, "rgb(0, 60, 48)"]
27
+ ]
climateqa/engine/talk_to_data/{myVanna.py → vanna/myVanna.py} RENAMED
File without changes
climateqa/engine/talk_to_data/{vanna_class.py → vanna/vanna_class.py} RENAMED
File without changes
climateqa/engine/talk_to_data/workflow/drias.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from typing import Any
4
+ import asyncio
5
+ from climateqa.engine.llm import get_llm
6
+ from climateqa.engine.talk_to_data.input_processing import find_param, find_relevant_plots, find_relevant_tables_per_plot
7
+ from climateqa.engine.talk_to_data.query import execute_sql_query, find_indicator_column
8
+ from climateqa.engine.talk_to_data.objects.plot import Plot
9
+ from climateqa.engine.talk_to_data.objects.states import State, TTDOutput
10
+ from climateqa.engine.talk_to_data.drias.config import DRIAS_TABLES, DRIAS_INDICATOR_COLUMNS_PER_TABLE, DRIAS_PLOT_PARAMETERS
11
+ from climateqa.engine.talk_to_data.drias.plots import DRIAS_PLOTS
12
+
13
+ ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))
14
+
15
+ async def process_output(
16
+ output_title: str,
17
+ table: str,
18
+ plot: Plot,
19
+ params: dict[str, Any]
20
+ ) -> tuple[str, TTDOutput, dict[str, bool]]:
21
+ """
22
+ Processes a table for a given plot and parameters: builds the SQL query, executes it,
23
+ and generates the corresponding figure.
24
+
25
+ Args:
26
+ output_title (str): Title for the output (used as key in outputs dict).
27
+ table (str): The name of the table to process.
28
+ plot (Plot): The plot object containing SQL query and visualization function.
29
+ params (dict[str, Any]): Parameters used for querying the table.
30
+
31
+ Returns:
32
+ tuple: (output_title, results dict, errors dict)
33
+ """
34
+ results: TTDOutput = {
35
+ 'status': 'OK',
36
+ 'plot': plot,
37
+ 'table': table,
38
+ 'sql_query': None,
39
+ 'dataframe': None,
40
+ 'figure': None,
41
+ 'plot_information': None
42
+ }
43
+ errors = {
44
+ 'have_sql_query': False,
45
+ 'have_dataframe': False
46
+ }
47
+
48
+ # Find the indicator column for this table
49
+ indicator_column = find_indicator_column(table, DRIAS_INDICATOR_COLUMNS_PER_TABLE)
50
+ if indicator_column:
51
+ params['indicator_column'] = indicator_column
52
+
53
+ # Build the SQL query
54
+ sql_query = plot['sql_query'](table, params)
55
+ if not sql_query:
56
+ results['status'] = 'ERROR'
57
+ return output_title, results, errors
58
+
59
+ results['plot_information'] = plot['plot_information'](table, params)
60
+
61
+ results['sql_query'] = sql_query
62
+ errors['have_sql_query'] = True
63
+
64
+ # Execute the SQL query
65
+ df = await execute_sql_query(sql_query)
66
+ if df is not None and len(df) > 0:
67
+ results['dataframe'] = df
68
+ errors['have_dataframe'] = True
69
+ else:
70
+ results['status'] = 'NO_DATA'
71
+
72
+ # Generate the figure (always, even if df is empty, for consistency)
73
+ results['figure'] = plot['plot_function'](params)
74
+
75
+ return output_title, results, errors
76
+
77
+ async def drias_workflow(user_input: str) -> State:
78
+ """
79
+ Orchestrates the DRIAS workflow: from user input to SQL queries, dataframes, and figures.
80
+
81
+ Args:
82
+ user_input (str): The user's question.
83
+
84
+ Returns:
85
+ State: Final state with all results and error messages if any.
86
+ """
87
+ state: State = {
88
+ 'user_input': user_input,
89
+ 'plots': [],
90
+ 'outputs': {},
91
+ 'error': ''
92
+ }
93
+
94
+ llm = get_llm(provider="openai")
95
+ plots = await find_relevant_plots(state, llm, DRIAS_PLOTS)
96
+
97
+ if not plots:
98
+ state['error'] = 'There is no plot to answer to the question'
99
+ return state
100
+
101
+ plots = plots[:2] # limit to 2 types of plots
102
+ state['plots'] = plots
103
+
104
+ errors = {
105
+ 'have_relevant_table': False,
106
+ 'have_sql_query': False,
107
+ 'have_dataframe': False
108
+ }
109
+ outputs = {}
110
+
111
+ # Find relevant tables for each plot and prepare outputs
112
+ for plot_name in plots:
113
+ plot = next((p for p in DRIAS_PLOTS if p['name'] == plot_name), None)
114
+ if plot is None:
115
+ continue
116
+
117
+ relevant_tables = await find_relevant_tables_per_plot(state, plot, llm, DRIAS_TABLES)
118
+ if relevant_tables:
119
+ errors['have_relevant_table'] = True
120
+
121
+ for table in relevant_tables:
122
+ output_title = f"{plot['short_name']} - {' '.join(table.capitalize().split('_'))}"
123
+ outputs[output_title] = {
124
+ 'table': table,
125
+ 'plot': plot,
126
+ 'status': 'OK'
127
+ }
128
+
129
+ # Gather all required parameters
130
+ params = {}
131
+ for param_name in DRIAS_PLOT_PARAMETERS:
132
+ param = await find_param(state, param_name, mode='DRIAS')
133
+ if param:
134
+ params.update(param)
135
+
136
+ # Process all outputs in parallel using process_output
137
+ tasks = [
138
+ process_output(output_title, output['table'], output['plot'], params.copy())
139
+ for output_title, output in outputs.items()
140
+ ]
141
+ results = await asyncio.gather(*tasks)
142
+
143
+ # Update outputs with results and error flags
144
+ for output_title, task_results, task_errors in results:
145
+ outputs[output_title]['sql_query'] = task_results['sql_query']
146
+ outputs[output_title]['dataframe'] = task_results['dataframe']
147
+ outputs[output_title]['figure'] = task_results['figure']
148
+ outputs[output_title]['plot_information'] = task_results['plot_information']
149
+ outputs[output_title]['status'] = task_results['status']
150
+ errors['have_sql_query'] |= task_errors['have_sql_query']
151
+ errors['have_dataframe'] |= task_errors['have_dataframe']
152
+
153
+ state['outputs'] = outputs
154
+
155
+ # Set error messages if needed
156
+ if not errors['have_relevant_table']:
157
+ state['error'] = "There is no relevant table in our database to answer your question"
158
+ elif not errors['have_sql_query']:
159
+ state['error'] = "There is no relevant sql query on our database that can help to answer your question"
160
+ elif not errors['have_dataframe']:
161
+ state['error'] = "There is no data in our table that can answer to your question"
162
+
163
+ return state
climateqa/engine/talk_to_data/workflow/ipcc.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from typing import Any
4
+ import asyncio
5
+ from climateqa.engine.llm import get_llm
6
+ from climateqa.engine.talk_to_data.input_processing import find_param, find_relevant_plots, find_relevant_tables_per_plot
7
+ from climateqa.engine.talk_to_data.query import execute_sql_query, find_indicator_column
8
+ from climateqa.engine.talk_to_data.objects.plot import Plot
9
+ from climateqa.engine.talk_to_data.objects.states import State, TTDOutput
10
+ from climateqa.engine.talk_to_data.ipcc.config import IPCC_TABLES, IPCC_INDICATOR_COLUMNS_PER_TABLE, IPCC_PLOT_PARAMETERS
11
+ from climateqa.engine.talk_to_data.ipcc.plots import IPCC_PLOTS
12
+
13
+ ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))
14
+
15
+ async def process_output(
16
+ output_title: str,
17
+ table: str,
18
+ plot: Plot,
19
+ params: dict[str, Any]
20
+ ) -> tuple[str, TTDOutput, dict[str, bool]]:
21
+ """
22
+ Process a table for a given plot and parameters: builds the SQL query, executes it,
23
+ and generates the corresponding figure.
24
+
25
+ Args:
26
+ output_title (str): Title for the output (used as key in outputs dict).
27
+ table (str): The name of the table to process.
28
+ plot (Plot): The plot object containing SQL query and visualization function.
29
+ params (dict[str, Any]): Parameters used for querying the table.
30
+
31
+ Returns:
32
+ tuple: (output_title, results dict, errors dict)
33
+ """
34
+ results: TTDOutput = {
35
+ 'status': 'OK',
36
+ 'plot': plot,
37
+ 'table': table,
38
+ 'sql_query': None,
39
+ 'dataframe': None,
40
+ 'figure': None,
41
+ 'plot_information': None,
42
+ }
43
+ errors = {
44
+ 'have_sql_query': False,
45
+ 'have_dataframe': False
46
+ }
47
+
48
+ # Find the indicator column for this table
49
+ indicator_column = find_indicator_column(table, IPCC_INDICATOR_COLUMNS_PER_TABLE)
50
+ if indicator_column:
51
+ params['indicator_column'] = indicator_column
52
+
53
+ # Build the SQL query
54
+ sql_query = plot['sql_query'](table, params)
55
+ if not sql_query:
56
+ results['status'] = 'ERROR'
57
+ return output_title, results, errors
58
+
59
+ results['plot_information'] = plot['plot_information'](table, params)
60
+
61
+ results['sql_query'] = sql_query
62
+ errors['have_sql_query'] = True
63
+
64
+ # Execute the SQL query
65
+ df = await execute_sql_query(sql_query)
66
+ if df is not None and not df.empty:
67
+ results['dataframe'] = df
68
+ errors['have_dataframe'] = True
69
+ else:
70
+ results['status'] = 'NO_DATA'
71
+
72
+ # Generate the figure (always, even if df is empty, for consistency)
73
+ results['figure'] = plot['plot_function'](params)
74
+
75
+ return output_title, results, errors
76
+
77
+ async def ipcc_workflow(user_input: str) -> State:
78
+ """
79
+ Performs the complete workflow of Talk To IPCC: from user input to SQL queries, dataframes, and figures.
80
+
81
+ Args:
82
+ user_input (str): The user's question.
83
+
84
+ Returns:
85
+ State: Final state with all the results and error messages if any.
86
+ """
87
+ state: State = {
88
+ 'user_input': user_input,
89
+ 'plots': [],
90
+ 'outputs': {},
91
+ 'error': ''
92
+ }
93
+
94
+ llm = get_llm(provider="openai")
95
+ plots = await find_relevant_plots(state, llm, IPCC_PLOTS)
96
+ state['plots'] = plots
97
+
98
+ if not plots:
99
+ state['error'] = 'There is no plot to answer to the question'
100
+ return state
101
+
102
+ errors = {
103
+ 'have_relevant_table': False,
104
+ 'have_sql_query': False,
105
+ 'have_dataframe': False
106
+ }
107
+ outputs = {}
108
+
109
+ # Find relevant tables for each plot and prepare outputs
110
+ for plot_name in plots:
111
+ plot = next((p for p in IPCC_PLOTS if p['name'] == plot_name), None)
112
+ if plot is None:
113
+ continue
114
+
115
+ relevant_tables = await find_relevant_tables_per_plot(state, plot, llm, IPCC_TABLES)
116
+ if relevant_tables:
117
+ errors['have_relevant_table'] = True
118
+
119
+ for table in relevant_tables:
120
+ output_title = f"{plot['short_name']} - {' '.join(table.capitalize().split('_'))}"
121
+ outputs[output_title] = {
122
+ 'table': table,
123
+ 'plot': plot,
124
+ 'status': 'OK'
125
+ }
126
+
127
+ # Gather all required parameters
128
+ params = {}
129
+ for param_name in IPCC_PLOT_PARAMETERS:
130
+ param = await find_param(state, param_name, mode='IPCC')
131
+ if param:
132
+ params.update(param)
133
+
134
+ # Process all outputs in parallel using process_output
135
+ tasks = [
136
+ process_output(output_title, output['table'], output['plot'], params.copy())
137
+ for output_title, output in outputs.items()
138
+ ]
139
+ results = await asyncio.gather(*tasks)
140
+
141
+ # Update outputs with results and error flags
142
+ for output_title, task_results, task_errors in results:
143
+ outputs[output_title]['sql_query'] = task_results['sql_query']
144
+ outputs[output_title]['dataframe'] = task_results['dataframe']
145
+ outputs[output_title]['figure'] = task_results['figure']
146
+ outputs[output_title]['plot_information'] = task_results['plot_information']
147
+ outputs[output_title]['status'] = task_results['status']
148
+ errors['have_sql_query'] |= task_errors['have_sql_query']
149
+ errors['have_dataframe'] |= task_errors['have_dataframe']
150
+
151
+ state['outputs'] = outputs
152
+
153
+ # Set error messages if needed
154
+ if not errors['have_relevant_table']:
155
+ state['error'] = "There is no relevant table in our database to answer your question"
156
+ elif not errors['have_sql_query']:
157
+ state['error'] = "There is no relevant sql query on our database that can help to answer your question"
158
+ elif not errors['have_dataframe']:
159
+ state['error'] = "There is no data in our table that can answer to your question"
160
+
161
+ return state
front/tabs/tab_drias.py CHANGED
@@ -4,26 +4,25 @@ import os
4
  import pandas as pd
5
 
6
  from climateqa.engine.talk_to_data.main import ask_drias
7
- from climateqa.engine.talk_to_data.config import DRIAS_MODELS, DRIAS_UI_TEXT
8
 
9
  class DriasUIElements(TypedDict):
10
  tab: gr.Tab
11
  details_accordion: gr.Accordion
12
  examples_hidden: gr.Textbox
13
  examples: gr.Examples
 
14
  drias_direct_question: gr.Textbox
15
  result_text: gr.Textbox
16
- table_names_display: gr.DataFrame
17
  query_accordion: gr.Accordion
18
  drias_sql_query: gr.Textbox
19
  chart_accordion: gr.Accordion
 
20
  model_selection: gr.Dropdown
21
  drias_display: gr.Plot
22
  table_accordion: gr.Accordion
23
  drias_table: gr.DataFrame
24
- pagination_display: gr.Markdown
25
- prev_button: gr.Button
26
- next_button: gr.Button
27
 
28
 
29
  async def ask_drias_query(query: str, index_state: int, user_id: str):
@@ -31,7 +30,7 @@ async def ask_drias_query(query: str, index_state: int, user_id: str):
31
  return result
32
 
33
 
34
- def show_results(sql_queries_state, dataframes_state, plots_state):
35
  if not sql_queries_state or not dataframes_state or not plots_state:
36
  # If all results are empty, show "No result"
37
  return (
@@ -40,9 +39,6 @@ def show_results(sql_queries_state, dataframes_state, plots_state):
40
  gr.update(visible=False),
41
  gr.update(visible=False),
42
  gr.update(visible=False),
43
- gr.update(visible=False),
44
- gr.update(visible=False),
45
- gr.update(visible=False),
46
  )
47
  else:
48
  # Show the appropriate components with their data
@@ -51,10 +47,7 @@ def show_results(sql_queries_state, dataframes_state, plots_state):
51
  gr.update(visible=True),
52
  gr.update(visible=True),
53
  gr.update(visible=True),
54
- gr.update(visible=True),
55
- gr.update(visible=True),
56
- gr.update(visible=True),
57
- gr.update(visible=True),
58
  )
59
 
60
 
@@ -72,44 +65,14 @@ def filter_by_model(dataframes, figures, index_state, model_selection):
72
  return df, figure
73
 
74
 
75
- def update_pagination(index, sql_queries):
76
- pagination = f"{index + 1}/{len(sql_queries)}" if sql_queries else ""
77
- return pagination
78
-
79
-
80
- def show_previous(index, sql_queries, dataframes, plots):
81
- if index > 0:
82
- index -= 1
83
- return (
84
- sql_queries[index],
85
- dataframes[index],
86
- plots[index](dataframes[index]),
87
- index,
88
- )
89
-
90
-
91
- def show_next(index, sql_queries, dataframes, plots):
92
- if index < len(sql_queries) - 1:
93
- index += 1
94
- return (
95
- sql_queries[index],
96
- dataframes[index],
97
- plots[index](dataframes[index]),
98
- index,
99
- )
100
-
101
-
102
- def display_table_names(table_names):
103
- return [table_names]
104
-
105
-
106
- def on_table_click(evt: gr.SelectData, table_names, sql_queries, dataframes, plots):
107
- index = evt.index[1]
108
  figure = plots[index](dataframes[index])
109
  return (
110
  sql_queries[index],
111
  dataframes[index],
112
  figure,
 
113
  index,
114
  )
115
 
@@ -117,7 +80,7 @@ def on_table_click(evt: gr.SelectData, table_names, sql_queries, dataframes, plo
117
  def create_drias_ui() -> DriasUIElements:
118
  """Create and return all UI elements for the DRIAS tab."""
119
  with gr.Tab("France - Talk to DRIAS", elem_id="tab-vanna", id=6) as tab:
120
- with gr.Accordion(label="Details") as details_accordion:
121
  gr.Markdown(DRIAS_UI_TEXT)
122
 
123
  # Add examples for common questions
@@ -141,24 +104,43 @@ def create_drias_ui() -> DriasUIElements:
141
  elem_id="direct-question",
142
  interactive=True,
143
  )
 
 
 
 
 
 
 
 
 
144
 
145
  result_text = gr.Textbox(
146
  label="", elem_id="no-result-label", interactive=False, visible=True
147
  )
148
-
149
- table_names_display = gr.DataFrame(
150
- [], label="List of relevant indicators", headers=None, interactive=False, elem_id="table-names", visible=False
151
- )
152
-
153
- with gr.Accordion(label="SQL Query Used", visible=False) as query_accordion:
154
- drias_sql_query = gr.Textbox(
155
- label="", elem_id="sql-query", interactive=False
156
  )
157
 
 
 
 
 
 
 
158
  with gr.Accordion(label="Chart", visible=False) as chart_accordion:
159
- model_selection = gr.Dropdown(
160
- label="Model", choices=DRIAS_MODELS, value="ALL", interactive=True
161
- )
 
 
 
 
162
  drias_display = gr.Plot(elem_id="vanna-plot")
163
 
164
  with gr.Accordion(
@@ -166,32 +148,23 @@ def create_drias_ui() -> DriasUIElements:
166
  ) as table_accordion:
167
  drias_table = gr.DataFrame([], elem_id="vanna-table")
168
 
169
- pagination_display = gr.Markdown(
170
- value="", visible=False, elem_id="pagination-display"
171
- )
172
-
173
- with gr.Row():
174
- prev_button = gr.Button("Previous", visible=False)
175
- next_button = gr.Button("Next", visible=False)
176
-
177
  return DriasUIElements(
178
  tab=tab,
179
  details_accordion=details_accordion,
180
  examples_hidden=examples_hidden,
181
  examples=examples,
 
182
  drias_direct_question=drias_direct_question,
183
  result_text=result_text,
184
  table_names_display=table_names_display,
185
  query_accordion=query_accordion,
186
  drias_sql_query=drias_sql_query,
187
  chart_accordion=chart_accordion,
 
188
  model_selection=model_selection,
189
  drias_display=drias_display,
190
  table_accordion=table_accordion,
191
  drias_table=drias_table,
192
- pagination_display=pagination_display,
193
- prev_button=prev_button,
194
- next_button=next_button
195
  )
196
 
197
 
@@ -202,94 +175,56 @@ def setup_drias_events(ui_elements: DriasUIElements, share_client=None, user_id=
202
  sql_queries_state = gr.State([])
203
  dataframes_state = gr.State([])
204
  plots_state = gr.State([])
 
205
  index_state = gr.State(0)
206
  table_names_list = gr.State([])
207
  user_id = gr.State(user_id)
208
 
 
 
 
 
 
 
 
209
  # Handle example selection
210
  ui_elements["examples_hidden"].change(
211
  lambda x: (gr.Accordion(open=False), gr.Textbox(value=x)),
212
  inputs=[ui_elements["examples_hidden"]],
213
  outputs=[ui_elements["details_accordion"], ui_elements["drias_direct_question"]]
214
  ).then(
215
- ask_drias_query,
216
- inputs=[ui_elements["examples_hidden"], index_state, user_id],
217
- outputs=[
218
- ui_elements["drias_sql_query"],
219
- ui_elements["drias_table"],
220
- ui_elements["drias_display"],
221
- sql_queries_state,
222
- dataframes_state,
223
- plots_state,
224
- index_state,
225
- table_names_list,
226
- ui_elements["result_text"],
227
- ],
228
- ).then(
229
- show_results,
230
- inputs=[sql_queries_state, dataframes_state, plots_state],
231
- outputs=[
232
- ui_elements["result_text"],
233
- ui_elements["query_accordion"],
234
- ui_elements["table_accordion"],
235
- ui_elements["chart_accordion"],
236
- ui_elements["prev_button"],
237
- ui_elements["next_button"],
238
- ui_elements["pagination_display"],
239
- ui_elements["table_names_display"],
240
- ],
241
- ).then(
242
- update_pagination,
243
- inputs=[index_state, sql_queries_state],
244
- outputs=[ui_elements["pagination_display"]],
245
- ).then(
246
- display_table_names,
247
- inputs=[table_names_list],
248
- outputs=[ui_elements["table_names_display"]],
249
- )
250
-
251
- # Handle direct question submission
252
- ui_elements["drias_direct_question"].submit(
253
- lambda: gr.Accordion(open=False),
254
  inputs=None,
255
- outputs=[ui_elements["details_accordion"]]
256
  ).then(
257
  ask_drias_query,
258
- inputs=[ui_elements["drias_direct_question"], index_state, user_id],
259
  outputs=[
260
  ui_elements["drias_sql_query"],
261
  ui_elements["drias_table"],
262
  ui_elements["drias_display"],
 
263
  sql_queries_state,
264
  dataframes_state,
265
  plots_state,
 
266
  index_state,
267
  table_names_list,
268
  ui_elements["result_text"],
269
  ],
270
  ).then(
271
  show_results,
272
- inputs=[sql_queries_state, dataframes_state, plots_state],
273
  outputs=[
274
  ui_elements["result_text"],
275
  ui_elements["query_accordion"],
276
  ui_elements["table_accordion"],
277
  ui_elements["chart_accordion"],
278
- ui_elements["prev_button"],
279
- ui_elements["next_button"],
280
- ui_elements["pagination_display"],
281
  ui_elements["table_names_display"],
282
  ],
283
- ).then(
284
- update_pagination,
285
- inputs=[index_state, sql_queries_state],
286
- outputs=[ui_elements["pagination_display"]],
287
- ).then(
288
- display_table_names,
289
- inputs=[table_names_list],
290
- outputs=[ui_elements["table_names_display"]],
291
  )
292
 
 
293
  # Handle model selection change
294
  ui_elements["model_selection"].change(
295
  filter_by_model,
@@ -297,36 +232,12 @@ def setup_drias_events(ui_elements: DriasUIElements, share_client=None, user_id=
297
  outputs=[ui_elements["drias_table"], ui_elements["drias_display"]],
298
  )
299
 
300
- # Handle pagination buttons
301
- ui_elements["prev_button"].click(
302
- show_previous,
303
- inputs=[index_state, sql_queries_state, dataframes_state, plots_state],
304
- outputs=[ui_elements["drias_sql_query"], ui_elements["drias_table"], ui_elements["drias_display"], index_state],
305
- ).then(
306
- update_pagination,
307
- inputs=[index_state, sql_queries_state],
308
- outputs=[ui_elements["pagination_display"]],
309
- )
310
-
311
- ui_elements["next_button"].click(
312
- show_next,
313
- inputs=[index_state, sql_queries_state, dataframes_state, plots_state],
314
- outputs=[ui_elements["drias_sql_query"], ui_elements["drias_table"], ui_elements["drias_display"], index_state],
315
- ).then(
316
- update_pagination,
317
- inputs=[index_state, sql_queries_state],
318
- outputs=[ui_elements["pagination_display"]],
319
- )
320
 
321
  # Handle table selection
322
- ui_elements["table_names_display"].select(
323
  fn=on_table_click,
324
- inputs=[table_names_list, sql_queries_state, dataframes_state, plots_state],
325
- outputs=[ui_elements["drias_sql_query"], ui_elements["drias_table"], ui_elements["drias_display"], index_state],
326
- ).then(
327
- update_pagination,
328
- inputs=[index_state, sql_queries_state],
329
- outputs=[ui_elements["pagination_display"]],
330
  )
331
 
332
  def create_drias_tab(share_client=None, user_id=None):
 
4
  import pandas as pd
5
 
6
  from climateqa.engine.talk_to_data.main import ask_drias
7
+ from climateqa.engine.talk_to_data.drias.config import DRIAS_MODELS, DRIAS_UI_TEXT
8
 
9
  class DriasUIElements(TypedDict):
10
  tab: gr.Tab
11
  details_accordion: gr.Accordion
12
  examples_hidden: gr.Textbox
13
  examples: gr.Examples
14
+ image_examples: gr.Row
15
  drias_direct_question: gr.Textbox
16
  result_text: gr.Textbox
17
+ table_names_display: gr.Radio
18
  query_accordion: gr.Accordion
19
  drias_sql_query: gr.Textbox
20
  chart_accordion: gr.Accordion
21
+ plot_information: gr.Markdown
22
  model_selection: gr.Dropdown
23
  drias_display: gr.Plot
24
  table_accordion: gr.Accordion
25
  drias_table: gr.DataFrame
 
 
 
26
 
27
 
28
  async def ask_drias_query(query: str, index_state: int, user_id: str):
 
30
  return result
31
 
32
 
33
+ def show_results(sql_queries_state, dataframes_state, plots_state, table_names):
34
  if not sql_queries_state or not dataframes_state or not plots_state:
35
  # If all results are empty, show "No result"
36
  return (
 
39
  gr.update(visible=False),
40
  gr.update(visible=False),
41
  gr.update(visible=False),
 
 
 
42
  )
43
  else:
44
  # Show the appropriate components with their data
 
47
  gr.update(visible=True),
48
  gr.update(visible=True),
49
  gr.update(visible=True),
50
+ gr.update(choices=table_names, value=table_names[0], visible=True),
 
 
 
51
  )
52
 
53
 
 
65
  return df, figure
66
 
67
 
68
+ def on_table_click(selected_label, table_names, sql_queries, dataframes, plot_informations, plots):
69
+ index = table_names.index(selected_label)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  figure = plots[index](dataframes[index])
71
  return (
72
  sql_queries[index],
73
  dataframes[index],
74
  figure,
75
+ plot_informations[index],
76
  index,
77
  )
78
 
 
80
  def create_drias_ui() -> DriasUIElements:
81
  """Create and return all UI elements for the DRIAS tab."""
82
  with gr.Tab("France - Talk to DRIAS", elem_id="tab-vanna", id=6) as tab:
83
+ with gr.Accordion(label="❓ How to use?", elem_id="details") as details_accordion:
84
  gr.Markdown(DRIAS_UI_TEXT)
85
 
86
  # Add examples for common questions
 
104
  elem_id="direct-question",
105
  interactive=True,
106
  )
107
+
108
+
109
+ with gr.Row(visible=True, elem_id="example-img-container") as image_examples:
110
+ gr.Markdown("### Examples of possible visualizations")
111
+
112
+ with gr.Row():
113
+ gr.Image("./front/assets/talk_to_drias_winter_temp_paris_example.png", label="Evolution of Mean Winter Temperature in Paris", elem_classes=["example-img"])
114
+ gr.Image("./front/assets/talk_to_drias_annual_temperature_france_example.png", label="Mean Annual Temperature in 2030 in France", elem_classes=["example-img"])
115
+ gr.Image("./front/assets/talk_to_drias_frequency_remarkable_precipitation_lyon_example.png", label="Frequency of Remarkable Daily Precipitation in Lyon", elem_classes=["example-img"])
116
 
117
  result_text = gr.Textbox(
118
  label="", elem_id="no-result-label", interactive=False, visible=True
119
  )
120
+
121
+ with gr.Row():
122
+ table_names_display = gr.Radio(
123
+ choices=[],
124
+ label="Relevant figures created",
125
+ interactive=True,
126
+ elem_id="table-names",
127
+ visible=False
128
  )
129
 
130
+ with gr.Accordion(label="SQL Query Used", visible=False) as query_accordion:
131
+ drias_sql_query = gr.Textbox(
132
+ label="", elem_id="sql-query", interactive=False
133
+ )
134
+
135
+
136
  with gr.Accordion(label="Chart", visible=False) as chart_accordion:
137
+ with gr.Row():
138
+ model_selection = gr.Dropdown(
139
+ label="Model", choices=DRIAS_MODELS, value="ALL", interactive=True
140
+ )
141
+ with gr.Accordion(label="Informations about the plot", open=False):
142
+ plot_information = gr.Markdown(value = "")
143
+
144
  drias_display = gr.Plot(elem_id="vanna-plot")
145
 
146
  with gr.Accordion(
 
148
  ) as table_accordion:
149
  drias_table = gr.DataFrame([], elem_id="vanna-table")
150
 
 
 
 
 
 
 
 
 
151
  return DriasUIElements(
152
  tab=tab,
153
  details_accordion=details_accordion,
154
  examples_hidden=examples_hidden,
155
  examples=examples,
156
+ image_examples=image_examples,
157
  drias_direct_question=drias_direct_question,
158
  result_text=result_text,
159
  table_names_display=table_names_display,
160
  query_accordion=query_accordion,
161
  drias_sql_query=drias_sql_query,
162
  chart_accordion=chart_accordion,
163
+ plot_information=plot_information,
164
  model_selection=model_selection,
165
  drias_display=drias_display,
166
  table_accordion=table_accordion,
167
  drias_table=drias_table,
 
 
 
168
  )
169
 
170
 
 
175
  sql_queries_state = gr.State([])
176
  dataframes_state = gr.State([])
177
  plots_state = gr.State([])
178
+ plot_informations_state = gr.State([])
179
  index_state = gr.State(0)
180
  table_names_list = gr.State([])
181
  user_id = gr.State(user_id)
182
 
183
+ # Handle direct question submission - trigger the same workflow by setting examples_hidden
184
+ ui_elements["drias_direct_question"].submit(
185
+ lambda x: gr.update(value=x),
186
+ inputs=[ui_elements["drias_direct_question"]],
187
+ outputs=[ui_elements["examples_hidden"]],
188
+ )
189
+
190
  # Handle example selection
191
  ui_elements["examples_hidden"].change(
192
  lambda x: (gr.Accordion(open=False), gr.Textbox(value=x)),
193
  inputs=[ui_elements["examples_hidden"]],
194
  outputs=[ui_elements["details_accordion"], ui_elements["drias_direct_question"]]
195
  ).then(
196
+ lambda : gr.update(visible=False),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  inputs=None,
198
+ outputs=ui_elements["image_examples"]
199
  ).then(
200
  ask_drias_query,
201
+ inputs=[ui_elements["examples_hidden"], index_state, user_id],
202
  outputs=[
203
  ui_elements["drias_sql_query"],
204
  ui_elements["drias_table"],
205
  ui_elements["drias_display"],
206
+ ui_elements["plot_information"],
207
  sql_queries_state,
208
  dataframes_state,
209
  plots_state,
210
+ plot_informations_state,
211
  index_state,
212
  table_names_list,
213
  ui_elements["result_text"],
214
  ],
215
  ).then(
216
  show_results,
217
+ inputs=[sql_queries_state, dataframes_state, plots_state, table_names_list],
218
  outputs=[
219
  ui_elements["result_text"],
220
  ui_elements["query_accordion"],
221
  ui_elements["table_accordion"],
222
  ui_elements["chart_accordion"],
 
 
 
223
  ui_elements["table_names_display"],
224
  ],
 
 
 
 
 
 
 
 
225
  )
226
 
227
+
228
  # Handle model selection change
229
  ui_elements["model_selection"].change(
230
  filter_by_model,
 
232
  outputs=[ui_elements["drias_table"], ui_elements["drias_display"]],
233
  )
234
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
  # Handle table selection
237
+ ui_elements["table_names_display"].change(
238
  fn=on_table_click,
239
+ inputs=[ui_elements["table_names_display"], table_names_list, sql_queries_state, dataframes_state, plot_informations_state, plots_state],
240
+ outputs=[ui_elements["drias_sql_query"], ui_elements["drias_table"], ui_elements["drias_display"], ui_elements["plot_information"], index_state],
 
 
 
 
241
  )
242
 
243
  def create_drias_tab(share_client=None, user_id=None):
front/tabs/tab_ipcc.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from random import choices
2
+ import gradio as gr
3
+ from typing import TypedDict
4
+ from climateqa.engine.talk_to_data.main import ask_ipcc
5
+ from climateqa.engine.talk_to_data.ipcc.config import IPCC_MODELS, IPCC_SCENARIO, IPCC_UI_TEXT
6
+ import uuid
7
+
8
+ class ipccUIElements(TypedDict):
9
+ tab: gr.Tab
10
+ details_accordion: gr.Accordion
11
+ examples_hidden: gr.Textbox
12
+ examples: gr.Examples
13
+ image_examples: gr.Row
14
+ ipcc_direct_question: gr.Textbox
15
+ result_text: gr.Textbox
16
+ table_names_display: gr.Radio
17
+ query_accordion: gr.Accordion
18
+ ipcc_sql_query: gr.Textbox
19
+ chart_accordion: gr.Accordion
20
+ plot_information: gr.Markdown
21
+ scenario_selection: gr.Dropdown
22
+ ipcc_display: gr.Plot
23
+ table_accordion: gr.Accordion
24
+ ipcc_table: gr.DataFrame
25
+
26
+
27
+ async def ask_ipcc_query(query: str, index_state: int, user_id: str):
28
+ result = await ask_ipcc(query, index_state, user_id)
29
+ return result
30
+
31
+ def hide_outputs():
32
+ """Hide all outputs initially."""
33
+ return (
34
+ gr.update(visible=True), # Show the result text
35
+ gr.update(visible=False), # Hide the query accordion
36
+ gr.update(visible=False), # Hide the table accordion
37
+ gr.update(visible=False), # Hide the chart accordion
38
+ gr.update(visible=False), # Hide table names
39
+ )
40
+
41
+ def show_results(sql_queries_state, dataframes_state, plots_state, table_names):
42
+ if not sql_queries_state or not dataframes_state or not plots_state:
43
+ # If all results are empty, show "No result"
44
+ return (
45
+ gr.update(visible=True),
46
+ gr.update(visible=False),
47
+ gr.update(visible=False),
48
+ gr.update(visible=False),
49
+ gr.update(visible=False),
50
+ )
51
+ else:
52
+ # Show the appropriate components with their data
53
+ return (
54
+ gr.update(visible=False),
55
+ gr.update(visible=True),
56
+ gr.update(visible=True),
57
+ gr.update(visible=True),
58
+ gr.update(choices=table_names, value=table_names[0], visible=True),
59
+ )
60
+
61
+
62
+ def show_filter_by_scenario(table_names, index_state, dataframes):
63
+ if len(table_names) > 0 and table_names[index_state].startswith("Map"):
64
+ df = dataframes[index_state]
65
+ scenarios = sorted(df["scenario"].unique())
66
+ return gr.update(visible=True, choices=scenarios, value=scenarios[0])
67
+ else:
68
+ return gr.update(visible=False)
69
+
70
+ def filter_by_scenario(dataframes, figures, table_names, index_state, scenario):
71
+ df = dataframes[index_state]
72
+ if not table_names[index_state].startswith("Map"):
73
+ return df, figures[index_state](df)
74
+ if df.empty:
75
+ return df, None
76
+ if "scenario" not in df.columns:
77
+ return df, figures[index_state](df)
78
+ else:
79
+ df = df[df["scenario"] == scenario]
80
+ if df.empty:
81
+ return df, None
82
+ figure = figures[index_state](df)
83
+ return df, figure
84
+
85
+
86
+ def display_table_names(table_names, index_state):
87
+ return [
88
+ [name]
89
+ for name in table_names
90
+ ]
91
+
92
+ def on_table_click(selected_label, table_names, sql_queries, dataframes, plot_informations, plots):
93
+ index = table_names.index(selected_label)
94
+ figure = plots[index](dataframes[index])
95
+
96
+ return (
97
+ sql_queries[index],
98
+ dataframes[index],
99
+ figure,
100
+ plot_informations[index],
101
+ index,
102
+ )
103
+
104
+
105
+ def create_ipcc_ui() -> ipccUIElements:
106
+
107
+ """Create and return all UI elements for the ipcc tab."""
108
+ with gr.Tab("(Beta) Talk to IPCC", elem_id="tab-vanna", id=7) as tab:
109
+ with gr.Accordion(label="❓ How to use?", elem_id="details") as details_accordion:
110
+ gr.Markdown(IPCC_UI_TEXT)
111
+
112
+ # Add examples for common questions
113
+ examples_hidden = gr.Textbox(visible=False, elem_id="ipcc-examples-hidden")
114
+ examples = gr.Examples(
115
+ examples=[
116
+ ["What will the temperature be like in Paris?"],
117
+ ["What will be the total rainfall in the USA in 2030?"],
118
+ ["How will the average temperature evolve in China?"],
119
+ ["What will be the average total precipitation in London ?"]
120
+ ],
121
+ label="Example Questions",
122
+ inputs=[examples_hidden],
123
+ outputs=[examples_hidden],
124
+ )
125
+
126
+ with gr.Row():
127
+ ipcc_direct_question = gr.Textbox(
128
+ label="Direct Question",
129
+ placeholder="You can write direct question here",
130
+ elem_id="direct-question",
131
+ interactive=True,
132
+ )
133
+
134
+ with gr.Row(visible=True, elem_id="example-img-container") as image_examples:
135
+ gr.Markdown("### Examples of possible visualizations")
136
+
137
+ with gr.Row():
138
+ gr.Image("./front/assets/talk_to_ipcc_france_example.png", label="Total Precipitation in 2030 in France", elem_classes=["example-img"])
139
+ gr.Image("./front/assets/talk_to_ipcc_new_york_example.png", label="Yearly Evolution of Mean Temperature in New York (Historical + SSP Scenarios)", elem_classes=["example-img"])
140
+ gr.Image("./front/assets/talk_to_ipcc_china_example.png", label="Mean Temperature in 2050 in China", elem_classes=["example-img"])
141
+
142
+ result_text = gr.Textbox(
143
+ label="", elem_id="no-result-label", interactive=False, visible=True
144
+ )
145
+ with gr.Row():
146
+ table_names_display = gr.Radio(
147
+ choices=[],
148
+ label="Relevant figures created",
149
+ interactive=True,
150
+ elem_id="table-names",
151
+ visible=False
152
+ )
153
+
154
+ with gr.Accordion(label="SQL Query Used", visible=False) as query_accordion:
155
+ ipcc_sql_query = gr.Textbox(
156
+ label="", elem_id="sql-query", interactive=False
157
+ )
158
+
159
+ with gr.Accordion(label="Chart", visible=False) as chart_accordion:
160
+
161
+ with gr.Row():
162
+ scenario_selection = gr.Dropdown(
163
+ label="Scenario", choices=IPCC_SCENARIO, value=IPCC_SCENARIO[0], interactive=True, visible=False
164
+ )
165
+
166
+ with gr.Accordion(label="Informations about the plot", open=False):
167
+ plot_information = gr.Markdown(value = "")
168
+
169
+ ipcc_display = gr.Plot(elem_id="vanna-plot")
170
+
171
+ with gr.Accordion(
172
+ label="Data used", open=False, visible=False
173
+ ) as table_accordion:
174
+ ipcc_table = gr.DataFrame([], elem_id="vanna-table")
175
+
176
+
177
+ return ipccUIElements(
178
+ tab=tab,
179
+ details_accordion=details_accordion,
180
+ examples_hidden=examples_hidden,
181
+ examples=examples,
182
+ image_examples=image_examples,
183
+ ipcc_direct_question=ipcc_direct_question,
184
+ result_text=result_text,
185
+ table_names_display=table_names_display,
186
+ query_accordion=query_accordion,
187
+ ipcc_sql_query=ipcc_sql_query,
188
+ chart_accordion=chart_accordion,
189
+ plot_information=plot_information,
190
+ scenario_selection=scenario_selection,
191
+ ipcc_display=ipcc_display,
192
+ table_accordion=table_accordion,
193
+ ipcc_table=ipcc_table,
194
+ )
195
+
196
+
197
+
198
+ def setup_ipcc_events(ui_elements: ipccUIElements, share_client=None, user_id=None) -> None:
199
+ """Set up all event handlers for the ipcc tab."""
200
+ # Create state variables
201
+ sql_queries_state = gr.State([])
202
+ dataframes_state = gr.State([])
203
+ plots_state = gr.State([])
204
+ plot_informations_state = gr.State([])
205
+ index_state = gr.State(0)
206
+ table_names_list = gr.State([])
207
+ user_id = gr.State(user_id)
208
+
209
+ # Handle direct question submission - trigger the same workflow by setting examples_hidden
210
+ ui_elements["ipcc_direct_question"].submit(
211
+ lambda x: gr.update(value=x),
212
+ inputs=[ui_elements["ipcc_direct_question"]],
213
+ outputs=[ui_elements["examples_hidden"]],
214
+ )
215
+
216
+ # Handle example selection
217
+ ui_elements["examples_hidden"].change(
218
+ lambda x: (gr.Accordion(open=False), gr.Textbox(value=x)),
219
+ inputs=[ui_elements["examples_hidden"]],
220
+ outputs=[ui_elements["details_accordion"], ui_elements["ipcc_direct_question"]]
221
+ ).then(
222
+ lambda : gr.update(visible=False),
223
+ inputs=None,
224
+ outputs=ui_elements["image_examples"]
225
+ ).then(
226
+ hide_outputs,
227
+ inputs=None,
228
+ outputs=[
229
+ ui_elements["result_text"],
230
+ ui_elements["query_accordion"],
231
+ ui_elements["table_accordion"],
232
+ ui_elements["chart_accordion"],
233
+ ui_elements["table_names_display"],
234
+ ]
235
+ ).then(
236
+ ask_ipcc_query,
237
+ inputs=[ui_elements["examples_hidden"], index_state, user_id],
238
+ outputs=[
239
+ ui_elements["ipcc_sql_query"],
240
+ ui_elements["ipcc_table"],
241
+ ui_elements["ipcc_display"],
242
+ ui_elements["plot_information"],
243
+ sql_queries_state,
244
+ dataframes_state,
245
+ plots_state,
246
+ plot_informations_state,
247
+ index_state,
248
+ table_names_list,
249
+ ui_elements["result_text"],
250
+ ],
251
+ ).then(
252
+ show_results,
253
+ inputs=[sql_queries_state, dataframes_state, plots_state, table_names_list],
254
+ outputs=[
255
+ ui_elements["result_text"],
256
+ ui_elements["query_accordion"],
257
+ ui_elements["table_accordion"],
258
+ ui_elements["chart_accordion"],
259
+ ui_elements["table_names_display"],
260
+ ],
261
+ ).then(
262
+ show_filter_by_scenario,
263
+ inputs=[table_names_list, index_state, dataframes_state],
264
+ outputs=[ui_elements["scenario_selection"]],
265
+ ).then(
266
+ filter_by_scenario,
267
+ inputs=[dataframes_state, plots_state, table_names_list, index_state, ui_elements["scenario_selection"]],
268
+ outputs=[ui_elements["ipcc_table"], ui_elements["ipcc_display"]],
269
+ )
270
+
271
+
272
+ # Handle model selection change
273
+ ui_elements["scenario_selection"].change(
274
+ filter_by_scenario,
275
+ inputs=[dataframes_state, plots_state, table_names_list, index_state, ui_elements["scenario_selection"]],
276
+ outputs=[ui_elements["ipcc_table"], ui_elements["ipcc_display"]],
277
+ )
278
+
279
+ # Handle table selection
280
+ ui_elements["table_names_display"].change(
281
+ fn=on_table_click,
282
+ inputs=[ui_elements["table_names_display"], table_names_list, sql_queries_state, dataframes_state, plot_informations_state, plots_state],
283
+ outputs=[ui_elements["ipcc_sql_query"], ui_elements["ipcc_table"], ui_elements["ipcc_display"], ui_elements["plot_information"], index_state],
284
+ ).then(
285
+ show_filter_by_scenario,
286
+ inputs=[table_names_list, index_state, dataframes_state],
287
+ outputs=[ui_elements["scenario_selection"]],
288
+ ).then(
289
+ filter_by_scenario,
290
+ inputs=[dataframes_state, plots_state, table_names_list, index_state, ui_elements["scenario_selection"]],
291
+ outputs=[ui_elements["ipcc_table"], ui_elements["ipcc_display"]],
292
+ )
293
+
294
+
295
+ def create_ipcc_tab(share_client=None, user_id=None):
296
+ """Create the ipcc tab with all its components and event handlers."""
297
+ ui_elements = create_ipcc_ui()
298
+ setup_ipcc_events(ui_elements, share_client=share_client, user_id=user_id)
299
+
300
+
requirements.txt CHANGED
@@ -25,4 +25,5 @@ geopy==2.4.1
25
  duckdb==1.2.1
26
  openai==1.61.1
27
  pydantic==2.9.2
28
- pydantic-settings==2.2.1
 
 
25
  duckdb==1.2.1
26
  openai==1.61.1
27
  pydantic==2.9.2
28
+ pydantic-settings==2.2.1
29
+ geojson==3.2.0
style.css CHANGED
@@ -656,12 +656,11 @@ a {
656
  /* overflow-y: scroll; */
657
  }
658
  #sql-query{
659
- max-height: 300px;
660
- overflow-y:scroll;
661
  }
662
 
663
  #sql-query textarea{
664
- min-height: 100px !important;
665
  }
666
 
667
  #sql-query span{
@@ -671,8 +670,11 @@ div#tab-vanna{
671
  max-height: 100¨vh;
672
  overflow-y: hidden;
673
  }
 
 
 
674
  #vanna-plot{
675
- max-height:500px
676
  }
677
 
678
  #pagination-display{
@@ -681,13 +683,33 @@ div#tab-vanna{
681
  font-size: 16px;
682
  }
683
 
684
- #table-names table{
685
- overflow: hidden;
 
 
 
 
 
 
 
 
 
 
686
  }
687
- #table-names thead{
 
 
 
 
 
688
  display: none;
689
  }
690
 
 
 
 
 
 
691
  /* DRIAS Data Table Styles */
692
  #vanna-table {
693
  height: 400px !important;
@@ -710,3 +732,13 @@ div#tab-vanna{
710
  background: white;
711
  z-index: 1;
712
  }
 
 
 
 
 
 
 
 
 
 
 
656
  /* overflow-y: scroll; */
657
  }
658
  #sql-query{
659
+ max-height: 100%;
 
660
  }
661
 
662
  #sql-query textarea{
663
+ min-height: 200px !important;
664
  }
665
 
666
  #sql-query span{
 
670
  max-height: 100¨vh;
671
  overflow-y: hidden;
672
  }
673
+ #details button span{
674
+ font-weight: bold;
675
+ }
676
  #vanna-plot{
677
+ max-height:1000px
678
  }
679
 
680
  #pagination-display{
 
683
  font-size: 16px;
684
  }
685
 
686
+
687
+ #table-names label {
688
+ display: block;
689
+ width: 100%;
690
+ box-sizing: border-box;
691
+ padding: 8px 12px;
692
+ margin-bottom: 4px;
693
+ border: 1px solid #ccc;
694
+ border-radius: 6px;
695
+ background-color: white;
696
+ cursor: pointer;
697
+ text-align: center;
698
  }
699
+
700
+ #table-names label:hover {
701
+ background-color: #f0f8ff;
702
+ }
703
+
704
+ #table-names input[type="radio"] {
705
  display: none;
706
  }
707
 
708
+ #table-names input[type="radio"]:checked + label {
709
+ background-color: #d0eaff;
710
+ border-color: #2196f3;
711
+ }
712
+
713
  /* DRIAS Data Table Styles */
714
  #vanna-table {
715
  height: 400px !important;
 
732
  background: white;
733
  z-index: 1;
734
  }
735
+
736
+ .example-img{
737
+ height: 250px;
738
+ object-fit: contain;
739
+ }
740
+
741
+ #example-img-container {
742
+ flex-direction: column;
743
+ align-items: left;
744
+ }