|
from typing import Tuple, Dict, List, Optional |
|
import streamlit as st |
|
import supervision as sv |
|
import numpy as np |
|
import cv2 |
|
from huggingface_hub import hf_hub_download |
|
from ultralytics import YOLO |
|
from PIL import Image |
|
import torch |
|
|
|
torch.cuda.is_available = lambda: False |
|
|
|
|
|
st.set_page_config( |
|
page_title="Medieval Manuscript Segmentation", |
|
page_icon="π", |
|
layout="wide" |
|
) |
|
|
|
|
|
MODEL_OPTIONS = { |
|
"YOLOv11-Nano": "medieval-yolo11n-seg.pt", |
|
"YOLOv11-Small": "medieval-yolo11s-seg.pt", |
|
"YOLOv11-Medium": "medieval-yolo11m-seg.pt", |
|
"YOLOv11-Large": "medieval-yolo11l-seg.pt", |
|
"YOLOv11-XLarge": "medieval-yolo11x-seg.pt", |
|
"YOLOv11-Medium Zones": "medieval_zones-yolo11m-seg.pt", |
|
"YOLOv11-Medium Lines": "medieval_lines-yolo11m-seg.pt", |
|
"ms_yolo11m-seg4-YTG": "ms_yolo11m-seg4-YTG.pt", |
|
"ms_yolo11m-seg5-swin_t": "ms_yolo11m-seg5-swin_t.pt", |
|
"ms_yolo11x-seg2-swin_t": "ms_yolo11x-seg2-swin_t.pt", |
|
"ms_yolo11m-seg6-convnext_tiny": "ms_yolo11m-seg6-convnext_tiny.pt", |
|
"yolo11m-seg-gpt": "yolo11m-seg-gpt.pt", |
|
"ms_yolo11x-seg3-swin_t-fpn": "ms_yolo11x-seg3-swin_t-fpn.pt", |
|
"yolo11x-seg-gpt7": "yolo11x-seg-gpt7.pt" |
|
} |
|
|
|
@st.cache_resource |
|
def load_models(): |
|
"""Load all models and cache them.""" |
|
models: Dict[str, YOLO] = {} |
|
for name, model_file in MODEL_OPTIONS.items(): |
|
try: |
|
model_path = hf_hub_download( |
|
repo_id="johnlockejrr/medieval-manuscript-yolov11-seg", |
|
filename=model_file |
|
) |
|
models[name] = YOLO(model_path) |
|
except Exception as e: |
|
st.warning(f"Error loading model {name}: {str(e)}") |
|
return models |
|
|
|
def simplify_polygons(polygons: List[np.ndarray], approx_level: float = 0.01) -> List[Optional[np.ndarray]]: |
|
"""Simplify polygon contours using Douglas-Peucker algorithm. |
|
|
|
Args: |
|
polygons: List of polygon contours |
|
approx_level: Approximation level (0-1), lower values mean more simplification |
|
|
|
Returns: |
|
List of simplified polygons (or None for invalid polygons) |
|
""" |
|
result = [] |
|
for polygon in polygons: |
|
if len(polygon) < 4: |
|
result.append(None) |
|
continue |
|
|
|
perimeter = cv2.arcLength(polygon, True) |
|
approx = cv2.approxPolyDP(polygon, approx_level * perimeter, True) |
|
if len(approx) < 4: |
|
result.append(None) |
|
continue |
|
|
|
result.append(approx.squeeze()) |
|
return result |
|
|
|
|
|
class OutlineMaskAnnotator: |
|
def __init__(self, color: tuple = (255, 0, 0), thickness: int = 2, simplify: bool = False): |
|
self.color = color |
|
self.thickness = thickness |
|
self.simplify = simplify |
|
|
|
def annotate(self, scene: np.ndarray, detections: sv.Detections) -> np.ndarray: |
|
if detections.mask is None: |
|
return scene |
|
|
|
scene = scene.copy() |
|
for mask in detections.mask: |
|
contours, _ = cv2.findContours( |
|
mask.astype(np.uint8), |
|
cv2.RETR_EXTERNAL, |
|
cv2.CHAIN_APPROX_SIMPLE |
|
) |
|
if self.simplify: |
|
contours = simplify_polygons(contours) |
|
contours = [c for c in contours if c is not None] |
|
|
|
cv2.drawContours( |
|
scene, |
|
contours, |
|
-1, |
|
self.color, |
|
self.thickness |
|
) |
|
return scene |
|
|
|
|
|
LABEL_ANNOTATOR = sv.LabelAnnotator( |
|
text_color=sv.Color.BLACK, |
|
text_scale=0.35, |
|
text_thickness=1, |
|
text_padding=2 |
|
) |
|
|
|
def detect_and_annotate( |
|
image: np.ndarray, |
|
model_name: str, |
|
conf_threshold: float, |
|
iou_threshold: float, |
|
simplify_polygons_option: bool |
|
) -> np.ndarray: |
|
|
|
model = models[model_name] |
|
|
|
|
|
results = model.predict( |
|
image, |
|
conf=conf_threshold, |
|
iou=iou_threshold |
|
)[0] |
|
|
|
|
|
boxes = results.boxes.xyxy.cpu().numpy() |
|
confidence = results.boxes.conf.cpu().numpy() |
|
class_ids = results.boxes.cls.cpu().numpy().astype(int) |
|
|
|
|
|
masks = None |
|
if results.masks is not None: |
|
masks = results.masks.data.cpu().numpy() |
|
|
|
masks = np.transpose(masks, (1, 2, 0)) |
|
h, w = image.shape[:2] |
|
resized_masks = [] |
|
for i in range(masks.shape[-1]): |
|
resized_mask = cv2.resize(masks[..., i], (w, h), interpolation=cv2.INTER_LINEAR) |
|
resized_masks.append(resized_mask > 0.5) |
|
masks = np.stack(resized_masks) if resized_masks else None |
|
|
|
|
|
detections = sv.Detections( |
|
xyxy=boxes, |
|
confidence=confidence, |
|
class_id=class_ids, |
|
mask=masks |
|
) |
|
|
|
|
|
labels = [ |
|
f"{results.names[class_id]} ({conf:.2f})" |
|
for class_id, conf |
|
in zip(class_ids, confidence) |
|
] |
|
|
|
|
|
mask_annotator = OutlineMaskAnnotator( |
|
color=(255, 0, 0), |
|
thickness=2, |
|
simplify=simplify_polygons_option |
|
) |
|
|
|
|
|
annotated_image = image.copy() |
|
if masks is not None: |
|
annotated_image = mask_annotator.annotate(scene=annotated_image, detections=detections) |
|
annotated_image = LABEL_ANNOTATOR.annotate(scene=annotated_image, detections=detections, labels=labels) |
|
|
|
return annotated_image |
|
|
|
|
|
models = load_models() |
|
|
|
|
|
st.title("Medieval Manuscript Segmentation with YOLO") |
|
|
|
|
|
with st.sidebar: |
|
st.header("Detection Settings") |
|
|
|
model_name = st.selectbox( |
|
"Model", |
|
options=list(MODEL_OPTIONS.keys()), |
|
index=0, |
|
help="Select YOLO model variant" |
|
) |
|
|
|
conf_threshold = st.slider( |
|
"Confidence Threshold", |
|
min_value=0.0, |
|
max_value=1.0, |
|
value=0.25, |
|
step=0.05, |
|
help="Minimum confidence score for detections" |
|
) |
|
|
|
iou_threshold = st.slider( |
|
"IoU Threshold", |
|
min_value=0.0, |
|
max_value=1.0, |
|
value=0.45, |
|
step=0.05, |
|
help="Decrease for stricter detection, increase for more overlapping masks" |
|
) |
|
|
|
simplify_polygons_option = st.checkbox( |
|
"Simplify Polygons", |
|
value=False, |
|
help="Simplify polygon contours for cleaner outlines" |
|
) |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
|
with col1: |
|
st.subheader("Input Image") |
|
uploaded_file = st.file_uploader( |
|
"Upload an image", |
|
type=["jpg", "jpeg", "png"], |
|
key="file_uploader" |
|
) |
|
|
|
if uploaded_file is not None: |
|
image = np.array(Image.open(uploaded_file)) |
|
st.image(image, caption="Uploaded Image", use_container_width=True) |
|
else: |
|
image = None |
|
st.info("Please upload an image file") |
|
|
|
with col2: |
|
st.subheader("Detection Result") |
|
|
|
if st.button("Detect", type="primary") and image is not None: |
|
with st.spinner("Processing image..."): |
|
annotated_image = detect_and_annotate( |
|
image, |
|
model_name, |
|
conf_threshold, |
|
iou_threshold, |
|
simplify_polygons_option |
|
) |
|
st.image(annotated_image, caption="Detection Result", use_container_width=True) |
|
elif image is None: |
|
st.warning("Please upload an image first") |
|
else: |
|
st.info("Click the Detect button to process the image") |