| 
							 | 
						import os | 
					
					
						
						| 
							 | 
						os.environ["TOKENIZERS_PARALLELISM"] = "false" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import streamlit as st | 
					
					
						
						| 
							 | 
						import cv2 | 
					
					
						
						| 
							 | 
						import numpy as np | 
					
					
						
						| 
							 | 
						from PIL import Image | 
					
					
						
						| 
							 | 
						import torch | 
					
					
						
						| 
							 | 
						from transformers import TableTransformerForObjectDetection, DetrImageProcessor | 
					
					
						
						| 
							 | 
						import pytesseract | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						st.set_page_config(page_title="Document AI Toolkit", page_icon="🤖", layout="wide") | 
					
					
						
						| 
							 | 
						st.markdown(""" | 
					
					
						
						| 
							 | 
						<style> | 
					
					
						
						| 
							 | 
						.main .block-container { max-width: 900px; padding: 2rem; } | 
					
					
						
						| 
							 | 
						</style> | 
					
					
						
						| 
							 | 
						""", unsafe_allow_html=True) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						@st.cache_resource | 
					
					
						
						| 
							 | 
						def load_model(): | 
					
					
						
						| 
							 | 
						    model = TableTransformerForObjectDetection.from_pretrained( | 
					
					
						
						| 
							 | 
						        "microsoft/table-transformer-structure-recognition" | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						    proc = DetrImageProcessor.from_pretrained( | 
					
					
						
						| 
							 | 
						        "microsoft/table-transformer-structure-recognition" | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						    model.eval() | 
					
					
						
						| 
							 | 
						    return model, proc | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						model, processor = None, None | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						def order_points(pts): | 
					
					
						
						| 
							 | 
						    xSorted = pts[np.argsort(pts[:, 0]), :] | 
					
					
						
						| 
							 | 
						    leftMost, rightMost = xSorted[:2, :], xSorted[2:, :] | 
					
					
						
						| 
							 | 
						    leftMost = leftMost[np.argsort(leftMost[:, 1]), :] | 
					
					
						
						| 
							 | 
						    (tl, bl) = leftMost | 
					
					
						
						| 
							 | 
						    D = np.linalg.norm(rightMost - tl, axis=1) | 
					
					
						
						| 
							 | 
						    (br, tr) = rightMost[np.argsort(D)[::-1], :] | 
					
					
						
						| 
							 | 
						    return np.array([tl, tr, br, bl], dtype="float32") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def four_point_warp(image, pts): | 
					
					
						
						| 
							 | 
						    pts = order_points(pts.astype("float32")) | 
					
					
						
						| 
							 | 
						    (tl, tr, br, bl) = pts | 
					
					
						
						| 
							 | 
						    widthA = np.linalg.norm(br - bl) | 
					
					
						
						| 
							 | 
						    widthB = np.linalg.norm(tr - tl) | 
					
					
						
						| 
							 | 
						    heightA = np.linalg.norm(tr - br) | 
					
					
						
						| 
							 | 
						    heightB = np.linalg.norm(tl - bl) | 
					
					
						
						| 
							 | 
						    maxW, maxH = int(max(widthA, widthB)), int(max(heightA, heightB)) | 
					
					
						
						| 
							 | 
						    dst = np.array([[0,0],[maxW-1,0],[maxW-1,maxH-1],[0,maxH-1]], dtype="float32") | 
					
					
						
						| 
							 | 
						    M = cv2.getPerspectiveTransform(pts, dst) | 
					
					
						
						| 
							 | 
						    return cv2.warpPerspective(image, M, (maxW, maxH)) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def touches_border(cnt, w, h, m=12): | 
					
					
						
						| 
							 | 
						    xs = cnt[:,:,0]; ys = cnt[:,:,1] | 
					
					
						
						| 
							 | 
						    sides = int(xs.min() < m) + int(w - xs.max() < m) + int(ys.min() < m) + int(h - ys.max() < m) | 
					
					
						
						| 
							 | 
						    return sides >= 3 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def find_page_quad(image, min_area_ratio=0.85): | 
					
					
						
						| 
							 | 
						    h, w = image.shape[:2] | 
					
					
						
						| 
							 | 
						    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | 
					
					
						
						| 
							 | 
						    gray = cv2.GaussianBlur(gray, (5,5), 0) | 
					
					
						
						| 
							 | 
						    edges = cv2.Canny(gray, 50, 150) | 
					
					
						
						| 
							 | 
						    edges = cv2.dilate(edges, np.ones((3,3), np.uint8), 1) | 
					
					
						
						| 
							 | 
						    cnts, _ = cv2.findContours(edges, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) | 
					
					
						
						| 
							 | 
						    if not cnts:  | 
					
					
						
						| 
							 | 
						        return None | 
					
					
						
						| 
							 | 
						    best, best_area = None, 0 | 
					
					
						
						| 
							 | 
						    img_area = w*h | 
					
					
						
						| 
							 | 
						    for c in cnts: | 
					
					
						
						| 
							 | 
						        peri = cv2.arcLength(c, True) | 
					
					
						
						| 
							 | 
						        approx = cv2.approxPolyDP(c, 0.02*peri, True) | 
					
					
						
						| 
							 | 
						        if len(approx) != 4:  | 
					
					
						
						| 
							 | 
						            continue | 
					
					
						
						| 
							 | 
						        area = cv2.contourArea(approx) | 
					
					
						
						| 
							 | 
						        if area > best_area and touches_border(approx, w, h) and (area/img_area) >= min_area_ratio: | 
					
					
						
						| 
							 | 
						            best_area, best = area, approx.reshape(4,2) | 
					
					
						
						| 
							 | 
						    return best | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def correct_orientation(image): | 
					
					
						
						| 
							 | 
						    try: | 
					
					
						
						| 
							 | 
						        osd = pytesseract.image_to_osd(image, output_type=pytesseract.Output.DICT, timeout=5) | 
					
					
						
						| 
							 | 
						        rotation = int(osd.get("rotate", 0)) | 
					
					
						
						| 
							 | 
						        if rotation: | 
					
					
						
						| 
							 | 
						            rot_map = {90: cv2.ROTATE_90_CLOCKWISE, 180: cv2.ROTATE_180, 270: cv2.ROTATE_90_COUNTERCLOCKWISE} | 
					
					
						
						| 
							 | 
						            return cv2.rotate(image, rot_map[rotation]) | 
					
					
						
						| 
							 | 
						        return image | 
					
					
						
						| 
							 | 
						    except Exception: | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | 
					
					
						
						| 
							 | 
						        thr = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1] | 
					
					
						
						| 
							 | 
						        rots = {0: thr, 90: cv2.rotate(thr, cv2.ROTATE_90_CLOCKWISE), | 
					
					
						
						| 
							 | 
						                180: cv2.rotate(thr, cv2.ROTATE_180), 270: cv2.rotate(thr, cv2.ROTATE_90_COUNTERCLOCKWISE)} | 
					
					
						
						| 
							 | 
						        best = 0; best_count = -1 | 
					
					
						
						| 
							 | 
						        for ang, img in rots.items(): | 
					
					
						
						| 
							 | 
						            try: | 
					
					
						
						| 
							 | 
						                data = pytesseract.image_to_data(img, output_type=pytesseract.Output.DICT, timeout=5) | 
					
					
						
						| 
							 | 
						                cnt = sum(1 for i,c in enumerate(data['conf']) | 
					
					
						
						| 
							 | 
						                          if str(c).isdigit() and int(c) > 10 and data['width'][i] > data['height'][i]) | 
					
					
						
						| 
							 | 
						                if cnt > best_count: best, best_count = ang, cnt | 
					
					
						
						| 
							 | 
						            except Exception: pass | 
					
					
						
						| 
							 | 
						        if best: | 
					
					
						
						| 
							 | 
						            rot_map = {90: cv2.ROTATE_90_CLOCKWISE, 180: cv2.ROTATE_180, 270: cv2.ROTATE_90_COUNTERCLOCKWISE} | 
					
					
						
						| 
							 | 
						            return cv2.rotate(image, rot_map[best]) | 
					
					
						
						| 
							 | 
						        return image | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def deskew_hough(image): | 
					
					
						
						| 
							 | 
						    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | 
					
					
						
						| 
							 | 
						    edges = cv2.Canny(gray, 50, 150, apertureSize=3) | 
					
					
						
						| 
							 | 
						    lines = cv2.HoughLines(edges, 1, np.pi/180, threshold=200) | 
					
					
						
						| 
							 | 
						    if lines is None: | 
					
					
						
						| 
							 | 
						        thr = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1] | 
					
					
						
						| 
							 | 
						        coords = np.column_stack(np.where(thr == 0)) | 
					
					
						
						| 
							 | 
						        if len(coords) < 100:  | 
					
					
						
						| 
							 | 
						            return image | 
					
					
						
						| 
							 | 
						        angle = cv2.minAreaRect(coords)[-1] | 
					
					
						
						| 
							 | 
						        angle = -(90 + angle) if angle < -45 else -angle | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        angs = [] | 
					
					
						
						| 
							 | 
						        for rho,theta in lines[:,0]: | 
					
					
						
						| 
							 | 
						            ang = np.degrees(theta) - 90.0 | 
					
					
						
						| 
							 | 
						            if ang < -45: ang += 90 | 
					
					
						
						| 
							 | 
						            if ang > 45:  ang -= 90 | 
					
					
						
						| 
							 | 
						            angs.append(ang) | 
					
					
						
						| 
							 | 
						        angle = float(np.median(angs)) | 
					
					
						
						| 
							 | 
						    if abs(angle) < 0.2: | 
					
					
						
						| 
							 | 
						        return image | 
					
					
						
						| 
							 | 
						    (h,w) = image.shape[:2] | 
					
					
						
						| 
							 | 
						    M = cv2.getRotationMatrix2D((w//2, h//2), angle, 1.0) | 
					
					
						
						| 
							 | 
						    return cv2.warpAffine(image, M, (w,h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REPLICATE) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def preprocess(image): | 
					
					
						
						| 
							 | 
						    oriented = correct_orientation(image) | 
					
					
						
						| 
							 | 
						    quad = find_page_quad(oriented, min_area_ratio=0.85) | 
					
					
						
						| 
							 | 
						    if quad is not None: | 
					
					
						
						| 
							 | 
						        warped = four_point_warp(oriented, quad) | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        warped = oriented | 
					
					
						
						| 
							 | 
						    return deskew_hough(warped) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						def merge_close_positions(positions, tol=8): | 
					
					
						
						| 
							 | 
						    if not positions: | 
					
					
						
						| 
							 | 
						        return [] | 
					
					
						
						| 
							 | 
						    positions = sorted(positions) | 
					
					
						
						| 
							 | 
						    merged, cluster = [], [positions[0]] | 
					
					
						
						| 
							 | 
						    for p in positions[1:]: | 
					
					
						
						| 
							 | 
						        if abs(p - cluster[-1]) <= tol: | 
					
					
						
						| 
							 | 
						            cluster.append(p) | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            merged.append(int(round(np.mean(cluster)))) | 
					
					
						
						| 
							 | 
						            cluster = [p] | 
					
					
						
						| 
							 | 
						    merged.append(int(round(np.mean(cluster)))) | 
					
					
						
						| 
							 | 
						    return merged | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def detect_grid_lines(image_bgr): | 
					
					
						
						| 
							 | 
						    """Return x and y grid line positions (pixels).""" | 
					
					
						
						| 
							 | 
						    gray = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2GRAY) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    binary = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_MEAN_C, | 
					
					
						
						| 
							 | 
						                                   cv2.THRESH_BINARY_INV, 15, 7) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    h, w = gray.shape | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    vert_len = max(10, h // 40) | 
					
					
						
						| 
							 | 
						    horz_len = max(10, w // 40) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    vertical_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, vert_len)) | 
					
					
						
						| 
							 | 
						    horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (horz_len, 1)) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    vtemp = cv2.erode(binary, vertical_kernel, iterations=1) | 
					
					
						
						| 
							 | 
						    vlines = cv2.dilate(vtemp, vertical_kernel, iterations=1) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    htemp = cv2.erode(binary, horizontal_kernel, iterations=1) | 
					
					
						
						| 
							 | 
						    hlines = cv2.dilate(htemp, horizontal_kernel, iterations=1) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    xs = [] | 
					
					
						
						| 
							 | 
						    cnts,_ = cv2.findContours(vlines, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | 
					
					
						
						| 
							 | 
						    for c in cnts: | 
					
					
						
						| 
							 | 
						        x,y,wc,hc = cv2.boundingRect(c) | 
					
					
						
						| 
							 | 
						        if hc >= 0.65*h and wc <= 0.04*w: | 
					
					
						
						| 
							 | 
						            xs.append(x + wc//2) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    ys = [] | 
					
					
						
						| 
							 | 
						    cnts,_ = cv2.findContours(hlines, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | 
					
					
						
						| 
							 | 
						    for c in cnts: | 
					
					
						
						| 
							 | 
						        x,y,wc,hc = cv2.boundingRect(c) | 
					
					
						
						| 
							 | 
						        if wc >= 0.65*w and hc <= 0.04*h: | 
					
					
						
						| 
							 | 
						            ys.append(y + hc//2) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    def adaptive_tol(vals): | 
					
					
						
						| 
							 | 
						        if len(vals) < 3:  | 
					
					
						
						| 
							 | 
						            return 8 | 
					
					
						
						| 
							 | 
						        diffs = np.diff(sorted(vals)) | 
					
					
						
						| 
							 | 
						        med = np.median(diffs) | 
					
					
						
						| 
							 | 
						        return int(max(6, min(20, 0.3*med))) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    xs = merge_close_positions(xs, tol=adaptive_tol(xs)) | 
					
					
						
						| 
							 | 
						    ys = merge_close_positions(ys, tol=adaptive_tol(ys)) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    if len(xs) < 2 or len(ys) < 2: | 
					
					
						
						| 
							 | 
						        return [], [] | 
					
					
						
						| 
							 | 
						    return xs, ys | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def draw_grid(image_bgr, xs, ys): | 
					
					
						
						| 
							 | 
						    img = image_bgr.copy() | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    cv2.rectangle(img, (xs[0], ys[0]), (xs[-1], ys[-1]), (255, 0, 255), 2) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    for y in ys: | 
					
					
						
						| 
							 | 
						        cv2.line(img, (xs[0], y), (xs[-1], y), (0, 255, 0), 2) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    for x in xs: | 
					
					
						
						| 
							 | 
						        cv2.line(img, (x, ys[0]), (x, ys[-1]), (255, 0, 0), 2) | 
					
					
						
						| 
							 | 
						    return img | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def line_based_table_structure(image_bgr): | 
					
					
						
						| 
							 | 
						    xs, ys = detect_grid_lines(image_bgr) | 
					
					
						
						| 
							 | 
						    if not xs or not ys: | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        return image_bgr.copy() | 
					
					
						
						| 
							 | 
						    return draw_grid(image_bgr, xs, ys) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						def transformer_table_structure(image_bgr): | 
					
					
						
						| 
							 | 
						    global model, processor | 
					
					
						
						| 
							 | 
						    if model is None or processor is None: | 
					
					
						
						| 
							 | 
						        model, processor = load_model() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    image_pil = Image.fromarray(cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)) | 
					
					
						
						| 
							 | 
						    inputs = processor(images=image_pil, return_tensors="pt") | 
					
					
						
						| 
							 | 
						    with torch.inference_mode(): | 
					
					
						
						| 
							 | 
						        outputs = model(**inputs) | 
					
					
						
						| 
							 | 
						    h, w = image_bgr.shape[:2] | 
					
					
						
						| 
							 | 
						    target_sizes = torch.tensor([[h, w]], dtype=torch.float32) | 
					
					
						
						| 
							 | 
						    results = processor.post_process_object_detection(outputs, threshold=0.6, target_sizes=target_sizes)[0] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    img = image_bgr.copy() | 
					
					
						
						| 
							 | 
						    colors = {"table row": (0,255,0), "table column": (255,0,0), "table": (255,0,255)} | 
					
					
						
						| 
							 | 
						    for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): | 
					
					
						
						| 
							 | 
						        cls = model.config.id2label[label.item()] | 
					
					
						
						| 
							 | 
						        if cls in colors: | 
					
					
						
						| 
							 | 
						            x1,y1,x2,y2 = [int(round(v)) for v in box.tolist()] | 
					
					
						
						| 
							 | 
						            x1 = max(0,min(x1,w-1)); x2 = max(0,min(x2,w-1)) | 
					
					
						
						| 
							 | 
						            y1 = max(0,min(y1,h-1)); y2 = max(0,min(y2,h-1)) | 
					
					
						
						| 
							 | 
						            cv2.rectangle(img, (x1,y1), (x2,y2), colors[cls], 2) | 
					
					
						
						| 
							 | 
						    return img | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						if "stage" not in st.session_state: | 
					
					
						
						| 
							 | 
						    st.session_state.stage = "upload" | 
					
					
						
						| 
							 | 
						    st.session_state.original_image = None | 
					
					
						
						| 
							 | 
						    st.session_state.processed_image = None | 
					
					
						
						| 
							 | 
						    st.session_state.annotated_image = None | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						with st.sidebar: | 
					
					
						
						| 
							 | 
						    st.title("🤖 Document AI Toolkit") | 
					
					
						
						| 
							 | 
						    st.markdown("---") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    method = st.radio( | 
					
					
						
						| 
							 | 
						        "Structure method", | 
					
					
						
						| 
							 | 
						        ["Line-based (ruled tables) – Recommended", "Transformer (TATR)"], | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if st.button("🔄 Start Over", use_container_width=True): | 
					
					
						
						| 
							 | 
						        for k in list(st.session_state.keys()): del st.session_state[k] | 
					
					
						
						| 
							 | 
						        st.rerun() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if st.session_state.stage == "upload": | 
					
					
						
						| 
							 | 
						        st.header("Step 1: Upload Image") | 
					
					
						
						| 
							 | 
						        uploaded = st.file_uploader("Upload your document", type=["jpg","jpeg","png"], label_visibility="collapsed") | 
					
					
						
						| 
							 | 
						        if uploaded: | 
					
					
						
						| 
							 | 
						            file_bytes = np.asarray(bytearray(uploaded.read()), dtype=np.uint8) | 
					
					
						
						| 
							 | 
						            st.session_state.original_image = cv2.imdecode(file_bytes, 1) | 
					
					
						
						| 
							 | 
						            st.session_state.stage = "processing"; st.rerun() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    elif st.session_state.stage == "processing": | 
					
					
						
						| 
							 | 
						        st.header("Step 2: Pre-process") | 
					
					
						
						| 
							 | 
						        if st.button("▶️ Start Pre-processing", use_container_width=True, type="primary"): | 
					
					
						
						| 
							 | 
						            with st.spinner("Correcting orientation • detecting page • deskewing…"): | 
					
					
						
						| 
							 | 
						                st.session_state.processed_image = preprocess(st.session_state.original_image) | 
					
					
						
						| 
							 | 
						            st.session_state.stage = "analysis"; st.rerun() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    elif st.session_state.stage == "analysis": | 
					
					
						
						| 
							 | 
						        st.header("Step 3: Analyze Table") | 
					
					
						
						| 
							 | 
						        if st.button("📊 Find Table Structure", use_container_width=True, type="primary"): | 
					
					
						
						| 
							 | 
						            with st.spinner("Detecting grid…"): | 
					
					
						
						| 
							 | 
						                if method.startswith("Line-based"): | 
					
					
						
						| 
							 | 
						                    st.session_state.annotated_image = line_based_table_structure( | 
					
					
						
						| 
							 | 
						                        st.session_state.processed_image | 
					
					
						
						| 
							 | 
						                    ) | 
					
					
						
						| 
							 | 
						                else: | 
					
					
						
						| 
							 | 
						                    st.session_state.annotated_image = transformer_table_structure( | 
					
					
						
						| 
							 | 
						                        st.session_state.processed_image | 
					
					
						
						| 
							 | 
						                    ) | 
					
					
						
						| 
							 | 
						            st.session_state.stage = "done"; st.rerun() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						st.title("Document Processing Workflow") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						exp1 = st.expander("Step 1: Upload Original Image", expanded=(st.session_state.stage=="upload")) | 
					
					
						
						| 
							 | 
						with exp1: | 
					
					
						
						| 
							 | 
						    if st.session_state.original_image is None: | 
					
					
						
						| 
							 | 
						        st.info("Please upload a document image using the sidebar to begin.") | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        st.image(cv2.cvtColor(st.session_state.original_image, cv2.COLOR_BGR2RGB), use_container_width=True) | 
					
					
						
						| 
							 | 
						        st.success("Image uploaded.") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						if st.session_state.original_image is not None: | 
					
					
						
						| 
							 | 
						    exp2 = st.expander("Step 2: Pre-process Document", expanded=(st.session_state.stage in ["processing","analysis"])) | 
					
					
						
						| 
							 | 
						    with exp2: | 
					
					
						
						| 
							 | 
						        if st.session_state.processed_image is None: | 
					
					
						
						| 
							 | 
						            st.info("Click 'Start Pre-processing' in the sidebar.") | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            st.image(cv2.cvtColor(st.session_state.processed_image, cv2.COLOR_BGR2RGB), | 
					
					
						
						| 
							 | 
						                     caption="Oriented • Page-safe (no inner crop) • Deskewed", use_container_width=True) | 
					
					
						
						| 
							 | 
						            st.success("Pre-processing complete.") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						if st.session_state.processed_image is not None: | 
					
					
						
						| 
							 | 
						    exp3 = st.expander("Step 3: Analyze Table Structure", expanded=(st.session_state.stage=="done")) | 
					
					
						
						| 
							 | 
						    with exp3: | 
					
					
						
						| 
							 | 
						        if st.session_state.annotated_image is None: | 
					
					
						
						| 
							 | 
						            st.info("Click 'Find Table Structure' in the sidebar.") | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            tab1, tab2 = st.tabs(["✅ Corrected Document", "📊 Table Structure"]) | 
					
					
						
						| 
							 | 
						            with tab1: | 
					
					
						
						| 
							 | 
						                st.image(cv2.cvtColor(st.session_state.processed_image, cv2.COLOR_BGR2RGB), use_container_width=True) | 
					
					
						
						| 
							 | 
						                _, buf = cv2.imencode(".jpg", st.session_state.processed_image) | 
					
					
						
						| 
							 | 
						                st.download_button("📥 Download Clean Image", data=buf.tobytes(), | 
					
					
						
						| 
							 | 
						                                   file_name="corrected.jpg", mime="image/jpeg", use_container_width=True) | 
					
					
						
						| 
							 | 
						            with tab2: | 
					
					
						
						| 
							 | 
						                st.image(cv2.cvtColor(st.session_state.annotated_image, cv2.COLOR_BGR2RGB), use_container_width=True) | 
					
					
						
						| 
							 | 
						            st.success("Analysis complete.") | 
					
					
						
						| 
							 | 
						
 |