Knightmovies commited on
Commit
97aecbf
·
verified ·
1 Parent(s): 7e2ff95

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +131 -34
app.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  import streamlit as st
2
  import cv2
3
  import numpy as np
@@ -5,7 +8,6 @@ from PIL import Image
5
  import torch
6
  from transformers import TableTransformerForObjectDetection, DetrImageProcessor
7
  import pytesseract
8
- from scipy.spatial import distance as dist
9
 
10
  # ==============================================================================
11
  # UI config
@@ -18,7 +20,7 @@ st.markdown("""
18
  """, unsafe_allow_html=True)
19
 
20
  # ==============================================================================
21
- # Load model
22
  # ==============================================================================
23
  @st.cache_resource
24
  def load_model():
@@ -31,10 +33,11 @@ def load_model():
31
  model.eval()
32
  return model, proc
33
 
34
- model, processor = load_model()
 
35
 
36
  # ==============================================================================
37
- # Helpers
38
  # ==============================================================================
39
  def order_points(pts):
40
  xSorted = pts[np.argsort(pts[:, 0]), :]
@@ -63,33 +66,27 @@ def touches_border(cnt, w, h, m=12):
63
  return sides >= 3
64
 
65
  def find_page_quad(image, min_area_ratio=0.85):
66
- """Return 4-point page quad only if it looks like the outer page."""
67
  h, w = image.shape[:2]
68
  gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
69
  gray = cv2.GaussianBlur(gray, (5,5), 0)
70
  edges = cv2.Canny(gray, 50, 150)
71
  edges = cv2.dilate(edges, np.ones((3,3), np.uint8), 1)
