johnlockejrr commited on
Commit
29358f1
·
verified ·
1 Parent(s): 5de898a

Upload app.py.bak

Browse files
Files changed (1) hide show
  1. app.py.bak +190 -0
app.py.bak ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Dict
2
+ import gradio as gr
3
+ import supervision as sv
4
+ import numpy as np
5
+ import cv2
6
+ from huggingface_hub import hf_hub_download
7
+ from ultralytics import YOLO
8
+
9
+ # Define models
10
+ MODEL_OPTIONS = {
11
+ "YOLOv11-Small": "medieval-yolo11s-seg.pt"
12
+ }
13
+
14
+ # Dictionary to store loaded models
15
+ models: Dict[str, YOLO] = {}
16
+
17
+ # Load all models
18
+ for name, model_file in MODEL_OPTIONS.items():
19
+ try:
20
+ model_path = hf_hub_download(
21
+ repo_id="johnlockejrr/medieval-manuscript-yolov11-seg",
22
+ filename=model_file
23
+ )
24
+ models[name] = YOLO(model_path)
25
+ except Exception as e:
26
+ print(f"Error loading model {name}: {str(e)}")
27
+
28
+ # Create annotators
29
+ LABEL_ANNOTATOR = sv.LabelAnnotator(text_color=sv.Color.BLACK)
30
+ MASK_ANNOTATOR = sv.MaskAnnotator()
31
+
32
+ def process_masks(masks: np.ndarray, target_shape: Tuple[int, int]) -> np.ndarray:
33
+ """Process and resize masks to target shape"""
34
+ if masks is None:
35
+ return None
36
+
37
+ processed_masks = []
38
+ h, w = target_shape
39
+ for mask in masks:
40
+ # Resize mask to target dimensions
41
+ resized_mask = cv2.resize(mask.astype(float), (w, h), interpolation=cv2.INTER_LINEAR)
42
+ # Threshold to create binary mask
43
+ processed_masks.append(resized_mask > 0.5)
44
+
45
+ return np.array(processed_masks)
46
+
47
+ def detect_and_annotate(
48
+ image: np.ndarray,
49
+ model_name: str,
50
+ conf_threshold: float,
51
+ iou_threshold: float
52
+ ) -> np.ndarray:
53
+ try:
54
+ if image is None:
55
+ return None
56
+
57
+ model = models.get(model_name)
58
+ if model is None:
59
+ raise ValueError(f"Model {model_name} not loaded")
60
+
61
+ # Perform inference
62
+ results = model.predict(
63
+ image,
64
+ conf=conf_threshold,
65
+ iou=iou_threshold
66
+ )[0]
67
+
68
+ # Convert results to supervision Detections
69
+ boxes = results.boxes.xyxy.cpu().numpy()
70
+ confidence = results.boxes.conf.cpu().numpy()
71
+ class_ids = results.boxes.cls.cpu().numpy().astype(int)
72
+
73
+ # Process masks
74
+ masks = None
75
+ if results.masks is not None:
76
+ masks = results.masks.data.cpu().numpy()
77
+ print(f"Original mask shape: {masks.shape}") # Debug
78
+
79
+ # Fix the shape mismatch - should be (num_masks, H, W)
80
+ if masks.shape[0] != len(boxes):
81
+ masks = np.transpose(masks, (2, 0, 1)) # Convert from (H,W,N) to (N,H,W)
82
+
83
+ print(f"Processed mask shape: {masks.shape}") # Debug
84
+
85
+ # Resize masks to original image dimensions
86
+ h, w = image.shape[:2]
87
+ resized_masks = []
88
+ for mask in masks:
89
+ resized_mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_LINEAR)
90
+ resized_masks.append(resized_mask)
91
+ masks = np.array(resized_masks)
92
+ masks = masks > 0.5 # Convert to boolean
93
+
94
+ # Create Detections object
95
+ detections = sv.Detections(
96
+ xyxy=boxes,
97
+ confidence=confidence,
98
+ class_id=class_ids,
99
+ mask=masks
100
+ )
101
+
102
+ # Create labels with confidence scores
103
+ labels = [
104
+ f"{results.names[class_id]} ({conf:.2f})"
105
+ for class_id, conf in zip(class_ids, confidence)
106
+ ]
107
+
108
+ # Annotate image
109
+ annotated_image = image.copy()
110
+ if masks is not None:
111
+ annotated_image = MASK_ANNOTATOR.annotate(
112
+ scene=annotated_image,
113
+ detections=detections
114
+ )
115
+ annotated_image = LABEL_ANNOTATOR.annotate(
116
+ scene=annotated_image,
117
+ detections=detections,
118
+ labels=labels
119
+ )
120
+
121
+ return annotated_image
122
+
123
+ except Exception as e:
124
+ print(f"Error during detection: {str(e)}")
125
+ return image
126
+
127
+ # Create Gradio interface
128
+ with gr.Blocks() as demo:
129
+ gr.Markdown("# Medieval Manuscript Segmentation with YOLO")
130
+
131
+ with gr.Row():
132
+ with gr.Column():
133
+ input_image = gr.Image(label="Input Image", type='numpy')
134
+ with gr.Accordion("Detection Settings", open=True):
135
+ model_selector = gr.Dropdown(
136
+ choices=list(MODEL_OPTIONS.keys()),
137
+ value=list(MODEL_OPTIONS.keys())[0],
138
+ label="Model"
139
+ )
140
+ conf_threshold = gr.Slider(
141
+ label="Confidence Threshold",
142
+ minimum=0.0,
143
+ maximum=1.0,
144
+ step=0.05,
145
+ value=0.25
146
+ )
147
+ iou_threshold = gr.Slider(
148
+ label="IoU Threshold",
149
+ minimum=0.0,
150
+ maximum=1.0,
151
+ step=0.05,
152
+ value=0.45
153
+ )
154
+ detect_btn = gr.Button("Detect", variant="primary")
155
+ clear_btn = gr.Button("Clear")
156
+
157
+ with gr.Column():
158
+ output_image = gr.Image(label="Segmentation Result", type='numpy')
159
+
160
+ def process_image(image, model_name, conf_threshold, iou_threshold):
161
+ try:
162
+ if image is None:
163
+ return None, None
164
+ annotated_image = detect_and_annotate(image, model_name, conf_threshold, iou_threshold)
165
+ return image, annotated_image
166
+ except Exception as e:
167
+ print(f"Error in process_image: {str(e)}")
168
+ return image, image # Fallback to original image
169
+
170
+ def clear():
171
+ return None, None
172
+
173
+ detect_btn.click(
174
+ process_image,
175
+ inputs=[input_image, model_selector, conf_threshold, iou_threshold],
176
+ outputs=[input_image, output_image]
177
+ )
178
+ clear_btn.click(
179
+ clear,
180
+ inputs=None,
181
+ outputs=[input_image, output_image]
182
+ )
183
+
184
+ if __name__ == "__main__":
185
+ demo.launch(
186
+ server_name="0.0.0.0",
187
+ server_port=7860,
188
+ show_error=True,
189
+ debug=True
190
+ )