File size: 13,486 Bytes
e2a80d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0b5bfb
 
 
e2a80d5
 
d0b5bfb
e2a80d5
 
 
d0b5bfb
 
e2a80d5
d0b5bfb
e2a80d5
 
d0b5bfb
e2a80d5
d0b5bfb
e2a80d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
from flask import Flask, render_template, request, jsonify, url_for, session, send_file, Response
import torch
from PIL import Image
import pandas as pd
import re
import os
import base64
import json
import traceback
from io import BytesIO
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration

app = Flask(__name__)
app.secret_key = os.urandom(24)  # Required for session
UPLOAD_FOLDER = 'static/uploads/'
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'} # Add allowed extensions

if not os.path.exists(UPLOAD_FOLDER):
    os.makedirs(UPLOAD_FOLDER)

# Load PaliGemma model and processor (load once)
import torch
from transformers import PaliGemmaForConditionalGeneration, AutoProcessor

def load_paligemma_model():
    try:
        print("Loading PaliGemma model from Hugging Face...")
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {device}")

        # Load model and processor from Hugging Face Hub
        model_name = "ahmed-masry/chartgemma"  # Update with the correct model name if needed
        model = PaliGemmaForConditionalGeneration.from_pretrained(
            model_name,
            torch_dtype=torch.float16
        )
        processor = AutoProcessor.from_pretrained(model_name)
        model = model.to(device)

        print("Model loaded successfully")
        return model, processor, device
    except Exception as e:
        print(f"Error loading model: {str(e)}")
        raise


# Store the model in the app context
with app.app_context():
    app.paligemma_model, app.paligemma_processor, app.device = load_paligemma_model()

# Helper function to check allowed extensions
def allowed_file(filename):
    return '.' in filename and \
           filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS

# Clean model output function - improved like the Streamlit version
def clean_model_output(text):
    if not text:
        print("Warning: Empty text passed to clean_model_output")
        return ""
    
    # Check if the entire response is a print statement and extract its content
    print_match = re.search(r'^print\(["\'](.+?)["\']\)$', text.strip())
    if print_match:
        return print_match.group(1)
    
    # Remove all print statements
    text = re.sub(r'print\(.+?\)', '', text, flags=re.DOTALL)
    
    # Remove Python code formatting artifacts
    text = re.sub(r'```python|```', '', text)
    
    return text.strip()

# Analyze chart function
def analyze_chart_with_paligemma(image, query, use_cot=False):
    try:
        print(f"Starting analysis with query: {query}")
        print(f"Use CoT: {use_cot}")
        
        model = app.paligemma_model
        processor = app.paligemma_processor
        device = app.device
        
        # Add program of thought prefix if CoT is enabled (matching Streamlit version)
        if use_cot and not query.startswith("program of thought:"):
            modified_query = f"program of thought: {query}"
        else:
            modified_query = query
            
        print(f"Modified query: {modified_query}")
        
        # Process inputs
        try:
            print("Processing inputs...")
            inputs = processor(text=modified_query, images=image, return_tensors="pt")
            print(f"Input keys: {inputs.keys()}")
            prompt_length = inputs['input_ids'].shape[1]  # Store prompt length for later use
            inputs = {k: v.to(device) for k, v in inputs.items()}
        except Exception as e:
            print(f"Error processing inputs: {str(e)}")
            traceback.print_exc()
            return f"Error processing inputs: {str(e)}"

        # Generate output
        try:
            print("Generating output...")
            with torch.no_grad():
                generate_ids = model.generate(
                    **inputs,
                    num_beams=4,
                    max_new_tokens=512,
                    output_scores=True,
                    return_dict_in_generate=True
                )
            
            output_text = processor.batch_decode(
                generate_ids.sequences[:, prompt_length:],
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False
            )[0]
            
            print(f"Raw output text: {output_text}")
            cleaned_output = clean_model_output(output_text)
            print(f"Cleaned output text: {cleaned_output}")
            return cleaned_output
        except Exception as e:
            print(f"Error generating output: {str(e)}")
            traceback.print_exc()
            return f"Error generating output: {str(e)}"
            
    except Exception as e:
        print(f"Error in analyze_chart_with_paligemma: {str(e)}")
        traceback.print_exc()
        return f"Error: {str(e)}"

