hfenlume commited on
Commit
c187aed
·
verified ·
1 Parent(s): 695536e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -147
app.py CHANGED
@@ -1,149 +1,56 @@
 
 
 
 
 
 
 
1
  import os
2
- import snowflake.connector
3
- import replicate
4
- import re
5
- import pandas as pd
6
  import streamlit as st
7
-
8
- # Snowflake connection parameters
9
- ACCOUNT = "EZ97576.ap-southeast-1"
10
- USER = "sureshsnowflake"
11
- PASSWORD = "Slavia@123"
12
- WAREHOUSE = "COMPUTE_WH"
13
- DATABASE = "SNOWFLAKE_SAMPLE_DATA"
14
- SCHEMA = "TPCDS_SF100TCL"
15
-
16
- # Replicate API key
17
- os.environ['REPLICATE_API_TOKEN'] = 'r8_E7Rn49bbi2O33bztSMYLKyqvWmo68mZ1Tg8M0'
18
-
19
-
20
- def interact_with_replicate(prompt):
21
- response = ""
22
- for event in replicate.stream(
23
- "snowflake/snowflake-arctic-instruct",
24
- input={
25
- "prompt": prompt,
26
- "max_new_tokens": 250
27
- },
28
- ):
29
- response += str(event)
30
-
31
- # Use regular expressions to extract SQL statements
32
- sql_statements = re.findall(
33
- r"(SELECT.*?;|INSERT.*?;|UPDATE.*?;|DELETE.*?;|CREATE.*?;|ALTER.*?;|DROP.*?;)", response, re.DOTALL | re.IGNORECASE)
34
-
35
- # Join the SQL statements into a single string
36
- return "\n".join(sql_statements)
37
-
38
-
39
- def get_snowflake_connection():
40
- return snowflake.connector.connect(
41
- user=USER,
42
- password=PASSWORD,
43
- account=ACCOUNT,
44
- warehouse=WAREHOUSE,
45
- database=DATABASE,
46
- schema=SCHEMA
47
- )
48
-
49
-
50
- def fetch_ddl_for_all_tables():
51
- conn = get_snowflake_connection()
52
- # Add headers
53
- ddl_data = ["table_name, column_name, data_type, is_nullable\n"]
54
-
55
- try:
56
- cur = conn.cursor()
57
- table_names_query = f"""
58
- SELECT TABLE_NAME
59
- FROM {DATABASE}.INFORMATION_SCHEMA.TABLES
60
- WHERE TABLE_SCHEMA = '{SCHEMA}'
61
- """
62
- cur.execute(table_names_query)
63
- table_names = cur.fetchall()
64
-
65
- for table_name in table_names:
66
- table_name = table_name[0]
67
- ddl_query = f"""
68
- SELECT table_name, column_name, data_type, is_nullable
69
- FROM {DATABASE}.INFORMATION_SCHEMA.COLUMNS
70
- WHERE table_name = '{table_name}' AND table_schema = '{SCHEMA}'
71
- """
72
- cur.execute(ddl_query)
73
- ddl_result = cur.fetchall()
74
- if ddl_result:
75
- ddl = "\n".join(
76
- [f"{row[0]}, {row[1]}, {row[2]}, {row[3]}" for row in ddl_result])
77
- else:
78
- ddl = f"No DDL found for table {table_name}"
79
- ddl_data.append(f"-- DDL for table {table_name} --\n{ddl}\n\n")
80
-
81
- with open('sample_ddl.txt', 'w') as file:
82
- file.writelines(ddl_data)
83
-
84
- st.success("DDLs written to sample_ddl.txt")
85
- finally:
86
- cur.close()
87
- conn.close()
88
-
89
-
90
- def generate_sql_query(sample_message):
91
- with open('sample_ddl.txt', 'r') as file:
92
- ddl_commands = file.read()
93
-
94
- instruction = "Read the all provided ddl statements and work on the statement. if any logical questions asked,find relattion between tables based on primarykey and foreign key difinition use snowflake supported functions to provide snowflake sql construct. validate functions used on coulmn data types,amiguity in joins before present also add as many description columns as possible. Respond only with the SQL query without any explanations or contextual details. if the ask is to find or provide then consider as a request for writing sql construct"
95
- combined_input = ddl_commands + sample_message + " " + instruction
96
-
97
- response = interact_with_replicate(combined_input)
98
-
99
- with open('generated_query.sql', 'w') as file:
100
- file.write(response)
101
-
102
- st.success("Response from Replicate has been written to 'generated_query.sql'")
103
-
104
-
105
- def execute_generated_sql():
106
- with open('generated_query.sql', 'r') as file:
107
- generated_sql = file.read()
108
-
109
- # Print the generated SQL for inspection
110
- st.text_area("Generated SQL query", generated_sql, height=200)
111
-
112
- conn = get_snowflake_connection()
113
-
114
- try:
115
- cur = conn.cursor()
116
- cur.execute(generated_sql)
117
-
118
- if cur.description is not None:
119
- result = cur.fetchall()
120
- columns = [desc[0] for desc in cur.description]
121
- df = pd.DataFrame(result, columns=columns)
122
- st.write("Result from executed SQL query:")
123
- st.dataframe(df)
124
- else:
125
- st.write(
126
- "The executed SQL did not return any results or is not a SELECT query.")
127
- finally:
128
- cur.close()
129
- conn.close()
130
-
131
-
132
- def main():
133
- st.title("Snowflake and Replicate Integration")
134
-
135
- st.header("Generate SQL Query and Execute")
136
- sample_message = st.text_area(
137
- "Enter your message for generating SQL query", height=100)
138
- if st.button("Generate SQL"):
139
- generate_sql_query(sample_message)
140
- st.success("SQL Query generated successfully. You can now execute it.")
141
-
142
- if st.button("Execute SQL"):
143
- execute_generated_sql()
144
-
145
-
146
- if __name__ == "__main__":
147
- # Fetch DDLs for all tables automatically before starting the app
148
- fetch_ddl_for_all_tables()
149
- main()
 
