Knightmovies's picture
Update app.py
97aecbf verified
raw
history blame
15 kB
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import streamlit as st
import cv2
import numpy as np
from PIL import Image
import torch
from transformers import TableTransformerForObjectDetection, DetrImageProcessor
import pytesseract
# ==============================================================================
# UI config
# ==============================================================================
st.set_page_config(page_title="Document AI Toolkit", page_icon="🤖", layout="wide")
st.markdown("""
<style>
.main .block-container { max-width: 900px; padding: 2rem; }
</style>
""", unsafe_allow_html=True)
# ==============================================================================
# Load model (only if you use the Transformer option)
# ==============================================================================
@st.cache_resource
def load_model():
model = TableTransformerForObjectDetection.from_pretrained(
"microsoft/table-transformer-structure-recognition"
)
proc = DetrImageProcessor.from_pretrained(
"microsoft/table-transformer-structure-recognition"
)
model.eval()
return model, proc
# Lazy-init so the app runs even if you only use line-based
model, processor = None, None
# ==============================================================================
# Page-safe preprocessing (no inner-table cropping) + robust deskew
# ==============================================================================
def order_points(pts):
xSorted = pts[np.argsort(pts[:, 0]), :]
leftMost, rightMost = xSorted[:2, :], xSorted[2:, :]
leftMost = leftMost[np.argsort(leftMost[:, 1]), :]
(tl, bl) = leftMost
D = np.linalg.norm(rightMost - tl, axis=1)
(br, tr) = rightMost[np.argsort(D)[::-1], :]
return np.array([tl, tr, br, bl], dtype="float32")
def four_point_warp(image, pts):
pts = order_points(pts.astype("float32"))
(tl, tr, br, bl) = pts
widthA = np.linalg.norm(br - bl)
widthB = np.linalg.norm(tr - tl)
heightA = np.linalg.norm(tr - br)
heightB = np.linalg.norm(tl - bl)
maxW, maxH = int(max(widthA, widthB)), int(max(heightA, heightB))
dst = np.array([[0,0],[maxW-1,0],[maxW-1,maxH-1],[0,maxH-1]], dtype="float32")
M = cv2.getPerspectiveTransform(pts, dst)
return cv2.warpPerspective(image, M, (maxW, maxH))
def touches_border(cnt, w, h, m=12):
xs = cnt[:,:,0]; ys = cnt[:,:,1]
sides = int(xs.min() < m) + int(w - xs.max() < m) + int(ys.min() < m) + int(h - ys.max() < m)
return sides >= 3
def find_page_quad(image, min_area_ratio=0.85):
h, w = image.shape[:2]
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
gray = cv2.GaussianBlur(gray, (5,5), 0)
edges = cv2.Canny(gray, 50, 150)
edges = cv2.dilate(edges, np.ones((3,3), np.uint8), 1)
cnts, _ = cv2.findContours(edges, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
if not cnts:
return None
best, best_area = None, 0
img_area = w*h
for c in cnts:
peri = cv2.arcLength(c, True)
approx = cv2.approxPolyDP(c, 0.02*peri, True)
if len(approx) != 4:
continue
area = cv2.contourArea(approx)
if area > best_area and touches_border(approx, w, h) and (area/img_area) >= min_area_ratio:
best_area, best = area, approx.reshape(4,2)
return best
def correct_orientation(image):
try:
osd = pytesseract.image_to_osd(image, output_type=pytesseract.Output.DICT, timeout=5)
rotation = int(osd.get("rotate", 0))
if rotation:
rot_map = {90: cv2.ROTATE_90_CLOCKWISE, 180: cv2.ROTATE_180, 270: cv2.ROTATE_90_COUNTERCLOCKWISE}
return cv2.rotate(image, rot_map[rotation])
return image
except Exception:
# Fallback heuristic
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
thr = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
rots = {0: thr, 90: cv2.rotate(thr, cv2.ROTATE_90_CLOCKWISE),
180: cv2.rotate(thr, cv2.ROTATE_180), 270: cv2.rotate(thr, cv2.ROTATE_90_COUNTERCLOCKWISE)}
best = 0; best_count = -1
for ang, img in rots.items():
try:
data = pytesseract.image_to_data(img, output_type=pytesseract.Output.DICT, timeout=5)
cnt = sum(1 for i,c in enumerate(data['conf'])
if str(c).isdigit() and int(c) > 10 and data['width'][i] > data['height'][i])
if cnt > best_count: best, best_count = ang, cnt
except Exception: pass
if best:
rot_map = {90: cv2.ROTATE_90_CLOCKWISE, 180: cv2.ROTATE_180, 270: cv2.ROTATE_90_COUNTERCLOCKWISE}
return cv2.rotate(image, rot_map[best])
return image
def deskew_hough(image):
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
edges = cv2.Canny(gray, 50, 150, apertureSize=3)
lines = cv2.HoughLines(edges, 1, np.pi/180, threshold=200)
if lines is None:
thr = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1]
coords = np.column_stack(np.where(thr == 0))
if len(coords) < 100:
return image
angle = cv2.minAreaRect(coords)[-1]
angle = -(90 + angle) if angle < -45 else -angle
else:
angs = []
for rho,theta in lines[:,0]:
ang = np.degrees(theta) - 90.0
if ang < -45: ang += 90
if ang > 45: ang -= 90
angs.append(ang)
angle = float(np.median(angs))
if abs(angle) < 0.2:
return image
(h,w) = image.shape[:2]
M = cv2.getRotationMatrix2D((w//2, h//2), angle, 1.0)
return cv2.warpAffine(image, M, (w,h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REPLICATE)
def preprocess(image):
oriented = correct_orientation(image)
quad = find_page_quad(oriented, min_area_ratio=0.85)
if quad is not None:
warped = four_point_warp(oriented, quad)
else:
warped = oriented
return deskew_hough(warped)
# ==============================================================================
# LINE-BASED table structure (precise on ruled tables)
# ==============================================================================
def merge_close_positions(positions, tol=8):
if not positions:
return []
positions = sorted(positions)
merged, cluster = [], [positions[0]]
for p in positions[1:]:
if abs(p - cluster[-1]) <= tol:
cluster.append(p)
else:
merged.append(int(round(np.mean(cluster))))
cluster = [p]
merged.append(int(round(np.mean(cluster))))
return merged
def detect_grid_lines(image_bgr):
"""Return x and y grid line positions (pixels)."""
gray = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2GRAY)
# Good default for scans/photos
binary = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_MEAN_C,
cv2.THRESH_BINARY_INV, 15, 7)
h, w = gray.shape
# Kernels sized to image
vert_len = max(10, h // 40)
horz_len = max(10, w // 40)
vertical_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, vert_len))
horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (horz_len, 1))
# Extract vertical lines
vtemp = cv2.erode(binary, vertical_kernel, iterations=1)
vlines = cv2.dilate(vtemp, vertical_kernel, iterations=1)
# Extract horizontal lines
htemp = cv2.erode(binary, horizontal_kernel, iterations=1)
hlines = cv2.dilate(htemp, horizontal_kernel, iterations=1)
# Find vertical positions by components that span most of the height
xs = []
cnts,_ = cv2.findContours(vlines, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for c in cnts:
x,y,wc,hc = cv2.boundingRect(c)
if hc >= 0.65*h and wc <= 0.04*w:
xs.append(x + wc//2)
# Find horizontal positions by components that span most of the width
ys = []
cnts,_ = cv2.findContours(hlines, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for c in cnts:
x,y,wc,hc = cv2.boundingRect(c)
if wc >= 0.65*w and hc <= 0.04*h:
ys.append(y + hc//2)
# Merge duplicates / double rules
# Adaptive tolerance: 8px or 30% of median spacing
def adaptive_tol(vals):
if len(vals) < 3:
return 8
diffs = np.diff(sorted(vals))
med = np.median(diffs)
return int(max(6, min(20, 0.3*med)))
xs = merge_close_positions(xs, tol=adaptive_tol(xs))
ys = merge_close_positions(ys, tol=adaptive_tol(ys))
# Require at least 2 lines each to make a grid
if len(xs) < 2 or len(ys) < 2:
return [], []
return xs, ys
def draw_grid(image_bgr, xs, ys):
img = image_bgr.copy()
# Outer table box (min/max)
cv2.rectangle(img, (xs[0], ys[0]), (xs[-1], ys[-1]), (255, 0, 255), 2)
# Horizontal lines (green)
for y in ys:
cv2.line(img, (xs[0], y), (xs[-1], y), (0, 255, 0), 2)
# Vertical lines (blue)
for x in xs:
cv2.line(img, (x, ys[0]), (x, ys[-1]), (255, 0, 0), 2)
return img
def line_based_table_structure(image_bgr):
xs, ys = detect_grid_lines(image_bgr)
if not xs or not ys:
# Nothing reliable found; just return original
return image_bgr.copy()
return draw_grid(image_bgr, xs, ys)
# ==============================================================================
# Transformer (TATR) structure (kept for comparison)
# ==============================================================================
def transformer_table_structure(image_bgr):
global model, processor
if model is None or processor is None:
model, processor = load_model()
image_pil = Image.fromarray(cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB))
inputs = processor(images=image_pil, return_tensors="pt")
with torch.inference_mode():
outputs = model(**inputs)
h, w = image_bgr.shape[:2]
target_sizes = torch.tensor([[h, w]], dtype=torch.float32)
results = processor.post_process_object_detection(outputs, threshold=0.6, target_sizes=target_sizes)[0]
img = image_bgr.copy()
colors = {"table row": (0,255,0), "table column": (255,0,0), "table": (255,0,255)}
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
cls = model.config.id2label[label.item()]
if cls in colors:
x1,y1,x2,y2 = [int(round(v)) for v in box.tolist()]
x1 = max(0,min(x1,w-1)); x2 = max(0,min(x2,w-1))
y1 = max(0,min(y1,h-1)); y2 = max(0,min(y2,h-1))
cv2.rectangle(img, (x1,y1), (x2,y2), colors[cls], 2)
return img
# ==============================================================================
# Streamlit app
# ==============================================================================
if "stage" not in st.session_state:
st.session_state.stage = "upload"
st.session_state.original_image = None
st.session_state.processed_image = None
st.session_state.annotated_image = None
with st.sidebar:
st.title("🤖 Document AI Toolkit")
st.markdown("---")
method = st.radio(
"Structure method",
["Line-based (ruled tables) – Recommended", "Transformer (TATR)"],
)
if st.button("🔄 Start Over", use_container_width=True):
for k in list(st.session_state.keys()): del st.session_state[k]
st.rerun()
if st.session_state.stage == "upload":
st.header("Step 1: Upload Image")
uploaded = st.file_uploader("Upload your document", type=["jpg","jpeg","png"], label_visibility="collapsed")
if uploaded:
file_bytes = np.asarray(bytearray(uploaded.read()), dtype=np.uint8)
st.session_state.original_image = cv2.imdecode(file_bytes, 1)
st.session_state.stage = "processing"; st.rerun()
elif st.session_state.stage == "processing":
st.header("Step 2: Pre-process")
if st.button("▶️ Start Pre-processing", use_container_width=True, type="primary"):
with st.spinner("Correcting orientation • detecting page • deskewing…"):
st.session_state.processed_image = preprocess(st.session_state.original_image)
st.session_state.stage = "analysis"; st.rerun()
elif st.session_state.stage == "analysis":
st.header("Step 3: Analyze Table")
if st.button("📊 Find Table Structure", use_container_width=True, type="primary"):
with st.spinner("Detecting grid…"):
if method.startswith("Line-based"):
st.session_state.annotated_image = line_based_table_structure(
st.session_state.processed_image
)
else:
st.session_state.annotated_image = transformer_table_structure(
st.session_state.processed_image
)
st.session_state.stage = "done"; st.rerun()
st.title("Document Processing Workflow")
exp1 = st.expander("Step 1: Upload Original Image", expanded=(st.session_state.stage=="upload"))
with exp1:
if st.session_state.original_image is None:
st.info("Please upload a document image using the sidebar to begin.")
else:
st.image(cv2.cvtColor(st.session_state.original_image, cv2.COLOR_BGR2RGB), use_container_width=True)
st.success("Image uploaded.")
if st.session_state.original_image is not None:
exp2 = st.expander("Step 2: Pre-process Document", expanded=(st.session_state.stage in ["processing","analysis"]))
with exp2:
if st.session_state.processed_image is None:
st.info("Click 'Start Pre-processing' in the sidebar.")
else:
st.image(cv2.cvtColor(st.session_state.processed_image, cv2.COLOR_BGR2RGB),
caption="Oriented • Page-safe (no inner crop) • Deskewed", use_container_width=True)
st.success("Pre-processing complete.")
if st.session_state.processed_image is not None:
exp3 = st.expander("Step 3: Analyze Table Structure", expanded=(st.session_state.stage=="done"))
with exp3:
if st.session_state.annotated_image is None:
st.info("Click 'Find Table Structure' in the sidebar.")
else:
tab1, tab2 = st.tabs(["✅ Corrected Document", "📊 Table Structure"])
with tab1:
st.image(cv2.cvtColor(st.session_state.processed_image, cv2.COLOR_BGR2RGB), use_container_width=True)
_, buf = cv2.imencode(".jpg", st.session_state.processed_image)
st.download_button("📥 Download Clean Image", data=buf.tobytes(),
file_name="corrected.jpg", mime="image/jpeg", use_container_width=True)
with tab2:
st.image(cv2.cvtColor(st.session_state.annotated_image, cv2.COLOR_BGR2RGB), use_container_width=True)
st.success("Analysis complete.")