File size: 2,645 Bytes
a190708
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import streamlit as st
from transformers import pipeline, GPT2Tokenizer, GPT2LMHeadModel
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Load the model and tokenizer
def load_model():
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
    model = GPT2LMHeadModel.from_pretrained("gpt2")
    return pipeline('text-generation', model=model, tokenizer=tokenizer)

text_generator = load_model()

def generate_sql_query(prompt):
    response = text_generator(prompt, max_length=100, num_return_sequences=1)
    return response[0]['generated_text'].strip()

def execute_sql_query(sql_query, data):
    # In this case, we'll directly use the pandas DataFrame as our data source
    result = pd.read_csv('Movies.csv')
    return result

def analyze_and_visualize(data, prompt):
    if 'trend' in prompt.lower() or 'time' in data.columns:
        vis_type = 'line'
    elif data.select_dtypes(include=['object']).shape[1] > 0:
        vis_type = 'bar'
    else:
        vis_type = 'scatter'
    
    plt.figure(figsize=(10, 6))
    if vis_type == 'line':
        sns.lineplot(data=data)
        plt.title('Trend Analysis')
    elif vis_type == 'bar':
        categorical_column = data.select_dtypes(include=['object']).columns[0]
        numerical_column = data.select_dtypes(include=['number']).columns[0]
        sns.barplot(x=categorical_column, y=numerical_column, data=data)
        plt.title('Categorical Data Analysis')
    else:
        sns.scatterplot(data=data)
        plt.title('Scatter Plot')
    
    plt.xlabel('X-axis Label')
    plt.ylabel('Y-axis Label')
    plt.xticks(rotation=45)
    plt.tight_layout()
    return plt  # Return the plot object instead of showing it

# Streamlit app
def main():
    st.title("Conversational BI Bot")
    prompt = st.text_input("Enter your query:", "")
    uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
    
    if st.button("Generate and Visualize") and uploaded_file is not None:
        if prompt:
            with st.spinner('Generating SQL query...'):
                sql_query = generate_sql_query(prompt)
                st.write(f"Generated SQL Query: `{sql_query}`")
                
                with st.spinner('Executing SQL query...'):
                    data = execute_sql_query(sql_query, uploaded_file)
                    st.write("Query Results:", data)
                    
                    with st.spinner('Generating Visualization...'):
                        plot = analyze_and_visualize(data, prompt)
                        st.pyplot(plot)  # Display the plot using st.pyplot()

if __name__ == "__main__":
    main()