timeki commited on
Commit
eb90d11
·
1 Parent(s): a7802db

Add step by step notebooks for drias

Browse files
sandbox/talk_to_data/20250306 - CQA - Drias.ipynb ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "## Import the function in main.py"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": null,
13
+ "metadata": {},
14
+ "outputs": [],
15
+ "source": [
16
+ "import sys\n",
17
+ "import os\n",
18
+ "sys.path.append(os.path.dirname(os.path.dirname(os.getcwd())))\n",
19
+ "\n",
20
+ "%load_ext autoreload\n",
21
+ "%autoreload 2\n",
22
+ "\n",
23
+ "from climateqa.engine.talk_to_data.main import ask_vanna\n"
24
+ ]
25
+ },
26
+ {
27
+ "cell_type": "markdown",
28
+ "metadata": {},
29
+ "source": [
30
+ "## Create a human query"
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "code",
35
+ "execution_count": null,
36
+ "metadata": {},
37
+ "outputs": [],
38
+ "source": [
39
+ "query = \"Comment vont évoluer les températures à marseille ?\""
40
+ ]
41
+ },
42
+ {
43
+ "cell_type": "markdown",
44
+ "metadata": {},
45
+ "source": [
46
+ "## Call the function ask vanna, it gives an output of a the sql query and the dataframe of the result (tuple)"
47
+ ]
48
+ },
49
+ {
50
+ "cell_type": "code",
51
+ "execution_count": null,
52
+ "metadata": {},
53
+ "outputs": [],
54
+ "source": [
55
+ "sql_query, df, fig = ask_vanna(query)\n",
56
+ "print(df.head())\n",
57
+ "fig.show()"
58
+ ]
59
+ }
60
+ ],
61
+ "metadata": {
62
+ "kernelspec": {
63
+ "display_name": "climateqa",
64
+ "language": "python",
65
+ "name": "python3"
66
+ },
67
+ "language_info": {
68
+ "codemirror_mode": {
69
+ "name": "ipython",
70
+ "version": 3
71
+ },
72
+ "file_extension": ".py",
73
+ "mimetype": "text/x-python",
74
+ "name": "python",
75
+ "nbconvert_exporter": "python",
76
+ "pygments_lexer": "ipython3",
77
+ "version": "3.11.9"
78
+ }
79
+ },
80
+ "nbformat": 4,
81
+ "nbformat_minor": 2
82
+ }
sandbox/talk_to_data/20250306 - CQA - Step_by_step_vanna.ipynb ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import sys\n",
10
+ "import os\n",
11
+ "sys.path.append(os.path.dirname(os.path.dirname(os.getcwd())))\n",
12
+ "\n",
13
+ "%load_ext autoreload\n",
14
+ "%autoreload 2\n",
15
+ "\n",
16
+ "from climateqa.engine.talk_to_data.main import ask_vanna\n",
17
+ "\n",
18
+ "import sqlite3\n",
19
+ "import os\n",
20
+ "import pandas as pd"
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "markdown",
25
+ "metadata": {},
26
+ "source": [
27
+ "# Imports"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "execution_count": null,
33
+ "metadata": {},
34
+ "outputs": [],
35
+ "source": [
36
+ "from climateqa.engine.talk_to_data.myVanna import MyVanna\n",
37
+ "from climateqa.engine.talk_to_data.utils import loc2coords, detect_location_with_openai, detectTable, nearestNeighbourSQL, detect_relevant_tables, replace_coordonates#,nearestNeighbourPostgres\n",
38
+ "\n",
39
+ "from climateqa.engine.llm import get_llm"
40
+ ]
41
+ },
42
+ {
43
+ "cell_type": "markdown",
44
+ "metadata": {},
45
+ "source": [
46
+ "# Vanna Ask\n"
47
+ ]
48
+ },
49
+ {
50
+ "cell_type": "code",
51
+ "execution_count": null,
52
+ "metadata": {},
53
+ "outputs": [],
54
+ "source": [
55
+ "from dotenv import load_dotenv\n",
56
+ "\n",
57
+ "load_dotenv()\n",
58
+ "\n",
59
+ "llm = get_llm(provider=\"openai\")\n",
60
+ "\n",
61
+ "OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')\n",
62
+ "PC_API_KEY = os.getenv('VANNA_PINECONE_API_KEY')\n",
63
+ "INDEX_NAME = os.getenv('VANNA_INDEX_NAME')\n",
64
+ "VANNA_MODEL = os.getenv('VANNA_MODEL')\n",
65
+ "\n",
66
+ "ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))\n",
67
+ "\n",
68
+ "#Vanna object\n",
69
+ "vn = MyVanna(config = {\"temperature\": 0, \"api_key\": OPENAI_API_KEY, 'model': VANNA_MODEL, 'pc_api_key': PC_API_KEY, 'index_name': INDEX_NAME, \"top_k\" : 4})\n",
70
+ "\n",
71
+ "db_vanna_path = ROOT_PATH + \"/data/drias/drias.db\"\n",
72
+ "vn.connect_to_sqlite(db_vanna_path)\n"
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "markdown",
77
+ "metadata": {},
78
+ "source": [
79
+ "# User query"
80
+ ]
81
+ },
82
+ {
83
+ "cell_type": "code",
84
+ "execution_count": null,
85
+ "metadata": {},
86
+ "outputs": [],
87
+ "source": [
88
+ "query = \"Quelle sera la température à Marseille sur les prochaines années ?\""
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "markdown",
93
+ "metadata": {},
94
+ "source": [
95
+ "## Detect location"
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "code",
100
+ "execution_count": null,
101
+ "metadata": {},
102
+ "outputs": [],
103
+ "source": [
104
+ "location = detect_location_with_openai(OPENAI_API_KEY, query)\n",
105
+ "print(location)"
106
+ ]
107
+ },
108
+ {
109
+ "cell_type": "markdown",
110
+ "metadata": {},
111
+ "source": [
112
+ "## Convert location to longitude, latitude coordonate"
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "code",
117
+ "execution_count": null,
118
+ "metadata": {},
119
+ "outputs": [],
120
+ "source": [
121
+ "coords = loc2coords(location)\n",
122
+ "user_input = query.lower().replace(location.lower(), f\"lat, long : {coords}\")\n",
123
+ "print(user_input)"
124
+ ]
125
+ },
126
+ {
127
+ "cell_type": "markdown",
128
+ "metadata": {},
129
+ "source": [
130
+ "# Find closest coordonates and replace lat,lon\n"
131
+ ]
132
+ },
133
+ {
134
+ "cell_type": "code",
135
+ "execution_count": null,
136
+ "metadata": {},
137
+ "outputs": [],
138
+ "source": [
139
+ "relevant_tables = detect_relevant_tables(user_input, llm) \n",
140
+ "coords_tables = [nearestNeighbourSQL(db_vanna_path, coords, relevant_tables[i]) for i in range(len(relevant_tables))]\n",
141
+ "user_input_with_coords = replace_coordonates(coords, user_input, coords_tables)\n",
142
+ "print(user_input_with_coords)"
143
+ ]
144
+ },
145
+ {
146
+ "cell_type": "markdown",
147
+ "metadata": {},
148
+ "source": [
149
+ "# Ask Vanna with correct coordonates"
150
+ ]
151
+ },
152
+ {
153
+ "cell_type": "code",
154
+ "execution_count": null,
155
+ "metadata": {},
156
+ "outputs": [],
157
+ "source": [
158
+ "user_input_with_coords"
159
+ ]
160
+ },
161
+ {
162
+ "cell_type": "code",
163
+ "execution_count": null,
164
+ "metadata": {},
165
+ "outputs": [],
166
+ "source": [
167
+ "sql_query, result_dataframe, figure = vn.ask(user_input_with_coords, print_results=False, allow_llm_to_see_data=True, auto_train=False)\n",
168
+ "print(result_dataframe.head())"
169
+ ]
170
+ },
171
+ {
172
+ "cell_type": "code",
173
+ "execution_count": null,
174
+ "metadata": {},
175
+ "outputs": [],
176
+ "source": [
177
+ "result_dataframe"
178
+ ]
179
+ },
180
+ {
181
+ "cell_type": "code",
182
+ "execution_count": null,
183
+ "metadata": {},
184
+ "outputs": [],
185
+ "source": [
186
+ "figure"
187
+ ]
188
+ },
189
+ {
190
+ "cell_type": "code",
191
+ "execution_count": null,
192
+ "metadata": {},
193
+ "outputs": [],
194
+ "source": []
195
+ }
196
+ ],
197
+ "metadata": {
198
+ "kernelspec": {
199
+ "display_name": "climateqa",
200
+ "language": "python",
201
+ "name": "python3"
202
+ },
203
+ "language_info": {
204
+ "codemirror_mode": {
205
+ "name": "ipython",
206
+ "version": 3
207
+ },
208
+ "file_extension": ".py",
209
+ "mimetype": "text/x-python",
210
+ "name": "python",
211
+ "nbconvert_exporter": "python",
212
+ "pygments_lexer": "ipython3",
213
+ "version": "3.11.9"
214
+ }
215
+ },
216
+ "nbformat": 4,
217
+ "nbformat_minor": 2
218
+ }