import streamlit as st from transformers import pipeline, GPT2Tokenizer, GPT2LMHeadModel import pandas as pd import matplotlib.pyplot as plt import seaborn as sns # Load the model and tokenizer def load_model(): tokenizer = GPT2Tokenizer.from_pretrained("gpt2") model = GPT2LMHeadModel.from_pretrained("gpt2") return pipeline('text-generation', model=model, tokenizer=tokenizer) text_generator = load_model() def generate_sql_query(prompt): response = text_generator(prompt, max_length=100, num_return_sequences=1) return response[0]['generated_text'].strip() def execute_sql_query(sql_query, data): # In this case, we'll directly use the pandas DataFrame as our data source result = pd.read_csv('Movies.csv') return result def analyze_and_visualize(data, prompt): if 'trend' in prompt.lower() or 'time' in data.columns: vis_type = 'line' elif data.select_dtypes(include=['object']).shape[1] > 0: vis_type = 'bar' else: vis_type = 'scatter' plt.figure(figsize=(10, 6)) if vis_type == 'line': sns.lineplot(data=data) plt.title('Trend Analysis') elif vis_type == 'bar': categorical_column = data.select_dtypes(include=['object']).columns[0] numerical_column = data.select_dtypes(include=['number']).columns[0] sns.barplot(x=categorical_column, y=numerical_column, data=data) plt.title('Categorical Data Analysis') else: sns.scatterplot(data=data) plt.title('Scatter Plot') plt.xlabel('X-axis Label') plt.ylabel('Y-axis Label') plt.xticks(rotation=45) plt.tight_layout() return plt # Return the plot object instead of showing it # Streamlit app def main(): st.title("Conversational BI Bot") prompt = st.text_input("Enter your query:", "") uploaded_file = st.file_uploader("Choose a CSV file", type="csv") if st.button("Generate and Visualize") and uploaded_file is not None: if prompt: with st.spinner('Generating SQL query...'): sql_query = generate_sql_query(prompt) st.write(f"Generated SQL Query: `{sql_query}`") with st.spinner('Executing SQL query...'): data = execute_sql_query(sql_query, uploaded_file) st.write("Query Results:", data) with st.spinner('Generating Visualization...'): plot = analyze_and_visualize(data, prompt) st.pyplot(plot) # Display the plot using st.pyplot() if __name__ == "__main__": main()