Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,149 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
-
import snowflake.connector
|
3 |
-
import replicate
|
4 |
-
import re
|
5 |
-
import pandas as pd
|
6 |
import streamlit as st
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
"
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
)
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
try:
|
56 |
-
cur = conn.cursor()
|
57 |
-
table_names_query = f"""
|
58 |
-
SELECT TABLE_NAME
|
59 |
-
FROM {DATABASE}.INFORMATION_SCHEMA.TABLES
|
60 |
-
WHERE TABLE_SCHEMA = '{SCHEMA}'
|
61 |
-
"""
|
62 |
-
cur.execute(table_names_query)
|
63 |
-
table_names = cur.fetchall()
|
64 |
-
|
65 |
-
for table_name in table_names:
|
66 |
-
table_name = table_name[0]
|
67 |
-
ddl_query = f"""
|
68 |
-
SELECT table_name, column_name, data_type, is_nullable
|
69 |
-
FROM {DATABASE}.INFORMATION_SCHEMA.COLUMNS
|
70 |
-
WHERE table_name = '{table_name}' AND table_schema = '{SCHEMA}'
|
71 |
-
"""
|
72 |
-
cur.execute(ddl_query)
|
73 |
-
ddl_result = cur.fetchall()
|
74 |
-
if ddl_result:
|
75 |
-
ddl = "\n".join(
|
76 |
-
[f"{row[0]}, {row[1]}, {row[2]}, {row[3]}" for row in ddl_result])
|
77 |
-
else:
|
78 |
-
ddl = f"No DDL found for table {table_name}"
|
79 |
-
ddl_data.append(f"-- DDL for table {table_name} --\n{ddl}\n\n")
|
80 |
-
|
81 |
-
with open('sample_ddl.txt', 'w') as file:
|
82 |
-
file.writelines(ddl_data)
|
83 |
-
|
84 |
-
st.success("DDLs written to sample_ddl.txt")
|
85 |
-
finally:
|
86 |
-
cur.close()
|
87 |
-
conn.close()
|
88 |
-
|
89 |
-
|
90 |
-
def generate_sql_query(sample_message):
|
91 |
-
with open('sample_ddl.txt', 'r') as file:
|
92 |
-
ddl_commands = file.read()
|
93 |
-
|
94 |
-
instruction = "Read the all provided ddl statements and work on the statement. if any logical questions asked,find relattion between tables based on primarykey and foreign key difinition use snowflake supported functions to provide snowflake sql construct. validate functions used on coulmn data types,amiguity in joins before present also add as many description columns as possible. Respond only with the SQL query without any explanations or contextual details. if the ask is to find or provide then consider as a request for writing sql construct"
|
95 |
-
combined_input = ddl_commands + sample_message + " " + instruction
|
96 |
-
|
97 |
-
response = interact_with_replicate(combined_input)
|
98 |
-
|
99 |
-
with open('generated_query.sql', 'w') as file:
|
100 |
-
file.write(response)
|
101 |
-
|
102 |
-
st.success("Response from Replicate has been written to 'generated_query.sql'")
|
103 |
-
|
104 |
-
|
105 |
-
def execute_generated_sql():
|
106 |
-
with open('generated_query.sql', 'r') as file:
|
107 |
-
generated_sql = file.read()
|
108 |
-
|
109 |
-
# Print the generated SQL for inspection
|
110 |
-
st.text_area("Generated SQL query", generated_sql, height=200)
|
111 |
-
|
112 |
-
conn = get_snowflake_connection()
|
113 |
-
|
114 |
-
try:
|
115 |
-
cur = conn.cursor()
|
116 |
-
cur.execute(generated_sql)
|
117 |
-
|
118 |
-
if cur.description is not None:
|
119 |
-
result = cur.fetchall()
|
120 |
-
columns = [desc[0] for desc in cur.description]
|
121 |
-
df = pd.DataFrame(result, columns=columns)
|
122 |
-
st.write("Result from executed SQL query:")
|
123 |
-
st.dataframe(df)
|
124 |
-
else:
|
125 |
-
st.write(
|
126 |
-
"The executed SQL did not return any results or is not a SELECT query.")
|
127 |
-
finally:
|
128 |
-
cur.close()
|
129 |
-
conn.close()
|
130 |
-
|
131 |
-
|
132 |
-
def main():
|
133 |
-
st.title("Snowflake and Replicate Integration")
|
134 |
-
|
135 |
-
st.header("Generate SQL Query and Execute")
|
136 |
-
sample_message = st.text_area(
|
137 |
-
"Enter your message for generating SQL query", height=100)
|
138 |
-
if st.button("Generate SQL"):
|
139 |
-
generate_sql_query(sample_message)
|
140 |
-
st.success("SQL Query generated successfully. You can now execute it.")
|
141 |
-
|
142 |
-
if st.button("Execute SQL"):
|
143 |
-
execute_generated_sql()
|
144 |
-
|
145 |
-
|
146 |
-
if __name__ == "__main__":
|
147 |
-
# Fetch DDLs for all tables automatically before starting the app
|
148 |
-
fetch_ddl_for_all_tables()
|
149 |
-
main()
|
|
|
1 |
+
# Questions:
|
2 |
+
# ==========
|
3 |
+
# Show me the total number of entries in the first table
|
4 |
+
# Select top 10 customers from Canada with highest sum of C_ACCTBAL value, in descending order
|
5 |
+
# Show me the total of Customers per Nation, in ascending order
|
6 |
+
# Show me a query that lists totals for extended price, discounted extended price, discounted extended price plus tax, average quantity, average extended price, and average discount. These aggregates are grouped by RETURNFLAG and LINESTATUS, and listed in ascending order of RETURNFLAG and LINESTATUS. A count of the number of line items in each group is included
|
7 |
+
|
8 |
import os
|
|
|
|
|
|
|
|
|
9 |
import streamlit as st
|
10 |
+
import pandas
|
11 |
+
from snowflake.snowpark import Session
|
12 |
+
from langchain_community.utilities import SQLDatabase
|
13 |
+
from langchain.chains import create_sql_query_chain
|
14 |
+
from transformers import AutoTokenizer
|
15 |
+
from langchain_community.llms import Replicate
|
16 |
+
from langchain_core.prompts import PromptTemplate
|
17 |
+
|
18 |
+
st.set_page_config(page_title="Snowflake Arctic", page_icon="🤖")
|
19 |
+
@st.cache_resource(show_spinner="Connecting...")
|
20 |
+
def getSession():
|
21 |
+
section = st.secrets[f"connections_snowflake"]
|
22 |
+
pars = {
|
23 |
+
"account": section["account"],
|
24 |
+
"user": section["user"],
|
25 |
+
"password": section["password"],
|
26 |
+
"database": section["database"],
|
27 |
+
"schema": section["schema"],
|
28 |
+
"warehouse": section["warehouse"],
|
29 |
+
"role": section["role"]
|
30 |
+
}
|
31 |
+
session = Session.builder.configs(pars).create()
|
32 |
+
|
33 |
+
url = (f"snowflake://{pars['user']}:{pars['password']}@{pars['account']}"
|
34 |
+
+ f"/{pars['database']}/{pars['schema']}"
|
35 |
+
+ f"?warehouse={pars['warehouse']}&role={pars['role']}")
|
36 |
+
db = SQLDatabase.from_uri(url)
|
37 |
+
|
38 |
+
os.environ['REPLICATE_API_TOKEN'] = st.secrets["REPLICATE_API_TOKEN"]
|
39 |
+
llm = Replicate(model="snowflake/snowflake-arctic-instruct", model_kwargs={"temperature": 0.75, "top_p": 1},)
|
40 |
+
chain = create_sql_query_chain(llm, db)
|
41 |
+
return session, db, chain
|
42 |
+
|
43 |
+
st.title("❄️ Snowflake Arctic with Replicate")
|
44 |
+
st.write("Returns and runs queries from questions in natural language.")
|
45 |
+
|
46 |
+
session, db, chain = getSession()
|
47 |
+
|
48 |
+
#user_query = st.chat_input("Type your message here...")
|
49 |
+
|
50 |
+
user_query = st.sidebar.text_area("Ask a question:", value="Show me the total number of entries in the first table")
|
51 |
+
sql = chain.invoke({"question": user_query}).rstrip(';')
|
52 |
+
|
53 |
+
tabQuery, tabData, tabLog = st.tabs("Query", "Data", "Log")
|
54 |
+
tabQuery.code(sql, language="sql")
|
55 |
+
tabData.dataframe(session.sql(sql))
|
56 |
+
tabLog.code(db.table_info, language="sql")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|