import streamlit as st import warnings warnings.simplefilter("ignore", UserWarning) from uuid import uuid4 from laia.scripts.htr.decode_ctc import run as decode from laia.common.arguments import CommonArgs, DataArgs, TrainerArgs, DecodeArgs import sys from tempfile import NamedTemporaryFile, mkdtemp from pathlib import Path from contextlib import redirect_stdout import re from PIL import Image from bidi.algorithm import get_display import multiprocessing from ultralytics import YOLO import cv2 import numpy as np import pandas as pd import logging from typing import List, Optional # Configure logging logging.getLogger("lightning.pytorch").setLevel(logging.ERROR) # Load YOLOv8 model model = YOLO('model.pt') images = Path(mkdtemp()) DEFAULT_HEIGHT = 128 TEXT_DIRECTION = "LTR" NUM_WORKERS = multiprocessing.cpu_count() # Regex pattern for extracting results IMAGE_ID_PATTERN = r"(?P[-a-z0-9]{36})" CONFIDENCE_PATTERN = r"(?P[0-9.]+)" # For line TEXT_PATTERN = r"\s*(?P.*)\s*" LINE_PREDICTION = re.compile(rf"{IMAGE_ID_PATTERN} {CONFIDENCE_PATTERN} {TEXT_PATTERN}") def get_width(image, height=DEFAULT_HEIGHT): aspect_ratio = image.width / image.height return height * aspect_ratio def simplify_polygons(polygons: List[np.ndarray], approx_level: float = 0.01) -> List[Optional[np.ndarray]]: """Simplify polygon contours using Douglas-Peucker algorithm. Args: polygons: List of polygon contours approx_level: Approximation level (0-1), lower values mean more simplification Returns: List of simplified polygons (or None for invalid polygons) """ result = [] for polygon in polygons: if len(polygon) < 4: result.append(None) continue perimeter = cv2.arcLength(polygon, True) approx = cv2.approxPolyDP(polygon, approx_level * perimeter, True) if len(approx) < 4: result.append(None) continue result.append(approx.squeeze()) return result def predict(model_name, input_img): model_dir = 'catmus-medieval' temperature = 2.0 batch_size = 1 weights_path = f"{model_dir}/weights.ckpt" syms_path = f"{model_dir}/syms.txt" language_model_params = {"language_model_weight": 1.0} use_language_model = True if use_language_model: language_model_params.update({ "language_model_path": f"{model_dir}/language_model.binary", "lexicon_path": f"{model_dir}/lexicon.txt", "tokens_path": f"{model_dir}/tokens.txt", }) common_args = CommonArgs( checkpoint="weights.ckpt", train_path=f"{model_dir}", experiment_dirname="", ) data_args = DataArgs(batch_size=batch_size, color_mode="L") trainer_args = TrainerArgs(progress_bar_refresh_rate=0) decode_args = DecodeArgs( include_img_ids=True, join_string="", convert_spaces=True, print_line_confidence_scores=True, print_word_confidence_scores=False, temperature=temperature, use_language_model=use_language_model, **language_model_params, ) with NamedTemporaryFile() as pred_stdout, NamedTemporaryFile() as img_list: image_id = uuid4() input_img = input_img.resize((int(get_width(input_img)), DEFAULT_HEIGHT)) input_img.save(f"{images}/{image_id}.jpg") Path(img_list.name).write_text("\n".join([str(image_id)])) with redirect_stdout(open(pred_stdout.name, mode="w")): decode( syms=str(syms_path), img_list=img_list.name, img_dirs=[str(images)], common=common_args, data=data_args, trainer=trainer_args, decode=decode_args, num_workers=1, ) sys.stdout.flush() predictions = Path(pred_stdout.name).read_text().strip().splitlines() _, score, text = LINE_PREDICTION.match(predictions[0]).groups() if TEXT_DIRECTION == "RTL": return input_img, {"text": get_display(text), "score": score} else: return input_img, {"text": text, "score": score} def process_image(image): # Perform inference on an image, select textline only results = model(image, classes=0) img_cv2 = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) masks = results[0].masks polygons = [] texts = [] if masks is not None: # Get masks data and original image dimensions masks = masks.data.cpu().numpy() img_height, img_width = img_cv2.shape[:2] # Get bounding boxes in xyxy format boxes = results[0].boxes.xyxy.cpu().numpy() # Sort by y-coordinate of the top-left corner sorted_indices = np.argsort(boxes[:, 1]) masks = masks[sorted_indices] boxes = boxes[sorted_indices] for i, (mask, box) in enumerate(zip(masks, boxes)): # Scale the mask to original image size mask = cv2.resize(mask.squeeze(), (img_width, img_height), interpolation=cv2.INTER_LINEAR) mask = (mask > 0.5).astype(np.uint8) * 255 # Apply threshold # Convert mask to polygon contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if contours: # Get the largest contour largest_contour = max(contours, key=cv2.contourArea) simplified_polygon = simplify_polygons([largest_contour])[0] if simplified_polygon is not None: # Crop the image using the bounding box for text recognition x1, y1, x2, y2 = map(int, box) crop_img = img_cv2[y1:y2, x1:x2] crop_pil = Image.fromarray(cv2.cvtColor(crop_img, cv2.COLOR_BGR2RGB)) # Recognize text using PyLaia model predicted = predict('pylaia-samaritan_v1', crop_pil) texts.append(predicted[1]["text"]) # Convert polygon to list of points for display poly_points = simplified_polygon.reshape(-1, 2).astype(int).tolist() polygons.append(f"Line {i+1}: {poly_points}") # Draw polygon on the image cv2.polylines(img_cv2, [simplified_polygon.reshape(-1, 1, 2).astype(int)], True, (0, 255, 0), 2) # Convert image back to RGB for display in Streamlit img_result = cv2.cvtColor(img_cv2, cv2.COLOR_BGR2RGB) # Combine polygons and texts into a DataFrame for table display table_data = pd.DataFrame({"Polygons": polygons, "Recognized Text": texts}) return Image.fromarray(img_result), table_data def segment_and_recognize(image): segmented_image, table_data = process_image(image) return segmented_image, table_data # Streamlit app layout st.set_page_config(layout="wide") # Use full page width st.title("YOLOv11 Text Line Segmentation & PyLaia Text Recognition on CATMuS/medieval") # File uploader uploaded_image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"]) # Process the image if uploaded if uploaded_image is not None: image = Image.open(uploaded_image) if st.button("Segment and Recognize"): # Perform segmentation and recognition segmented_image, table_data = segment_and_recognize(image) # Layout: Image on the left, Table on the right col1, col2 = st.columns([2, 3]) # Adjust the ratio if needed with col1: st.image(segmented_image, caption="Segmented Image with Polygon Masks", use_container_width=True) with col2: st.table(table_data)