from typing import Tuple, Dict, List, Optional
import gradio as gr
import supervision as sv
import numpy as np
import cv2
from huggingface_hub import hf_hub_download
from ultralytics import YOLO

# Define models
MODEL_OPTIONS = {
    "YOLOv11-Nano": "medieval-yolo11n-seg.pt",
    "YOLOv11-Small": "medieval-yolo11s-seg.pt",
    "YOLOv11-Medium": "medieval-yolo11m-seg.pt",
    "YOLOv11-Large": "medieval-yolo11l-seg.pt",
    "YOLOv11-XLarge": "medieval-yolo11x-seg.pt",
    "YOLOv11-Medium Zones": "medieval_zones-yolo11m-seg.pt",
    "YOLOv11-Medium Lines": "medieval_lines-yolo11m-seg.pt",
    "ms_yolo11m-seg4-YTG": "ms_yolo11m-seg4-YTG.pt",
    "ms_yolo11m-seg5-swin_t": "ms_yolo11m-seg5-swin_t.pt",
    "ms_yolo11x-seg2-swin_t": "ms_yolo11x-seg2-swin_t.pt",
    "ms_yolo11m-seg6-convnext_tiny": "ms_yolo11m-seg6-convnext_tiny.pt",
    "yolo11m-seg-gpt": "yolo11m-seg-gpt.pt",
    "ms_yolo11x-seg3-swin_t-fpn": "ms_yolo11x-seg3-swin_t-fpn.pt",
    "yolo11x-seg-gpt7": "yolo11x-seg-gpt7.pt"
}

# Dictionary to store loaded models
models: Dict[str, YOLO] = {}

# Load all models
for name, model_file in MODEL_OPTIONS.items():
    try:
        model_path = hf_hub_download(
            repo_id="johnlockejrr/medieval-manuscript-yolov11-seg",
            filename=model_file
        )
        models[name] = YOLO(model_path)
    except Exception as e:
        print(f"Error loading model {name}: {str(e)}")

def simplify_polygons(polygons: List[np.ndarray], approx_level: float = 0.01) -> List[Optional[np.ndarray]]:
    """Simplify polygon contours using Douglas-Peucker algorithm.
    
    Args:
        polygons: List of polygon contours
        approx_level: Approximation level (0-1), lower values mean more simplification
        
    Returns:
        List of simplified polygons (or None for invalid polygons)
    """
    result = []
    for polygon in polygons:
        if len(polygon) < 4:
            result.append(None)
            continue

        perimeter = cv2.arcLength(polygon, True)
        approx = cv2.approxPolyDP(polygon, approx_level * perimeter, True)
        if len(approx) < 4:
            result.append(None)
            continue

        result.append(approx.squeeze())
    return result

# Custom MaskAnnotator for outline-only masks with simplified polygons
class OutlineMaskAnnotator:
    def __init__(self, color: tuple = (255, 0, 0), thickness: int = 2, simplify: bool = False):
        self.color = color
        self.thickness = thickness
        self.simplify = simplify
        
    def annotate(self, scene: np.ndarray, detections: sv.Detections) -> np.ndarray:
        if detections.mask is None:
            return scene
            
        scene = scene.copy()
        for mask in detections.mask:
            contours, _ = cv2.findContours(
                mask.astype(np.uint8),
                cv2.RETR_EXTERNAL,
                cv2.CHAIN_APPROX_SIMPLE
            )
            if self.simplify:
                contours = simplify_polygons(contours)
                contours = [c for c in contours if c is not None]
                
            cv2.drawContours(
                scene,
                contours,
                -1,
                self.color,
                self.thickness
            )
        return scene

# Create annotators with new settings
LABEL_ANNOTATOR = sv.LabelAnnotator(
    text_color=sv.Color.BLACK,
    text_scale=0.35,
    text_thickness=1,
    text_padding=2
)

