Spaces:
Runtime error
Runtime error
from flask import Flask, render_template, request, jsonify, url_for, session, send_file, Response | |
import torch | |
from PIL import Image | |
import pandas as pd | |
import re | |
import os | |
import base64 | |
import json | |
import traceback | |
from io import BytesIO | |
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration | |
app = Flask(__name__) | |
app.secret_key = os.urandom(24) # Required for session | |
UPLOAD_FOLDER = 'static/uploads/' | |
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER | |
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'} # Add allowed extensions | |
if not os.path.exists(UPLOAD_FOLDER): | |
os.makedirs(UPLOAD_FOLDER) | |
# Load PaliGemma model and processor (load once) | |
import torch | |
from transformers import PaliGemmaForConditionalGeneration, AutoProcessor | |
def load_paligemma_model(): | |
try: | |
print("Loading PaliGemma model from Hugging Face...") | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Using device: {device}") | |
# Load model and processor from Hugging Face Hub | |
model_name = "ahmed-masry/chartgemma" # Update with the correct model name if needed | |
model = PaliGemmaForConditionalGeneration.from_pretrained( | |
model_name, | |
torch_dtype=torch.float16 | |
) | |
processor = AutoProcessor.from_pretrained(model_name) | |
model = model.to(device) | |
print("Model loaded successfully") | |
return model, processor, device | |
except Exception as e: | |
print(f"Error loading model: {str(e)}") | |
raise | |
# Store the model in the app context | |
with app.app_context(): | |
app.paligemma_model, app.paligemma_processor, app.device = load_paligemma_model() | |
# Helper function to check allowed extensions | |
def allowed_file(filename): | |
return '.' in filename and \ | |
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS | |
# Clean model output function - improved like the Streamlit version | |
def clean_model_output(text): | |
if not text: | |
print("Warning: Empty text passed to clean_model_output") | |
return "" | |
# Check if the entire response is a print statement and extract its content | |
print_match = re.search(r'^print\(["\'](.+?)["\']\)$', text.strip()) | |
if print_match: | |
return print_match.group(1) | |
# Remove all print statements | |
text = re.sub(r'print\(.+?\)', '', text, flags=re.DOTALL) | |
# Remove Python code formatting artifacts | |
text = re.sub(r'```python|```', '', text) | |
return text.strip() | |
# Analyze chart function | |
def analyze_chart_with_paligemma(image, query, use_cot=False): | |
try: | |
print(f"Starting analysis with query: {query}") | |
print(f"Use CoT: {use_cot}") | |
model = app.paligemma_model | |
processor = app.paligemma_processor | |
device = app.device | |
# Add program of thought prefix if CoT is enabled (matching Streamlit version) | |
if use_cot and not query.startswith("program of thought:"): | |
modified_query = f"program of thought: {query}" | |
else: | |
modified_query = query | |
print(f"Modified query: {modified_query}") | |
# Process inputs | |
try: | |
print("Processing inputs...") | |
inputs = processor(text=modified_query, images=image, return_tensors="pt") | |
print(f"Input keys: {inputs.keys()}") | |
prompt_length = inputs['input_ids'].shape[1] # Store prompt length for later use | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
except Exception as e: | |
print(f"Error processing inputs: {str(e)}") | |
traceback.print_exc() | |
return f"Error processing inputs: {str(e)}" | |
# Generate output | |
try: | |
print("Generating output...") | |
with torch.no_grad(): | |
generate_ids = model.generate( | |
**inputs, | |
num_beams=4, | |
max_new_tokens=512, | |
output_scores=True, | |
return_dict_in_generate=True | |
) | |
output_text = processor.batch_decode( | |
generate_ids.sequences[:, prompt_length:], | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=False | |
)[0] | |
print(f"Raw output text: {output_text}") | |
cleaned_output = clean_model_output(output_text) | |
print(f"Cleaned output text: {cleaned_output}") | |
return cleaned_output | |
except Exception as e: | |
print(f"Error generating output: {str(e)}") | |
traceback.print_exc() | |
return f"Error generating output: {str(e)}" | |
except Exception as e: | |
print(f"Error in analyze_chart_with_paligemma: {str(e)}") | |
traceback.print_exc() | |
return f"Error: {str(e)}" | |
# Extract data points function - updated to match Streamlit version | |
def extract_data_points(image): | |
print("Starting data extraction...") | |
try: | |
# Special query to extract data points - same as Streamlit | |
extraction_query = "program of thought: Extract all data points from this chart. List each category or series and all its corresponding values in a structured format." | |
print(f"Using extraction query: {extraction_query}") | |
result = analyze_chart_with_paligemma(image, extraction_query, use_cot=True) | |
print(f"Extraction result: {result}") | |
# Parse the result into a DataFrame using the improved parser | |
df = parse_chart_data(result) | |
return df | |
except Exception as e: | |
print(f"Error extracting data points: {str(e)}") | |
traceback.print_exc() | |
return pd.DataFrame({'Error': [str(e)]}) | |
# Parse chart data function - completely revamped to match Streamlit's implementation | |
def parse_chart_data(text): | |
try: | |
# Clean the text from print statements first | |
text = clean_model_output(text) | |
print(f"Parsing cleaned text: {text}") | |
data = {} | |
lines = text.split('\n') | |
current_category = None | |
# First pass: Look for category and value pairs | |
for line in lines: | |
if not line.strip(): | |
continue | |
if ':' in line and not re.search(r'\d+\.\d+', line): | |
current_category = line.split(':')[0].strip() | |
data[current_category] = [] | |
elif current_category and (re.search(r'\d+', line) or ',' in line): | |
value_match = re.findall(r'[-+]?\d*\.\d+|\d+', line) | |
if value_match: | |
data[current_category].extend(value_match) | |
# Second pass: If no categories found, try alternative pattern matching | |
if not data: | |
table_pattern = r'(\w+(?:\s\w+)*)\s*[:|]\s*((?:\d+(?:\.\d+)?(?:\s*,\s*\d+(?:\.\d+)?)*)|(?:\d+(?:\.\d+)?))' | |
matches = re.findall(table_pattern, text) | |
for category, values in matches: | |
category = category.strip() | |
if category not in data: | |
data[category] = [] | |
if ',' in values: | |
values = [v.strip() for v in values.split(',')] | |
else: | |
values = [values.strip()] | |
data[category].extend(values) | |
# Convert all values to float where possible | |
for key in data: | |
data[key] = [float(val) if re.match(r'^[-+]?\d*\.?\d+$', val) else val for val in data[key]] | |
# Create DataFrame | |
if data: | |
df = pd.DataFrame(data) | |
print(f"Successfully parsed data: {df.head()}") | |
else: | |
df = pd.DataFrame({'Extracted_Text': [text]}) | |
print("Could not extract structured data, returning raw text") | |
return df | |
except Exception as e: | |
print(f"Error parsing chart data: {str(e)}") | |
traceback.print_exc() | |
return pd.DataFrame({'Raw_Text': [text]}) | |
def index(): | |
image_url = session.get('image_url', None) | |
return render_template('index.html', image_url=image_url) | |
def upload_image(): | |
try: | |
if 'image' not in request.files: | |
return jsonify({"error": "No file uploaded"}), 400 | |
file = request.files['image'] | |
if file.filename == '': | |
return jsonify({"error": "No selected file"}), 400 | |
if not allowed_file(file.filename): | |
return jsonify({"error": "Invalid file type"}), 400 | |
filename = file.filename | |
file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename) | |
file.save(file_path) | |
session['image_url'] = url_for('static', filename=f'uploads/{filename}') | |
session['image_filename'] = filename | |
print(f"Image uploaded: {filename}") | |
return jsonify({"image_url": session['image_url']}) | |
except Exception as e: | |
print(f"Error in upload_image: {str(e)}") | |
traceback.print_exc() | |
return jsonify({"error": str(e)}), 500 | |
def analyze_chart(): | |
try: | |
query = request.form['query'] | |
use_cot = request.form.get('use_cot') == 'true' | |
image_filename = session.get('image_filename') | |
if not image_filename: | |
return jsonify({"error": "No image found in session. Please upload an image first."}), 400 | |
image_path = os.path.join(app.config['UPLOAD_FOLDER'], image_filename) | |
if not os.path.exists(image_path): | |
return jsonify({"error": "Image not found. Please upload again."}), 400 | |
image = Image.open(image_path).convert('RGB') | |
answer = analyze_chart_with_paligemma(image, query, use_cot) | |
return jsonify({"answer": answer}) | |
except Exception as e: | |
print(f"Error in analyze_chart: {str(e)}") | |
traceback.print_exc() | |
return jsonify({"error": str(e)}) | |
def extract_data(): | |
try: | |
image_filename = session.get('image_filename') | |
if not image_filename: | |
return jsonify({"error": "No image found in session. Please upload an image first."}), 400 | |
image_path = os.path.join(app.config['UPLOAD_FOLDER'], image_filename) | |
if not os.path.exists(image_path): | |
return jsonify({"error": "Image not found. Please upload again."}), 400 | |
image = Image.open(image_path).convert('RGB') | |
df = extract_data_points(image) | |
# Check if DataFrame is empty or contains only error messages | |
if df.empty: | |
return jsonify({"error": "Could not extract data from the image"}), 400 | |
# Convert DataFrame to CSV data | |
csv_data = df.to_csv(index=False) | |
print(f"CSV data generated: {csv_data[:100]}...") # Print first 100 chars | |
# Encode CSV data to base64 | |
csv_base64 = base64.b64encode(csv_data.encode()).decode('utf-8') | |
return jsonify({"csv_data": csv_base64}) | |
except Exception as e: | |
print(f"Error in extract_data: {str(e)}") | |
traceback.print_exc() | |
return jsonify({"error": str(e)}) | |
def download_csv(): | |
try: | |
print("Download CSV route called") | |
image_filename = session.get('image_filename') | |
if not image_filename: | |
print("No image in session") | |
return jsonify({"error": "No image found in session. Please upload an image first."}), 400 | |
image_path = os.path.join(app.config['UPLOAD_FOLDER'], image_filename) | |
print(f"Looking for image at: {image_path}") | |
if not os.path.exists(image_path): | |
print("Image file not found") | |
return jsonify({"error": "Image not found. Please upload again."}), 400 | |
print("Loading image") | |
image = Image.open(image_path).convert('RGB') | |
print("Extracting data points") | |
df = extract_data_points(image) | |
print(f"DataFrame: {df}") | |
# Create a BytesIO object to hold the CSV data in memory | |
csv_buffer = BytesIO() | |
df.to_csv(csv_buffer, index=False, encoding='utf-8') | |
csv_buffer.seek(0) # Reset the buffer's position to the beginning | |
# Debug: print CSV content | |
csv_content = csv_buffer.getvalue().decode('utf-8') | |
print(f"CSV Content: {csv_content}") | |
csv_buffer.seek(0) # Reset buffer position again after reading | |
print("Preparing response") | |
# Create direct response with CSV data | |
response = Response( | |
csv_buffer.getvalue(), | |
mimetype='text/csv', | |
headers={ | |
'Content-Disposition': 'attachment; filename=extracted_data.csv', | |
'Content-Type': 'text/csv' | |
} | |
) | |
print("Returning CSV response") | |
return response | |
except Exception as e: | |
print(f"Error in download_csv: {str(e)}") | |
traceback.print_exc() | |
return jsonify({"error": str(e)}), 500 | |
# Create a utility function to match the Streamlit version | |
def get_csv_download_link(df, filename="chart_data.csv"): | |
csv = df.to_csv(index=False) | |
b64 = base64.b64encode(csv.encode()).decode() | |
href = f'<a href="data:file/csv;base64,{b64}" download="{filename}">Download CSV File</a>' | |
return href | |
if __name__ == '__main__': | |
app.run(debug=True) |