sushintern01 commited on
Commit
5d30640
Β·
verified Β·
1 Parent(s): 2f0c1d1

Upload 4 files

Browse files
Files changed (4) hide show
  1. README.md +71 -12
  2. app.py +295 -0
  3. chart_example_1.png +0 -0
  4. requirements.txt +8 -0
README.md CHANGED
@@ -1,12 +1,71 @@
1
- ---
2
- title: ChartQA
3
- emoji: πŸŒ–
4
- colorFrom: indigo
5
- colorTo: green
6
- sdk: streamlit
7
- sdk_version: 1.44.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ChartQA
2
+
3
+ # Chart Q&A Application
4
+
5
+ ## Overview
6
+ This Chart Q&A application allows users to analyze and extract information from chart images using the PaliGemma model. Users can upload chart images, ask questions about the charts, and extract structured data for further analysis.
7
+
8
+ ## Features
9
+ - Upload chart images (PNG, JPG, JPEG)
10
+ - Load a sample chart for demonstration
11
+ - Ask natural language questions about chart content
12
+ - Extract data points from charts into a structured format
13
+ - Download extracted data as CSV
14
+ - Chain-of-Thought reasoning for improved analysis
15
+ - Question history tracking
16
+
17
+ ## Requirements
18
+ - Python 3.8+
19
+ - Dependencies listed in `requirements.txt`
20
+
21
+ ## Installation
22
+
23
+ 1. Clone this repository:
24
+ ```bash
25
+ git clone https://github.com/sushantgai/ChartQA.git
26
+ cd ChartQA
27
+ ```
28
+
29
+ 2. Create a virtual environment:
30
+ ```bash
31
+ python -m venv venv
32
+ source venv/bin/activate # On Windows: venv\Scripts\activate
33
+ ```
34
+
35
+ 3. Install the required packages:
36
+ ```bash
37
+ pip install -r requirements.txt
38
+ ```
39
+
40
+ ## Usage
41
+
42
+ 1. Run the Streamlit application:
43
+ ```bash
44
+ streamlit run app.py
45
+ ```
46
+
47
+ 2. Access the application in your web browser at http://localhost:8501
48
+
49
+ 3. Usage steps:
50
+ - Click "Load Model" in the sidebar to initialize the PaliGemma model
51
+ - Upload a chart image or load the sample chart
52
+ - Ask questions about the chart in the text input field
53
+ - Click "Extract Data Points" to convert the chart into tabular data
54
+ - Download the extracted data as CSV if needed
55
+
56
+ ## Model Information
57
+
58
+ This application uses a fine-tuned version of the PaliGemma model specifically trained for chart understanding:
59
+ - Model: ahmed-masry/chartgemma
60
+ - The model can analyze various types of charts including bar charts, line charts, pie charts, and more
61
+
62
+ ## Notes
63
+ - The first load of the model may take some time depending on your hardware
64
+ - GPU acceleration is automatically used if available, otherwise CPU is used
65
+ - Chain-of-Thought reasoning can be toggled on/off in the sidebar
66
+ - For best results, use clear images of charts with readable text and labels
67
+
68
+
69
+ ## Acknowledgements
70
+ - This application uses the PaliGemma model fine-tuned for chart analysis
71
+ - Based on the transformers library from Hugging Face
app.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from PIL import Image
4
+ import io
5
+ import requests
6
+ from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
7
+ import matplotlib.pyplot as plt
8
+ import os
9
+ import pandas as pd
10
+ import re
11
+ import base64
12
+
13
+ # Set page config
14
+ st.set_page_config(
15
+ page_title="Chart Q&A ",
16
+ page_icon="πŸ“Š",
17
+ layout="wide"
18
+ )
19
+
20
+ # Initialize session state variables
21
+ if 'paligemma_model' not in st.session_state:
22
+ st.session_state.paligemma_model = None
23
+ if 'paligemma_processor' not in st.session_state:
24
+ st.session_state.paligemma_processor = None
25
+ if 'device' not in st.session_state:
26
+ st.session_state.device = None
27
+ if 'current_image' not in st.session_state:
28
+ st.session_state.current_image = None
29
+ if 'chat_history' not in st.session_state:
30
+ st.session_state.chat_history = []
31
+ if 'extracted_data' not in st.session_state:
32
+ st.session_state.extracted_data = None
33
+
34
+ # Initialize PaliGemma Model
35
+ @st.cache_resource
36
+ def load_paligemma_model():
37
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
+ model = PaliGemmaForConditionalGeneration.from_pretrained(
39
+ "ahmed-masry/chartgemma",
40
+ torch_dtype=torch.float16
41
+ )
42
+ processor = AutoProcessor.from_pretrained("ahmed-masry/chartgemma")
43
+ model = model.to(device)
44
+ return model, processor, device
45
+
46
+ # Function to download sample chart
47
+ def download_sample_chart(url, filename):
48
+ try:
49
+ if not os.path.exists(filename):
50
+ response = requests.get(url)
51
+ if response.status_code == 200:
52
+ with open(filename, 'wb') as f:
53
+ f.write(response.content)
54
+ return True
55
+ else:
56
+ st.error(f"Failed to download sample chart: {response.status_code}")
57
+ return False
58
+ return True
59
+ except Exception as e:
60
+ st.error(f"Error downloading sample chart: {str(e)}")
61
+ return False
62
+
63
+ # Function to clean model output from print statements and other artifacts
64
+ def clean_model_output(text):
65
+ # Check if the entire response is a print statement and extract its content
66
+ print_match = re.search(r'^print\(["\'](.+?)["\']\)$', text.strip())
67
+ if print_match:
68
+ return print_match.group(1)
69
+
70
+ # Remove all print statements
71
+ text = re.sub(r'print\(.+?\)', '', text, flags=re.DOTALL)
72
+
73
+ # Remove Python code formatting artifacts
74
+ text = re.sub(r'```python|```', '', text)
75
+
76
+ return text.strip()
77
+
78
+ # Function to analyze chart with PaliGemma
79
+ def analyze_chart_with_paligemma(model, processor, device, image, query, use_cot=False):
80
+ try:
81
+ # Add program of thought prefix if CoT is enabled
82
+ if use_cot and not query.startswith("program of thought:"):
83
+ modified_query = f"program of thought: {query}"
84
+ else:
85
+ modified_query = query
86
+
87
+ inputs = processor(text=modified_query, images=image, return_tensors="pt")
88
+ prompt_length = inputs['input_ids'].shape[1]
89
+ inputs = {k: v.to(device) for k, v in inputs.items()}
90
+
91
+ # Generate with progress bar
92
+ progress_bar = st.progress(0)
93
+
94
+ with torch.no_grad():
95
+ generate_ids = model.generate(
96
+ **inputs,
97
+ num_beams=4,
98
+ max_new_tokens=512,
99
+ output_scores=True,
100
+ return_dict_in_generate=True
101
+ )
102
+
103
+ progress_bar.progress(100)
104
+
105
+ output_text = processor.batch_decode(
106
+ generate_ids.sequences[:, prompt_length:],
107
+ skip_special_tokens=True,
108
+ clean_up_tokenization_spaces=False
109
+ )[0]
110
+
111
+ # Clean output from print statements and other artifacts
112
+ output_text = clean_model_output(output_text)
113
+
114
+ return output_text
115
+ except Exception as e:
116
+ st.error(f"Error analyzing chart : {str(e)}")
117
+ return f"Error: {str(e)}"
118
+
119
+ # Function to extract data points from chart
120
+ def extract_data_points(model, processor, device, image):
121
+ try:
122
+ # Special query to extract data points
123
+ extraction_query = "program of thought: Extract all data points from this chart. List each category or series and all its corresponding values in a structured format."
124
+
125
+ with st.spinner("Extracting data points from chart..."):
126
+ result = analyze_chart_with_paligemma(model, processor, device, image, extraction_query)
127
+
128
+ # Parse the result into a DataFrame
129
+ df = parse_chart_data(result)
130
+ return df
131
+ except Exception as e:
132
+ st.error(f"Error extracting data points: {str(e)}")
133
+ return None
134
+
135
+ # Function to parse chart data from model response
136
+ def parse_chart_data(text):
137
+ try:
138
+ # Clean the text from print statements first
139
+ text = clean_model_output(text)
140
+
141
+ data = {}
142
+ lines = text.split('\n')
143
+ current_category = None
144
+
145
+ for line in lines:
146
+ if not line.strip():
147
+ continue
148
+
149
+ if ':' in line and not re.search(r'\d+\.\d+', line):
150
+ current_category = line.split(':')[0].strip()
151
+ data[current_category] = []
152
+ elif current_category and (re.search(r'\d+', line) or ',' in line):
153
+ value_match = re.findall(r'[-+]?\d*\.\d+|\d+', line)
154
+ if value_match:
155
+ data[current_category].extend(value_match)
156
+
157
+ if not data:
158
+ table_pattern = r'(\w+(?:\s\w+)*)\s*[:|]\s*((?:\d+(?:\.\d+)?(?:\s*,\s*\d+(?:\.\d+)?)*)|(?:\d+(?:\.\d+)?))'
159
+ matches = re.findall(table_pattern, text)
160
+ for category, values in matches:
161
+ category = category.strip()
162
+ if category not in data:
163
+ data[category] = []
164
+ if ',' in values:
165
+ values = [v.strip() for v in values.split(',')]
166
+ else:
167
+ values = [values.strip()]
168
+ data[category].extend(values)
169
+
170
+ df = pd.DataFrame(data)
171
+
172
+ if df.empty:
173
+ df = pd.DataFrame({'Extracted_Text': [text]})
174
+
175
+ return df
176
+ except Exception as e:
177
+ st.error(f"Error parsing chart data: {str(e)}")
178
+ return pd.DataFrame({'Raw_Text': [text]})
179
+
180
+ # Function to create a download link for dataframe
181
+ def get_csv_download_link(df, filename="chart_data.csv"):
182
+ csv = df.to_csv(index=False)
183
+ b64 = base64.b64encode(csv.encode()).decode()
184
+ href = f'<a href="data:file/csv;base64,{b64}" download="{filename}">Download CSV File</a>'
185
+ return href
186
+
187
+ # Main UI
188
+ st.title("πŸ“Š Chart Analysis ")
189
+
190
+
191
+ # Sidebar for model loading and options
192
+ with st.sidebar:
193
+ st.header("Model Setup")
194
+
195
+ if st.button("Load Model"):
196
+ with st.spinner("Loading model... This may take a moment"):
197
+ model, processor, device = load_paligemma_model()
198
+ st.session_state.paligemma_model = model
199
+ st.session_state.paligemma_processor = processor
200
+ st.session_state.device = device
201
+ st.success(f"βœ… Model loaded successfully on {device}!")
202
+
203
+ st.header("Options")
204
+ use_cot = st.checkbox("Enable Chain-of-Thought reasoning", value=True,
205
+ help="Adds 'program of thought:' prefix to prompts for better reasoning")
206
+
207
+ st.header("Sample Charts")
208
+ if st.button("Load Sample Chart"):
209
+ sample_url = "https://raw.githubusercontent.com/vis-nlp/ChartQA/main/ChartQA%20Dataset/val/png/multi_col_1229.png"
210
+ sample_filename = "chart_example_1.png"
211
+ if download_sample_chart(sample_url, sample_filename):
212
+ st.session_state.current_image = Image.open(sample_filename).convert('RGB')
213
+ st.success("Sample chart loaded!")
214
+
215
+ # Main content area
216
+ col1, col2 = st.columns([3, 2])
217
+
218
+ with col1:
219
+ st.header("Upload Chart")
220
+ uploaded_file = st.file_uploader("Choose a chart image", type=["png", "jpg", "jpeg"])
221
+
222
+ if uploaded_file is not None:
223
+ try:
224
+ image = Image.open(uploaded_file).convert('RGB')
225
+ st.session_state.current_image = image
226
+ # Reset extracted data when new image is uploaded
227
+ st.session_state.extracted_data = None
228
+ except Exception as e:
229
+ st.error(f"Error opening image: {str(e)}")
230
+
231
+ # Display current image
232
+ if st.session_state.current_image is not None:
233
+ st.image(st.session_state.current_image, caption="Current Chart", use_column_width=True)
234
+
235
+ # Add extract data points button
236
+ if st.session_state.paligemma_model is not None:
237
+ if st.button("Extract Data Points from Chart"):
238
+ df = extract_data_points(
239
+ st.session_state.paligemma_model,
240
+ st.session_state.paligemma_processor,
241
+ st.session_state.device,
242
+ st.session_state.current_image
243
+ )
244
+ if df is not None:
245
+ st.session_state.extracted_data = df
246
+ st.success("Data points extracted successfully!")
247
+
248
+ with col2:
249
+ st.header("Ask Questions")
250
+
251
+ if st.session_state.paligemma_model is None:
252
+ st.warning("Please load the model first from the sidebar.")
253
+ elif st.session_state.current_image is None:
254
+ st.warning("Please upload a chart image or load a sample chart.")
255
+ else:
256
+ # Query input
257
+ query = st.text_input("Ask a question about the chart:",
258
+ placeholder="E.g., What is the highest value in the chart?")
259
+
260
+ if query:
261
+ if st.button("Analyze Chart"):
262
+ with st.spinner("Analyzing chart "):
263
+ answer = analyze_chart_with_paligemma(
264
+ st.session_state.paligemma_model,
265
+ st.session_state.paligemma_processor,
266
+ st.session_state.device,
267
+ st.session_state.current_image,
268
+ query,
269
+ use_cot
270
+ )
271
+
272
+ # Add to chat history
273
+ st.session_state.chat_history.append({
274
+ "question": query,
275
+ "answer": answer
276
+ })
277
+
278
+ # Display answer
279
+ st.subheader("Answer")
280
+ st.write(answer)
281
+
282
+ # Display extracted data if available
283
+ if st.session_state.extracted_data is not None:
284
+ st.header("Extracted Data Points")
285
+ st.dataframe(st.session_state.extracted_data)
286
+
287
+ # Download button for CSV
288
+ st.markdown(get_csv_download_link(st.session_state.extracted_data), unsafe_allow_html=True)
289
+
290
+ # Display chat history
291
+ if st.session_state.chat_history:
292
+ st.header("Question History")
293
+ for i, qa in enumerate(reversed(st.session_state.chat_history)):
294
+ with st.expander(f"Q: {qa['question']}", expanded=(i==0)):
295
+ st.markdown(f"**A:** {qa['answer']}")
chart_example_1.png ADDED
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ streamlit==1.31.0
2
+ torch==2.2.0
3
+ Pillow==10.2.0
4
+ requests==2.31.0
5
+ transformers==4.38.0
6
+ matplotlib==3.8.2
7
+ pandas==2.2.0
8
+ base64==1.0.0