|
import streamlit as st |
|
from transformers import pipeline, GPT2Tokenizer, GPT2LMHeadModel |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
|
|
|
|
def load_model(): |
|
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") |
|
model = GPT2LMHeadModel.from_pretrained("gpt2") |
|
return pipeline('text-generation', model=model, tokenizer=tokenizer) |
|
|
|
text_generator = load_model() |
|
|
|
def generate_sql_query(prompt): |
|
response = text_generator(prompt, max_length=100, num_return_sequences=1) |
|
return response[0]['generated_text'].strip() |
|
|
|
def execute_sql_query(sql_query, data): |
|
|
|
result = pd.read_csv('Movies.csv') |
|
return result |
|
|
|
def analyze_and_visualize(data, prompt): |
|
if 'trend' in prompt.lower() or 'time' in data.columns: |
|
vis_type = 'line' |
|
elif data.select_dtypes(include=['object']).shape[1] > 0: |
|
vis_type = 'bar' |
|
else: |
|
vis_type = 'scatter' |
|
|
|
plt.figure(figsize=(10, 6)) |
|
if vis_type == 'line': |
|
sns.lineplot(data=data) |
|
plt.title('Trend Analysis') |
|
elif vis_type == 'bar': |
|
categorical_column = data.select_dtypes(include=['object']).columns[0] |
|
numerical_column = data.select_dtypes(include=['number']).columns[0] |
|
sns.barplot(x=categorical_column, y=numerical_column, data=data) |
|
plt.title('Categorical Data Analysis') |
|
else: |
|
sns.scatterplot(data=data) |
|
plt.title('Scatter Plot') |
|
|
|
plt.xlabel('X-axis Label') |
|
plt.ylabel('Y-axis Label') |
|
plt.xticks(rotation=45) |
|
plt.tight_layout() |
|
return plt |
|
|
|
|
|
def main(): |
|
st.title("Conversational BI Bot") |
|
prompt = st.text_input("Enter your query:", "") |
|
uploaded_file = st.file_uploader("Choose a CSV file", type="csv") |
|
|
|
if st.button("Generate and Visualize") and uploaded_file is not None: |
|
if prompt: |
|
with st.spinner('Generating SQL query...'): |
|
sql_query = generate_sql_query(prompt) |
|
st.write(f"Generated SQL Query: `{sql_query}`") |
|
|
|
with st.spinner('Executing SQL query...'): |
|
data = execute_sql_query(sql_query, uploaded_file) |
|
st.write("Query Results:", data) |
|
|
|
with st.spinner('Generating Visualization...'): |
|
plot = analyze_and_visualize(data, prompt) |
|
st.pyplot(plot) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|