Spaces:
Runtime error
Runtime error
File size: 13,486 Bytes
e2a80d5 d0b5bfb e2a80d5 d0b5bfb e2a80d5 d0b5bfb e2a80d5 d0b5bfb e2a80d5 d0b5bfb e2a80d5 d0b5bfb e2a80d5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 |
from flask import Flask, render_template, request, jsonify, url_for, session, send_file, Response
import torch
from PIL import Image
import pandas as pd
import re
import os
import base64
import json
import traceback
from io import BytesIO
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
app = Flask(__name__)
app.secret_key = os.urandom(24) # Required for session
UPLOAD_FOLDER = 'static/uploads/'
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'} # Add allowed extensions
if not os.path.exists(UPLOAD_FOLDER):
os.makedirs(UPLOAD_FOLDER)
# Load PaliGemma model and processor (load once)
import torch
from transformers import PaliGemmaForConditionalGeneration, AutoProcessor
def load_paligemma_model():
try:
print("Loading PaliGemma model from Hugging Face...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load model and processor from Hugging Face Hub
model_name = "ahmed-masry/chartgemma" # Update with the correct model name if needed
model = PaliGemmaForConditionalGeneration.from_pretrained(
model_name,
torch_dtype=torch.float16
)
processor = AutoProcessor.from_pretrained(model_name)
model = model.to(device)
print("Model loaded successfully")
return model, processor, device
except Exception as e:
print(f"Error loading model: {str(e)}")
raise
# Store the model in the app context
with app.app_context():
app.paligemma_model, app.paligemma_processor, app.device = load_paligemma_model()
# Helper function to check allowed extensions
def allowed_file(filename):
return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
# Clean model output function - improved like the Streamlit version
def clean_model_output(text):
if not text:
print("Warning: Empty text passed to clean_model_output")
return ""
# Check if the entire response is a print statement and extract its content
print_match = re.search(r'^print\(["\'](.+?)["\']\)$', text.strip())
if print_match:
return print_match.group(1)
# Remove all print statements
text = re.sub(r'print\(.+?\)', '', text, flags=re.DOTALL)
# Remove Python code formatting artifacts
text = re.sub(r'```python|```', '', text)
return text.strip()
# Analyze chart function
def analyze_chart_with_paligemma(image, query, use_cot=False):
try:
print(f"Starting analysis with query: {query}")
print(f"Use CoT: {use_cot}")
model = app.paligemma_model
processor = app.paligemma_processor
device = app.device
# Add program of thought prefix if CoT is enabled (matching Streamlit version)
if use_cot and not query.startswith("program of thought:"):
modified_query = f"program of thought: {query}"
else:
modified_query = query
print(f"Modified query: {modified_query}")
# Process inputs
try:
print("Processing inputs...")
inputs = processor(text=modified_query, images=image, return_tensors="pt")
print(f"Input keys: {inputs.keys()}")
prompt_length = inputs['input_ids'].shape[1] # Store prompt length for later use
inputs = {k: v.to(device) for k, v in inputs.items()}
except Exception as e:
print(f"Error processing inputs: {str(e)}")
traceback.print_exc()
return f"Error processing inputs: {str(e)}"
# Generate output
try:
print("Generating output...")
with torch.no_grad():
generate_ids = model.generate(
**inputs,
num_beams=4,
max_new_tokens=512,
output_scores=True,
return_dict_in_generate=True
)
output_text = processor.batch_decode(
generate_ids.sequences[:, prompt_length:],
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)[0]
print(f"Raw output text: {output_text}")
cleaned_output = clean_model_output(output_text)
print(f"Cleaned output text: {cleaned_output}")
return cleaned_output
except Exception as e:
print(f"Error generating output: {str(e)}")
traceback.print_exc()
return f"Error generating output: {str(e)}"
except Exception as e:
print(f"Error in analyze_chart_with_paligemma: {str(e)}")
traceback.print_exc()
return f"Error: {str(e)}"
# Extract data points function - updated to match Streamlit version
def extract_data_points(image):
print("Starting data extraction...")
try:
# Special query to extract data points - same as Streamlit
extraction_query = "program of thought: Extract all data points from this chart. List each category or series and all its corresponding values in a structured format."
print(f"Using extraction query: {extraction_query}")
result = analyze_chart_with_paligemma(image, extraction_query, use_cot=True)
print(f"Extraction result: {result}")
# Parse the result into a DataFrame using the improved parser
df = parse_chart_data(result)
return df
except Exception as e:
print(f"Error extracting data points: {str(e)}")
traceback.print_exc()
return pd.DataFrame({'Error': [str(e)]})
# Parse chart data function - completely revamped to match Streamlit's implementation
def parse_chart_data(text):
try:
# Clean the text from print statements first
text = clean_model_output(text)
print(f"Parsing cleaned text: {text}")
data = {}
lines = text.split('\n')
current_category = None
# First pass: Look for category and value pairs
for line in lines:
if not line.strip():
continue
if ':' in line and not re.search(r'\d+\.\d+', line):
current_category = line.split(':')[0].strip()
data[current_category] = []
elif current_category and (re.search(r'\d+', line) or ',' in line):
value_match = re.findall(r'[-+]?\d*\.\d+|\d+', line)
if value_match:
data[current_category].extend(value_match)
# Second pass: If no categories found, try alternative pattern matching
if not data:
table_pattern = r'(\w+(?:\s\w+)*)\s*[:|]\s*((?:\d+(?:\.\d+)?(?:\s*,\s*\d+(?:\.\d+)?)*)|(?:\d+(?:\.\d+)?))'
matches = re.findall(table_pattern, text)
for category, values in matches:
category = category.strip()
if category not in data:
data[category] = []
if ',' in values:
values = [v.strip() for v in values.split(',')]
else:
values = [values.strip()]
data[category].extend(values)
# Convert all values to float where possible
for key in data:
data[key] = [float(val) if re.match(r'^[-+]?\d*\.?\d+$', val) else val for val in data[key]]
# Create DataFrame
if data:
df = pd.DataFrame(data)
print(f"Successfully parsed data: {df.head()}")
else:
df = pd.DataFrame({'Extracted_Text': [text]})
print("Could not extract structured data, returning raw text")
return df
except Exception as e:
print(f"Error parsing chart data: {str(e)}")
traceback.print_exc()
return pd.DataFrame({'Raw_Text': [text]})
@app.route('/')
def index():
image_url = session.get('image_url', None)
return render_template('index.html', image_url=image_url)
@app.route('/upload', methods=['POST'])
def upload_image():
try:
if 'image' not in request.files:
return jsonify({"error": "No file uploaded"}), 400
file = request.files['image']
if file.filename == '':
return jsonify({"error": "No selected file"}), 400
if not allowed_file(file.filename):
return jsonify({"error": "Invalid file type"}), 400
filename = file.filename
file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
file.save(file_path)
session['image_url'] = url_for('static', filename=f'uploads/{filename}')
session['image_filename'] = filename
print(f"Image uploaded: {filename}")
return jsonify({"image_url": session['image_url']})
except Exception as e:
print(f"Error in upload_image: {str(e)}")
traceback.print_exc()
return jsonify({"error": str(e)}), 500
@app.route('/analyze', methods=['POST'])
def analyze_chart():
try:
query = request.form['query']
use_cot = request.form.get('use_cot') == 'true'
image_filename = session.get('image_filename')
if not image_filename:
return jsonify({"error": "No image found in session. Please upload an image first."}), 400
image_path = os.path.join(app.config['UPLOAD_FOLDER'], image_filename)
if not os.path.exists(image_path):
return jsonify({"error": "Image not found. Please upload again."}), 400
image = Image.open(image_path).convert('RGB')
answer = analyze_chart_with_paligemma(image, query, use_cot)
return jsonify({"answer": answer})
except Exception as e:
print(f"Error in analyze_chart: {str(e)}")
traceback.print_exc()
return jsonify({"error": str(e)})
@app.route('/extract', methods=['POST'])
def extract_data():
try:
image_filename = session.get('image_filename')
if not image_filename:
return jsonify({"error": "No image found in session. Please upload an image first."}), 400
image_path = os.path.join(app.config['UPLOAD_FOLDER'], image_filename)
if not os.path.exists(image_path):
return jsonify({"error": "Image not found. Please upload again."}), 400
image = Image.open(image_path).convert('RGB')
df = extract_data_points(image)
# Check if DataFrame is empty or contains only error messages
if df.empty:
return jsonify({"error": "Could not extract data from the image"}), 400
# Convert DataFrame to CSV data
csv_data = df.to_csv(index=False)
print(f"CSV data generated: {csv_data[:100]}...") # Print first 100 chars
# Encode CSV data to base64
csv_base64 = base64.b64encode(csv_data.encode()).decode('utf-8')
return jsonify({"csv_data": csv_base64})
except Exception as e:
print(f"Error in extract_data: {str(e)}")
traceback.print_exc()
return jsonify({"error": str(e)})
@app.route('/download_csv')
def download_csv():
try:
print("Download CSV route called")
image_filename = session.get('image_filename')
if not image_filename:
print("No image in session")
return jsonify({"error": "No image found in session. Please upload an image first."}), 400
image_path = os.path.join(app.config['UPLOAD_FOLDER'], image_filename)
print(f"Looking for image at: {image_path}")
if not os.path.exists(image_path):
print("Image file not found")
return jsonify({"error": "Image not found. Please upload again."}), 400
print("Loading image")
image = Image.open(image_path).convert('RGB')
print("Extracting data points")
df = extract_data_points(image)
print(f"DataFrame: {df}")
# Create a BytesIO object to hold the CSV data in memory
csv_buffer = BytesIO()
df.to_csv(csv_buffer, index=False, encoding='utf-8')
csv_buffer.seek(0) # Reset the buffer's position to the beginning
# Debug: print CSV content
csv_content = csv_buffer.getvalue().decode('utf-8')
print(f"CSV Content: {csv_content}")
csv_buffer.seek(0) # Reset buffer position again after reading
print("Preparing response")
# Create direct response with CSV data
response = Response(
csv_buffer.getvalue(),
mimetype='text/csv',
headers={
'Content-Disposition': 'attachment; filename=extracted_data.csv',
'Content-Type': 'text/csv'
}
)
print("Returning CSV response")
return response
except Exception as e:
print(f"Error in download_csv: {str(e)}")
traceback.print_exc()
return jsonify({"error": str(e)}), 500
# Create a utility function to match the Streamlit version
def get_csv_download_link(df, filename="chart_data.csv"):
csv = df.to_csv(index=False)
b64 = base64.b64encode(csv.encode()).decode()
href = f'<a href="data:file/csv;base64,{b64}" download="{filename}">Download CSV File</a>'
return href
if __name__ == '__main__':
app.run(debug=True) |