|
import streamlit as st |
|
import numpy as np |
|
from llm import load_llm, response_generator |
|
from sql import csv_to_sqlite, run_sql_query |
|
|
|
|
|
repo_id = "Qwen/Qwen2.5-Coder-1.5B-Instruct-GGUF" |
|
filename = "qwen2.5-coder-1.5b-instruct-q8_0.gguf" |
|
|
|
|
|
|
|
llm = load_llm(repo_id, filename) |
|
|
|
st.title("CSV TO SQL") |
|
st.write("To start, Upload your CSV below π") |
|
if st.button("Example prompt"): |
|
st.session_state.csv_file = "./data/sales.csv" |
|
st.session_state.db_name = "sales" |
|
st.session_state.table_name = "sales" |
|
csv_to_sqlite("./data/sales.csv", "sales", "sales") |
|
|
|
prompt = "What is the sum, count and average sales?" |
|
|
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
response_sql = response_generator( |
|
db_name=st.session_state.db_name, |
|
table_name=st.session_state.table_name, |
|
llm=llm, |
|
messages=st.session_state.messages, |
|
question=prompt, |
|
) |
|
result = run_sql_query(db_name=st.session_state.db_name, query=response_sql) |
|
st.session_state.messages.append({"role": "assistant", "content": response_sql}) |
|
st.session_state.messages.append( |
|
{"role": "assistant", "content": str(result), "result": result} |
|
) |
|
|
|
|
|
with st.expander("Upload CSV"): |
|
csv_file = st.file_uploader( |
|
"CSV", |
|
) |
|
db_name = st.text_input("DB Name") |
|
table_name = st.text_input("Table Name") |
|
if st.button("Save"): |
|
if csv_file and db_name and table_name: |
|
st.session_state.csv_file = csv_file |
|
st.session_state.db_name = db_name |
|
st.session_state.table_name = table_name |
|
|
|
csv_to_sqlite(csv_file, db_name, table_name) |
|
st.write("Saved β
") |
|
else: |
|
st.write("Please enter all values") |
|
|
|
|
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [] |
|
|
|
|
|
for message in st.session_state.messages: |
|
with st.chat_message(message["role"]): |
|
if "content" in message: |
|
if message["role"] == "user": |
|
st.markdown(message["content"]) |
|
else: |
|
st.code(message["content"]) |
|
if "result" in message: |
|
st.dataframe(message["result"]) |
|
|
|
|
|
if prompt := st.chat_input( |
|
"What is up?", |
|
disabled=( |
|
not "db_name" in st.session_state or not "table_name" in st.session_state |
|
), |
|
): |
|
|
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
|
|
with st.chat_message("user"): |
|
st.markdown(prompt) |
|
|
|
|
|
with st.chat_message("assistant"): |
|
response_sql = response_generator( |
|
db_name=st.session_state.db_name, |
|
table_name=st.session_state.table_name, |
|
llm=llm, |
|
messages=st.session_state.messages, |
|
question=prompt, |
|
) |
|
response = st.code(response_sql) |
|
result = run_sql_query(db_name=st.session_state.db_name, query=response_sql) |
|
st.markdown(result) |
|
st.table(result) |
|
|
|
|
|
st.session_state.messages.append({"role": "assistant", "content": response_sql}) |
|
|
|
with st.sidebar: |
|
st.title("Data Previewer") |
|
st.write("You can see you CSV file content here") |
|
if ( |
|
"csv_file" in st.session_state |
|
and "db_name" in st.session_state |
|
and "table_name" in st.session_state |
|
): |
|
result = run_sql_query( |
|
db_name=st.session_state.db_name, |
|
query=f"SELECT * FROM {st.session_state.table_name}", |
|
) |
|
st.dataframe(result) |
|
|