# Extract data points function - updated to match Streamlit version
def extract_data_points(image):
    print("Starting data extraction...")
    try:
        # Special query to extract data points - same as Streamlit
        extraction_query = "program of thought: Extract all data points from this chart. List each category or series and all its corresponding values in a structured format."
        
        print(f"Using extraction query: {extraction_query}")
        result = analyze_chart_with_paligemma(image, extraction_query, use_cot=True)
        print(f"Extraction result: {result}")
        
        # Parse the result into a DataFrame using the improved parser
        df = parse_chart_data(result)
        return df
    except Exception as e:
        print(f"Error extracting data points: {str(e)}")
        traceback.print_exc()
        return pd.DataFrame({'Error': [str(e)]})

# Parse chart data function - completely revamped to match Streamlit's implementation
def parse_chart_data(text):
    try:
        # Clean the text from print statements first
        text = clean_model_output(text)
        print(f"Parsing cleaned text: {text}")

        data = {}
        lines = text.split('\n')
        current_category = None

        # First pass: Look for category and value pairs
        for line in lines:
            if not line.strip():
                continue

            if ':' in line and not re.search(r'\d+\.\d+', line):
                current_category = line.split(':')[0].strip()
                data[current_category] = []
            elif current_category and (re.search(r'\d+', line) or ',' in line):
                value_match = re.findall(r'[-+]?\d*\.\d+|\d+', line)
                if value_match:
                    data[current_category].extend(value_match)

        # Second pass: If no categories found, try alternative pattern matching
        if not data:
            table_pattern = r'(\w+(?:\s\w+)*)\s*[:|]\s*((?:\d+(?:\.\d+)?(?:\s*,\s*\d+(?:\.\d+)?)*)|(?:\d+(?:\.\d+)?))'
            matches = re.findall(table_pattern, text)
            for category, values in matches:
                category = category.strip()
                if category not in data:
                    data[category] = []
                if ',' in values:
                    values = [v.strip() for v in values.split(',')]
                else:
                    values = [values.strip()]
                data[category].extend(values)

        # Convert all values to float where possible
        for key in data:
            data[key] = [float(val) if re.match(r'^[-+]?\d*\.?\d+$', val) else val for val in data[key]]

        # Create DataFrame
        if data:
            df = pd.DataFrame(data)
            print(f"Successfully parsed data: {df.head()}")
        else:
            df = pd.DataFrame({'Extracted_Text': [text]})
            print("Could not extract structured data, returning raw text")

        return df
    except Exception as e:
        print(f"Error parsing chart data: {str(e)}")
        traceback.print_exc()
        return pd.DataFrame({'Raw_Text': [text]})

@app.route('/')
def index():
    image_url = session.get('image_url', None)
    return render_template('index.html', image_url=image_url)

@app.route('/upload', methods=['POST'])
def upload_image():
    try:
        if 'image' not in request.files:
            return jsonify({"error": "No file uploaded"}), 400

        file = request.files['image']
        if file.filename == '':
            return jsonify({"error": "No selected file"}), 400

        if not allowed_file(file.filename):
             return jsonify({"error": "Invalid file type"}), 400

        filename = file.filename
        file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
        file.save(file_path)

        session['image_url'] = url_for('static', filename=f'uploads/{filename}')
        session['image_filename'] = filename
        print(f"Image uploaded: {filename}")

        return jsonify({"image_url": session['image_url']})

    except Exception as e:
        print(f"Error in upload_image: {str(e)}")
        traceback.print_exc()
        return jsonify({"error": str(e)}), 500