72
-
73
  cnts, _ = cv2.findContours(edges, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
74
  if not cnts:
75
  return None
76
-
77
- best = None; best_area = 0
78
- img_area = w * h
79
  for c in cnts:
80
  peri = cv2.arcLength(c, True)
81
- approx = cv2.approxPolyDP(c, 0.02 * peri, True)
82
- if len(approx) != 4:
83
  continue
84
  area = cv2.contourArea(approx)
85
- if area > best_area and touches_border(approx, w, h) and (area / img_area) >= min_area_ratio:
86
- best_area = area
87
- best = approx.reshape(4,2)
88
-
89
- return best # None if not a real page quad
90
 
91
  def correct_orientation(image):
92
- """Rotate according to Tesseract OSD (CW angle), fallback heuristic."""
93
  try:
94
  osd = pytesseract.image_to_osd(image, output_type=pytesseract.Output.DICT, timeout=5)
95
  rotation = int(osd.get("rotate", 0))
@@ -98,6 +95,7 @@ def correct_orientation(image):
98
  return cv2.rotate(image, rot_map[rotation])
99
  return image
100
  except Exception:
 
101
  gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
102
  thr = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
103
  rots = {0: thr, 90: cv2.rotate(thr, cv2.ROTATE_90_CLOCKWISE),
@@ -116,12 +114,10 @@ def correct_orientation(image):
116
  return image
117
 
118
  def deskew_hough(image):
119
- """Use dominant Hough-line angle for small residual tilt (no cropping)."""
120
  gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
121
  edges = cv2.Canny(gray, 50, 150, apertureSize=3)
122
  lines = cv2.HoughLines(edges, 1, np.pi/180, threshold=200)
123
  if lines is None:
124
- # fallback: minAreaRect on ink pixels
125
  thr = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1]
126
  coords = np.column_stack(np.where(thr == 0))
127
  if len(coords) < 100:
@@ -143,27 +139,117 @@ def deskew_hough(image):
143
  return cv2.warpAffine(image, M, (w,h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REPLICATE)
144
 
145
  def preprocess(image):
146
- """Full pipeline that NEVER crops to the inner table."""
147
- # 1) fix upside-down / sideways
148
  oriented = correct_orientation(image)
149
-
150
- # 2) warp only if we really found the OUTER page; otherwise keep full image
151
  quad = find_page_quad(oriented, min_area_ratio=0.85)
152
  if quad is not None:
153
  warped = four_point_warp(oriented, quad)
154
  else:
155
- warped = oriented # keep full page (prevents accidental crop to table)
156
-
157
- # 3) small deskew using Hough median angle so grid aligns to axes
158
  return deskew_hough(warped)
159
 
160
- def extract_and_draw_table_structure(image_bgr):
161
- """Run TableTransformer and draw table/table row/table column boxes."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  image_pil = Image.fromarray(cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB))
163
  inputs = processor(images=image_pil, return_tensors="pt")
164
  with torch.inference_mode():
165
  outputs = model(**inputs)
166
-
167
  h, w = image_bgr.shape[:2]
168
  target_sizes = torch.tensor([[h, w]], dtype=torch.float32)
169
  results = processor.post_process_object_detection(outputs, threshold=0.6, target_sizes=target_sizes)[0]
@@ -191,6 +277,12 @@ if "stage" not in st.session_state:
191
  with st.sidebar:
192
  st.title("🤖 Document AI Toolkit")
193
  st.markdown("---")
 
 
 
 
 
 
194
  if st.button("🔄 Start Over", use_container_width=True):
195
  for k in list(st.session_state.keys()): del st.session_state[k]
196
  st.rerun()
@@ -206,17 +298,22 @@ with st.sidebar:
206
  elif st.session_state.stage == "processing":
207
  st.header("Step 2: Pre-process")
208
  if st.button("▶️ Start Pre-processing", use_container_width=True, type="primary"):
209
- with st.spinner("Correcting orientation • detecting true page • deskewing…"):
210
  st.session_state.processed_image = preprocess(st.session_state.original_image)
211
  st.session_state.stage = "analysis"; st.rerun()
212
 
213
  elif st.session_state.stage == "analysis":
214
  st.header("Step 3: Analyze Table")
215
  if st.button("📊 Find Table Structure", use_container_width=True, type="primary"):
216
- with st.spinner("Running Table Transformer…"):
217
- st.session_state.annotated_image = extract_and_draw_table_structure(
218
- st.session_state.processed_image
219
- )
 
 
 
 
 
220
  st.session_state.stage = "done"; st.rerun()
221
 
222
  st.title("Document Processing Workflow")
 
1
+ import os
2
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
3
+
4
  import streamlit as st
5
  import cv2
6
  import numpy as np
 
8
  import torch
9
  from transformers import TableTransformerForObjectDetection, DetrImageProcessor
10
  import pytesseract
 
11
 
12
  # ==============================================================================
13
  # UI config
 
20
  """, unsafe_allow_html=True)
21
 
22
  # ==============================================================================
23
+ # Load model (only if you use the Transformer option)
24
  # ==============================================================================
25
  @st.cache_resource
26
  def load_model():
 
33
  model.eval()
34
  return model, proc
35
 
36
+ # Lazy-init so the app runs even if you only use line-based
37
+ model, processor = None, None
38
 
39
  # ==============================================================================
40
+ # Page-safe preprocessing (no inner-table cropping) + robust deskew
41
  # ==============================================================================
42
  def order_points(pts):
43
  xSorted = pts[np.argsort(pts[:, 0]), :]
 
66
  return sides >= 3
67
 
68
  def find_page_quad(image, min_area_ratio=0.85):
 
69
  h, w = image.shape[:2]
70
  gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
71
  gray = cv2.GaussianBlur(gray, (5,5), 0)
72
  edges = cv2.Canny(gray, 50, 150)
73
  edges = cv2.dilate(edges, np.ones((3,3), np.uint8), 1)
 
74
  cnts, _ = cv2.findContours(edges, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
75
  if not cnts:
76
  return None
77
+ best, best_area = None, 0
78
+ img_area = w*h
 
79
  for c in cnts:
80
  peri = cv2.arcLength(c, True)
81
+ approx = cv2.approxPolyDP(c, 0.02*peri, True)
82
+ if len(approx) != 4:
83
  continue
84
  area = cv2.contourArea(approx)
85
+ if area > best_area and touches_border(approx, w, h) and (area/img_area) >= min_area_ratio:
86
+ best_area, best = area, approx.reshape(4,2)
87
+ return best
 
 
88
 
89
  def correct_orientation(image):
 
90
  try:
91
  osd = pytesseract.image_to_osd(image, output_type=pytesseract.Output.DICT, timeout=5)
92
  rotation = int(osd.get("rotate", 0))
 
95
  return cv2.rotate(image, rot_map[rotation])
96
  return image
97
  except Exception:
98
+ # Fallback heuristic
99
  gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
100
  thr = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
101
  rots = {0: thr, 90: cv2.rotate(thr, cv2.ROTATE_90_CLOCKWISE),
 
114
  return image
115
 
116
  def deskew_hough(image):
 
117
  gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
118
  edges = cv2.Canny(gray, 50, 150, apertureSize=3)
119
  lines = cv2.HoughLines(edges, 1, np.pi/180, threshold=200)
120
  if lines is None:
 
121
  thr = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1]
122
  coords = np.column_stack(np.where(thr == 0))
123
  if len(coords) < 100:
 
139
  return cv2.warpAffine(image, M, (w,h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REPLICATE)
140
 
141
  def preprocess(image):
 
 
142
  oriented = correct_orientation(image)
 
 
143
  quad = find_page_quad(oriented, min_area_ratio=0.85)
144
  if quad is not None:
145
  warped = four_point_warp(oriented, quad)
146
  else:
147
+ warped = oriented
 
 
148
  return deskew_hough(warped)
149
 
150
+ # ==============================================================================
151
+ # LINE-BASED table structure (precise on ruled tables)
152
+ # ==============================================================================
153
+ def merge_close_positions(positions, tol=8):
154
+ if not positions:
155
+ return []
156
+ positions = sorted(positions)
157
+ merged, cluster = [], [positions[0]]
158
+ for p in positions[1:]:
159
+ if abs(p - cluster[-1]) <= tol:
160
+ cluster.append(p)
161
+ else:
162
+ merged.append(int(round(np.mean(cluster))))
163
+ cluster = [p]
164
+ merged.append(int(round(np.mean(cluster))))
165
+ return merged
166
+
167
+ def detect_grid_lines(image_bgr):
168
+ """Return x and y grid line positions (pixels)."""
169
+ gray = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2GRAY)
170
+ # Good default for scans/photos
171
+ binary = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_MEAN_C,
172
+ cv2.THRESH_BINARY_INV, 15, 7)
173
+
174
+ h, w = gray.shape
175
+ # Kernels sized to image
176
+ vert_len = max(10, h // 40)
177
+ horz_len = max(10, w // 40)
178
+
179
+ vertical_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, vert_len))
180
+ horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (horz_len, 1))
181
+
182
+ # Extract vertical lines
183
+ vtemp = cv2.erode(binary, vertical_kernel, iterations=1)
184
+ vlines = cv2.dilate(vtemp, vertical_kernel, iterations=1)
185
+ # Extract horizontal lines
186
+ htemp = cv2.erode(binary, horizontal_kernel, iterations=1)
187
+ hlines = cv2.dilate(htemp, horizontal_kernel, iterations=1)
188
+
189
+ # Find vertical positions by components that span most of the height
190
+ xs = []
191
+ cnts,_ = cv2.findContours(vlines, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
192
+ for c in cnts:
193
+ x,y,wc,hc = cv2.boundingRect(c)
194
+ if hc >= 0.65*h and wc <= 0.04*w:
195
+ xs.append(x + wc//2)
196
+
197
+ # Find horizontal positions by components that span most of the width
198
+ ys = []
199
+ cnts,_ = cv2.findContours(hlines, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
200
+ for c in cnts:
201
+ x,y,wc,hc = cv2.boundingRect(c)
202
+ if wc >= 0.65*w and hc <= 0.04*h:
203
+ ys.append(y + hc//2)
204
+
205
+ # Merge duplicates / double rules
206
+ # Adaptive tolerance: 8px or 30% of median spacing
207
+ def adaptive_tol(vals):
208
+ if len(vals) < 3:
209
+ return 8
210
+ diffs = np.diff(sorted(vals))
211
+ med = np.median(diffs)
212
+ return int(max(6, min(20, 0.3*med)))
213
+
214
+ xs = merge_close_positions(xs, tol=adaptive_tol(xs))
215
+ ys = merge_close_positions(ys, tol=adaptive_tol(ys))
216
+
217
+ # Require at least 2 lines each to make a grid
218
+ if len(xs) < 2 or len(ys) < 2:
219
+ return [], []
220
+ return xs, ys
221
+
222
+ def draw_grid(image_bgr, xs, ys):
223
+ img = image_bgr.copy()
224
+ # Outer table box (min/max)
225
+ cv2.rectangle(img, (xs[0], ys[0]), (xs[-1], ys[-1]), (255, 0, 255), 2)
226
+ # Horizontal lines (green)
227
+ for y in ys:
228
+ cv2.line(img, (xs[0], y), (xs[-1], y), (0, 255, 0), 2)
229
+ # Vertical lines (blue)
230
+ for x in xs:
231
+ cv2.line(img, (x, ys[0]), (x, ys[-1]), (255, 0, 0), 2)
232
+ return img
233
+
234
+ def line_based_table_structure(image_bgr):
235
+ xs, ys = detect_grid_lines(image_bgr)
236
+ if not xs or not ys:
237
+ # Nothing reliable found; just return original
238
+ return image_bgr.copy()
239
+ return draw_grid(image_bgr, xs, ys)
240
+
241
+ # ==============================================================================
242
+ # Transformer (TATR) structure (kept for comparison)
243
+ # ==============================================================================
244
+ def transformer_table_structure(image_bgr):
245
+ global model, processor
246
+ if model is None or processor is None:
247
+ model, processor = load_model()
248
+
249
  image_pil = Image.fromarray(cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB))
250
  inputs = processor(images=image_pil, return_tensors="pt")
251
  with torch.inference_mode():
252
  outputs = model(**inputs)
 
253
  h, w = image_bgr.shape[:2]
254
  target_sizes = torch.tensor([[h, w]], dtype=torch.float32)
255
  results = processor.post_process_object_detection(outputs, threshold=0.6, target_sizes=target_sizes)[0]
 
277
  with st.sidebar:
278
  st.title("🤖 Document AI Toolkit")
279
  st.markdown("---")
280
+
281
+ method = st.radio(
282
+ "Structure method",
283
+ ["Line-based (ruled tables) – Recommended", "Transformer (TATR)"],
284
+ )
285
+
286
  if st.button("🔄 Start Over", use_container_width=True):
287
  for k in list(st.session_state.keys()): del st.session_state[k]
288
  st.rerun()
 
298
  elif st.session_state.stage == "processing":
299
  st.header("Step 2: Pre-process")
300
  if st.button("▶️ Start Pre-processing", use_container_width=True, type="primary"):
301
+ with st.spinner("Correcting orientation • detecting page • deskewing…"):
302
  st.session_state.processed_image = preprocess(st.session_state.original_image)
303
  st.session_state.stage = "analysis"; st.rerun()
304
 
305
  elif st.session_state.stage == "analysis":
306
  st.header("Step 3: Analyze Table")
307
  if st.button("📊 Find Table Structure", use_container_width=True, type="primary"):
308
+ with st.spinner("Detecting grid…"):
309
+ if method.startswith("Line-based"):
310
+ st.session_state.annotated_image = line_based_table_structure(
311
+ st.session_state.processed_image
312
+ )
313
+ else:
314
+ st.session_state.annotated_image = transformer_table_structure(
315
+ st.session_state.processed_image
316
+ )
317
  st.session_state.stage = "done"; st.rerun()
318
 
319
  st.title("Document Processing Workflow")