File size: 6,050 Bytes
29358f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
from typing import Tuple, Dict
import gradio as gr
import supervision as sv
import numpy as np
import cv2
from huggingface_hub import hf_hub_download
from ultralytics import YOLO

# Define models
MODEL_OPTIONS = {
    "YOLOv11-Small": "medieval-yolo11s-seg.pt"
}

# Dictionary to store loaded models
models: Dict[str, YOLO] = {}

# Load all models
for name, model_file in MODEL_OPTIONS.items():
    try:
        model_path = hf_hub_download(
            repo_id="johnlockejrr/medieval-manuscript-yolov11-seg",
            filename=model_file
        )
        models[name] = YOLO(model_path)
    except Exception as e:
        print(f"Error loading model {name}: {str(e)}")

# Create annotators
LABEL_ANNOTATOR = sv.LabelAnnotator(text_color=sv.Color.BLACK)
MASK_ANNOTATOR = sv.MaskAnnotator()

def process_masks(masks: np.ndarray, target_shape: Tuple[int, int]) -> np.ndarray:
    """Process and resize masks to target shape"""
    if masks is None:
        return None
        
    processed_masks = []
    h, w = target_shape
    for mask in masks:
        # Resize mask to target dimensions
        resized_mask = cv2.resize(mask.astype(float), (w, h), interpolation=cv2.INTER_LINEAR)
        # Threshold to create binary mask
        processed_masks.append(resized_mask > 0.5)
    
    return np.array(processed_masks)

def detect_and_annotate(
    image: np.ndarray,
    model_name: str,
    conf_threshold: float,
    iou_threshold: float
) -> np.ndarray:
    try:
        if image is None:
            return None
            
        model = models.get(model_name)
        if model is None:
            raise ValueError(f"Model {model_name} not loaded")
        
        # Perform inference
        results = model.predict(
            image,
            conf=conf_threshold,
            iou=iou_threshold
        )[0]
        
        # Convert results to supervision Detections
        boxes = results.boxes.xyxy.cpu().numpy()
        confidence = results.boxes.conf.cpu().numpy()
        class_ids = results.boxes.cls.cpu().numpy().astype(int)
        
        # Process masks
        masks = None
        if results.masks is not None:
            masks = results.masks.data.cpu().numpy()
            print(f"Original mask shape: {masks.shape}")  # Debug
            
            # Fix the shape mismatch - should be (num_masks, H, W)
            if masks.shape[0] != len(boxes):
                masks = np.transpose(masks, (2, 0, 1))  # Convert from (H,W,N) to (N,H,W)
            
            print(f"Processed mask shape: {masks.shape}")  # Debug
            
            # Resize masks to original image dimensions
            h, w = image.shape[:2]
            resized_masks = []
            for mask in masks:
                resized_mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_LINEAR)
                resized_masks.append(resized_mask)
            masks = np.array(resized_masks)
            masks = masks > 0.5  # Convert to boolean
            
        # Create Detections object
        detections = sv.Detections(
            xyxy=boxes,
            confidence=confidence,
            class_id=class_ids,
            mask=masks
        )
        
        # Create labels with confidence scores
        labels = [
            f"{results.names[class_id]} ({conf:.2f})"
            for class_id, conf in zip(class_ids, confidence)
        ]

        # Annotate image
        annotated_image = image.copy()
        if masks is not None:
            annotated_image = MASK_ANNOTATOR.annotate(
                scene=annotated_image, 
                detections=detections
            )
        annotated_image = LABEL_ANNOTATOR.annotate(
            scene=annotated_image,
            detections=detections,
            labels=labels
        )
        
        return annotated_image
        
    except Exception as e:
        print(f"Error during detection: {str(e)}")
        return image

# Create Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Medieval Manuscript Segmentation with YOLO")
    
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(label="Input Image", type='numpy')
            with gr.Accordion("Detection Settings", open=True):
                model_selector = gr.Dropdown(
                    choices=list(MODEL_OPTIONS.keys()),
                    value=list(MODEL_OPTIONS.keys())[0],
                    label="Model"
                )
                conf_threshold = gr.Slider(
                    label="Confidence Threshold",
                    minimum=0.0,
                    maximum=1.0,
                    step=0.05,
                    value=0.25
                )
                iou_threshold = gr.Slider(
                    label="IoU Threshold",
                    minimum=0.0,
                    maximum=1.0,
                    step=0.05,
                    value=0.45
                )
            detect_btn = gr.Button("Detect", variant="primary")
            clear_btn = gr.Button("Clear")
                
        with gr.Column():
            output_image = gr.Image(label="Segmentation Result", type='numpy')

    def process_image(image, model_name, conf_threshold, iou_threshold):
        try:
            if image is None:
                return None, None
            annotated_image = detect_and_annotate(image, model_name, conf_threshold, iou_threshold)
            return image, annotated_image
        except Exception as e:
            print(f"Error in process_image: {str(e)}")
            return image, image  # Fallback to original image

    def clear():
        return None, None

    detect_btn.click(
        process_image,
        inputs=[input_image, model_selector, conf_threshold, iou_threshold],
        outputs=[input_image, output_image]
    )
    clear_btn.click(
        clear,
        inputs=None,
        outputs=[input_image, output_image]
    )

if __name__ == "__main__":
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        show_error=True,
        debug=True
    )