import gradio as gr
import torch
import numpy as np
import cv2
from sam2.build_sam import build_sam2_video_predictor
import tempfile
import os
import contextlib
from trajectory_service import TrajectoryService


class VideoTracker:
    def __init__(self):
        self.checkpoint = "./models/sam2.1_hiera_tiny.pt"
        self.model_cfg = "configs/sam2.1/sam2.1_hiera_t.yaml"
        self.predictor = build_sam2_video_predictor(
            self.model_cfg, self.checkpoint, device="cpu", mode="eval"
        )
        self.state = None
        self.video_frames = None
        self.current_frame_idx = 0
        self.masks = []
        self.points = []
        self.frame_count = 0
        self.video_info = None
        self.obj_id = 1
        self.out_mask_logits = None
        self.frame_masks = {}  # Store masks for each frame
        self.trajectory_service = None

    def load_video(self, video_path):
        if video_path is None:
            return None, gr.Slider(minimum=0, maximum=0, step=1, value=0)

        # Create a temporary file for the video
        temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
        temp_file.close()

        # Copy the uploaded video to the temporary file
        with open(video_path, "rb") as f_src, open(temp_file.name, "wb") as f_dst:
            f_dst.write(f_src.read())

        # Load video frames using OpenCV
        cap = cv2.VideoCapture(temp_file.name)
        frames = []
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            frames.append(frame)

        if not frames:
            cap.release()
            os.unlink(temp_file.name)
            return None, gr.Slider(minimum=0, maximum=0, step=1, value=0)

        # Store video info
        fps = cap.get(cv2.CAP_PROP_FPS)
        self.video_info = {
            "path": temp_file.name,
            "height": int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)),
            "width": int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
            "fps": fps,
            "total_frames": len(frames),
        }
        cap.release()

        self.video_frames = frames
        self.frame_count = len(frames)
        self.trajectory_service = TrajectoryService(fps=fps)

        # Initialize SAM2 state with video path
        with torch.inference_mode():
            self.state = self.predictor.init_state(temp_file.name)

        # Now we can remove the temp file
        os.unlink(temp_file.name)

        return frames[0], gr.Slider(minimum=0, maximum=len(frames) - 1, step=1, value=0)

    def update_frame(self, frame_number):
        if self.video_frames is None:
            return None

        self.current_frame_idx = frame_number
        frame = self.video_frames[frame_number].copy()

        # Apply any existing mask for this frame
        if frame_number in self.frame_masks:
            self.out_mask_logits = self.frame_masks[frame_number]
            frame = self._draw_tracking(frame)

        # Draw points (just the points, no trajectory)
        for point in self.points:
            if point[0] == frame_number:
                cv2.circle(
                    frame, (int(point[1]), int(point[2])), 5, (255, 255, 0), -1
                )  # Yellow dot
                cv2.circle(
                    frame, (int(point[1]), int(point[2])), 7, (0, 0, 0), 1
                )  # Black border

        return frame

    def add_point(self, frame, evt: gr.SelectData):
        """Add a point and get ball prediction with enhanced mask visualization"""
        if self.state is None:
            return frame

        x, y = evt.index[0], evt.index[1]
        self.points.append((self.current_frame_idx, x, y))

        # Add point to trajectory service for later use
        if self.trajectory_service:
            self.trajectory_service.add_point(self.current_frame_idx, x, y)

        frame_with_points = frame.copy()

        # Get ball prediction using SAM2.1
        with torch.inference_mode():
            # Convert points and labels to numpy arrays
            points = np.array([(x, y)], dtype=np.float32)
            labels = np.array([1], dtype=np.int32)  # 1 for positive click

            # Add point and get mask
            _, out_obj_ids, out_mask_logits = self.predictor.add_new_points(
                inference_state=self.state,
                frame_idx=self.current_frame_idx,
                obj_id=self.obj_id,
                points=points,
                labels=labels,
            )

            if out_mask_logits is not None and len(out_mask_logits) > 0:
                self.out_mask_logits = (
                    out_mask_logits[0]
                    if isinstance(out_mask_logits, list)
                    else out_mask_logits
                )
                # Store mask for this frame
                self.frame_masks[self.current_frame_idx] = self.out_mask_logits

        # Draw tracking visualization with enhanced mask
        frame_with_points = self._draw_tracking(frame_with_points)

        # Draw point on top of mask (just the point, no trajectory)
        cv2.circle(
            frame_with_points, (int(x), int(y)), 5, (255, 255, 0), -1
        )  # Yellow dot
        cv2.circle(frame_with_points, (int(x), int(y)), 7, (0, 0, 0), 1)  # Black border

        return frame_with_points

    def propagate_video(self):
        if self.state is None:
            return None

        output_frames = self.video_frames.copy()

        # Store all masks and their centers for trajectory calculation
        all_masks = []
        mask_centers = []

        # First pass: collect all masks and calculate centers
        with torch.inference_mode():
            for frame_idx, obj_ids, masks in self.predictor.propagate_in_video(
                self.state,
                start_frame_idx=0,
                reverse=False,
            ):
                if masks is not None and len(masks) > 0:
                    mask = masks[0] if isinstance(masks, list) else masks
                    all_masks.append((frame_idx, mask))

                    # Get mask center
                    mask_np = (mask > 0.0).cpu().numpy()
                    center = self._get_mask_center(mask_np)
                    if center is not None:
                        mask_centers.append((frame_idx, center[0], center[1]))

                    # Store mask for each frame
                    self.frame_masks[frame_idx] = mask

        # Add detected centers to trajectory service
        if self.trajectory_service:
            # Clear existing points and add user-selected points first
            self.trajectory_service.clear_points()
            for point in self.points:
                self.trajectory_service.add_point(point[0], point[1], point[2])

            # Add centers from mask detection
            for center in mask_centers:
                if center[0] not in [
                    p[0] for p in self.points
                ]:  # Don't duplicate user points
                    self.trajectory_service.add_point(center[0], center[1], center[2])

            # Calculate trajectory with all points
            trajectory_points = self.trajectory_service.get_trajectory()

        # Second pass: apply visualization with temporal smoothing and trajectory
        for i, frame in enumerate(output_frames):
            frame = frame.copy()

            # Find masks for this frame
            current_masks = [m[1] for m in all_masks if m[0] == i]

            if current_masks:
                self.out_mask_logits = current_masks[0]
                mask_np = (current_masks[0] > 0.0).cpu().numpy()
                mask_np = self._handle_mask_dimensions(mask_np)
                mask_np = mask_np.astype(np.uint8)
                frame = self._draw_tracking(frame, alpha=0.6)

                try:
                    kernel = np.ones((5, 5), np.uint8)
                    dilated_mask = cv2.dilate(mask_np, kernel, iterations=2)
                    glow = frame.copy()
                    glow[dilated_mask > 0] = [0, 255, 255]  # Yellow glow
                    frame = cv2.addWeighted(frame, 0.7, glow, 0.3, 0)
                except cv2.error as e:
                    print(
                        f"Warning: Could not apply glow effect. Mask shape: {mask_np.shape}, Frame shape: {frame.shape}"
                    )

            # Draw trajectory
            if self.trajectory_service and trajectory_points:
                frame = self.trajectory_service.draw_trajectory(frame, i)

            output_frames[i] = frame

        # Save as video with higher quality
        temp_output = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
        height, width = output_frames[0].shape[:2]
        writer = cv2.VideoWriter(
            temp_output, cv2.VideoWriter_fourcc(*"mp4v"), 30, (width, height), True
        )

        for frame in output_frames:
            writer.write(frame)
        writer.release()

        return temp_output

    def _get_mask_center(self, mask_np):
        """Get the center point of a mask"""
        if mask_np is None:
            return None

        # Ensure mask is 2D
        mask_np = self._handle_mask_dimensions(mask_np)
        mask_np = (mask_np > 0.0).astype(np.uint8)

        # Find contours
        contours, _ = cv2.findContours(
            mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
        )
        if not contours:
            return None

        # Get largest contour
        largest_contour = max(contours, key=cv2.contourArea)

        # Calculate centroid
        M = cv2.moments(largest_contour)
        if M["m00"] == 0:
            return None

        cx = int(M["m10"] / M["m00"])
        cy = int(M["m01"] / M["m00"])

        return (cx, cy)

    def _handle_mask_dimensions(self, mask_np):
        """Helper function to handle various mask dimensions"""
        # Handle 4D tensor (1, 1, H, W)
        if len(mask_np.shape) == 4:
            mask_np = mask_np[0, 0]
        # Handle 3D tensor (1, H, W) or (H, W, 1)
        elif len(mask_np.shape) == 3:
            if mask_np.shape[0] == 1:  # (1, H, W) format
                mask_np = mask_np[0]
            elif mask_np.shape[2] == 1:  # (H, W, 1) format
                mask_np = mask_np[:, :, 0]
        return mask_np

    def _draw_tracking(self, frame, alpha=0.5):
        """Draw object mask on frame with enhanced visualization"""
        if self.out_mask_logits is not None:
            # Convert logits to binary mask
            if isinstance(self.out_mask_logits, list):
                mask = self.out_mask_logits[0]
            else:
                mask = self.out_mask_logits

            # Get binary mask and handle dimensions
            mask_np = (mask > 0.0).cpu().numpy()
            mask_np = self._handle_mask_dimensions(mask_np)

            if mask_np.shape[:2] == frame.shape[:2]:
                # Create a red overlay for the mask
                overlay = frame.copy()
                overlay[mask_np > 0] = [0, 0, 255]  # BGR format: Red color

                # Add a border around the mask for better visibility
                contours, _ = cv2.findContours(
                    mask_np.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
                )

                # Draw thicker contours for better visibility
                cv2.drawContours(
                    overlay, contours, -1, (0, 255, 255), 3
                )  # Thicker yellow border

                # Draw a second contour for emphasis
                cv2.drawContours(
                    frame, contours, -1, (255, 255, 0), 1
                )  # Thin bright border

                # Blend the overlay with original frame
                frame = cv2.addWeighted(overlay, alpha, frame, 1 - alpha, 0)

        return frame