@app.route('/analyze', methods=['POST'])
def analyze_chart():
    try:
        query = request.form['query']
        use_cot = request.form.get('use_cot') == 'true'
        image_filename = session.get('image_filename')
        
        if not image_filename:
            return jsonify({"error": "No image found in session. Please upload an image first."}), 400
            
        image_path = os.path.join(app.config['UPLOAD_FOLDER'], image_filename)

        if not os.path.exists(image_path):
            return jsonify({"error": "Image not found. Please upload again."}), 400

        image = Image.open(image_path).convert('RGB')
        answer = analyze_chart_with_paligemma(image, query, use_cot)

        return jsonify({"answer": answer})

    except Exception as e:
        print(f"Error in analyze_chart: {str(e)}")
        traceback.print_exc()
        return jsonify({"error": str(e)})

@app.route('/extract', methods=['POST'])
def extract_data():
    try:
        image_filename = session.get('image_filename')
        
        if not image_filename:
            return jsonify({"error": "No image found in session. Please upload an image first."}), 400
            
        image_path = os.path.join(app.config['UPLOAD_FOLDER'], image_filename)

        if not os.path.exists(image_path):
            return jsonify({"error": "Image not found. Please upload again."}), 400

        image = Image.open(image_path).convert('RGB')
        df = extract_data_points(image)
        
        # Check if DataFrame is empty or contains only error messages
        if df.empty:
            return jsonify({"error": "Could not extract data from the image"}), 400

        # Convert DataFrame to CSV data
        csv_data = df.to_csv(index=False)
        print(f"CSV data generated: {csv_data[:100]}...")  # Print first 100 chars

        # Encode CSV data to base64
        csv_base64 = base64.b64encode(csv_data.encode()).decode('utf-8')

        return jsonify({"csv_data": csv_base64})

    except Exception as e:
        print(f"Error in extract_data: {str(e)}")
        traceback.print_exc()
        return jsonify({"error": str(e)})

@app.route('/download_csv')
def download_csv():
    try:
        print("Download CSV route called")
        image_filename = session.get('image_filename')
        
        if not image_filename:
            print("No image in session")
            return jsonify({"error": "No image found in session. Please upload an image first."}), 400
            
        image_path = os.path.join(app.config['UPLOAD_FOLDER'], image_filename)
        print(f"Looking for image at: {image_path}")

        if not os.path.exists(image_path):
            print("Image file not found")
            return jsonify({"error": "Image not found. Please upload again."}), 400

        print("Loading image")
        image = Image.open(image_path).convert('RGB')
        print("Extracting data points")
        df = extract_data_points(image)
        
        print(f"DataFrame: {df}")
        
        # Create a BytesIO object to hold the CSV data in memory
        csv_buffer = BytesIO()
        df.to_csv(csv_buffer, index=False, encoding='utf-8')
        csv_buffer.seek(0)  # Reset the buffer's position to the beginning
        
        # Debug: print CSV content
        csv_content = csv_buffer.getvalue().decode('utf-8')
        print(f"CSV Content: {csv_content}")
        csv_buffer.seek(0)  # Reset buffer position again after reading
        
        print("Preparing response")
        # Create direct response with CSV data
        response = Response(
            csv_buffer.getvalue(),
            mimetype='text/csv',
            headers={
                'Content-Disposition': 'attachment; filename=extracted_data.csv',
                'Content-Type': 'text/csv'
            }
        )
        
        print("Returning CSV response")
        return response

    except Exception as e:
        print(f"Error in download_csv: {str(e)}")
        traceback.print_exc()
        return jsonify({"error": str(e)}), 500

# Create a utility function to match the Streamlit version
def get_csv_download_link(df, filename="chart_data.csv"):
    csv = df.to_csv(index=False)
    b64 = base64.b64encode(csv.encode()).decode()
    href = f'<a href="data:file/csv;base64,{b64}" download="{filename}">Download CSV File</a>'
    return href

if __name__ == '__main__':
    app.run(debug=True)