from typing import Tuple, Dict import gradio as gr import supervision as sv import numpy as np import cv2 from huggingface_hub import hf_hub_download from ultralytics import YOLO # Define models MODEL_OPTIONS = { "YOLOv11-Small": "medieval-yolo11s-seg.pt" } # Dictionary to store loaded models models: Dict[str, YOLO] = {} # Load all models for name, model_file in MODEL_OPTIONS.items(): try: model_path = hf_hub_download( repo_id="johnlockejrr/medieval-manuscript-yolov11-seg", filename=model_file ) models[name] = YOLO(model_path) except Exception as e: print(f"Error loading model {name}: {str(e)}") # Create annotators LABEL_ANNOTATOR = sv.LabelAnnotator(text_color=sv.Color.BLACK) MASK_ANNOTATOR = sv.MaskAnnotator() def process_masks(masks: np.ndarray, target_shape: Tuple[int, int]) -> np.ndarray: """Process and resize masks to target shape""" if masks is None: return None processed_masks = [] h, w = target_shape for mask in masks: # Resize mask to target dimensions resized_mask = cv2.resize(mask.astype(float), (w, h), interpolation=cv2.INTER_LINEAR) # Threshold to create binary mask processed_masks.append(resized_mask > 0.5) return np.array(processed_masks) def detect_and_annotate( image: np.ndarray, model_name: str, conf_threshold: float, iou_threshold: float ) -> np.ndarray: try: if image is None: return None model = models.get(model_name) if model is None: raise ValueError(f"Model {model_name} not loaded") # Perform inference results = model.predict( image, conf=conf_threshold, iou=iou_threshold )[0] # Convert results to supervision Detections boxes = results.boxes.xyxy.cpu().numpy() confidence = results.boxes.conf.cpu().numpy() class_ids = results.boxes.cls.cpu().numpy().astype(int) # Process masks masks = None if results.masks is not None: masks = results.masks.data.cpu().numpy() print(f"Original mask shape: {masks.shape}") # Debug # Fix the shape mismatch - should be (num_masks, H, W) if masks.shape[0] != len(boxes): masks = np.transpose(masks, (2, 0, 1)) # Convert from (H,W,N) to (N,H,W) print(f"Processed mask shape: {masks.shape}") # Debug # Resize masks to original image dimensions h, w = image.shape[:2] resized_masks = [] for mask in masks: resized_mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_LINEAR) resized_masks.append(resized_mask) masks = np.array(resized_masks) masks = masks > 0.5 # Convert to boolean # Create Detections object detections = sv.Detections( xyxy=boxes, confidence=confidence, class_id=class_ids, mask=masks ) # Create labels with confidence scores labels = [ f"{results.names[class_id]} ({conf:.2f})" for class_id, conf in zip(class_ids, confidence) ] # Annotate image annotated_image = image.copy() if masks is not None: annotated_image = MASK_ANNOTATOR.annotate( scene=annotated_image, detections=detections ) annotated_image = LABEL_ANNOTATOR.annotate( scene=annotated_image, detections=detections, labels=labels ) return annotated_image except Exception as e: print(f"Error during detection: {str(e)}") return image # Create Gradio interface with gr.Blocks() as demo: gr.Markdown("# Medieval Manuscript Segmentation with YOLO") with gr.Row(): with gr.Column(): input_image = gr.Image(label="Input Image", type='numpy') with gr.Accordion("Detection Settings", open=True): model_selector = gr.Dropdown( choices=list(MODEL_OPTIONS.keys()), value=list(MODEL_OPTIONS.keys())[0], label="Model" ) conf_threshold = gr.Slider( label="Confidence Threshold", minimum=0.0, maximum=1.0, step=0.05, value=0.25 ) iou_threshold = gr.Slider( label="IoU Threshold", minimum=0.0, maximum=1.0, step=0.05, value=0.45 ) detect_btn = gr.Button("Detect", variant="primary") clear_btn = gr.Button("Clear") with gr.Column(): output_image = gr.Image(label="Segmentation Result", type='numpy') def process_image(image, model_name, conf_threshold, iou_threshold): try: if image is None: return None, None annotated_image = detect_and_annotate(image, model_name, conf_threshold, iou_threshold) return image, annotated_image except Exception as e: print(f"Error in process_image: {str(e)}") return image, image # Fallback to original image def clear(): return None, None detect_btn.click( process_image, inputs=[input_image, model_selector, conf_threshold, iou_threshold], outputs=[input_image, output_image] ) clear_btn.click( clear, inputs=None, outputs=[input_image, output_image] ) if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7860, show_error=True, debug=True )