Update app.py
Browse files
app.py
CHANGED
@@ -1,3 +1,6 @@
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
import cv2
|
3 |
import numpy as np
|
@@ -5,7 +8,6 @@ from PIL import Image
|
|
5 |
import torch
|
6 |
from transformers import TableTransformerForObjectDetection, DetrImageProcessor
|
7 |
import pytesseract
|
8 |
-
from scipy.spatial import distance as dist
|
9 |
|
10 |
# ==============================================================================
|
11 |
# UI config
|
@@ -18,7 +20,7 @@ st.markdown("""
|
|
18 |
""", unsafe_allow_html=True)
|
19 |
|
20 |
# ==============================================================================
|
21 |
-
# Load model
|
22 |
# ==============================================================================
|
23 |
@st.cache_resource
|
24 |
def load_model():
|
@@ -31,10 +33,11 @@ def load_model():
|
|
31 |
model.eval()
|
32 |
return model, proc
|
33 |
|
34 |
-
|
|
|
35 |
|
36 |
# ==============================================================================
|
37 |
-
#
|
38 |
# ==============================================================================
|
39 |
def order_points(pts):
|
40 |
xSorted = pts[np.argsort(pts[:, 0]), :]
|
@@ -63,33 +66,27 @@ def touches_border(cnt, w, h, m=12):
|
|
63 |
return sides >= 3
|
64 |
|
65 |
def find_page_quad(image, min_area_ratio=0.85):
|
66 |
-
"""Return 4-point page quad only if it looks like the outer page."""
|
67 |
h, w = image.shape[:2]
|
68 |
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
69 |
gray = cv2.GaussianBlur(gray, (5,5), 0)
|
70 |
edges = cv2.Canny(gray, 50, 150)
|
71 |
edges = cv2.dilate(edges, np.ones((3,3), np.uint8), 1)
|
72 |
-
|
73 |
cnts, _ = cv2.findContours(edges, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
|
74 |
if not cnts:
|
75 |
return None
|
76 |
-
|
77 |
-
|
78 |
-
img_area = w * h
|
79 |
for c in cnts:
|
80 |
peri = cv2.arcLength(c, True)
|
81 |
-
approx = cv2.approxPolyDP(c, 0.02
|
82 |
-
if len(approx) != 4:
|
83 |
continue
|
84 |
area = cv2.contourArea(approx)
|
85 |
-
if area > best_area and touches_border(approx, w, h) and (area
|
86 |
-
best_area = area
|
87 |
-
|
88 |
-
|
89 |
-
return best # None if not a real page quad
|
90 |
|
91 |
def correct_orientation(image):
|
92 |
-
"""Rotate according to Tesseract OSD (CW angle), fallback heuristic."""
|
93 |
try:
|
94 |
osd = pytesseract.image_to_osd(image, output_type=pytesseract.Output.DICT, timeout=5)
|
95 |
rotation = int(osd.get("rotate", 0))
|
@@ -98,6 +95,7 @@ def correct_orientation(image):
|
|
98 |
return cv2.rotate(image, rot_map[rotation])
|
99 |
return image
|
100 |
except Exception:
|
|
|
101 |
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
102 |
thr = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
|
103 |
rots = {0: thr, 90: cv2.rotate(thr, cv2.ROTATE_90_CLOCKWISE),
|
@@ -116,12 +114,10 @@ def correct_orientation(image):
|
|
116 |
return image
|
117 |
|
118 |
def deskew_hough(image):
|
119 |
-
"""Use dominant Hough-line angle for small residual tilt (no cropping)."""
|
120 |
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
121 |
edges = cv2.Canny(gray, 50, 150, apertureSize=3)
|
122 |
lines = cv2.HoughLines(edges, 1, np.pi/180, threshold=200)
|
123 |
if lines is None:
|
124 |
-
# fallback: minAreaRect on ink pixels
|
125 |
thr = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1]
|
126 |
coords = np.column_stack(np.where(thr == 0))
|
127 |
if len(coords) < 100:
|
@@ -143,27 +139,117 @@ def deskew_hough(image):
|
|
143 |
return cv2.warpAffine(image, M, (w,h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REPLICATE)
|
144 |
|
145 |
def preprocess(image):
|
146 |
-
"""Full pipeline that NEVER crops to the inner table."""
|
147 |
-
# 1) fix upside-down / sideways
|
148 |
oriented = correct_orientation(image)
|
149 |
-
|
150 |
-
# 2) warp only if we really found the OUTER page; otherwise keep full image
|
151 |
quad = find_page_quad(oriented, min_area_ratio=0.85)
|
152 |
if quad is not None:
|
153 |
warped = four_point_warp(oriented, quad)
|
154 |
else:
|
155 |
-
warped = oriented
|
156 |
-
|
157 |
-
# 3) small deskew using Hough median angle so grid aligns to axes
|
158 |
return deskew_hough(warped)
|
159 |
|
160 |
-
|
161 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
image_pil = Image.fromarray(cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB))
|
163 |
inputs = processor(images=image_pil, return_tensors="pt")
|
164 |
with torch.inference_mode():
|
165 |
outputs = model(**inputs)
|
166 |
-
|
167 |
h, w = image_bgr.shape[:2]
|
168 |
target_sizes = torch.tensor([[h, w]], dtype=torch.float32)
|
169 |
results = processor.post_process_object_detection(outputs, threshold=0.6, target_sizes=target_sizes)[0]
|
@@ -191,6 +277,12 @@ if "stage" not in st.session_state:
|
|
191 |
with st.sidebar:
|
192 |
st.title("🤖 Document AI Toolkit")
|
193 |
st.markdown("---")
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
if st.button("🔄 Start Over", use_container_width=True):
|
195 |
for k in list(st.session_state.keys()): del st.session_state[k]
|
196 |
st.rerun()
|
@@ -206,17 +298,22 @@ with st.sidebar:
|
|
206 |
elif st.session_state.stage == "processing":
|
207 |
st.header("Step 2: Pre-process")
|
208 |
if st.button("▶️ Start Pre-processing", use_container_width=True, type="primary"):
|
209 |
-
with st.spinner("Correcting orientation • detecting
|
210 |
st.session_state.processed_image = preprocess(st.session_state.original_image)
|
211 |
st.session_state.stage = "analysis"; st.rerun()
|
212 |
|
213 |
elif st.session_state.stage == "analysis":
|
214 |
st.header("Step 3: Analyze Table")
|
215 |
if st.button("📊 Find Table Structure", use_container_width=True, type="primary"):
|
216 |
-
with st.spinner("
|
217 |
-
|
218 |
-
st.session_state.
|
219 |
-
|
|
|
|
|
|
|
|
|
|
|
220 |
st.session_state.stage = "done"; st.rerun()
|
221 |
|
222 |
st.title("Document Processing Workflow")
|
|
|
1 |
+
import os
|
2 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
3 |
+
|
4 |
import streamlit as st
|
5 |
import cv2
|
6 |
import numpy as np
|
|
|
8 |
import torch
|
9 |
from transformers import TableTransformerForObjectDetection, DetrImageProcessor
|
10 |
import pytesseract
|
|
|
11 |
|
12 |
# ==============================================================================
|
13 |
# UI config
|
|
|
20 |
""", unsafe_allow_html=True)
|
21 |
|
22 |
# ==============================================================================
|
23 |
+
# Load model (only if you use the Transformer option)
|
24 |
# ==============================================================================
|
25 |
@st.cache_resource
|
26 |
def load_model():
|
|
|
33 |
model.eval()
|
34 |
return model, proc
|
35 |
|
36 |
+
# Lazy-init so the app runs even if you only use line-based
|
37 |
+
model, processor = None, None
|
38 |
|
39 |
# ==============================================================================
|
40 |
+
# Page-safe preprocessing (no inner-table cropping) + robust deskew
|
41 |
# ==============================================================================
|
42 |
def order_points(pts):
|
43 |
xSorted = pts[np.argsort(pts[:, 0]), :]
|
|
|
66 |
return sides >= 3
|
67 |
|
68 |
def find_page_quad(image, min_area_ratio=0.85):
|
|
|
69 |
h, w = image.shape[:2]
|
70 |
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
71 |
gray = cv2.GaussianBlur(gray, (5,5), 0)
|
72 |
edges = cv2.Canny(gray, 50, 150)
|
73 |
edges = cv2.dilate(edges, np.ones((3,3), np.uint8), 1)
|
|
|
74 |
cnts, _ = cv2.findContours(edges, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
|
75 |
if not cnts:
|
76 |
return None
|
77 |
+
best, best_area = None, 0
|
78 |
+
img_area = w*h
|
|
|
79 |
for c in cnts:
|
80 |
peri = cv2.arcLength(c, True)
|
81 |
+
approx = cv2.approxPolyDP(c, 0.02*peri, True)
|
82 |
+
if len(approx) != 4:
|
83 |
continue
|
84 |
area = cv2.contourArea(approx)
|
85 |
+
if area > best_area and touches_border(approx, w, h) and (area/img_area) >= min_area_ratio:
|
86 |
+
best_area, best = area, approx.reshape(4,2)
|
87 |
+
return best
|
|
|
|
|
88 |
|
89 |
def correct_orientation(image):
|
|
|
90 |
try:
|
91 |
osd = pytesseract.image_to_osd(image, output_type=pytesseract.Output.DICT, timeout=5)
|
92 |
rotation = int(osd.get("rotate", 0))
|
|
|
95 |
return cv2.rotate(image, rot_map[rotation])
|
96 |
return image
|
97 |
except Exception:
|
98 |
+
# Fallback heuristic
|
99 |
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
100 |
thr = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
|
101 |
rots = {0: thr, 90: cv2.rotate(thr, cv2.ROTATE_90_CLOCKWISE),
|
|
|
114 |
return image
|
115 |
|
116 |
def deskew_hough(image):
|
|
|
117 |
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
118 |
edges = cv2.Canny(gray, 50, 150, apertureSize=3)
|
119 |
lines = cv2.HoughLines(edges, 1, np.pi/180, threshold=200)
|
120 |
if lines is None:
|
|
|
121 |
thr = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1]
|
122 |
coords = np.column_stack(np.where(thr == 0))
|
123 |
if len(coords) < 100:
|
|
|
139 |
return cv2.warpAffine(image, M, (w,h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REPLICATE)
|
140 |
|
141 |
def preprocess(image):
|
|
|
|
|
142 |
oriented = correct_orientation(image)
|
|
|
|
|
143 |
quad = find_page_quad(oriented, min_area_ratio=0.85)
|
144 |
if quad is not None:
|
145 |
warped = four_point_warp(oriented, quad)
|
146 |
else:
|
147 |
+
warped = oriented
|
|
|
|
|
148 |
return deskew_hough(warped)
|
149 |
|
150 |
+
# ==============================================================================
|
151 |
+
# LINE-BASED table structure (precise on ruled tables)
|
152 |
+
# ==============================================================================
|
153 |
+
def merge_close_positions(positions, tol=8):
|
154 |
+
if not positions:
|
155 |
+
return []
|
156 |
+
positions = sorted(positions)
|
157 |
+
merged, cluster = [], [positions[0]]
|
158 |
+
for p in positions[1:]:
|
159 |
+
if abs(p - cluster[-1]) <= tol:
|
160 |
+
cluster.append(p)
|
161 |
+
else:
|
162 |
+
merged.append(int(round(np.mean(cluster))))
|
163 |
+
cluster = [p]
|
164 |
+
merged.append(int(round(np.mean(cluster))))
|
165 |
+
return merged
|
166 |
+
|
167 |
+
def detect_grid_lines(image_bgr):
|
168 |
+
"""Return x and y grid line positions (pixels)."""
|
169 |
+
gray = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2GRAY)
|
170 |
+
# Good default for scans/photos
|
171 |
+
binary = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_MEAN_C,
|
172 |
+
cv2.THRESH_BINARY_INV, 15, 7)
|
173 |
+
|
174 |
+
h, w = gray.shape
|
175 |
+
# Kernels sized to image
|
176 |
+
vert_len = max(10, h // 40)
|
177 |
+
horz_len = max(10, w // 40)
|
178 |
+
|
179 |
+
vertical_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, vert_len))
|
180 |
+
horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (horz_len, 1))
|
181 |
+
|
182 |
+
# Extract vertical lines
|
183 |
+
vtemp = cv2.erode(binary, vertical_kernel, iterations=1)
|
184 |
+
vlines = cv2.dilate(vtemp, vertical_kernel, iterations=1)
|
185 |
+
# Extract horizontal lines
|
186 |
+
htemp = cv2.erode(binary, horizontal_kernel, iterations=1)
|
187 |
+
hlines = cv2.dilate(htemp, horizontal_kernel, iterations=1)
|
188 |
+
|
189 |
+
# Find vertical positions by components that span most of the height
|
190 |
+
xs = []
|
191 |
+
cnts,_ = cv2.findContours(vlines, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
192 |
+
for c in cnts:
|
193 |
+
x,y,wc,hc = cv2.boundingRect(c)
|
194 |
+
if hc >= 0.65*h and wc <= 0.04*w:
|
195 |
+
xs.append(x + wc//2)
|
196 |
+
|
197 |
+
# Find horizontal positions by components that span most of the width
|
198 |
+
ys = []
|
199 |
+
cnts,_ = cv2.findContours(hlines, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
200 |
+
for c in cnts:
|
201 |
+
x,y,wc,hc = cv2.boundingRect(c)
|
202 |
+
if wc >= 0.65*w and hc <= 0.04*h:
|
203 |
+
ys.append(y + hc//2)
|
204 |
+
|
205 |
+
# Merge duplicates / double rules
|
206 |
+
# Adaptive tolerance: 8px or 30% of median spacing
|
207 |
+
def adaptive_tol(vals):
|
208 |
+
if len(vals) < 3:
|
209 |
+
return 8
|
210 |
+
diffs = np.diff(sorted(vals))
|
211 |
+
med = np.median(diffs)
|
212 |
+
return int(max(6, min(20, 0.3*med)))
|
213 |
+
|
214 |
+
xs = merge_close_positions(xs, tol=adaptive_tol(xs))
|
215 |
+
ys = merge_close_positions(ys, tol=adaptive_tol(ys))
|
216 |
+
|
217 |
+
# Require at least 2 lines each to make a grid
|
218 |
+
if len(xs) < 2 or len(ys) < 2:
|
219 |
+
return [], []
|
220 |
+
return xs, ys
|
221 |
+
|
222 |
+
def draw_grid(image_bgr, xs, ys):
|
223 |
+
img = image_bgr.copy()
|
224 |
+
# Outer table box (min/max)
|
225 |
+
cv2.rectangle(img, (xs[0], ys[0]), (xs[-1], ys[-1]), (255, 0, 255), 2)
|
226 |
+
# Horizontal lines (green)
|
227 |
+
for y in ys:
|
228 |
+
cv2.line(img, (xs[0], y), (xs[-1], y), (0, 255, 0), 2)
|
229 |
+
# Vertical lines (blue)
|
230 |
+
for x in xs:
|
231 |
+
cv2.line(img, (x, ys[0]), (x, ys[-1]), (255, 0, 0), 2)
|
232 |
+
return img
|
233 |
+
|
234 |
+
def line_based_table_structure(image_bgr):
|
235 |
+
xs, ys = detect_grid_lines(image_bgr)
|
236 |
+
if not xs or not ys:
|
237 |
+
# Nothing reliable found; just return original
|
238 |
+
return image_bgr.copy()
|
239 |
+
return draw_grid(image_bgr, xs, ys)
|
240 |
+
|
241 |
+
# ==============================================================================
|
242 |
+
# Transformer (TATR) structure (kept for comparison)
|
243 |
+
# ==============================================================================
|
244 |
+
def transformer_table_structure(image_bgr):
|
245 |
+
global model, processor
|
246 |
+
if model is None or processor is None:
|
247 |
+
model, processor = load_model()
|
248 |
+
|
249 |
image_pil = Image.fromarray(cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB))
|
250 |
inputs = processor(images=image_pil, return_tensors="pt")
|
251 |
with torch.inference_mode():
|
252 |
outputs = model(**inputs)
|
|
|
253 |
h, w = image_bgr.shape[:2]
|
254 |
target_sizes = torch.tensor([[h, w]], dtype=torch.float32)
|
255 |
results = processor.post_process_object_detection(outputs, threshold=0.6, target_sizes=target_sizes)[0]
|
|
|
277 |
with st.sidebar:
|
278 |
st.title("🤖 Document AI Toolkit")
|
279 |
st.markdown("---")
|
280 |
+
|
281 |
+
method = st.radio(
|
282 |
+
"Structure method",
|
283 |
+
["Line-based (ruled tables) – Recommended", "Transformer (TATR)"],
|
284 |
+
)
|
285 |
+
|
286 |
if st.button("🔄 Start Over", use_container_width=True):
|
287 |
for k in list(st.session_state.keys()): del st.session_state[k]
|
288 |
st.rerun()
|
|
|
298 |
elif st.session_state.stage == "processing":
|
299 |
st.header("Step 2: Pre-process")
|
300 |
if st.button("▶️ Start Pre-processing", use_container_width=True, type="primary"):
|
301 |
+
with st.spinner("Correcting orientation • detecting page • deskewing…"):
|
302 |
st.session_state.processed_image = preprocess(st.session_state.original_image)
|
303 |
st.session_state.stage = "analysis"; st.rerun()
|
304 |
|
305 |
elif st.session_state.stage == "analysis":
|
306 |
st.header("Step 3: Analyze Table")
|
307 |
if st.button("📊 Find Table Structure", use_container_width=True, type="primary"):
|
308 |
+
with st.spinner("Detecting grid…"):
|
309 |
+
if method.startswith("Line-based"):
|
310 |
+
st.session_state.annotated_image = line_based_table_structure(
|
311 |
+
st.session_state.processed_image
|
312 |
+
)
|
313 |
+
else:
|
314 |
+
st.session_state.annotated_image = transformer_table_structure(
|
315 |
+
st.session_state.processed_image
|
316 |
+
)
|
317 |
st.session_state.stage = "done"; st.rerun()
|
318 |
|
319 |
st.title("Document Processing Workflow")
|