import cv2
import streamlit as st
import tempfile
import numpy as np
from face_detection import FaceDetector
from mark_detection import MarkDetector
from pose_estimation import PoseEstimator
from utils import refine
from PIL import Image
st.title("Head Pose Estimation")
st.text("Just a heads up (pun intended)... The code used for this space is largely borrowed from https://github.com/yinguobing/head-pose-estimation. Slightly altered to fit image needs and make it work on huggingface.")
# Choose between Image or Video file upload
file_type = st.selectbox("Choose the type of file you want to upload", ("Image", "Video"))
uploaded_file = st.file_uploader(
    "Upload an image or video file of your face", 
    type=["jpg", "jpeg", "png", "mp4", "mov", "avi", "mkv"]
)

# Display placeholder for real-time video output
FRAME_WINDOW = st.image([])

if uploaded_file is not None:
    # Video processing
    if file_type == "Video":
        tfile = tempfile.NamedTemporaryFile(delete=False)
        tfile.write(uploaded_file.read())
        cap = cv2.VideoCapture(tfile.name)
        st.write(f"Video source: {tfile.name}")

        # Getting frame sizes
        frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        
        # Initialize face detection, landmark detection, and pose estimation models
        face_detector = FaceDetector("assets/face_detector.onnx")
        mark_detector = MarkDetector("assets/face_landmarks.onnx")
        pose_estimator = PoseEstimator(frame_width, frame_height)

        # Process each frame
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break

            # Step 1: Detect faces in the frame
            faces, _ = face_detector.detect(frame, 0.7)

            # If a face is detected, proceed with pose estimation
            if len(faces) > 0:
                # Detect landmarks for the first face
                face = refine(faces, frame_width, frame_height, 0.15)[0]
                x1, y1, x2, y2 = face[:4].astype(int)
                patch = frame[y1:y2, x1:x2]

                # Run landmark detection and convert local face area to global image
                marks = mark_detector.detect([patch])[0].reshape([68, 2])
                marks *= (x2 - x1)
                marks[:, 0] += x1
                marks[:, 1] += y1

                # Pose estimation with the detected landmarks
                pose = pose_estimator.solve(marks)

                # Draw the pose on the frame
                pose_estimator.visualize(frame, pose, color=(0, 255, 0))

            # Convert frame to RGB for Streamlit display
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            FRAME_WINDOW.image(frame_rgb)

        cap.release()

    # Image processing
    elif file_type == "Image":
        # Load and process uploaded image
        image = np.array(Image.open(uploaded_file))
        frame_height, frame_width, _ = image.shape

        # Initialize models for detection and pose estimation
        face_detector = FaceDetector("assets/face_detector.onnx")
        mark_detector = MarkDetector("assets/face_landmarks.onnx")
        pose_estimator = PoseEstimator(frame_width, frame_height)

        # Detect face and landmarks
        faces, _ = face_detector.detect(image, 0.7)
        if len(faces) > 0:
            face = refine(faces, frame_width, frame_height, 0.15)[0]
            x1, y1, x2, y2 = face[:4].astype(int)
            patch = image[y1:y2, x1:x2]

            # Detect landmarks and map them to global image coordinates
            marks = mark_detector.detect([patch])[0].reshape([68, 2])
            marks *= (x2 - x1)
            marks[:, 0] += x1
            marks[:, 1] += y1

            # Estimate pose and visualize on image
            pose = pose_estimator.solve(marks)
            pose_estimator.visualize(image, pose, color=(0, 255, 0))

            # Convert image to RGB and display in Streamlit
            image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            st.image(image_rgb, caption="Pose Estimated Image", use_column_width=True)