Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from transformers import pipeline, GPT2Tokenizer, GPT2LMHeadModel
|
3 |
+
import pandas as pd
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import seaborn as sns
|
6 |
+
|
7 |
+
# Load the model and tokenizer
|
8 |
+
def load_model():
|
9 |
+
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
10 |
+
model = GPT2LMHeadModel.from_pretrained("gpt2")
|
11 |
+
return pipeline('text-generation', model=model, tokenizer=tokenizer)
|
12 |
+
|
13 |
+
text_generator = load_model()
|
14 |
+
|
15 |
+
def generate_sql_query(prompt):
|
16 |
+
response = text_generator(prompt, max_length=100, num_return_sequences=1)
|
17 |
+
return response[0]['generated_text'].strip()
|
18 |
+
|
19 |
+
def execute_sql_query(sql_query, data):
|
20 |
+
# In this case, we'll directly use the pandas DataFrame as our data source
|
21 |
+
result = pd.read_csv('Movies.csv')
|
22 |
+
return result
|
23 |
+
|
24 |
+
def analyze_and_visualize(data, prompt):
|
25 |
+
if 'trend' in prompt.lower() or 'time' in data.columns:
|
26 |
+
vis_type = 'line'
|
27 |
+
elif data.select_dtypes(include=['object']).shape[1] > 0:
|
28 |
+
vis_type = 'bar'
|
29 |
+
else:
|
30 |
+
vis_type = 'scatter'
|
31 |
+
|
32 |
+
plt.figure(figsize=(10, 6))
|
33 |
+
if vis_type == 'line':
|
34 |
+
sns.lineplot(data=data)
|
35 |
+
plt.title('Trend Analysis')
|
36 |
+
elif vis_type == 'bar':
|
37 |
+
categorical_column = data.select_dtypes(include=['object']).columns[0]
|
38 |
+
numerical_column = data.select_dtypes(include=['number']).columns[0]
|
39 |
+
sns.barplot(x=categorical_column, y=numerical_column, data=data)
|
40 |
+
plt.title('Categorical Data Analysis')
|
41 |
+
else:
|
42 |
+
sns.scatterplot(data=data)
|
43 |
+
plt.title('Scatter Plot')
|
44 |
+
|
45 |
+
plt.xlabel('X-axis Label')
|
46 |
+
plt.ylabel('Y-axis Label')
|
47 |
+
plt.xticks(rotation=45)
|
48 |
+
plt.tight_layout()
|
49 |
+
return plt # Return the plot object instead of showing it
|
50 |
+
|
51 |
+
# Streamlit app
|
52 |
+
def main():
|
53 |
+
st.title("Conversational BI Bot")
|
54 |
+
prompt = st.text_input("Enter your query:", "")
|
55 |
+
uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
|
56 |
+
|
57 |
+
if st.button("Generate and Visualize") and uploaded_file is not None:
|
58 |
+
if prompt:
|
59 |
+
with st.spinner('Generating SQL query...'):
|
60 |
+
sql_query = generate_sql_query(prompt)
|
61 |
+
st.write(f"Generated SQL Query: `{sql_query}`")
|
62 |
+
|
63 |
+
with st.spinner('Executing SQL query...'):
|
64 |
+
data = execute_sql_query(sql_query, uploaded_file)
|
65 |
+
st.write("Query Results:", data)
|
66 |
+
|
67 |
+
with st.spinner('Generating Visualization...'):
|
68 |
+
plot = analyze_and_visualize(data, prompt)
|
69 |
+
st.pyplot(plot) # Display the plot using st.pyplot()
|
70 |
+
|
71 |
+
if __name__ == "__main__":
|
72 |
+
main()
|