Spaces:
Sleeping
Sleeping
from typing import Tuple, Dict | |
import gradio as gr | |
import supervision as sv | |
import numpy as np | |
import cv2 | |
from huggingface_hub import hf_hub_download | |
from ultralytics import YOLO | |
# Define models | |
MODEL_OPTIONS = { | |
"YOLOv11-Small": "medieval-yolo11s-seg.pt" | |
} | |
# Dictionary to store loaded models | |
models: Dict[str, YOLO] = {} | |
# Load all models | |
for name, model_file in MODEL_OPTIONS.items(): | |
try: | |
model_path = hf_hub_download( | |
repo_id="johnlockejrr/medieval-manuscript-yolov11-seg", | |
filename=model_file | |
) | |
models[name] = YOLO(model_path) | |
except Exception as e: | |
print(f"Error loading model {name}: {str(e)}") | |
# Create annotators | |
LABEL_ANNOTATOR = sv.LabelAnnotator(text_color=sv.Color.BLACK) | |
MASK_ANNOTATOR = sv.MaskAnnotator() | |
def process_masks(masks: np.ndarray, target_shape: Tuple[int, int]) -> np.ndarray: | |
"""Process and resize masks to target shape""" | |
if masks is None: | |
return None | |
processed_masks = [] | |
h, w = target_shape | |
for mask in masks: | |
# Resize mask to target dimensions | |
resized_mask = cv2.resize(mask.astype(float), (w, h), interpolation=cv2.INTER_LINEAR) | |
# Threshold to create binary mask | |
processed_masks.append(resized_mask > 0.5) | |
return np.array(processed_masks) | |
def detect_and_annotate( | |
image: np.ndarray, | |
model_name: str, | |
conf_threshold: float, | |
iou_threshold: float | |
) -> np.ndarray: | |
try: | |
if image is None: | |
return None | |
model = models.get(model_name) | |
if model is None: | |
raise ValueError(f"Model {model_name} not loaded") | |
# Perform inference | |
results = model.predict( | |
image, | |
conf=conf_threshold, | |
iou=iou_threshold | |
)[0] | |
# Convert results to supervision Detections | |
boxes = results.boxes.xyxy.cpu().numpy() | |
confidence = results.boxes.conf.cpu().numpy() | |
class_ids = results.boxes.cls.cpu().numpy().astype(int) | |
# Process masks | |
masks = None | |
if results.masks is not None: | |
masks = results.masks.data.cpu().numpy() | |
print(f"Original mask shape: {masks.shape}") # Debug | |
# Fix the shape mismatch - should be (num_masks, H, W) | |
if masks.shape[0] != len(boxes): | |
masks = np.transpose(masks, (2, 0, 1)) # Convert from (H,W,N) to (N,H,W) | |
print(f"Processed mask shape: {masks.shape}") # Debug | |
# Resize masks to original image dimensions | |
h, w = image.shape[:2] | |
resized_masks = [] | |
for mask in masks: | |
resized_mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_LINEAR) | |
resized_masks.append(resized_mask) | |
masks = np.array(resized_masks) | |
masks = masks > 0.5 # Convert to boolean | |
# Create Detections object | |
detections = sv.Detections( | |
xyxy=boxes, | |
confidence=confidence, | |
class_id=class_ids, | |
mask=masks | |
) | |
# Create labels with confidence scores | |
labels = [ | |
f"{results.names[class_id]} ({conf:.2f})" | |
for class_id, conf in zip(class_ids, confidence) | |
] | |
# Annotate image | |
annotated_image = image.copy() | |
if masks is not None: | |
annotated_image = MASK_ANNOTATOR.annotate( | |
scene=annotated_image, | |
detections=detections | |
) | |
annotated_image = LABEL_ANNOTATOR.annotate( | |
scene=annotated_image, | |
detections=detections, | |
labels=labels | |
) | |
return annotated_image | |
except Exception as e: | |
print(f"Error during detection: {str(e)}") | |
return image | |
# Create Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Medieval Manuscript Segmentation with YOLO") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(label="Input Image", type='numpy') | |
with gr.Accordion("Detection Settings", open=True): | |
model_selector = gr.Dropdown( | |
choices=list(MODEL_OPTIONS.keys()), | |
value=list(MODEL_OPTIONS.keys())[0], | |
label="Model" | |
) | |
conf_threshold = gr.Slider( | |
label="Confidence Threshold", | |
minimum=0.0, | |
maximum=1.0, | |
step=0.05, | |
value=0.25 | |
) | |
iou_threshold = gr.Slider( | |
label="IoU Threshold", | |
minimum=0.0, | |
maximum=1.0, | |
step=0.05, | |
value=0.45 | |
) | |
detect_btn = gr.Button("Detect", variant="primary") | |
clear_btn = gr.Button("Clear") | |
with gr.Column(): | |
output_image = gr.Image(label="Segmentation Result", type='numpy') | |
def process_image(image, model_name, conf_threshold, iou_threshold): | |
try: | |
if image is None: | |
return None, None | |
annotated_image = detect_and_annotate(image, model_name, conf_threshold, iou_threshold) | |
return image, annotated_image | |
except Exception as e: | |
print(f"Error in process_image: {str(e)}") | |
return image, image # Fallback to original image | |
def clear(): | |
return None, None | |
detect_btn.click( | |
process_image, | |
inputs=[input_image, model_selector, conf_threshold, iou_threshold], | |
outputs=[input_image, output_image] | |
) | |
clear_btn.click( | |
clear, | |
inputs=None, | |
outputs=[input_image, output_image] | |
) | |
if __name__ == "__main__": | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
show_error=True, | |
debug=True | |
) |