def detect_and_annotate(
    image: np.ndarray,
    model_name: str,
    conf_threshold: float,
    iou_threshold: float,
    simplify_polygons_option: bool,
    imgsz: Tuple[int, int] = (640, 512)
) -> np.ndarray:
    # Get the selected model
    model = models[model_name]
    
    # Perform inference
    results = model.predict(
        image,
        imgsz=imgsz,
        conf=conf_threshold,
        iou=iou_threshold
    )[0]
    
    # Convert results to supervision Detections
    boxes = results.boxes.xyxy.cpu().numpy()
    confidence = results.boxes.conf.cpu().numpy()
    class_ids = results.boxes.cls.cpu().numpy().astype(int)
    
    # Handle masks if they exist
    masks = None
    if results.masks is not None:
        masks = results.masks.data.cpu().numpy()
        # Convert from (N,H,W) to (H,W,N) for processing
        masks = np.transpose(masks, (1, 2, 0))
        h, w = image.shape[:2]
        resized_masks = []
        for i in range(masks.shape[-1]):
            resized_mask = cv2.resize(masks[..., i], (w, h), interpolation=cv2.INTER_LINEAR)
            resized_masks.append(resized_mask > 0.5)
        masks = np.stack(resized_masks) if resized_masks else None
    
    # Create Detections object
    detections = sv.Detections(
        xyxy=boxes,
        confidence=confidence,
        class_id=class_ids,
        mask=masks
    )
    
    # Create labels with confidence scores
    labels = [
        f"{results.names[class_id]} ({conf:.2f})"
        for class_id, conf
        in zip(class_ids, confidence)
    ]

    # Create mask annotator based on the simplify option
    mask_annotator = OutlineMaskAnnotator(
        color=(255, 0, 0),
        thickness=2,
        simplify=simplify_polygons_option
    )

    # Annotate image
    annotated_image = image.copy()
    if masks is not None:
        annotated_image = mask_annotator.annotate(scene=annotated_image, detections=detections)
    annotated_image = LABEL_ANNOTATOR.annotate(scene=annotated_image, detections=detections, labels=labels)
    
    return annotated_image

# Rest of the Gradio interface remains exactly the same
with gr.Blocks() as demo:
    gr.Markdown("# Medieval Manuscript Segmentation with YOLO")
    
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(
                label="Input Image",
                type='numpy'
            )
            with gr.Accordion("Detection Settings", open=True):
                model_selector = gr.Dropdown(
                    choices=list(MODEL_OPTIONS.keys()),
                    value=list(MODEL_OPTIONS.keys())[0],
                    label="Model",
                    info="Select YOLO model variant"
                )
                with gr.Row():
                    conf_threshold = gr.Slider(
                        label="Confidence Threshold",
                        minimum=0.0,
                        maximum=1.0,
                        step=0.05,
                        value=0.25,
                    )
                    iou_threshold = gr.Slider(
                        label="IoU Threshold",
                        minimum=0.0,
                        maximum=1.0,
                        step=0.05,
                        value=0.45,
                        info="Decrease for stricter detection, increase for more overlapping masks"
                    )
                simplify_polygons_option = gr.Checkbox(
                    label="Simplify Polygons",
                    value=False,
                    info="Simplify polygon contours for cleaner outlines"
                )
            with gr.Row():
                clear_btn = gr.Button("Clear")
                detect_btn = gr.Button("Detect", variant="primary")
                
        with gr.Column():
            output_image = gr.Image(
                label="Detection Result",
                type='numpy'
            )

    def process_image(
        image: np.ndarray,
        model_name: str,
        conf_threshold: float,
        iou_threshold: float,
        simplify_polygons_option: bool
    ) -> Tuple[np.ndarray, np.ndarray]:
        if image is None:
            return None, None
        annotated_image = detect_and_annotate(image, model_name, conf_threshold, iou_threshold, simplify_polygons_option)
        return image, annotated_image

    def clear():
        return None, None

    detect_btn.click(
        process_image,
        inputs=[input_image, model_selector, conf_threshold, iou_threshold, simplify_polygons_option],
        outputs=[input_image, output_image]
    )
    clear_btn.click(
        clear,
        inputs=None,
        outputs=[input_image, output_image]
    )

if __name__ == "__main__":
    demo.launch(debug=True, show_error=True)