medieval-yolo11-seg / app.py.bak
johnlockejrr's picture
Upload app.py.bak
29358f1 verified
from typing import Tuple, Dict
import gradio as gr
import supervision as sv
import numpy as np
import cv2
from huggingface_hub import hf_hub_download
from ultralytics import YOLO
# Define models
MODEL_OPTIONS = {
"YOLOv11-Small": "medieval-yolo11s-seg.pt"
}
# Dictionary to store loaded models
models: Dict[str, YOLO] = {}
# Load all models
for name, model_file in MODEL_OPTIONS.items():
try:
model_path = hf_hub_download(
repo_id="johnlockejrr/medieval-manuscript-yolov11-seg",
filename=model_file
)
models[name] = YOLO(model_path)
except Exception as e:
print(f"Error loading model {name}: {str(e)}")
# Create annotators
LABEL_ANNOTATOR = sv.LabelAnnotator(text_color=sv.Color.BLACK)
MASK_ANNOTATOR = sv.MaskAnnotator()
def process_masks(masks: np.ndarray, target_shape: Tuple[int, int]) -> np.ndarray:
"""Process and resize masks to target shape"""
if masks is None:
return None
processed_masks = []
h, w = target_shape
for mask in masks:
# Resize mask to target dimensions
resized_mask = cv2.resize(mask.astype(float), (w, h), interpolation=cv2.INTER_LINEAR)
# Threshold to create binary mask
processed_masks.append(resized_mask > 0.5)
return np.array(processed_masks)
def detect_and_annotate(
image: np.ndarray,
model_name: str,
conf_threshold: float,
iou_threshold: float
) -> np.ndarray:
try:
if image is None:
return None
model = models.get(model_name)
if model is None:
raise ValueError(f"Model {model_name} not loaded")
# Perform inference
results = model.predict(
image,
conf=conf_threshold,
iou=iou_threshold
)[0]
# Convert results to supervision Detections
boxes = results.boxes.xyxy.cpu().numpy()
confidence = results.boxes.conf.cpu().numpy()
class_ids = results.boxes.cls.cpu().numpy().astype(int)
# Process masks
masks = None
if results.masks is not None:
masks = results.masks.data.cpu().numpy()
print(f"Original mask shape: {masks.shape}") # Debug
# Fix the shape mismatch - should be (num_masks, H, W)
if masks.shape[0] != len(boxes):
masks = np.transpose(masks, (2, 0, 1)) # Convert from (H,W,N) to (N,H,W)
print(f"Processed mask shape: {masks.shape}") # Debug
# Resize masks to original image dimensions
h, w = image.shape[:2]
resized_masks = []
for mask in masks:
resized_mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_LINEAR)
resized_masks.append(resized_mask)
masks = np.array(resized_masks)
masks = masks > 0.5 # Convert to boolean
# Create Detections object
detections = sv.Detections(
xyxy=boxes,
confidence=confidence,
class_id=class_ids,
mask=masks
)
# Create labels with confidence scores
labels = [
f"{results.names[class_id]} ({conf:.2f})"
for class_id, conf in zip(class_ids, confidence)
]
# Annotate image
annotated_image = image.copy()
if masks is not None:
annotated_image = MASK_ANNOTATOR.annotate(
scene=annotated_image,
detections=detections
)
annotated_image = LABEL_ANNOTATOR.annotate(
scene=annotated_image,
detections=detections,
labels=labels
)
return annotated_image
except Exception as e:
print(f"Error during detection: {str(e)}")
return image
# Create Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Medieval Manuscript Segmentation with YOLO")
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image", type='numpy')
with gr.Accordion("Detection Settings", open=True):
model_selector = gr.Dropdown(
choices=list(MODEL_OPTIONS.keys()),
value=list(MODEL_OPTIONS.keys())[0],
label="Model"
)
conf_threshold = gr.Slider(
label="Confidence Threshold",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.25
)
iou_threshold = gr.Slider(
label="IoU Threshold",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.45
)
detect_btn = gr.Button("Detect", variant="primary")
clear_btn = gr.Button("Clear")
with gr.Column():
output_image = gr.Image(label="Segmentation Result", type='numpy')
def process_image(image, model_name, conf_threshold, iou_threshold):
try:
if image is None:
return None, None
annotated_image = detect_and_annotate(image, model_name, conf_threshold, iou_threshold)
return image, annotated_image
except Exception as e:
print(f"Error in process_image: {str(e)}")
return image, image # Fallback to original image
def clear():
return None, None
detect_btn.click(
process_image,
inputs=[input_image, model_selector, conf_threshold, iou_threshold],
outputs=[input_image, output_image]
)
clear_btn.click(
clear,
inputs=None,
outputs=[input_image, output_image]
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True,
debug=True
)