def create_interface():
    tracker = VideoTracker()

    with gr.Blocks() as interface:
        gr.Markdown("# Object Tracking with SAM2")
        gr.Markdown("Upload a video and click on objects to track them")

        with gr.Row():
            with gr.Column(scale=2):
                video_input = gr.Video(label="Input Video")
                image_output = gr.Image(label="Current Frame", interactive=True)
                frame_slider = gr.Slider(
                    minimum=0,
                    maximum=0,
                    step=1,
                    value=0,
                    label="Frame Selection",
                    interactive=True,
                )

            with gr.Column(scale=1):
                propagate_btn = gr.Button("Propagate Through Video", variant="primary")
                video_output = gr.Video(label="Output Video")

        video_input.change(
            fn=tracker.load_video,
            inputs=[video_input],
            outputs=[image_output, frame_slider],
        )

        frame_slider.change(
            fn=tracker.update_frame,
            inputs=[frame_slider],
            outputs=[image_output],
        )

        image_output.select(
            fn=tracker.add_point,
            inputs=[image_output],
            outputs=[image_output],
        )

        propagate_btn.click(
            fn=tracker.propagate_video,
            inputs=[],
            outputs=[video_output],
        )

    return interface


if __name__ == "__main__":
    interface = create_interface()
    interface.launch(share=True)