Kashif17 commited on
Commit
a190708
·
verified ·
1 Parent(s): 6d3ebd3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -0
app.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import pipeline, GPT2Tokenizer, GPT2LMHeadModel
3
+ import pandas as pd
4
+ import matplotlib.pyplot as plt
5
+ import seaborn as sns
6
+
7
+ # Load the model and tokenizer
8
+ def load_model():
9
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
10
+ model = GPT2LMHeadModel.from_pretrained("gpt2")
11
+ return pipeline('text-generation', model=model, tokenizer=tokenizer)
12
+
13
+ text_generator = load_model()
14
+
15
+ def generate_sql_query(prompt):
16
+ response = text_generator(prompt, max_length=100, num_return_sequences=1)
17
+ return response[0]['generated_text'].strip()
18
+
19
+ def execute_sql_query(sql_query, data):
20
+ # In this case, we'll directly use the pandas DataFrame as our data source
21
+ result = pd.read_csv('Movies.csv')
22
+ return result
23
+
24
+ def analyze_and_visualize(data, prompt):
25
+ if 'trend' in prompt.lower() or 'time' in data.columns:
26
+ vis_type = 'line'
27
+ elif data.select_dtypes(include=['object']).shape[1] > 0:
28
+ vis_type = 'bar'
29
+ else:
30
+ vis_type = 'scatter'
31
+
32
+ plt.figure(figsize=(10, 6))
33
+ if vis_type == 'line':
34
+ sns.lineplot(data=data)
35
+ plt.title('Trend Analysis')
36
+ elif vis_type == 'bar':
37
+ categorical_column = data.select_dtypes(include=['object']).columns[0]
38
+ numerical_column = data.select_dtypes(include=['number']).columns[0]
39
+ sns.barplot(x=categorical_column, y=numerical_column, data=data)
40
+ plt.title('Categorical Data Analysis')
41
+ else:
42
+ sns.scatterplot(data=data)
43
+ plt.title('Scatter Plot')
44
+
45
+ plt.xlabel('X-axis Label')
46
+ plt.ylabel('Y-axis Label')
47
+ plt.xticks(rotation=45)
48
+ plt.tight_layout()
49
+ return plt # Return the plot object instead of showing it
50
+
51
+ # Streamlit app
52
+ def main():
53
+ st.title("Conversational BI Bot")
54
+ prompt = st.text_input("Enter your query:", "")
55
+ uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
56
+
57
+ if st.button("Generate and Visualize") and uploaded_file is not None:
58
+ if prompt:
59
+ with st.spinner('Generating SQL query...'):
60
+ sql_query = generate_sql_query(prompt)
61
+ st.write(f"Generated SQL Query: `{sql_query}`")
62
+
63
+ with st.spinner('Executing SQL query...'):
64
+ data = execute_sql_query(sql_query, uploaded_file)
65
+ st.write("Query Results:", data)
66
+
67
+ with st.spinner('Generating Visualization...'):
68
+ plot = analyze_and_visualize(data, prompt)
69
+ st.pyplot(plot) # Display the plot using st.pyplot()
70
+
71
+ if __name__ == "__main__":
72
+ main()