File size: 7,794 Bytes
630c7a9 d5d8604 630c7a9 d5d8604 630c7a9 d5d8604 630c7a9 d5d8604 630c7a9 d5d8604 630c7a9 d5d8604 630c7a9 d5d8604 630c7a9 d5d8604 630c7a9 d5d8604 630c7a9 d5d8604 630c7a9 d5d8604 630c7a9 d5d8604 630c7a9 d5d8604 630c7a9 d5d8604 630c7a9 d5d8604 630c7a9 d5d8604 630c7a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 |
import streamlit as st
import warnings
warnings.simplefilter("ignore", UserWarning)
from uuid import uuid4
from laia.scripts.htr.decode_ctc import run as decode
from laia.common.arguments import CommonArgs, DataArgs, TrainerArgs, DecodeArgs
import sys
from tempfile import NamedTemporaryFile, mkdtemp
from pathlib import Path
from contextlib import redirect_stdout
import re
from PIL import Image
from bidi.algorithm import get_display
import multiprocessing
from ultralytics import YOLO
import cv2
import numpy as np
import pandas as pd
import logging
from typing import List, Optional
# Configure logging
logging.getLogger("lightning.pytorch").setLevel(logging.ERROR)
# Load YOLOv8 model
model = YOLO('model.pt')
images = Path(mkdtemp())
DEFAULT_HEIGHT = 128
TEXT_DIRECTION = "LTR"
NUM_WORKERS = multiprocessing.cpu_count()
# Regex pattern for extracting results
IMAGE_ID_PATTERN = r"(?P<image_id>[-a-z0-9]{36})"
CONFIDENCE_PATTERN = r"(?P<confidence>[0-9.]+)" # For line
TEXT_PATTERN = r"\s*(?P<text>.*)\s*"
LINE_PREDICTION = re.compile(rf"{IMAGE_ID_PATTERN} {CONFIDENCE_PATTERN} {TEXT_PATTERN}")
def get_width(image, height=DEFAULT_HEIGHT):
aspect_ratio = image.width / image.height
return height * aspect_ratio
def simplify_polygons(polygons: List[np.ndarray], approx_level: float = 0.01) -> List[Optional[np.ndarray]]:
"""Simplify polygon contours using Douglas-Peucker algorithm.
Args:
polygons: List of polygon contours
approx_level: Approximation level (0-1), lower values mean more simplification
Returns:
List of simplified polygons (or None for invalid polygons)
"""
result = []
for polygon in polygons:
if len(polygon) < 4:
result.append(None)
continue
perimeter = cv2.arcLength(polygon, True)
approx = cv2.approxPolyDP(polygon, approx_level * perimeter, True)
if len(approx) < 4:
result.append(None)
continue
result.append(approx.squeeze())
return result
def predict(model_name, input_img):
model_dir = 'catmus-medieval'
temperature = 2.0
batch_size = 1
weights_path = f"{model_dir}/weights.ckpt"
syms_path = f"{model_dir}/syms.txt"
language_model_params = {"language_model_weight": 1.0}
use_language_model = True
if use_language_model:
language_model_params.update({
"language_model_path": f"{model_dir}/language_model.binary",
"lexicon_path": f"{model_dir}/lexicon.txt",
"tokens_path": f"{model_dir}/tokens.txt",
})
common_args = CommonArgs(
checkpoint="weights.ckpt",
train_path=f"{model_dir}",
experiment_dirname="",
)
data_args = DataArgs(batch_size=batch_size, color_mode="L")
trainer_args = TrainerArgs(progress_bar_refresh_rate=0)
decode_args = DecodeArgs(
include_img_ids=True,
join_string="",
convert_spaces=True,
print_line_confidence_scores=True,
print_word_confidence_scores=False,
temperature=temperature,
use_language_model=use_language_model,
**language_model_params,
)
with NamedTemporaryFile() as pred_stdout, NamedTemporaryFile() as img_list:
image_id = uuid4()
input_img = input_img.resize((int(get_width(input_img)), DEFAULT_HEIGHT))
input_img.save(f"{images}/{image_id}.jpg")
Path(img_list.name).write_text("\n".join([str(image_id)]))
with redirect_stdout(open(pred_stdout.name, mode="w")):
decode(
syms=str(syms_path),
img_list=img_list.name,
img_dirs=[str(images)],
common=common_args,
data=data_args,
trainer=trainer_args,
decode=decode_args,
num_workers=1,
)
sys.stdout.flush()
predictions = Path(pred_stdout.name).read_text().strip().splitlines()
_, score, text = LINE_PREDICTION.match(predictions[0]).groups()
if TEXT_DIRECTION == "RTL":
return input_img, {"text": get_display(text), "score": score}
else:
return input_img, {"text": text, "score": score}
def process_image(image):
# Perform inference on an image, select textline only
results = model(image, classes=0)
img_cv2 = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
masks = results[0].masks
polygons = []
texts = []
if masks is not None:
# Get masks data and original image dimensions
masks = masks.data.cpu().numpy()
img_height, img_width = img_cv2.shape[:2]
# Get bounding boxes in xyxy format
boxes = results[0].boxes.xyxy.cpu().numpy()
# Sort by y-coordinate of the top-left corner
sorted_indices = np.argsort(boxes[:, 1])
masks = masks[sorted_indices]
boxes = boxes[sorted_indices]
for i, (mask, box) in enumerate(zip(masks, boxes)):
# Scale the mask to original image size
mask = cv2.resize(mask.squeeze(), (img_width, img_height), interpolation=cv2.INTER_LINEAR)
mask = (mask > 0.5).astype(np.uint8) * 255 # Apply threshold
# Convert mask to polygon
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if contours:
# Get the largest contour
largest_contour = max(contours, key=cv2.contourArea)
simplified_polygon = simplify_polygons([largest_contour])[0]
if simplified_polygon is not None:
# Crop the image using the bounding box for text recognition
x1, y1, x2, y2 = map(int, box)
crop_img = img_cv2[y1:y2, x1:x2]
crop_pil = Image.fromarray(cv2.cvtColor(crop_img, cv2.COLOR_BGR2RGB))
# Recognize text using PyLaia model
predicted = predict('pylaia-samaritan_v1', crop_pil)
texts.append(predicted[1]["text"])
# Convert polygon to list of points for display
poly_points = simplified_polygon.reshape(-1, 2).astype(int).tolist()
polygons.append(f"Line {i+1}: {poly_points}")
# Draw polygon on the image
cv2.polylines(img_cv2, [simplified_polygon.reshape(-1, 1, 2).astype(int)],
True, (0, 255, 0), 2)
# Convert image back to RGB for display in Streamlit
img_result = cv2.cvtColor(img_cv2, cv2.COLOR_BGR2RGB)
# Combine polygons and texts into a DataFrame for table display
table_data = pd.DataFrame({"Polygons": polygons, "Recognized Text": texts})
return Image.fromarray(img_result), table_data
def segment_and_recognize(image):
segmented_image, table_data = process_image(image)
return segmented_image, table_data
# Streamlit app layout
st.set_page_config(layout="wide") # Use full page width
st.title("YOLOv11 Text Line Segmentation & PyLaia Text Recognition on CATMuS/medieval")
# File uploader
uploaded_image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
# Process the image if uploaded
if uploaded_image is not None:
image = Image.open(uploaded_image)
if st.button("Segment and Recognize"):
# Perform segmentation and recognition
segmented_image, table_data = segment_and_recognize(image)
# Layout: Image on the left, Table on the right
col1, col2 = st.columns([2, 3]) # Adjust the ratio if needed
with col1:
st.image(segmented_image, caption="Segmented Image with Polygon Masks", use_container_width=True)
with col2:
st.table(table_data)
|