1
+ # Questions:
2
+ # ==========
3
+ # Show me the total number of entries in the first table
4
+ # Select top 10 customers from Canada with highest sum of C_ACCTBAL value, in descending order
5
+ # Show me the total of Customers per Nation, in ascending order
6
+ # Show me a query that lists totals for extended price, discounted extended price, discounted extended price plus tax, average quantity, average extended price, and average discount. These aggregates are grouped by RETURNFLAG and LINESTATUS, and listed in ascending order of RETURNFLAG and LINESTATUS. A count of the number of line items in each group is included
7
+
8
  import os
 
 
 
 
9
  import streamlit as st
10
+ import pandas
11
+ from snowflake.snowpark import Session
12
+ from langchain_community.utilities import SQLDatabase
13
+ from langchain.chains import create_sql_query_chain
14
+ from transformers import AutoTokenizer
15
+ from langchain_community.llms import Replicate
16
+ from langchain_core.prompts import PromptTemplate
17
+
18
+ st.set_page_config(page_title="Snowflake Arctic", page_icon="🤖")
19
+ @st.cache_resource(show_spinner="Connecting...")
20
+ def getSession():
21
+ section = st.secrets[f"connections_snowflake"]
22
+ pars = {
23
+ "account": section["account"],
24
+ "user": section["user"],
25
+ "password": section["password"],
26
+ "database": section["database"],
27
+ "schema": section["schema"],
28
+ "warehouse": section["warehouse"],
29
+ "role": section["role"]
30
+ }
31
+ session = Session.builder.configs(pars).create()
32
+
33
+ url = (f"snowflake://{pars['user']}:{pars['password']}@{pars['account']}"
34
+ + f"/{pars['database']}/{pars['schema']}"
35
+ + f"?warehouse={pars['warehouse']}&role={pars['role']}")
36
+ db = SQLDatabase.from_uri(url)
37
+
38
+ os.environ['REPLICATE_API_TOKEN'] = st.secrets["REPLICATE_API_TOKEN"]
39
+ llm = Replicate(model="snowflake/snowflake-arctic-instruct", model_kwargs={"temperature": 0.75, "top_p": 1},)
40
+ chain = create_sql_query_chain(llm, db)
41
+ return session, db, chain
42
+
43
+ st.title("❄️ Snowflake Arctic with Replicate")
44
+ st.write("Returns and runs queries from questions in natural language.")
45
+
46
+ session, db, chain = getSession()
47
+
48
+ #user_query = st.chat_input("Type your message here...")
49
+
50
+ user_query = st.sidebar.text_area("Ask a question:", value="Show me the total number of entries in the first table")
51
+ sql = chain.invoke({"question": user_query}).rstrip(';')
52
+
53
+ tabQuery, tabData, tabLog = st.tabs("Query", "Data", "Log")
54
+ tabQuery.code(sql, language="sql")
55
+ tabData.dataframe(session.sql(sql))
56
+ tabLog.code(db.table_info, language="sql")