Commit
·
1eae86b
1
Parent(s):
1700186
feat: added drias model choice and changed TTD UI
Browse files- app.py +26 -15
- climateqa/engine/talk_to_data/main.py +22 -3
- climateqa/engine/talk_to_data/plot.py +5 -4
- climateqa/engine/talk_to_data/sql_query.py +13 -4
- climateqa/engine/talk_to_data/utils.py +36 -0
- climateqa/engine/talk_to_data/workflow.py +3 -2
- style.css +15 -5
app.py
CHANGED
@@ -12,7 +12,7 @@ from climateqa.engine.reranker import get_reranker
|
|
12 |
from climateqa.engine.graph import make_graph_agent,make_graph_agent_poc
|
13 |
from climateqa.engine.chains.retrieve_papers import find_papers
|
14 |
from climateqa.chat import start_chat, chat_stream, finish_chat
|
15 |
-
from climateqa.engine.talk_to_data.main import ask_drias
|
16 |
from climateqa.engine.talk_to_data.myVanna import MyVanna
|
17 |
|
18 |
from front.tabs import (create_config_modal, create_examples_tab, create_papers_tab, create_figures_tab, create_chat_interface, create_about_tab)
|
@@ -87,8 +87,8 @@ vn.connect_to_sqlite(db_vanna_path)
|
|
87 |
# def ask_vanna_query(query):
|
88 |
# return ask_vanna(vn, db_vanna_path, query)
|
89 |
|
90 |
-
def ask_drias_query(query, index_state):
|
91 |
-
return ask_drias(db_vanna_path, query, index_state)
|
92 |
|
93 |
async def chat(query, history, audience, sources, reports, relevant_content_sources_selection, search_only):
|
94 |
print("chat cqa - message received")
|
@@ -139,27 +139,40 @@ def update_sources_number_display(sources_textbox, figures_cards, current_graphs
|
|
139 |
|
140 |
# vanna_display = gr.Plot()
|
141 |
# vanna_direct_question.submit(ask_drias_query, [vanna_direct_question], [vanna_sql_query ,vanna_table, vanna_display])
|
|
|
142 |
def create_drias_tab():
|
143 |
with gr.Tab("Beta - Talk to DRIAS", elem_id="tab-vanna", id=6):
|
144 |
-
|
|
|
|
|
145 |
|
146 |
-
with gr.Accordion("
|
147 |
-
drias_sql_query = gr.Textbox(label="
|
|
|
|
|
148 |
drias_table = gr.DataFrame([], elem_id="vanna-table")
|
149 |
-
drias_display = gr.Plot()
|
150 |
|
151 |
-
|
|
|
|
|
|
|
152 |
prev_button = gr.Button("Previous")
|
153 |
next_button = gr.Button("Next")
|
154 |
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
|
160 |
drias_direct_question.submit(
|
161 |
ask_drias_query,
|
162 |
-
inputs=[drias_direct_question, index_state],
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
outputs=[drias_sql_query, drias_table, drias_display, sql_queries_state, dataframes_state, plots_state, index_state]
|
164 |
)
|
165 |
|
@@ -184,8 +197,6 @@ def create_drias_tab():
|
|
184 |
inputs=[index_state, sql_queries_state, dataframes_state, plots_state],
|
185 |
outputs=[drias_sql_query, drias_table, drias_display, index_state]
|
186 |
)
|
187 |
-
|
188 |
-
|
189 |
|
190 |
# # UI Layout Components
|
191 |
def cqa_tab(tab_name):
|
|
|
12 |
from climateqa.engine.graph import make_graph_agent,make_graph_agent_poc
|
13 |
from climateqa.engine.chains.retrieve_papers import find_papers
|
14 |
from climateqa.chat import start_chat, chat_stream, finish_chat
|
15 |
+
from climateqa.engine.talk_to_data.main import ask_drias, DRIAS_MODELS
|
16 |
from climateqa.engine.talk_to_data.myVanna import MyVanna
|
17 |
|
18 |
from front.tabs import (create_config_modal, create_examples_tab, create_papers_tab, create_figures_tab, create_chat_interface, create_about_tab)
|
|
|
87 |
# def ask_vanna_query(query):
|
88 |
# return ask_vanna(vn, db_vanna_path, query)
|
89 |
|
90 |
+
def ask_drias_query(query: str, index_state: int, drias_model: str):
|
91 |
+
return ask_drias(db_vanna_path, query, index_state, drias_model)
|
92 |
|
93 |
async def chat(query, history, audience, sources, reports, relevant_content_sources_selection, search_only):
|
94 |
print("chat cqa - message received")
|
|
|
139 |
|
140 |
# vanna_display = gr.Plot()
|
141 |
# vanna_direct_question.submit(ask_drias_query, [vanna_direct_question], [vanna_sql_query ,vanna_table, vanna_display])
|
142 |
+
|
143 |
def create_drias_tab():
|
144 |
with gr.Tab("Beta - Talk to DRIAS", elem_id="tab-vanna", id=6):
|
145 |
+
with gr.Row():
|
146 |
+
drias_direct_question = gr.Textbox(label="Direct Question", placeholder="You can write direct question here", elem_id="direct-question", interactive=True)
|
147 |
+
model_selection = gr.Dropdown(label="Model", choices=DRIAS_MODELS ,elem_id="drias-model", value="ALL", interactive=True)
|
148 |
|
149 |
+
with gr.Accordion(label="SQL Query Used"):
|
150 |
+
drias_sql_query = gr.Textbox(label="", elem_id="sql-query", interactive=False)
|
151 |
+
|
152 |
+
with gr.Accordion(label='Data used', open=False):
|
153 |
drias_table = gr.DataFrame([], elem_id="vanna-table")
|
|
|
154 |
|
155 |
+
with gr.Accordion(label="Chart"):
|
156 |
+
drias_display = gr.Plot(elem_id="vanna-plot")
|
157 |
+
|
158 |
+
with gr.Row():
|
159 |
prev_button = gr.Button("Previous")
|
160 |
next_button = gr.Button("Next")
|
161 |
|
162 |
+
sql_queries_state = gr.State([])
|
163 |
+
dataframes_state = gr.State([])
|
164 |
+
plots_state = gr.State([])
|
165 |
+
index_state = gr.State(0)
|
166 |
|
167 |
drias_direct_question.submit(
|
168 |
ask_drias_query,
|
169 |
+
inputs=[drias_direct_question, index_state, model_selection],
|
170 |
+
outputs=[drias_sql_query, drias_table, drias_display, sql_queries_state, dataframes_state, plots_state, index_state]
|
171 |
+
)
|
172 |
+
|
173 |
+
model_selection.change(
|
174 |
+
ask_drias_query,
|
175 |
+
inputs=[drias_direct_question, index_state, model_selection],
|
176 |
outputs=[drias_sql_query, drias_table, drias_display, sql_queries_state, dataframes_state, plots_state, index_state]
|
177 |
)
|
178 |
|
|
|
197 |
inputs=[index_state, sql_queries_state, dataframes_state, plots_state],
|
198 |
outputs=[drias_sql_query, drias_table, drias_display, index_state]
|
199 |
)
|
|
|
|
|
200 |
|
201 |
# # UI Layout Components
|
202 |
def cqa_tab(tab_name):
|
climateqa/engine/talk_to_data/main.py
CHANGED
@@ -13,13 +13,12 @@ def ask_llm_column_names(sql_query, llm):
|
|
13 |
columns_list = ast.literal_eval(columns.strip("```python\n").strip())
|
14 |
return columns_list
|
15 |
|
16 |
-
def ask_drias(db_drias_path:str, query:str
|
17 |
-
final_state = drias_workflow(db_drias_path, query)
|
18 |
sql_queries = []
|
19 |
result_dataframes = []
|
20 |
figures = []
|
21 |
|
22 |
-
|
23 |
for plot_state in final_state['plot_states'].values():
|
24 |
for table_state in plot_state['table_states'].values():
|
25 |
if table_state['status'] == 'OK':
|
@@ -30,9 +29,29 @@ def ask_drias(db_drias_path:str, query:str , index_state: int):
|
|
30 |
result_dataframes.append(table_state['dataframe'])
|
31 |
if 'figure' in table_state and table_state['figure'] is not None:
|
32 |
figures.append(table_state['figure'](table_state['dataframe']))
|
|
|
33 |
|
34 |
return sql_queries[index_state], result_dataframes[index_state], figures[index_state], sql_queries, result_dataframes, figures, index_state
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
# def ask_vanna(vn,db_vanna_path, query):
|
37 |
|
38 |
# try :
|
|
|
13 |
columns_list = ast.literal_eval(columns.strip("```python\n").strip())
|
14 |
return columns_list
|
15 |
|
16 |
+
def ask_drias(db_drias_path:str, query:str, index_state: int = 0, drias_model: str = "ALL"):
|
17 |
+
final_state = drias_workflow(db_drias_path, query, drias_model)
|
18 |
sql_queries = []
|
19 |
result_dataframes = []
|
20 |
figures = []
|
21 |
|
|
|
22 |
for plot_state in final_state['plot_states'].values():
|
23 |
for table_state in plot_state['table_states'].values():
|
24 |
if table_state['status'] == 'OK':
|
|
|
29 |
result_dataframes.append(table_state['dataframe'])
|
30 |
if 'figure' in table_state and table_state['figure'] is not None:
|
31 |
figures.append(table_state['figure'](table_state['dataframe']))
|
32 |
+
|
33 |
|
34 |
return sql_queries[index_state], result_dataframes[index_state], figures[index_state], sql_queries, result_dataframes, figures, index_state
|
35 |
|
36 |
+
DRIAS_MODELS = [
|
37 |
+
'ALL',
|
38 |
+
'RegCM4-6_MPI-ESM-LR',
|
39 |
+
'RACMO22E_EC-EARTH',
|
40 |
+
'RegCM4-6_HadGEM2-ES',
|
41 |
+
'HadREM3-GA7_EC-EARTH',
|
42 |
+
'HadREM3-GA7_CNRM-CM5',
|
43 |
+
'REMO2015_NorESM1-M',
|
44 |
+
'SMHI-RCA4_EC-EARTH',
|
45 |
+
'WRF381P_NorESM1-M',
|
46 |
+
'ALADIN63_CNRM-CM5',
|
47 |
+
'CCLM4-8-17_MPI-ESM-LR',
|
48 |
+
'HIRHAM5_IPSL-CM5A-MR',
|
49 |
+
'HadREM3-GA7_HadGEM2-ES',
|
50 |
+
'SMHI-RCA4_IPSL-CM5A-MR',
|
51 |
+
'HIRHAM5_NorESM1-M',
|
52 |
+
'REMO2009_MPI-ESM-LR',
|
53 |
+
'CCLM4-8-17_HadGEM2-ES'
|
54 |
+
]
|
55 |
# def ask_vanna(vn,db_vanna_path, query):
|
56 |
|
57 |
# try :
|
climateqa/engine/talk_to_data/plot.py
CHANGED
@@ -53,7 +53,7 @@ def plot_indicator_evolution_at_location(params: dict) -> Callable[..., Figure]:
|
|
53 |
# Compute the 10-year rolling average
|
54 |
sliding_averages = (
|
55 |
df_avg[indicator]
|
56 |
-
.rolling(window=10, min_periods=
|
57 |
.mean()
|
58 |
.astype(float)
|
59 |
.tolist()
|
@@ -68,7 +68,7 @@ def plot_indicator_evolution_at_location(params: dict) -> Callable[..., Figure]:
|
|
68 |
# Compute the 10-year rolling average
|
69 |
sliding_averages = (
|
70 |
df_model[indicator]
|
71 |
-
.rolling(window=10, min_periods=
|
72 |
.mean()
|
73 |
.astype(float)
|
74 |
.tolist()
|
@@ -241,6 +241,7 @@ def plot_distribution_of_indicator_for_given_year(
|
|
241 |
yaxis_title="Frequency",
|
242 |
plot_bgcolor="rgba(0, 0, 0, 0)",
|
243 |
showlegend=False,
|
|
|
244 |
)
|
245 |
|
246 |
return fig
|
@@ -313,8 +314,8 @@ def plot_map_of_france_of_indicator_for_given_year(
|
|
313 |
mapbox_style="open-street-map", # Use OpenStreetMap
|
314 |
mapbox_zoom=3,
|
315 |
mapbox_center={"lat": 46.6, "lon": 2.0},
|
316 |
-
coloraxis_colorbar=dict(title=f"{indicator_label}"), # Add legend
|
317 |
-
title=f"{indicator_label} in {year} in France", # Title
|
318 |
)
|
319 |
return fig
|
320 |
|
|
|
53 |
# Compute the 10-year rolling average
|
54 |
sliding_averages = (
|
55 |
df_avg[indicator]
|
56 |
+
.rolling(window=10, min_periods=1)
|
57 |
.mean()
|
58 |
.astype(float)
|
59 |
.tolist()
|
|
|
68 |
# Compute the 10-year rolling average
|
69 |
sliding_averages = (
|
70 |
df_model[indicator]
|
71 |
+
.rolling(window=10, min_periods=1)
|
72 |
.mean()
|
73 |
.astype(float)
|
74 |
.tolist()
|
|
|
241 |
yaxis_title="Frequency",
|
242 |
plot_bgcolor="rgba(0, 0, 0, 0)",
|
243 |
showlegend=False,
|
244 |
+
pan=False
|
245 |
)
|
246 |
|
247 |
return fig
|
|
|
314 |
mapbox_style="open-street-map", # Use OpenStreetMap
|
315 |
mapbox_zoom=3,
|
316 |
mapbox_center={"lat": 46.6, "lon": 2.0},
|
317 |
+
coloraxis_colorbar=dict(title=f"{indicator_label} {'(Model Average)' if model == 'ALL' else '(Model : ' + model + ')'}"), # Add legend
|
318 |
+
title=f"{indicator_label} in {year} in France ", # Title
|
319 |
)
|
320 |
return fig
|
321 |
|
climateqa/engine/talk_to_data/sql_query.py
CHANGED
@@ -60,10 +60,16 @@ def indicator_per_year_at_location_query(
|
|
60 |
indicator_column = params.get("indicator_column")
|
61 |
latitude = params.get("latitude")
|
62 |
longitude = params.get("longitude")
|
|
|
63 |
|
64 |
if indicator_column is None or latitude is None or longitude is None: # If one parameter is missing, returns an empty query
|
65 |
return ""
|
66 |
-
|
|
|
|
|
|
|
|
|
|
|
67 |
return sql_query
|
68 |
|
69 |
class IndicatorForGivenYearQueryParams(TypedDict, total=False):
|
@@ -85,9 +91,12 @@ def indicator_for_given_year_query(
|
|
85 |
"""
|
86 |
indicator_column = params.get("indicator_column")
|
87 |
year = params.get('year')
|
88 |
-
|
89 |
if year is None or indicator_column is None: # If one parameter is missing, returns an empty query
|
90 |
return ""
|
91 |
-
|
92 |
-
|
|
|
|
|
|
|
93 |
return sql_query
|
|
|
60 |
indicator_column = params.get("indicator_column")
|
61 |
latitude = params.get("latitude")
|
62 |
longitude = params.get("longitude")
|
63 |
+
model = params.get('model')
|
64 |
|
65 |
if indicator_column is None or latitude is None or longitude is None: # If one parameter is missing, returns an empty query
|
66 |
return ""
|
67 |
+
|
68 |
+
if model == 'ALL':
|
69 |
+
sql_query = f"SELECT year, {indicator_column}, model\nFROM {table}\nWHERE latitude = {latitude} \nand longitude={longitude} \nOrder by Year"
|
70 |
+
else:
|
71 |
+
sql_query = f"SELECT year, {indicator_column}, model\nFROM {table}\nWHERE latitude = {latitude} \nand longitude={longitude} \nand model='{model}' \nOrder by Year"
|
72 |
+
|
73 |
return sql_query
|
74 |
|
75 |
class IndicatorForGivenYearQueryParams(TypedDict, total=False):
|
|
|
91 |
"""
|
92 |
indicator_column = params.get("indicator_column")
|
93 |
year = params.get('year')
|
94 |
+
model = params.get('model')
|
95 |
if year is None or indicator_column is None: # If one parameter is missing, returns an empty query
|
96 |
return ""
|
97 |
+
|
98 |
+
if model == 'ALL':
|
99 |
+
sql_query = f"Select {indicator_column}, latitude, longitude, model\nFrom {table}\nWhere year = {year}"
|
100 |
+
else:
|
101 |
+
sql_query = f"Select {indicator_column}, latitude, longitude, model\nFrom {table}\nWhere year = {year}\nand model = '{model}'"
|
102 |
return sql_query
|
climateqa/engine/talk_to_data/utils.py
CHANGED
@@ -27,6 +27,31 @@ def detect_location_with_openai(sentence):
|
|
27 |
return location_list[0]
|
28 |
else:
|
29 |
return ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
|
32 |
def detectTable(sql_query):
|
@@ -65,6 +90,17 @@ def nearestNeighbourSQL(db: str, location: tuple, table: str):
|
|
65 |
|
66 |
|
67 |
def detect_relevant_tables(db: str, user_question: str, plot: Plot, llm) -> list[str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
conn = sqlite3.connect(db)
|
69 |
cursor = conn.cursor()
|
70 |
|
|
|
27 |
return location_list[0]
|
28 |
else:
|
29 |
return ""
|
30 |
+
|
31 |
+
def detect_year_with_openai(sentence: str):
|
32 |
+
"""
|
33 |
+
Detects years in a sentence using OpenAI's API via LangChain.
|
34 |
+
"""
|
35 |
+
llm = get_llm()
|
36 |
+
|
37 |
+
prompt = f"""
|
38 |
+
Extract all years mentioned in the following sentence.
|
39 |
+
Return the result as a Python list. If no year are mentioned, return an empty list.
|
40 |
+
|
41 |
+
Sentence: "{sentence}"
|
42 |
+
"""
|
43 |
+
|
44 |
+
response = llm.invoke(prompt)
|
45 |
+
if response is None:
|
46 |
+
return None
|
47 |
+
response_split = response.content.strip("```python\n").split('=')
|
48 |
+
years_list = []
|
49 |
+
if len(response_split) > 1:
|
50 |
+
years_list = ast.literal_eval(response_split[1])
|
51 |
+
if years_list and len(years_list) > 0:
|
52 |
+
return years_list[0]
|
53 |
+
else:
|
54 |
+
return None
|
55 |
|
56 |
|
57 |
def detectTable(sql_query):
|
|
|
90 |
|
91 |
|
92 |
def detect_relevant_tables(db: str, user_question: str, plot: Plot, llm) -> list[str]:
|
93 |
+
"""Detect relevant tables regarding the plot and the user input
|
94 |
+
|
95 |
+
Args:
|
96 |
+
db (str): database path
|
97 |
+
user_question (str): initial user input
|
98 |
+
plot (Plot): plot object for which we wanna plot
|
99 |
+
llm (_type_): LLM
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
list[str]: list of table names
|
103 |
+
"""
|
104 |
conn = sqlite3.connect(db)
|
105 |
cursor = conn.cursor()
|
106 |
|
climateqa/engine/talk_to_data/workflow.py
CHANGED
@@ -38,7 +38,7 @@ class State(TypedDict):
|
|
38 |
plots: list[str]
|
39 |
plot_states: dict[str, PlotState]
|
40 |
|
41 |
-
def drias_workflow(db_drias_path: str, user_input: str) -> State:
|
42 |
"""Performs the complete workflow of Talk To Drias : from user input to sql queries, dataframes and figures generated
|
43 |
|
44 |
Args:
|
@@ -87,7 +87,7 @@ def drias_workflow(db_drias_path: str, user_input: str) -> State:
|
|
87 |
'status': 'OK'
|
88 |
}
|
89 |
table_state['params'] = {
|
90 |
-
'model':
|
91 |
}
|
92 |
for param_name in plot['params']:
|
93 |
param = find_param(state, param_name, table, db_drias_path)
|
@@ -99,6 +99,7 @@ def drias_workflow(db_drias_path: str, user_input: str) -> State:
|
|
99 |
if sql_query == "":
|
100 |
table_state['status'] = 'ERROR'
|
101 |
continue
|
|
|
102 |
|
103 |
table_state['sql_query'] = sql_query
|
104 |
results = execute_sql_query(db_drias_path, sql_query)
|
|
|
38 |
plots: list[str]
|
39 |
plot_states: dict[str, PlotState]
|
40 |
|
41 |
+
def drias_workflow(db_drias_path: str, user_input: str, model: str) -> State:
|
42 |
"""Performs the complete workflow of Talk To Drias : from user input to sql queries, dataframes and figures generated
|
43 |
|
44 |
Args:
|
|
|
87 |
'status': 'OK'
|
88 |
}
|
89 |
table_state['params'] = {
|
90 |
+
'model': model
|
91 |
}
|
92 |
for param_name in plot['params']:
|
93 |
param = find_param(state, param_name, table, db_drias_path)
|
|
|
99 |
if sql_query == "":
|
100 |
table_state['status'] = 'ERROR'
|
101 |
continue
|
102 |
+
print(sql_query)
|
103 |
|
104 |
table_state['sql_query'] = sql_query
|
105 |
results = execute_sql_query(db_drias_path, sql_query)
|
style.css
CHANGED
@@ -487,7 +487,6 @@ a {
|
|
487 |
height: calc(100vh - 190px) !important;
|
488 |
overflow-y: scroll !important;
|
489 |
}
|
490 |
-
div#tab-vanna,
|
491 |
div#sources-figures,
|
492 |
div#graphs-container,
|
493 |
div#tab-citations {
|
@@ -607,14 +606,25 @@ a {
|
|
607 |
}
|
608 |
|
609 |
#vanna-display {
|
610 |
-
max-height:
|
611 |
/* overflow-y: scroll; */
|
612 |
}
|
613 |
#sql-query{
|
614 |
-
max-height:
|
615 |
overflow-y:scroll;
|
616 |
}
|
617 |
-
|
618 |
-
|
|
|
|
|
|
|
|
|
619 |
overflow-y:scroll;
|
|
|
|
|
|
|
620 |
}
|
|
|
|
|
|
|
|
|
|
487 |
height: calc(100vh - 190px) !important;
|
488 |
overflow-y: scroll !important;
|
489 |
}
|
|
|
490 |
div#sources-figures,
|
491 |
div#graphs-container,
|
492 |
div#tab-citations {
|
|
|
606 |
}
|
607 |
|
608 |
#vanna-display {
|
609 |
+
max-height: 200px;
|
610 |
/* overflow-y: scroll; */
|
611 |
}
|
612 |
#sql-query{
|
613 |
+
max-height: 300px;
|
614 |
overflow-y:scroll;
|
615 |
}
|
616 |
+
|
617 |
+
#sql-query span{
|
618 |
+
display: none;
|
619 |
+
}
|
620 |
+
div#tab-vanna{
|
621 |
+
max-height: 100¨vh;
|
622 |
overflow-y:scroll;
|
623 |
+
}
|
624 |
+
#vanna-plot{
|
625 |
+
max-height:500px
|
626 |
}
|
627 |
+
|
628 |
+
#drias-model{
|
629 |
+
max-width: 25%;
|
630 |
+
}
|