armanddemasson commited on
Commit
e6e652c
·
1 Parent(s): 1eae86b

fix: fixed bugs and errors

Browse files
climateqa/engine/talk_to_data/plot.py CHANGED
@@ -241,7 +241,6 @@ def plot_distribution_of_indicator_for_given_year(
241
  yaxis_title="Frequency",
242
  plot_bgcolor="rgba(0, 0, 0, 0)",
243
  showlegend=False,
244
- pan=False
245
  )
246
 
247
  return fig
@@ -314,8 +313,8 @@ def plot_map_of_france_of_indicator_for_given_year(
314
  mapbox_style="open-street-map", # Use OpenStreetMap
315
  mapbox_zoom=3,
316
  mapbox_center={"lat": 46.6, "lon": 2.0},
317
- coloraxis_colorbar=dict(title=f"{indicator_label} {'(Model Average)' if model == 'ALL' else '(Model : ' + model + ')'}"), # Add legend
318
- title=f"{indicator_label} in {year} in France ", # Title
319
  )
320
  return fig
321
 
 
241
  yaxis_title="Frequency",
242
  plot_bgcolor="rgba(0, 0, 0, 0)",
243
  showlegend=False,
 
244
  )
245
 
246
  return fig
 
313
  mapbox_style="open-street-map", # Use OpenStreetMap
314
  mapbox_zoom=3,
315
  mapbox_center={"lat": 46.6, "lon": 2.0},
316
+ coloraxis_colorbar=dict(title=f"{indicator_label}"), # Add legend
317
+ title=f"{indicator_label} in {year} in France {'(Model Average)' if model == 'ALL' else '(Model : ' + model + ')'} " # Title
318
  )
319
  return fig
320
 
climateqa/engine/talk_to_data/sql_query.py CHANGED
@@ -66,9 +66,9 @@ def indicator_per_year_at_location_query(
66
  return ""
67
 
68
  if model == 'ALL':
69
- sql_query = f"SELECT year, {indicator_column}, model\nFROM {table}\nWHERE latitude = {latitude} \nand longitude={longitude} \nOrder by Year"
70
  else:
71
- sql_query = f"SELECT year, {indicator_column}, model\nFROM {table}\nWHERE latitude = {latitude} \nand longitude={longitude} \nand model='{model}' \nOrder by Year"
72
 
73
  return sql_query
74
 
@@ -98,5 +98,5 @@ def indicator_for_given_year_query(
98
  if model == 'ALL':
99
  sql_query = f"Select {indicator_column}, latitude, longitude, model\nFrom {table}\nWhere year = {year}"
100
  else:
101
- sql_query = f"Select {indicator_column}, latitude, longitude, model\nFrom {table}\nWhere year = {year}\nand model = '{model}'"
102
  return sql_query
 
66
  return ""
67
 
68
  if model == 'ALL':
69
+ sql_query = f"SELECT year, {indicator_column}, model\nFROM {table}\nWHERE latitude = {latitude} \nAnd longitude = {longitude} \nOrder by Year"
70
  else:
71
+ sql_query = f"SELECT year, {indicator_column}, model\nFROM {table}\nWHERE latitude = {latitude} \nAnd longitude = {longitude} \nAnd model = '{model}' \nOrder by Year"
72
 
73
  return sql_query
74
 
 
98
  if model == 'ALL':
99
  sql_query = f"Select {indicator_column}, latitude, longitude, model\nFrom {table}\nWhere year = {year}"
100
  else:
101
+ sql_query = f"Select {indicator_column}, latitude, longitude, model\nFrom {table}\nWhere year = {year}\nAnd model = '{model}'"
102
  return sql_query
climateqa/engine/talk_to_data/utils.py CHANGED
@@ -1,4 +1,5 @@
1
  import re
 
2
 
3
  from sympy import use
4
  from geopy.geocoders import Nominatim
@@ -6,6 +7,7 @@ import sqlite3
6
  import ast
7
  from climateqa.engine.llm import get_llm
8
  from climateqa.engine.talk_to_data.plot import PLOTS, Plot
 
9
 
10
 
11
  def detect_location_with_openai(sentence):
@@ -27,28 +29,31 @@ def detect_location_with_openai(sentence):
27
  return location_list[0]
28
  else:
29
  return ""
30
-
 
 
 
 
 
31
  def detect_year_with_openai(sentence: str):
32
  """
33
  Detects years in a sentence using OpenAI's API via LangChain.
34
  """
35
  llm = get_llm()
36
 
37
- prompt = f"""
38
  Extract all years mentioned in the following sentence.
39
  Return the result as a Python list. If no year are mentioned, return an empty list.
40
 
41
  Sentence: "{sentence}"
42
  """
43
 
44
- response = llm.invoke(prompt)
45
- if response is None:
46
- return None
47
- response_split = response.content.strip("```python\n").split('=')
48
- years_list = []
49
- if len(response_split) > 1:
50
- years_list = ast.literal_eval(response_split[1])
51
- if years_list and len(years_list) > 0:
52
  return years_list[0]
53
  else:
54
  return None
 
1
  import re
2
+ from typing import Annotated, TypedDict
3
 
4
  from sympy import use
5
  from geopy.geocoders import Nominatim
 
7
  import ast
8
  from climateqa.engine.llm import get_llm
9
  from climateqa.engine.talk_to_data.plot import PLOTS, Plot
10
+ from langchain_core.prompts import ChatPromptTemplate
11
 
12
 
13
  def detect_location_with_openai(sentence):
 
29
  return location_list[0]
30
  else:
31
  return ""
32
+
33
+ class ArrayOutput(TypedDict):
34
+ """Generated SQL query."""
35
+
36
+ array: Annotated[str, ..., "Syntactically valid python array."]
37
+
38
  def detect_year_with_openai(sentence: str):
39
  """
40
  Detects years in a sentence using OpenAI's API via LangChain.
41
  """
42
  llm = get_llm()
43
 
44
+ prompt = """
45
  Extract all years mentioned in the following sentence.
46
  Return the result as a Python list. If no year are mentioned, return an empty list.
47
 
48
  Sentence: "{sentence}"
49
  """
50
 
51
+ prompt = ChatPromptTemplate.from_template(prompt)
52
+ structured_llm = llm.with_structured_output(ArrayOutput)
53
+ chain = prompt | structured_llm
54
+ response: ArrayOutput = chain.invoke({"sentence": sentence})
55
+ years_list = eval(response['array'])
56
+ if len(years_list) > 0:
 
 
57
  return years_list[0]
58
  else:
59
  return None
climateqa/engine/talk_to_data/workflow.py CHANGED
@@ -99,7 +99,6 @@ def drias_workflow(db_drias_path: str, user_input: str, model: str) -> State:
99
  if sql_query == "":
100
  table_state['status'] = 'ERROR'
101
  continue
102
- print(sql_query)
103
 
104
  table_state['sql_query'] = sql_query
105
  results = execute_sql_query(db_drias_path, sql_query)
 
99
  if sql_query == "":
100
  table_state['status'] = 'ERROR'
101
  continue
 
102
 
103
  table_state['sql_query'] = sql_query
104
  results = execute_sql_query(db_drias_path, sql_query)