Spaces:
Sleeping
Sleeping
from typing import Tuple, Dict, List, Optional | |
import gradio as gr | |
import supervision as sv | |
import numpy as np | |
import cv2 | |
from huggingface_hub import hf_hub_download | |
from ultralytics import YOLO | |
# Define models | |
MODEL_OPTIONS = { | |
"YOLOv11-Nano": "medieval-yolo11n-seg.pt", | |
"YOLOv11-Small": "medieval-yolo11s-seg.pt", | |
"YOLOv11-Medium": "medieval-yolo11m-seg.pt", | |
"YOLOv11-Large": "medieval-yolo11l-seg.pt", | |
"YOLOv11-XLarge": "medieval-yolo11x-seg.pt", | |
"YOLOv11-Medium Zones": "medieval_zones-yolo11m-seg.pt", | |
"YOLOv11-Medium Lines": "medieval_lines-yolo11m-seg.pt", | |
"ms_yolo11m-seg4-YTG": "ms_yolo11m-seg4-YTG.pt", | |
"ms_yolo11m-seg5-swin_t": "ms_yolo11m-seg5-swin_t.pt", | |
"ms_yolo11x-seg2-swin_t": "ms_yolo11x-seg2-swin_t.pt", | |
"ms_yolo11m-seg6-convnext_tiny": "ms_yolo11m-seg6-convnext_tiny.pt", | |
"yolo11m-seg-gpt": "yolo11m-seg-gpt.pt", | |
"ms_yolo11x-seg3-swin_t-fpn": "ms_yolo11x-seg3-swin_t-fpn.pt", | |
"yolo11x-seg-gpt7": "yolo11x-seg-gpt7.pt" | |
} | |
# Dictionary to store loaded models | |
models: Dict[str, YOLO] = {} | |
# Load all models | |
for name, model_file in MODEL_OPTIONS.items(): | |
try: | |
model_path = hf_hub_download( | |
repo_id="johnlockejrr/medieval-manuscript-yolov11-seg", | |
filename=model_file | |
) | |
models[name] = YOLO(model_path) | |
except Exception as e: | |
print(f"Error loading model {name}: {str(e)}") | |
def simplify_polygons(polygons: List[np.ndarray], approx_level: float = 0.01) -> List[Optional[np.ndarray]]: | |
"""Simplify polygon contours using Douglas-Peucker algorithm. | |
Args: | |
polygons: List of polygon contours | |
approx_level: Approximation level (0-1), lower values mean more simplification | |
Returns: | |
List of simplified polygons (or None for invalid polygons) | |
""" | |
result = [] | |
for polygon in polygons: | |
if len(polygon) < 4: | |
result.append(None) | |
continue | |
perimeter = cv2.arcLength(polygon, True) | |
approx = cv2.approxPolyDP(polygon, approx_level * perimeter, True) | |
if len(approx) < 4: | |
result.append(None) | |
continue | |
result.append(approx.squeeze()) | |
return result | |
# Custom MaskAnnotator for outline-only masks with simplified polygons | |
class OutlineMaskAnnotator: | |
def __init__(self, color: tuple = (255, 0, 0), thickness: int = 2, simplify: bool = False): | |
self.color = color | |
self.thickness = thickness | |
self.simplify = simplify | |
def annotate(self, scene: np.ndarray, detections: sv.Detections) -> np.ndarray: | |
if detections.mask is None: | |
return scene | |
scene = scene.copy() | |
for mask in detections.mask: | |
contours, _ = cv2.findContours( | |
mask.astype(np.uint8), | |
cv2.RETR_EXTERNAL, | |
cv2.CHAIN_APPROX_SIMPLE | |
) | |
if self.simplify: | |
contours = simplify_polygons(contours) | |
contours = [c for c in contours if c is not None] | |
cv2.drawContours( | |
scene, | |
contours, | |
-1, | |
self.color, | |
self.thickness | |
) | |
return scene | |
# Create annotators with new settings | |
LABEL_ANNOTATOR = sv.LabelAnnotator( | |
text_color=sv.Color.BLACK, | |
text_scale=0.35, | |
text_thickness=1, | |
text_padding=2 | |
) | |
def detect_and_annotate( | |
image: np.ndarray, | |
model_name: str, | |
conf_threshold: float, | |
iou_threshold: float, | |
simplify_polygons_option: bool | |
) -> np.ndarray: | |
# Get the selected model | |
model = models[model_name] | |
# Perform inference | |
results = model.predict( | |
image, | |
conf=conf_threshold, | |
iou=iou_threshold | |
)[0] | |
# Convert results to supervision Detections | |
boxes = results.boxes.xyxy.cpu().numpy() | |
confidence = results.boxes.conf.cpu().numpy() | |
class_ids = results.boxes.cls.cpu().numpy().astype(int) | |
# Handle masks if they exist | |
masks = None | |
if results.masks is not None: | |
masks = results.masks.data.cpu().numpy() | |
# Convert from (N,H,W) to (H,W,N) for processing | |
masks = np.transpose(masks, (1, 2, 0)) | |
h, w = image.shape[:2] | |
resized_masks = [] | |
for i in range(masks.shape[-1]): | |
resized_mask = cv2.resize(masks[..., i], (w, h), interpolation=cv2.INTER_LINEAR) | |
resized_masks.append(resized_mask > 0.5) | |
masks = np.stack(resized_masks) if resized_masks else None | |
# Create Detections object | |
detections = sv.Detections( | |
xyxy=boxes, | |
confidence=confidence, | |
class_id=class_ids, | |
mask=masks | |
) | |
# Create labels with confidence scores | |
labels = [ | |
f"{results.names[class_id]} ({conf:.2f})" | |
for class_id, conf | |
in zip(class_ids, confidence) | |
] | |
# Create mask annotator based on the simplify option | |
mask_annotator = OutlineMaskAnnotator( | |
color=(255, 0, 0), | |
thickness=2, | |
simplify=simplify_polygons_option | |
) | |
# Annotate image | |
annotated_image = image.copy() | |
if masks is not None: | |
annotated_image = mask_annotator.annotate(scene=annotated_image, detections=detections) | |
annotated_image = LABEL_ANNOTATOR.annotate(scene=annotated_image, detections=detections, labels=labels) | |
return annotated_image | |
# Rest of the Gradio interface remains exactly the same | |
with gr.Blocks() as demo: | |
gr.Markdown("# Medieval Manuscript Segmentation with YOLO") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image( | |
label="Input Image", | |
type='numpy' | |
) | |
with gr.Accordion("Detection Settings", open=True): | |
model_selector = gr.Dropdown( | |
choices=list(MODEL_OPTIONS.keys()), | |
value=list(MODEL_OPTIONS.keys())[0], | |
label="Model", | |
info="Select YOLO model variant" | |
) | |
with gr.Row(): | |
conf_threshold = gr.Slider( | |
label="Confidence Threshold", | |
minimum=0.0, | |
maximum=1.0, | |
step=0.05, | |
value=0.25, | |
) | |
iou_threshold = gr.Slider( | |
label="IoU Threshold", | |
minimum=0.0, | |
maximum=1.0, | |
step=0.05, | |
value=0.45, | |
info="Decrease for stricter detection, increase for more overlapping masks" | |
) | |
simplify_polygons_option = gr.Checkbox( | |
label="Simplify Polygons", | |
value=False, | |
info="Simplify polygon contours for cleaner outlines" | |
) | |
with gr.Row(): | |
clear_btn = gr.Button("Clear") | |
detect_btn = gr.Button("Detect", variant="primary") | |
with gr.Column(): | |
output_image = gr.Image( | |
label="Detection Result", | |
type='numpy' | |
) | |
def process_image( | |
image: np.ndarray, | |
model_name: str, | |
conf_threshold: float, | |
iou_threshold: float, | |
simplify_polygons_option: bool | |
) -> Tuple[np.ndarray, np.ndarray]: | |
if image is None: | |
return None, None | |
annotated_image = detect_and_annotate(image, model_name, conf_threshold, iou_threshold, simplify_polygons_option) | |
return image, annotated_image | |
def clear(): | |
return None, None | |
detect_btn.click( | |
process_image, | |
inputs=[input_image, model_selector, conf_threshold, iou_threshold, simplify_polygons_option], | |
outputs=[input_image, output_image] | |
) | |
clear_btn.click( | |
clear, | |
inputs=None, | |
outputs=[input_image, output_image] | |
) | |
if __name__ == "__main__": | |
demo.launch(debug=True, show_error=True) |