demo_app / app.py
Kashif17's picture
Create app.py
a190708 verified
import streamlit as st
from transformers import pipeline, GPT2Tokenizer, GPT2LMHeadModel
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
# Load the model and tokenizer
def load_model():
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
return pipeline('text-generation', model=model, tokenizer=tokenizer)
text_generator = load_model()
def generate_sql_query(prompt):
response = text_generator(prompt, max_length=100, num_return_sequences=1)
return response[0]['generated_text'].strip()
def execute_sql_query(sql_query, data):
# In this case, we'll directly use the pandas DataFrame as our data source
result = pd.read_csv('Movies.csv')
return result
def analyze_and_visualize(data, prompt):
if 'trend' in prompt.lower() or 'time' in data.columns:
vis_type = 'line'
elif data.select_dtypes(include=['object']).shape[1] > 0:
vis_type = 'bar'
else:
vis_type = 'scatter'
plt.figure(figsize=(10, 6))
if vis_type == 'line':
sns.lineplot(data=data)
plt.title('Trend Analysis')
elif vis_type == 'bar':
categorical_column = data.select_dtypes(include=['object']).columns[0]
numerical_column = data.select_dtypes(include=['number']).columns[0]
sns.barplot(x=categorical_column, y=numerical_column, data=data)
plt.title('Categorical Data Analysis')
else:
sns.scatterplot(data=data)
plt.title('Scatter Plot')
plt.xlabel('X-axis Label')
plt.ylabel('Y-axis Label')
plt.xticks(rotation=45)
plt.tight_layout()
return plt # Return the plot object instead of showing it
# Streamlit app
def main():
st.title("Conversational BI Bot")
prompt = st.text_input("Enter your query:", "")
uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
if st.button("Generate and Visualize") and uploaded_file is not None:
if prompt:
with st.spinner('Generating SQL query...'):
sql_query = generate_sql_query(prompt)
st.write(f"Generated SQL Query: `{sql_query}`")
with st.spinner('Executing SQL query...'):
data = execute_sql_query(sql_query, uploaded_file)
st.write("Query Results:", data)
with st.spinner('Generating Visualization...'):
plot = analyze_and_visualize(data, prompt)
st.pyplot(plot) # Display the plot using st.pyplot()
if __name__ == "__main__":
main()