Spaces:
Running
Running
Upload 4 files
Browse files- README.md +71 -12
- app.py +295 -0
- chart_example_1.png +0 -0
- requirements.txt +8 -0
README.md
CHANGED
@@ -1,12 +1,71 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ChartQA
|
2 |
+
|
3 |
+
# Chart Q&A Application
|
4 |
+
|
5 |
+
## Overview
|
6 |
+
This Chart Q&A application allows users to analyze and extract information from chart images using the PaliGemma model. Users can upload chart images, ask questions about the charts, and extract structured data for further analysis.
|
7 |
+
|
8 |
+
## Features
|
9 |
+
- Upload chart images (PNG, JPG, JPEG)
|
10 |
+
- Load a sample chart for demonstration
|
11 |
+
- Ask natural language questions about chart content
|
12 |
+
- Extract data points from charts into a structured format
|
13 |
+
- Download extracted data as CSV
|
14 |
+
- Chain-of-Thought reasoning for improved analysis
|
15 |
+
- Question history tracking
|
16 |
+
|
17 |
+
## Requirements
|
18 |
+
- Python 3.8+
|
19 |
+
- Dependencies listed in `requirements.txt`
|
20 |
+
|
21 |
+
## Installation
|
22 |
+
|
23 |
+
1. Clone this repository:
|
24 |
+
```bash
|
25 |
+
git clone https://github.com/sushantgai/ChartQA.git
|
26 |
+
cd ChartQA
|
27 |
+
```
|
28 |
+
|
29 |
+
2. Create a virtual environment:
|
30 |
+
```bash
|
31 |
+
python -m venv venv
|
32 |
+
source venv/bin/activate # On Windows: venv\Scripts\activate
|
33 |
+
```
|
34 |
+
|
35 |
+
3. Install the required packages:
|
36 |
+
```bash
|
37 |
+
pip install -r requirements.txt
|
38 |
+
```
|
39 |
+
|
40 |
+
## Usage
|
41 |
+
|
42 |
+
1. Run the Streamlit application:
|
43 |
+
```bash
|
44 |
+
streamlit run app.py
|
45 |
+
```
|
46 |
+
|
47 |
+
2. Access the application in your web browser at http://localhost:8501
|
48 |
+
|
49 |
+
3. Usage steps:
|
50 |
+
- Click "Load Model" in the sidebar to initialize the PaliGemma model
|
51 |
+
- Upload a chart image or load the sample chart
|
52 |
+
- Ask questions about the chart in the text input field
|
53 |
+
- Click "Extract Data Points" to convert the chart into tabular data
|
54 |
+
- Download the extracted data as CSV if needed
|
55 |
+
|
56 |
+
## Model Information
|
57 |
+
|
58 |
+
This application uses a fine-tuned version of the PaliGemma model specifically trained for chart understanding:
|
59 |
+
- Model: ahmed-masry/chartgemma
|
60 |
+
- The model can analyze various types of charts including bar charts, line charts, pie charts, and more
|
61 |
+
|
62 |
+
## Notes
|
63 |
+
- The first load of the model may take some time depending on your hardware
|
64 |
+
- GPU acceleration is automatically used if available, otherwise CPU is used
|
65 |
+
- Chain-of-Thought reasoning can be toggled on/off in the sidebar
|
66 |
+
- For best results, use clear images of charts with readable text and labels
|
67 |
+
|
68 |
+
|
69 |
+
## Acknowledgements
|
70 |
+
- This application uses the PaliGemma model fine-tuned for chart analysis
|
71 |
+
- Based on the transformers library from Hugging Face
|
app.py
ADDED
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
import io
|
5 |
+
import requests
|
6 |
+
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import os
|
9 |
+
import pandas as pd
|
10 |
+
import re
|
11 |
+
import base64
|
12 |
+
|
13 |
+
# Set page config
|
14 |
+
st.set_page_config(
|
15 |
+
page_title="Chart Q&A ",
|
16 |
+
page_icon="π",
|
17 |
+
layout="wide"
|
18 |
+
)
|
19 |
+
|
20 |
+
# Initialize session state variables
|
21 |
+
if 'paligemma_model' not in st.session_state:
|
22 |
+
st.session_state.paligemma_model = None
|
23 |
+
if 'paligemma_processor' not in st.session_state:
|
24 |
+
st.session_state.paligemma_processor = None
|
25 |
+
if 'device' not in st.session_state:
|
26 |
+
st.session_state.device = None
|
27 |
+
if 'current_image' not in st.session_state:
|
28 |
+
st.session_state.current_image = None
|
29 |
+
if 'chat_history' not in st.session_state:
|
30 |
+
st.session_state.chat_history = []
|
31 |
+
if 'extracted_data' not in st.session_state:
|
32 |
+
st.session_state.extracted_data = None
|
33 |
+
|
34 |
+
# Initialize PaliGemma Model
|
35 |
+
@st.cache_resource
|
36 |
+
def load_paligemma_model():
|
37 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
38 |
+
model = PaliGemmaForConditionalGeneration.from_pretrained(
|
39 |
+
"ahmed-masry/chartgemma",
|
40 |
+
torch_dtype=torch.float16
|
41 |
+
)
|
42 |
+
processor = AutoProcessor.from_pretrained("ahmed-masry/chartgemma")
|
43 |
+
model = model.to(device)
|
44 |
+
return model, processor, device
|
45 |
+
|
46 |
+
# Function to download sample chart
|
47 |
+
def download_sample_chart(url, filename):
|
48 |
+
try:
|
49 |
+
if not os.path.exists(filename):
|
50 |
+
response = requests.get(url)
|
51 |
+
if response.status_code == 200:
|
52 |
+
with open(filename, 'wb') as f:
|
53 |
+
f.write(response.content)
|
54 |
+
return True
|
55 |
+
else:
|
56 |
+
st.error(f"Failed to download sample chart: {response.status_code}")
|
57 |
+
return False
|
58 |
+
return True
|
59 |
+
except Exception as e:
|
60 |
+
st.error(f"Error downloading sample chart: {str(e)}")
|
61 |
+
return False
|
62 |
+
|
63 |
+
# Function to clean model output from print statements and other artifacts
|
64 |
+
def clean_model_output(text):
|
65 |
+
# Check if the entire response is a print statement and extract its content
|
66 |
+
print_match = re.search(r'^print\(["\'](.+?)["\']\)$', text.strip())
|
67 |
+
if print_match:
|
68 |
+
return print_match.group(1)
|
69 |
+
|
70 |
+
# Remove all print statements
|
71 |
+
text = re.sub(r'print\(.+?\)', '', text, flags=re.DOTALL)
|
72 |
+
|
73 |
+
# Remove Python code formatting artifacts
|
74 |
+
text = re.sub(r'```python|```', '', text)
|
75 |
+
|
76 |
+
return text.strip()
|
77 |
+
|
78 |
+
# Function to analyze chart with PaliGemma
|
79 |
+
def analyze_chart_with_paligemma(model, processor, device, image, query, use_cot=False):
|
80 |
+
try:
|
81 |
+
# Add program of thought prefix if CoT is enabled
|
82 |
+
if use_cot and not query.startswith("program of thought:"):
|
83 |
+
modified_query = f"program of thought: {query}"
|
84 |
+
else:
|
85 |
+
modified_query = query
|
86 |
+
|
87 |
+
inputs = processor(text=modified_query, images=image, return_tensors="pt")
|
88 |
+
prompt_length = inputs['input_ids'].shape[1]
|
89 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
90 |
+
|
91 |
+
# Generate with progress bar
|
92 |
+
progress_bar = st.progress(0)
|
93 |
+
|
94 |
+
with torch.no_grad():
|
95 |
+
generate_ids = model.generate(
|
96 |
+
**inputs,
|
97 |
+
num_beams=4,
|
98 |
+
max_new_tokens=512,
|
99 |
+
output_scores=True,
|
100 |
+
return_dict_in_generate=True
|
101 |
+
)
|
102 |
+
|
103 |
+
progress_bar.progress(100)
|
104 |
+
|
105 |
+
output_text = processor.batch_decode(
|
106 |
+
generate_ids.sequences[:, prompt_length:],
|
107 |
+
skip_special_tokens=True,
|
108 |
+
clean_up_tokenization_spaces=False
|
109 |
+
)[0]
|
110 |
+
|
111 |
+
# Clean output from print statements and other artifacts
|
112 |
+
output_text = clean_model_output(output_text)
|
113 |
+
|
114 |
+
return output_text
|
115 |
+
except Exception as e:
|
116 |
+
st.error(f"Error analyzing chart : {str(e)}")
|
117 |
+
return f"Error: {str(e)}"
|
118 |
+
|
119 |
+
# Function to extract data points from chart
|
120 |
+
def extract_data_points(model, processor, device, image):
|
121 |
+
try:
|
122 |
+
# Special query to extract data points
|
123 |
+
extraction_query = "program of thought: Extract all data points from this chart. List each category or series and all its corresponding values in a structured format."
|
124 |
+
|
125 |
+
with st.spinner("Extracting data points from chart..."):
|
126 |
+
result = analyze_chart_with_paligemma(model, processor, device, image, extraction_query)
|
127 |
+
|
128 |
+
# Parse the result into a DataFrame
|
129 |
+
df = parse_chart_data(result)
|
130 |
+
return df
|
131 |
+
except Exception as e:
|
132 |
+
st.error(f"Error extracting data points: {str(e)}")
|
133 |
+
return None
|
134 |
+
|
135 |
+
# Function to parse chart data from model response
|
136 |
+
def parse_chart_data(text):
|
137 |
+
try:
|
138 |
+
# Clean the text from print statements first
|
139 |
+
text = clean_model_output(text)
|
140 |
+
|
141 |
+
data = {}
|
142 |
+
lines = text.split('\n')
|
143 |
+
current_category = None
|
144 |
+
|
145 |
+
for line in lines:
|
146 |
+
if not line.strip():
|
147 |
+
continue
|
148 |
+
|
149 |
+
if ':' in line and not re.search(r'\d+\.\d+', line):
|
150 |
+
current_category = line.split(':')[0].strip()
|
151 |
+
data[current_category] = []
|
152 |
+
elif current_category and (re.search(r'\d+', line) or ',' in line):
|
153 |
+
value_match = re.findall(r'[-+]?\d*\.\d+|\d+', line)
|
154 |
+
if value_match:
|
155 |
+
data[current_category].extend(value_match)
|
156 |
+
|
157 |
+
if not data:
|
158 |
+
table_pattern = r'(\w+(?:\s\w+)*)\s*[:|]\s*((?:\d+(?:\.\d+)?(?:\s*,\s*\d+(?:\.\d+)?)*)|(?:\d+(?:\.\d+)?))'
|
159 |
+
matches = re.findall(table_pattern, text)
|
160 |
+
for category, values in matches:
|
161 |
+
category = category.strip()
|
162 |
+
if category not in data:
|
163 |
+
data[category] = []
|
164 |
+
if ',' in values:
|
165 |
+
values = [v.strip() for v in values.split(',')]
|
166 |
+
else:
|
167 |
+
values = [values.strip()]
|
168 |
+
data[category].extend(values)
|
169 |
+
|
170 |
+
df = pd.DataFrame(data)
|
171 |
+
|
172 |
+
if df.empty:
|
173 |
+
df = pd.DataFrame({'Extracted_Text': [text]})
|
174 |
+
|
175 |
+
return df
|
176 |
+
except Exception as e:
|
177 |
+
st.error(f"Error parsing chart data: {str(e)}")
|
178 |
+
return pd.DataFrame({'Raw_Text': [text]})
|
179 |
+
|
180 |
+
# Function to create a download link for dataframe
|
181 |
+
def get_csv_download_link(df, filename="chart_data.csv"):
|
182 |
+
csv = df.to_csv(index=False)
|
183 |
+
b64 = base64.b64encode(csv.encode()).decode()
|
184 |
+
href = f'<a href="data:file/csv;base64,{b64}" download="{filename}">Download CSV File</a>'
|
185 |
+
return href
|
186 |
+
|
187 |
+
# Main UI
|
188 |
+
st.title("π Chart Analysis ")
|
189 |
+
|
190 |
+
|
191 |
+
# Sidebar for model loading and options
|
192 |
+
with st.sidebar:
|
193 |
+
st.header("Model Setup")
|
194 |
+
|
195 |
+
if st.button("Load Model"):
|
196 |
+
with st.spinner("Loading model... This may take a moment"):
|
197 |
+
model, processor, device = load_paligemma_model()
|
198 |
+
st.session_state.paligemma_model = model
|
199 |
+
st.session_state.paligemma_processor = processor
|
200 |
+
st.session_state.device = device
|
201 |
+
st.success(f"β
Model loaded successfully on {device}!")
|
202 |
+
|
203 |
+
st.header("Options")
|
204 |
+
use_cot = st.checkbox("Enable Chain-of-Thought reasoning", value=True,
|
205 |
+
help="Adds 'program of thought:' prefix to prompts for better reasoning")
|
206 |
+
|
207 |
+
st.header("Sample Charts")
|
208 |
+
if st.button("Load Sample Chart"):
|
209 |
+
sample_url = "https://raw.githubusercontent.com/vis-nlp/ChartQA/main/ChartQA%20Dataset/val/png/multi_col_1229.png"
|
210 |
+
sample_filename = "chart_example_1.png"
|
211 |
+
if download_sample_chart(sample_url, sample_filename):
|
212 |
+
st.session_state.current_image = Image.open(sample_filename).convert('RGB')
|
213 |
+
st.success("Sample chart loaded!")
|
214 |
+
|
215 |
+
# Main content area
|
216 |
+
col1, col2 = st.columns([3, 2])
|
217 |
+
|
218 |
+
with col1:
|
219 |
+
st.header("Upload Chart")
|
220 |
+
uploaded_file = st.file_uploader("Choose a chart image", type=["png", "jpg", "jpeg"])
|
221 |
+
|
222 |
+
if uploaded_file is not None:
|
223 |
+
try:
|
224 |
+
image = Image.open(uploaded_file).convert('RGB')
|
225 |
+
st.session_state.current_image = image
|
226 |
+
# Reset extracted data when new image is uploaded
|
227 |
+
st.session_state.extracted_data = None
|
228 |
+
except Exception as e:
|
229 |
+
st.error(f"Error opening image: {str(e)}")
|
230 |
+
|
231 |
+
# Display current image
|
232 |
+
if st.session_state.current_image is not None:
|
233 |
+
st.image(st.session_state.current_image, caption="Current Chart", use_column_width=True)
|
234 |
+
|
235 |
+
# Add extract data points button
|
236 |
+
if st.session_state.paligemma_model is not None:
|
237 |
+
if st.button("Extract Data Points from Chart"):
|
238 |
+
df = extract_data_points(
|
239 |
+
st.session_state.paligemma_model,
|
240 |
+
st.session_state.paligemma_processor,
|
241 |
+
st.session_state.device,
|
242 |
+
st.session_state.current_image
|
243 |
+
)
|
244 |
+
if df is not None:
|
245 |
+
st.session_state.extracted_data = df
|
246 |
+
st.success("Data points extracted successfully!")
|
247 |
+
|
248 |
+
with col2:
|
249 |
+
st.header("Ask Questions")
|
250 |
+
|
251 |
+
if st.session_state.paligemma_model is None:
|
252 |
+
st.warning("Please load the model first from the sidebar.")
|
253 |
+
elif st.session_state.current_image is None:
|
254 |
+
st.warning("Please upload a chart image or load a sample chart.")
|
255 |
+
else:
|
256 |
+
# Query input
|
257 |
+
query = st.text_input("Ask a question about the chart:",
|
258 |
+
placeholder="E.g., What is the highest value in the chart?")
|
259 |
+
|
260 |
+
if query:
|
261 |
+
if st.button("Analyze Chart"):
|
262 |
+
with st.spinner("Analyzing chart "):
|
263 |
+
answer = analyze_chart_with_paligemma(
|
264 |
+
st.session_state.paligemma_model,
|
265 |
+
st.session_state.paligemma_processor,
|
266 |
+
st.session_state.device,
|
267 |
+
st.session_state.current_image,
|
268 |
+
query,
|
269 |
+
use_cot
|
270 |
+
)
|
271 |
+
|
272 |
+
# Add to chat history
|
273 |
+
st.session_state.chat_history.append({
|
274 |
+
"question": query,
|
275 |
+
"answer": answer
|
276 |
+
})
|
277 |
+
|
278 |
+
# Display answer
|
279 |
+
st.subheader("Answer")
|
280 |
+
st.write(answer)
|
281 |
+
|
282 |
+
# Display extracted data if available
|
283 |
+
if st.session_state.extracted_data is not None:
|
284 |
+
st.header("Extracted Data Points")
|
285 |
+
st.dataframe(st.session_state.extracted_data)
|
286 |
+
|
287 |
+
# Download button for CSV
|
288 |
+
st.markdown(get_csv_download_link(st.session_state.extracted_data), unsafe_allow_html=True)
|
289 |
+
|
290 |
+
# Display chat history
|
291 |
+
if st.session_state.chat_history:
|
292 |
+
st.header("Question History")
|
293 |
+
for i, qa in enumerate(reversed(st.session_state.chat_history)):
|
294 |
+
with st.expander(f"Q: {qa['question']}", expanded=(i==0)):
|
295 |
+
st.markdown(f"**A:** {qa['answer']}")
|
chart_example_1.png
ADDED
![]() |
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit==1.31.0
|
2 |
+
torch==2.2.0
|
3 |
+
Pillow==10.2.0
|
4 |
+
requests==2.31.0
|
5 |
+
transformers==4.38.0
|
6 |
+
matplotlib==3.8.2
|
7 |
+
pandas==2.2.0
|
8 |
+
base64==1.0.0
|