timeki armanddemasson commited on
Commit
47a68b0
·
1 Parent(s): 8c7a7fe

feature/drias_parallelization (#25)

Browse files

- log to huggingface (f9c4c84a71d320c6db05ee099b9e35492ba7b184)
- Merged in feat/logs_on_huggingface (pull request #5) (261632833e7d939c321f208edfd6576e68947b4b)
- feat: added multithreading to run sql queries in talk to drias (705ccece7775c65a9c7b73b091cdddbc4246f2e7)
- chore: remove prints in talk to drias workflow (a967134f90c70d87bb3d786f2feb81f2e56fdb9f)
- Merged in feat/improve_drias_exeuction_time (pull request #6) (7c38528636ac19c1efa240a122e523ed0c34706a)
- fix import (05b8df9c9b74926459da70797b6852ff07a4d838)
- Merge branch 'main' into dev (8fb231c8beabf8a6406f05cf4cac564c5d81c7ce)
- Merged in dev (pull request #7) (6b9f71b1cf216eef0fd7973412f675d8633a5f4a)
- fix import (b35df2a8160723e43f74a040aa94983069066213)
- Merge branch 'main' of https://bitbucket.org/ekimetrics/climate_qa (f96cfd0715ec2b1ed7a78775ea7f8722f5793d8f)


Co-authored-by: Armand Demasson <[email protected]>

climateqa/chat.py CHANGED
@@ -12,15 +12,11 @@ from .handle_stream_events import (
12
  convert_to_docs_to_html,
13
  stream_answer,
14
  handle_retrieved_owid_graphs,
15
- serialize_docs,
16
  )
17
-
18
- # Function to log data on Azure
19
- def log_on_azure(file, logs, share_client):
20
- logs = json.dumps(logs)
21
- file_client = share_client.get_file_client(file)
22
- file_client.upload_file(logs)
23
-
24
  # Chat functions
25
  def start_chat(query, history, search_only):
26
  history = history + [ChatMessage(role="user", content=query)]
@@ -32,28 +28,6 @@ def start_chat(query, history, search_only):
32
  def finish_chat():
33
  return gr.update(interactive=True, value="")
34
 
35
- def log_interaction_to_azure(history, output_query, sources, docs, share_client, user_id):
36
- try:
37
- # Log interaction to Azure if not in local environment
38
- if os.getenv("GRADIO_ENV") != "local":
39
- timestamp = str(datetime.now().timestamp())
40
- prompt = history[1]["content"]
41
- logs = {
42
- "user_id": str(user_id),
43
- "prompt": prompt,
44
- "query": prompt,
45
- "question": output_query,
46
- "sources": sources,
47
- "docs": serialize_docs(docs),
48
- "answer": history[-1].content,
49
- "time": timestamp,
50
- }
51
- log_on_azure(f"{timestamp}.json", logs, share_client)
52
- except Exception as e:
53
- print(f"Error logging on Azure Blob Storage: {e}")
54
- error_msg = f"ClimateQ&A Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)"
55
- raise gr.Error(error_msg)
56
-
57
  def handle_numerical_data(event):
58
  if event["name"] == "retrieve_drias_data" and event["event"] == "on_chain_end":
59
  numerical_data = event["data"]["output"]["drias_data"]
@@ -61,27 +35,6 @@ def handle_numerical_data(event):
61
  return numerical_data, sql_query
62
  return None, None
63
 
64
- def log_drias_interaction_to_azure(query, sql_query, data, share_client, user_id):
65
- try:
66
- # Log interaction to Azure if not in local environment
67
- if os.getenv("GRADIO_ENV") != "local":
68
- timestamp = str(datetime.now().timestamp())
69
- logs = {
70
- "user_id": str(user_id),
71
- "query": query,
72
- "sql_query": sql_query,
73
- # "data": data.to_dict() if data is not None else None,
74
- "time": timestamp,
75
- }
76
- log_on_azure(f"drias_{timestamp}.json", logs, share_client)
77
- print(f"Logged Drias interaction to Azure Blob Storage: {logs}")
78
- else:
79
- print("share_client or user_id is None, or GRADIO_ENV is local")
80
- except Exception as e:
81
- print(f"Error logging Drias interaction on Azure Blob Storage: {e}")
82
- error_msg = f"Drias Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)"
83
- raise gr.Error(error_msg)
84
-
85
  # Main chat function
86
  async def chat_stream(
87
  agent : CompiledStateGraph,
@@ -235,9 +188,7 @@ async def chat_stream(
235
  print(f"Event {event} has failed")
236
  raise gr.Error(str(e))
237
 
238
-
239
-
240
  # Call the function to log interaction
241
- log_interaction_to_azure(history, output_query, sources, docs, share_client, user_id)
242
 
243
  yield history, docs_html, output_query, output_language, related_contents, graphs_html, follow_up_examples#, vanna_data
 
12
  convert_to_docs_to_html,
13
  stream_answer,
14
  handle_retrieved_owid_graphs,
 
15
  )
16
+ from .logging import (
17
+ log_interaction_to_huggingface
18
+ )
19
+
 
 
 
20
  # Chat functions
21
  def start_chat(query, history, search_only):
22
  history = history + [ChatMessage(role="user", content=query)]
 
28
  def finish_chat():
29
  return gr.update(interactive=True, value="")
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  def handle_numerical_data(event):
32
  if event["name"] == "retrieve_drias_data" and event["event"] == "on_chain_end":
33
  numerical_data = event["data"]["output"]["drias_data"]
 
35
  return numerical_data, sql_query
36
  return None, None
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  # Main chat function
39
  async def chat_stream(
40
  agent : CompiledStateGraph,
 
188
  print(f"Event {event} has failed")
189
  raise gr.Error(str(e))
190
 
 
 
191
  # Call the function to log interaction
192
+ log_interaction_to_huggingface(history, output_query, sources, docs, share_client, user_id)
193
 
194
  yield history, docs_html, output_query, output_language, related_contents, graphs_html, follow_up_examples#, vanna_data
climateqa/engine/talk_to_data/main.py CHANGED
@@ -1,5 +1,6 @@
1
- from climateqa.engine.talk_to_data.workflow import drias_workflow
2
  from climateqa.engine.llm import get_llm
 
3
  import ast
4
 
5
  llm = get_llm(provider="openai")
@@ -37,7 +38,7 @@ def ask_llm_column_names(sql_query: str, llm) -> list[str]:
37
  columns_list = ast.literal_eval(columns.strip("```python\n").strip())
38
  return columns_list
39
 
40
- async def ask_drias(query: str, index_state: int = 0) -> tuple:
41
  """Main function to process a DRIAS query and return results.
42
 
43
  This function orchestrates the DRIAS workflow, processing a user query to generate
@@ -85,6 +86,8 @@ async def ask_drias(query: str, index_state: int = 0) -> tuple:
85
  sql_query = sql_queries[index_state]
86
  dataframe = result_dataframes[index_state]
87
  figure = figures[index_state](dataframe)
 
 
88
 
89
  return sql_query, dataframe, figure, sql_queries, result_dataframes, figures, index_state, table_list, ""
90
 
 
1
+ from climateqa.engine.talk_to_data.talk_to_drias import drias_workflow
2
  from climateqa.engine.llm import get_llm
3
+ from climateqa.logging import log_drias_interaction_to_huggingface
4
  import ast
5
 
6
  llm = get_llm(provider="openai")
 
38
  columns_list = ast.literal_eval(columns.strip("```python\n").strip())
39
  return columns_list
40
 
41
+ async def ask_drias(query: str, index_state: int = 0, user_id: str = None) -> tuple:
42
  """Main function to process a DRIAS query and return results.
43
 
44
  This function orchestrates the DRIAS workflow, processing a user query to generate
 
86
  sql_query = sql_queries[index_state]
87
  dataframe = result_dataframes[index_state]
88
  figure = figures[index_state](dataframe)
89
+
90
+ log_drias_interaction_to_huggingface(query, sql_query, user_id)
91
 
92
  return sql_query, dataframe, figure, sql_queries, result_dataframes, figures, index_state, table_list, ""
93
 
climateqa/engine/talk_to_data/sql_query.py CHANGED
@@ -22,9 +22,10 @@ async def execute_sql_query(sql_query: str) -> pd.DataFrame:
22
  """
23
  def _execute_query():
24
  # Execute the query
25
- results = duckdb.sql(sql_query)
 
26
  # return fetched data
27
- return results.fetchdf()
28
 
29
  # Run the query in a thread pool to avoid blocking
30
  loop = asyncio.get_event_loop()
 
22
  """
23
  def _execute_query():
24
  # Execute the query
25
+ con = duckdb.connect()
26
+ results = con.sql(sql_query).fetchdf()
27
  # return fetched data
28
+ return results
29
 
30
  # Run the query in a thread pool to avoid blocking
31
  loop = asyncio.get_event_loop()
climateqa/engine/talk_to_data/{workflow.py → talk_to_drias.py} RENAMED
@@ -1,10 +1,12 @@
1
  import os
2
 
3
  from typing import Any, Callable, TypedDict, Optional
 
4
  import pandas as pd
5
-
6
  from plotly.graph_objects import Figure
7
  from climateqa.engine.llm import get_llm
 
8
  from climateqa.engine.talk_to_data.config import INDICATOR_COLUMNS_PER_TABLE
9
  from climateqa.engine.talk_to_data.plot import PLOTS, Plot
10
  from climateqa.engine.talk_to_data.sql_query import execute_sql_query
@@ -17,6 +19,7 @@ from climateqa.engine.talk_to_data.utils import (
17
  detect_relevant_tables,
18
  )
19
 
 
20
  ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))
21
 
22
  class TableState(TypedDict):
@@ -61,101 +64,6 @@ class State(TypedDict):
61
  plot_states: dict[str, PlotState]
62
  error: Optional[str]
63
 
64
- async def drias_workflow(user_input: str) -> State:
65
- """Performs the complete workflow of Talk To Drias : from user input to sql queries, dataframes and figures generated
66
-
67
- Args:
68
- user_input (str): initial user input
69
-
70
- Returns:
71
- State: Final state with all the results
72
- """
73
- state: State = {
74
- 'user_input': user_input,
75
- 'plots': [],
76
- 'plot_states': {}
77
- }
78
-
79
- llm = get_llm(provider="openai")
80
-
81
- plots = await find_relevant_plots(state, llm)
82
- state['plots'] = plots
83
-
84
- if not state['plots']:
85
- state['error'] = 'There is no plot to answer to the question'
86
- return state
87
-
88
- have_relevant_table = False
89
- have_sql_query = False
90
- have_dataframe = False
91
- for plot_name in state['plots']:
92
-
93
- plot = next((p for p in PLOTS if p['name'] == plot_name), None) # Find the associated plot object
94
- if plot is None:
95
- continue
96
-
97
- plot_state: PlotState = {
98
- 'plot_name': plot_name,
99
- 'tables': [],
100
- 'table_states': {}
101
- }
102
-
103
- plot_state['plot_name'] = plot_name
104
-
105
- relevant_tables = await find_relevant_tables_per_plot(state, plot, llm)
106
- if len(relevant_tables) > 0 :
107
- have_relevant_table = True
108
-
109
- plot_state['tables'] = relevant_tables
110
-
111
- params = {}
112
- for param_name in plot['params']:
113
- param = await find_param(state, param_name, relevant_tables[0])
114
- if param:
115
- params.update(param)
116
-
117
- for n, table in enumerate(plot_state['tables']):
118
- if n > 2:
119
- break
120
-
121
- table_state: TableState = {
122
- 'table_name': table,
123
- 'params': params,
124
- 'status': 'OK'
125
- }
126
-
127
- table_state["params"]['indicator_column'] = find_indicator_column(table)
128
-
129
- sql_query = plot['sql_query'](table, table_state['params'])
130
-
131
- if sql_query == "":
132
- table_state['status'] = 'ERROR'
133
- continue
134
- else :
135
- have_sql_query = True
136
-
137
- table_state['sql_query'] = sql_query
138
- df = await execute_sql_query(sql_query)
139
-
140
- if len(df) > 0:
141
- have_dataframe = True
142
-
143
- figure = plot['plot_function'](table_state['params'])
144
- table_state['dataframe'] = df
145
- table_state['figure'] = figure
146
- plot_state['table_states'][table] = table_state
147
-
148
- state['plot_states'][plot_name] = plot_state
149
-
150
- if not have_relevant_table:
151
- state['error'] = "There is no relevant table in the our database to answer your question"
152
- elif not have_sql_query:
153
- state['error'] = "There is no relevant sql query on our database that can help to answer your question"
154
- elif not have_dataframe:
155
- state['error'] = "There is no data in our table that can answer to your question"
156
-
157
- return state
158
-
159
  async def find_relevant_plots(state: State, llm) -> list[str]:
160
  print("---- Find relevant plots ----")
161
  relevant_plots = await detect_relevant_plots(state['user_input'], llm)
@@ -238,6 +146,128 @@ def find_indicator_column(table: str) -> str:
238
  return INDICATOR_COLUMNS_PER_TABLE[table]
239
 
240
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  # def make_write_query_node():
242
 
243
  # def write_query(state):
 
1
  import os
2
 
3
  from typing import Any, Callable, TypedDict, Optional
4
+ from numpy import sort
5
  import pandas as pd
6
+ import asyncio
7
  from plotly.graph_objects import Figure
8
  from climateqa.engine.llm import get_llm
9
+ from climateqa.engine.talk_to_data import sql_query
10
  from climateqa.engine.talk_to_data.config import INDICATOR_COLUMNS_PER_TABLE
11
  from climateqa.engine.talk_to_data.plot import PLOTS, Plot
12
  from climateqa.engine.talk_to_data.sql_query import execute_sql_query
 
19
  detect_relevant_tables,
20
  )
21
 
22
+
23
  ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))
24
 
25
  class TableState(TypedDict):
 
64
  plot_states: dict[str, PlotState]
65
  error: Optional[str]
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  async def find_relevant_plots(state: State, llm) -> list[str]:
68
  print("---- Find relevant plots ----")
69
  relevant_plots = await detect_relevant_plots(state['user_input'], llm)
 
146
  return INDICATOR_COLUMNS_PER_TABLE[table]
147
 
148
 
149
+ async def process_table(
150
+ table: str,
151
+ params: dict[str, Any],
152
+ plot: Plot,
153
+ ) -> TableState:
154
+ """Processes a table to extract relevant data and generate visualizations.
155
+
156
+ This function retrieves the SQL query for the specified table, executes it,
157
+ and generates a visualization based on the results.
158
+
159
+ Args:
160
+ table (str): The name of the table to process
161
+ params (dict[str, Any]): Parameters used for querying the table
162
+ plot (Plot): The plot object containing SQL query and visualization function
163
+
164
+ Returns:
165
+ TableState: The state of the processed table
166
+ """
167
+ table_state: TableState = {
168
+ 'table_name': table,
169
+ 'params': params.copy(),
170
+ 'status': 'OK',
171
+ 'dataframe': None,
172
+ 'sql_query': None,
173
+ 'figure': None
174
+ }
175
+
176
+ table_state['params']['indicator_column'] = find_indicator_column(table)
177
+ sql_query = plot['sql_query'](table, table_state['params'])
178
+
179
+ if sql_query == "":
180
+ table_state['status'] = 'ERROR'
181
+ return table_state
182
+ table_state['sql_query'] = sql_query
183
+ df = await execute_sql_query(sql_query)
184
+
185
+ table_state['dataframe'] = df
186
+ table_state['figure'] = plot['plot_function'](table_state['params'])
187
+
188
+ return table_state
189
+
190
+ async def drias_workflow(user_input: str) -> State:
191
+ """Performs the complete workflow of Talk To Drias : from user input to sql queries, dataframes and figures generated
192
+
193
+ Args:
194
+ user_input (str): initial user input
195
+
196
+ Returns:
197
+ State: Final state with all the results
198
+ """
199
+ state: State = {
200
+ 'user_input': user_input,
201
+ 'plots': [],
202
+ 'plot_states': {},
203
+ 'error': ''
204
+ }
205
+
206
+ llm = get_llm(provider="openai")
207
+
208
+ plots = await find_relevant_plots(state, llm)
209
+
210
+ state['plots'] = plots
211
+
212
+ if len(state['plots']) < 1:
213
+ state['error'] = 'There is no plot to answer to the question'
214
+ return state
215
+
216
+ have_relevant_table = False
217
+ have_sql_query = False
218
+ have_dataframe = False
219
+
220
+ for plot_name in state['plots']:
221
+
222
+ plot = next((p for p in PLOTS if p['name'] == plot_name), None) # Find the associated plot object
223
+ if plot is None:
224
+ continue
225
+
226
+ plot_state: PlotState = {
227
+ 'plot_name': plot_name,
228
+ 'tables': [],
229
+ 'table_states': {}
230
+ }
231
+
232
+ plot_state['plot_name'] = plot_name
233
+
234
+ relevant_tables = await find_relevant_tables_per_plot(state, plot, llm)
235
+
236
+ if len(relevant_tables) > 0 :
237
+ have_relevant_table = True
238
+
239
+ plot_state['tables'] = relevant_tables
240
+
241
+ params = {}
242
+ for param_name in plot['params']:
243
+ param = await find_param(state, param_name, relevant_tables[0])
244
+ if param:
245
+ params.update(param)
246
+
247
+ tasks = [process_table(table, params, plot) for table in plot_state['tables'][:3]]
248
+ results = await asyncio.gather(*tasks)
249
+
250
+ # Store results back in plot_state
251
+ have_dataframe = False
252
+ have_sql_query = False
253
+ for table_state in results:
254
+ if table_state['sql_query']:
255
+ have_sql_query = True
256
+ if table_state['dataframe'] is not None and len(table_state['dataframe']) > 0:
257
+ have_dataframe = True
258
+ plot_state['table_states'][table_state['table_name']] = table_state
259
+
260
+ state['plot_states'][plot_name] = plot_state
261
+
262
+ if not have_relevant_table:
263
+ state['error'] = "There is no relevant table in our database to answer your question"
264
+ elif not have_sql_query:
265
+ state['error'] = "There is no relevant sql query on our database that can help to answer your question"
266
+ elif not have_dataframe:
267
+ state['error'] = "There is no data in our table that can answer to your question"
268
+
269
+ return state
270
+
271
  # def make_write_query_node():
272
 
273
  # def write_query(state):
climateqa/handle_stream_events.py CHANGED
@@ -1,7 +1,7 @@
1
  from langchain_core.runnables.schema import StreamEvent
2
  from gradio import ChatMessage
3
  from climateqa.engine.chains.prompts import audience_prompts
4
- from front.utils import make_html_source,parse_output_llm_with_sources,serialize_docs,make_toolbox,generate_html_graphs
5
  import numpy as np
6
 
7
  def init_audience(audience :str) -> str:
 
1
  from langchain_core.runnables.schema import StreamEvent
2
  from gradio import ChatMessage
3
  from climateqa.engine.chains.prompts import audience_prompts
4
+ from front.utils import make_html_source,parse_output_llm_with_sources
5
  import numpy as np
6
 
7
  def init_audience(audience :str) -> str:
climateqa/logging.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from datetime import datetime
3
+ import json
4
+ from huggingface_hub import HfApi
5
+ import gradio as gr
6
+ import csv
7
+
8
+ def serialize_docs(docs:list)->list:
9
+ new_docs = []
10
+ for doc in docs:
11
+ new_doc = {}
12
+ new_doc["page_content"] = doc.page_content
13
+ new_doc["metadata"] = doc.metadata
14
+ new_docs.append(new_doc)
15
+ return new_docs
16
+
17
+ ## AZURE LOGGING - DEPRECATED
18
+
19
+ # def log_on_azure(file, logs, share_client):
20
+ # """Log data to Azure Blob Storage.
21
+
22
+ # Args:
23
+ # file (str): Name of the file to store logs
24
+ # logs (dict): Log data to store
25
+ # share_client: Azure share client instance
26
+ # """
27
+ # logs = json.dumps(logs)
28
+ # file_client = share_client.get_file_client(file)
29
+ # file_client.upload_file(logs)
30
+
31
+
32
+ # def log_interaction_to_azure(history, output_query, sources, docs, share_client, user_id):
33
+ # """Log chat interaction to Azure and Hugging Face.
34
+
35
+ # Args:
36
+ # history (list): Chat message history
37
+ # output_query (str): Processed query
38
+ # sources (list): Knowledge base sources used
39
+ # docs (list): Retrieved documents
40
+ # share_client: Azure share client instance
41
+ # user_id (str): User identifier
42
+ # """
43
+ # try:
44
+ # # Log interaction to Azure if not in local environment
45
+ # if os.getenv("GRADIO_ENV") != "local":
46
+ # timestamp = str(datetime.now().timestamp())
47
+ # prompt = history[1]["content"]
48
+ # logs = {
49
+ # "user_id": str(user_id),
50
+ # "prompt": prompt,
51
+ # "query": prompt,
52
+ # "question": output_query,
53
+ # "sources": sources,
54
+ # "docs": serialize_docs(docs),
55
+ # "answer": history[-1].content,
56
+ # "time": timestamp,
57
+ # }
58
+ # # Log to Azure
59
+ # log_on_azure(f"{timestamp}.json", logs, share_client)
60
+ # except Exception as e:
61
+ # print(f"Error logging on Azure Blob Storage: {e}")
62
+ # error_msg = f"ClimateQ&A Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)"
63
+ # raise gr.Error(error_msg)
64
+
65
+ # def log_drias_interaction_to_azure(query, sql_query, data, share_client, user_id):
66
+ # """Log Drias data interaction to Azure and Hugging Face.
67
+
68
+ # Args:
69
+ # query (str): User query
70
+ # sql_query (str): SQL query used
71
+ # data: Retrieved data
72
+ # share_client: Azure share client instance
73
+ # user_id (str): User identifier
74
+ # """
75
+ # try:
76
+ # # Log interaction to Azure if not in local environment
77
+ # if os.getenv("GRADIO_ENV") != "local":
78
+ # timestamp = str(datetime.now().timestamp())
79
+ # logs = {
80
+ # "user_id": str(user_id),
81
+ # "query": query,
82
+ # "sql_query": sql_query,
83
+ # "time": timestamp,
84
+ # }
85
+ # log_on_azure(f"drias_{timestamp}.json", logs, share_client)
86
+ # print(f"Logged Drias interaction to Azure Blob Storage: {logs}")
87
+ # else:
88
+ # print("share_client or user_id is None, or GRADIO_ENV is local")
89
+ # except Exception as e:
90
+ # print(f"Error logging Drias interaction on Azure Blob Storage: {e}")
91
+ # error_msg = f"Drias Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)"
92
+ # raise gr.Error(error_msg)
93
+
94
+ ## HUGGING FACE LOGGING
95
+
96
+ def log_on_huggingface(log_filename, logs):
97
+ """Log data to Hugging Face dataset repository.
98
+
99
+ Args:
100
+ log_filename (str): Name of the file to store logs
101
+ logs (dict): Log data to store
102
+ """
103
+ try:
104
+ # Get Hugging Face token from environment
105
+ hf_token = os.getenv("HF_LOGS_TOKEN")
106
+ if not hf_token:
107
+ print("HF_LOGS_TOKEN not found in environment variables")
108
+ return
109
+
110
+ # Get repository name from environment or use default
111
+ repo_id = os.getenv("HF_DATASET_REPO", "timeki/climateqa_logs")
112
+
113
+ # Initialize HfApi
114
+ api = HfApi(token=hf_token)
115
+
116
+ # Add timestamp to the log data
117
+ logs["timestamp"] = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
118
+
119
+ # Convert logs to JSON string
120
+ logs_json = json.dumps(logs)
121
+
122
+ # Upload directly from memory
123
+ api.upload_file(
124
+ path_or_fileobj=logs_json.encode('utf-8'),
125
+ path_in_repo=log_filename,
126
+ repo_id=repo_id,
127
+ repo_type="dataset"
128
+ )
129
+
130
+ except Exception as e:
131
+ print(f"Error logging to Hugging Face: {e}")
132
+
133
+
134
+ def log_interaction_to_huggingface(history, output_query, sources, docs, share_client, user_id):
135
+ """Log chat interaction to Hugging Face.
136
+
137
+ Args:
138
+ history (list): Chat message history
139
+ output_query (str): Processed query
140
+ sources (list): Knowledge base sources used
141
+ docs (list): Retrieved documents
142
+ share_client: Azure share client instance (unused in this function)
143
+ user_id (str): User identifier
144
+ """
145
+ try:
146
+ # Log interaction if not in local environment
147
+ if os.getenv("GRADIO_ENV") != "local":
148
+ timestamp = str(datetime.now().timestamp())
149
+ prompt = history[1]["content"]
150
+ logs = {
151
+ "user_id": str(user_id),
152
+ "prompt": prompt,
153
+ "query": prompt,
154
+ "question": output_query,
155
+ "sources": sources,
156
+ "docs": serialize_docs(docs),
157
+ "answer": history[-1].content,
158
+ "time": timestamp,
159
+ }
160
+ # Log to Hugging Face
161
+ log_on_huggingface(f"chat/{timestamp}.json", logs)
162
+ except Exception as e:
163
+ print(f"Error logging to Hugging Face: {e}")
164
+ error_msg = f"ClimateQ&A Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)"
165
+ raise gr.Error(error_msg)
166
+
167
+ def log_drias_interaction_to_huggingface(query, sql_query, user_id):
168
+ """Log Drias data interaction to Hugging Face.
169
+
170
+ Args:
171
+ query (str): User query
172
+ sql_query (str): SQL query used
173
+ data: Retrieved data
174
+ user_id (str): User identifier
175
+ """
176
+ try:
177
+ if os.getenv("GRADIO_ENV") != "local":
178
+ timestamp = str(datetime.now().timestamp())
179
+ logs = {
180
+ "user_id": str(user_id),
181
+ "query": query,
182
+ "sql_query": sql_query,
183
+ "time": timestamp,
184
+ }
185
+ log_on_huggingface(f"drias/drias_{timestamp}.json", logs)
186
+ print(f"Logged Drias interaction to Hugging Face: {logs}")
187
+ else:
188
+ print("share_client or user_id is None, or GRADIO_ENV is local")
189
+ except Exception as e:
190
+ print(f"Error logging Drias interaction to Hugging Face: {e}")
191
+ error_msg = f"Drias Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)"
192
+ raise gr.Error(error_msg)
193
+
194
+
data/drias/drias.db DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:1e29ba55d0122dc034b76113941769b44214355d4528bcc5b3d8f71f3c50bf59
3
- size 280621056
 
 
 
 
front/tabs/chat_interface.py CHANGED
@@ -39,7 +39,7 @@ What do you want to learn ?
39
  # """
40
 
41
  init_prompt_poc = """
42
- Hello, I am ClimateQ&A, a conversational assistant designed to help you understand climate change and biodiversity loss. I will answer your questions by **sifting through the IPCC and IPBES scientific reports, the Paris Climate Action Plan (PCAET), the Biodiversity Plan 2018-2024, and the Acclimaterra reports from the Nouvelle-Aquitaine Region**.
43
 
44
  ❓ How to use
45
  - **Language**: You can ask me your questions in any language.
 
39
  # """
40
 
41
  init_prompt_poc = """
42
+ Hello, I am ClimateQ&A, a conversational assistant designed to help you understand climate change and biodiversity loss. I will answer your questions by **sifting through the IPCC and IPBES scientific reports, the Paris Climate Action Plan (PCAET), the Paris Biodiversity Plan 2018-2024, and the Acclimaterra reports from the Nouvelle-Aquitaine Region**.
43
 
44
  ❓ How to use
45
  - **Language**: You can ask me your questions in any language.
front/tabs/tab_drias.py CHANGED
@@ -5,8 +5,6 @@ import pandas as pd
5
 
6
  from climateqa.engine.talk_to_data.main import ask_drias
7
  from climateqa.engine.talk_to_data.config import DRIAS_MODELS, DRIAS_UI_TEXT
8
- from climateqa.chat import log_drias_interaction_to_azure
9
-
10
 
11
  class DriasUIElements(TypedDict):
12
  tab: gr.Tab
@@ -28,8 +26,8 @@ class DriasUIElements(TypedDict):
28
  next_button: gr.Button
29
 
30
 
31
- async def ask_drias_query(query: str, index_state: int):
32
- result = await ask_drias(query, index_state)
33
  return result
34
 
35
 
@@ -196,19 +194,7 @@ def create_drias_ui() -> DriasUIElements:
196
  next_button=next_button
197
  )
198
 
199
- def log_drias_to_azure(query: str, sql_query: str, data, share_client, user_id):
200
- """Log Drias interaction to Azure storage."""
201
- print("log_drias_to_azure")
202
- if share_client is not None and user_id is not None:
203
- log_drias_interaction_to_azure(
204
- query=query,
205
- sql_query=sql_query,
206
- data=data,
207
- share_client=share_client,
208
- user_id=user_id
209
- )
210
- else:
211
- print("share_client or user_id is None")
212
 
213
  def setup_drias_events(ui_elements: DriasUIElements, share_client=None, user_id=None) -> None:
214
  """Set up all event handlers for the DRIAS tab."""
@@ -218,10 +204,7 @@ def setup_drias_events(ui_elements: DriasUIElements, share_client=None, user_id=
218
  plots_state = gr.State([])
219
  index_state = gr.State(0)
220
  table_names_list = gr.State([])
221
-
222
- def log_drias_interaction(query: str, sql_query: str, data: pd.DataFrame):
223
- log_drias_to_azure(query, sql_query, data, share_client, user_id)
224
-
225
 
226
  # Handle example selection
227
  ui_elements["examples_hidden"].change(
@@ -230,7 +213,7 @@ def setup_drias_events(ui_elements: DriasUIElements, share_client=None, user_id=
230
  outputs=[ui_elements["details_accordion"], ui_elements["drias_direct_question"]]
231
  ).then(
232
  ask_drias_query,
233
- inputs=[ui_elements["examples_hidden"], index_state],
234
  outputs=[
235
  ui_elements["drias_sql_query"],
236
  ui_elements["drias_table"],
@@ -242,10 +225,6 @@ def setup_drias_events(ui_elements: DriasUIElements, share_client=None, user_id=
242
  table_names_list,
243
  ui_elements["result_text"],
244
  ],
245
- ).then(
246
- log_drias_interaction,
247
- inputs=[ui_elements["examples_hidden"], ui_elements["drias_sql_query"], ui_elements["drias_table"]],
248
- outputs=[],
249
  ).then(
250
  show_results,
251
  inputs=[sql_queries_state, dataframes_state, plots_state],
@@ -276,7 +255,7 @@ def setup_drias_events(ui_elements: DriasUIElements, share_client=None, user_id=
276
  outputs=[ui_elements["details_accordion"]]
277
  ).then(
278
  ask_drias_query,
279
- inputs=[ui_elements["drias_direct_question"], index_state],
280
  outputs=[
281
  ui_elements["drias_sql_query"],
282
  ui_elements["drias_table"],
@@ -288,10 +267,6 @@ def setup_drias_events(ui_elements: DriasUIElements, share_client=None, user_id=
288
  table_names_list,
289
  ui_elements["result_text"],
290
  ],
291
- ).then(
292
- log_drias_interaction,
293
- inputs=[ui_elements["drias_direct_question"], ui_elements["drias_sql_query"], ui_elements["drias_table"]],
294
- outputs=[],
295
  ).then(
296
  show_results,
297
  inputs=[sql_queries_state, dataframes_state, plots_state],
 
5
 
6
  from climateqa.engine.talk_to_data.main import ask_drias
7
  from climateqa.engine.talk_to_data.config import DRIAS_MODELS, DRIAS_UI_TEXT
 
 
8
 
9
  class DriasUIElements(TypedDict):
10
  tab: gr.Tab
 
26
  next_button: gr.Button
27
 
28
 
29
+ async def ask_drias_query(query: str, index_state: int, user_id: str):
30
+ result = await ask_drias(query, index_state, user_id)
31
  return result
32
 
33
 
 
194
  next_button=next_button
195
  )
196
 
197
+
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
  def setup_drias_events(ui_elements: DriasUIElements, share_client=None, user_id=None) -> None:
200
  """Set up all event handlers for the DRIAS tab."""
 
204
  plots_state = gr.State([])
205
  index_state = gr.State(0)
206
  table_names_list = gr.State([])
207
+ user_id = gr.State(user_id)
 
 
 
208
 
209
  # Handle example selection
210
  ui_elements["examples_hidden"].change(
 
213
  outputs=[ui_elements["details_accordion"], ui_elements["drias_direct_question"]]
214
  ).then(
215
  ask_drias_query,
216
+ inputs=[ui_elements["examples_hidden"], index_state, user_id],
217
  outputs=[
218
  ui_elements["drias_sql_query"],
219
  ui_elements["drias_table"],
 
225
  table_names_list,
226
  ui_elements["result_text"],
227
  ],
 
 
 
 
228
  ).then(
229
  show_results,
230
  inputs=[sql_queries_state, dataframes_state, plots_state],
 
255
  outputs=[ui_elements["details_accordion"]]
256
  ).then(
257
  ask_drias_query,
258
+ inputs=[ui_elements["drias_direct_question"], index_state, user_id],
259
  outputs=[
260
  ui_elements["drias_sql_query"],
261
  ui_elements["drias_table"],
 
267
  table_names_list,
268
  ui_elements["result_text"],
269
  ],
 
 
 
 
270
  ).then(
271
  show_results,
272
  inputs=[sql_queries_state, dataframes_state, plots_state],
front/utils.py CHANGED
@@ -13,17 +13,6 @@ def make_pairs(lst:list)->list:
13
  return [(lst[i], lst[i + 1]) for i in range(0, len(lst), 2)]
14
 
15
 
16
- def serialize_docs(docs:list)->list:
17
- new_docs = []
18
- for doc in docs:
19
- new_doc = {}
20
- new_doc["page_content"] = doc.page_content
21
- new_doc["metadata"] = doc.metadata
22
- new_docs.append(new_doc)
23
- return new_docs
24
-
25
-
26
-
27
  def parse_output_llm_with_sources(output:str)->str:
28
  # Split the content into a list of text and "[Doc X]" references
29
  content_parts = re.split(r'\[(Doc\s?\d+(?:,\s?Doc\s?\d+)*)\]', output)
 
13
  return [(lst[i], lst[i + 1]) for i in range(0, len(lst), 2)]
14
 
15
 
 
 
 
 
 
 
 
 
 
 
 
16
  def parse_output_llm_with_sources(output:str)->str:
17
  # Split the content into a list of text and "[Doc X]" references
18
  content_parts = re.split(r'\[(Doc\s?\d+(?:,\s?Doc\s?\d+)*)\]', output)
requirements.txt CHANGED
@@ -8,6 +8,7 @@ langgraph==0.2.70
8
  pinecone-client==4.1.0
9
  sentence-transformers==2.6.0
10
  huggingface-hub==0.25.2
 
11
  pyalex==0.13
12
  networkx==3.2.1
13
  pyvis==0.3.2
 
8
  pinecone-client==4.1.0
9
  sentence-transformers==2.6.0
10
  huggingface-hub==0.25.2
11
+ datasets==3.5.0
12
  pyalex==0.13
13
  networkx==3.2.1
14
  pyvis==0.3.2