Spaces:
Sleeping
Sleeping
import os | |
import streamlit as st | |
from st_aggrid import AgGrid | |
import pandas as pd | |
from transformers import pipeline, T5ForConditionalGeneration, T5Tokenizer | |
# Set the page layout for Streamlit | |
st.set_page_config(layout="wide") | |
# CSS styling | |
# ... (keep your existing CSS code) | |
# Initialize TAPAS pipeline | |
tqa = pipeline(task="table-question-answering", | |
model="google/tapas-large-finetuned-wtq", | |
device="cpu") | |
# Initialize T5 tokenizer and model for text generation | |
t5_tokenizer = T5Tokenizer.from_pretrained("t5-small") | |
t5_model = T5ForConditionalGeneration.from_pretrained("t5-small") | |
# Title and Introduction | |
st.title("HERTOG-AI Table Question Answering and Data Analysis App") | |
st.markdown(""" | |
This app allows you to upload a table (CSV or Excel) and ask questions about the data. | |
Based on your question, it will provide the corresponding answer using the **TAPAS** model and additional data processing. | |
### Available Features: | |
- **mean()**: For "average", it computes the mean of the entire numeric DataFrame. | |
- **sum()**: For "sum", it calculates the sum of all numeric values in the DataFrame. | |
- **max()**: For "max", it computes the maximum value in the DataFrame. | |
- **min()**: For "min", it computes the minimum value in the DataFrame. | |
- **count()**: For "count", it counts the non-null values in the entire DataFrame. | |
You can upload your data and ask questions like "What is the average of column X?" or "What is the sum of column Y?". The app will automatically process the data and give you the relevant answer. | |
""") | |
# File uploader in the sidebar | |
file_name = st.sidebar.file_uploader("Upload file:", type=['csv', 'xlsx']) | |
# File processing and question answering | |
if file_name is None: | |
st.markdown('<p class="font">Please upload an excel or csv file </p>', unsafe_allow_html=True) | |
else: | |
try: | |
# Check file type and handle reading accordingly | |
if file_name.name.endswith('.csv'): | |
df = pd.read_csv(file_name, sep=';', encoding='ISO-8859-1') # Adjust encoding if needed | |
elif file_name.name.endswith('.xlsx'): | |
df = pd.read_excel(file_name, engine='openpyxl') # Use openpyxl to read .xlsx files | |
else: | |
st.error("Unsupported file type") | |
df = None | |
if df is not None: | |
numeric_columns = df.select_dtypes(include=['object']).columns | |
for col in numeric_columns: | |
df[col] = pd.to_numeric(df[col], errors='ignore') | |
st.write("Original Data:") | |
st.write(df) | |
df_numeric = df.copy() | |
df = df.astype(str) | |
# Display the first 5 rows of the dataframe in an editable grid | |
grid_response = AgGrid( | |
df.head(5), | |
columns_auto_size_mode='FIT_CONTENTS', | |
editable=True, | |
height=300, | |
width='100%', | |
) | |
except Exception as e: | |
st.error(f"Error reading file: {str(e)}") | |
# User input for the question | |
question = st.text_input('Type your question') | |
# Process the answer using TAPAS and T5 | |
with st.spinner(): | |
if st.button('Answer'): | |
try: | |
raw_answer = tqa(table=df, query=question, truncation=True) | |
st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Raw Result From TAPAS: </p>", | |
unsafe_allow_html=True) | |
st.success(raw_answer) | |
answer = raw_answer['answer'] | |
aggregator = raw_answer.get('aggregator', '') | |
coordinates = raw_answer.get('coordinates', []) | |
cells = raw_answer.get('cells', []) | |
# Check if the answer contains non-numeric values, and filter them out | |
numeric_cells = [] | |
for cell in cells: | |
try: | |
numeric_cells.append(float(cell)) # Convert to float if possible | |
except ValueError: | |
pass # Ignore non-numeric cells | |
# Handle aggregation based on user question or TAPAS output | |
if 'average' in question.lower() or aggregator == 'AVG': | |
if numeric_cells: | |
avg_value = sum(numeric_cells) / len(numeric_cells) # Calculate average | |
base_sentence = f"The average for '{question}' is {avg_value:.2f}." | |
else: | |
base_sentence = f"No numeric data found for calculating the average of '{question}'." | |
elif 'sum' in question.lower() or aggregator == 'SUM': | |
if numeric_cells: | |
total_sum = sum(numeric_cells) # Calculate sum | |
base_sentence = f"The sum for '{question}' is {total_sum:.2f}." | |
else: | |
base_sentence = f"No numeric data found for calculating the sum of '{question}'." | |
elif 'max' in question.lower() or aggregator == 'MAX': | |
if numeric_cells: | |
max_value = max(numeric_cells) # Find max value | |
base_sentence = f"The maximum value for '{question}' is {max_value:.2f}." | |
else: | |
base_sentence = f"No numeric data found for finding the maximum value of '{question}'." | |
elif 'min' in question.lower() or aggregator == 'MIN': | |
if numeric_cells: | |
min_value = min(numeric_cells) # Find min value | |
base_sentence = f"The minimum value for '{question}' is {min_value:.2f}." | |
else: | |
base_sentence = f"No numeric data found for finding the minimum value of '{question}'." | |
elif 'count' in question.lower() or aggregator == 'COUNT': | |
count_value = len(numeric_cells) # Count numeric cells | |
base_sentence = f"The total count of numeric values for '{question}' is {count_value}." | |
else: | |
# Construct a base sentence for other aggregators or no aggregation | |
base_sentence = f"The answer from TAPAS for '{question}' is {answer}." | |
if coordinates and cells: | |
rows_info = [f"Row {coordinate[0] + 1}, Column '{df.columns[coordinate[1]]}' with value {cell}" | |
for coordinate, cell in zip(coordinates, cells)] | |
rows_description = " and ".join(rows_info) | |
base_sentence += f" This includes the following data: {rows_description}." | |
# Generate a fluent response using the T5 model, rephrasing the base sentence | |
input_text = f"Given the question: '{question}', generate a more human-readable response: {base_sentence}" | |
# Tokenize the input and generate a fluent response using T5 | |
inputs = t5_tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True) | |
summary_ids = t5_model.generate(inputs, max_length=150, num_beams=4, early_stopping=True) | |
# Decode the generated text | |
generated_text = t5_tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
# Display the final generated response | |
st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Final Generated Response with LLM: </p>", unsafe_allow_html=True) | |
st.success(generated_text) | |
except Exception as e: | |
st.warning(f"Error processing question or generating answer: {str(e)}") | |
st.warning("Please retype your question and make sure to use the column name and cell value correctly.") | |