import gradio as gr
import cv2
import numpy as np
from PIL import Image
import mediapipe as mp
from fer import FER  # Facial emotion recognition

# -----------------------------
# Configuration
# -----------------------------
SKIP_RATE = 1  # For image processing, always run the analysis
DESIRED_SIZE = (640, 480)

# -----------------------------
# Global caches for overlay info and frame counters
# -----------------------------
posture_cache = {"landmarks": None, "text": "Initializing...", "counter": 0}
emotion_cache = {"text": "Initializing...", "counter": 0}
faces_cache = {"boxes": None, "text": "Initializing...", "counter": 0}

# -----------------------------
# Initialize Models and Helpers
# -----------------------------
# MediaPipe Pose, Face Detection, and Face Mesh
mp_pose = mp.solutions.pose
pose = mp_pose.Pose()
mp_drawing = mp.solutions.drawing_utils

mp_face_detection = mp.solutions.face_detection
face_detection = mp_face_detection.FaceDetection(min_detection_confidence=0.5)

# Initialize the FER emotion detector (using the FER package)
emotion_detector = FER(mtcnn=True)

# -----------------------------
# Overlay Drawing Functions
# -----------------------------
def draw_posture_overlay(raw_frame, landmarks):
    for connection in mp_pose.POSE_CONNECTIONS:
        start_idx, end_idx = connection
        if start_idx < len(landmarks) and end_idx < len(landmarks):
            start_point = landmarks[start_idx]
            end_point = landmarks[end_idx]
            cv2.line(raw_frame, start_point, end_point, (50, 205, 50), 2)
    for (x, y) in landmarks:
        cv2.circle(raw_frame, (x, y), 4, (50, 205, 50), -1)
    return raw_frame

def draw_boxes_overlay(raw_frame, boxes, color):
    for (x1, y1, x2, y2) in boxes:
        cv2.rectangle(raw_frame, (x1, y1), (x2, y2), color, 2)
    return raw_frame

# -----------------------------
# Heavy (Synchronous) Detection Functions
# -----------------------------
def compute_posture_overlay(image):
    frame_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
    h, w, _ = frame_bgr.shape
    frame_bgr_small = cv2.resize(frame_bgr, DESIRED_SIZE)
    small_h, small_w, _ = frame_bgr_small.shape
    frame_rgb_small = cv2.cvtColor(frame_bgr_small, cv2.COLOR_BGR2RGB)
    pose_results = pose.process(frame_rgb_small)
    if pose_results.pose_landmarks:
        landmarks = []
        for lm in pose_results.pose_landmarks.landmark:
            x = int(lm.x * small_w * (w / small_w))
            y = int(lm.y * small_h * (h / small_h))
            landmarks.append((x, y))
        text = "Posture detected"
    else:
        landmarks = []
        text = "No posture detected"
    return landmarks, text

def compute_emotion_overlay(image):
    frame_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
    frame_bgr_small = cv2.resize(frame_bgr, DESIRED_SIZE)
    frame_rgb_small = cv2.cvtColor(frame_bgr_small, cv2.COLOR_BGR2RGB)
    emotions = emotion_detector.detect_emotions(frame_rgb_small)
    if emotions:
        top_emotion, score = max(emotions[0]["emotions"].items(), key=lambda x: x[1])
        text = f"{top_emotion} ({score:.2f})"
    else:
        text = "No face detected"
    return text

def compute_faces_overlay(image):
    frame_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
    h, w, _ = frame_bgr.shape
    frame_bgr_small = cv2.resize(frame_bgr, DESIRED_SIZE)
    small_h, small_w, _ = frame_bgr_small.shape
    frame_rgb_small = cv2.cvtColor(frame_bgr_small, cv2.COLOR_BGR2RGB)
    face_results = face_detection.process(frame_rgb_small)
    boxes = []
    if face_results.detections:
        for detection in face_results.detections:
            bbox = detection.location_data.relative_bounding_box
            x = int(bbox.xmin * small_w)
            y = int(bbox.ymin * small_h)
            box_w = int(bbox.width * small_w)
            box_h = int(bbox.height * small_h)
            boxes.append((x, y, x + box_w, y + box_h))
        text = f"Detected {len(boxes)} face(s)"
    else:
        text = "No faces detected"
    return boxes, text

# -----------------------------
# New Facemesh Functions (with connected red lines and mask output)
# -----------------------------
def compute_facemesh_overlay(image):
    """
    Uses MediaPipe Face Mesh to detect and draw facial landmarks.
    Draws green dots for landmarks and connects them with thin red lines.
    Returns two images:
      - annotated: the original image overlaid with the facemesh
      - mask: a black background image with only the facemesh drawn
    """
    frame_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
    h, w, _ = frame_bgr.shape
    # Create a copy for annotated output and a black mask
    annotated = frame_bgr.copy()
    mask = np.zeros_like(frame_bgr)
    
    # Initialize Face Mesh in static mode
    face_mesh = mp.solutions.face_mesh.FaceMesh(
        static_image_mode=True, max_num_faces=1, refine_landmarks=True, min_detection_confidence=0.5
    )
    results = face_mesh.process(cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB))
    
    if results.multi_face_landmarks:
        for face_landmarks in results.multi_face_landmarks:
            # Convert landmarks to pixel coordinates
            landmark_points = []
            for lm in face_landmarks.landmark:
                x = int(lm.x * w)
                y = int(lm.y * h)
                landmark_points.append((x, y))
            # Draw thin red lines between connected landmarks using the FACEMESH_TESSELATION
            for connection in mp.solutions.face_mesh.FACEMESH_TESSELATION:
                start_idx, end_idx = connection
                if start_idx < len(landmark_points) and end_idx < len(landmark_points):
                    pt1 = landmark_points[start_idx]
                    pt2 = landmark_points[end_idx]
                    cv2.line(annotated, pt1, pt2, (255, 0, 0), 1)
                    cv2.line(mask, pt1, pt2, (255, 0, 0), 1)
            # Draw green dots for each landmark
            for pt in landmark_points:
                cv2.circle(annotated, pt, 2, (0, 255, 0), -1)
                cv2.circle(mask, pt, 2, (0, 255, 0), -1)
        text = "Facemesh detected"
    else:
        text = "No facemesh detected"
    face_mesh.close()
    return annotated, mask, text

def analyze_facemesh(image):
    annotated_image, mask_image, text = compute_facemesh_overlay(image)
    return (annotated_image, mask_image, 
            f"<div style='color: #00ff00 !important;'>Facemesh Analysis: {text}</div>")

# -----------------------------
# Main Analysis Functions for Single Image
# -----------------------------
def analyze_posture_current(image):
    global posture_cache
    posture_cache["counter"] += 1
    current_frame = np.array(image)
    if posture_cache["counter"] % SKIP_RATE == 0 or posture_cache["landmarks"] is None:
        landmarks, text = compute_posture_overlay(image)
        posture_cache["landmarks"] = landmarks
        posture_cache["text"] = text
    output = current_frame.copy()
    if posture_cache["landmarks"]:
        output = draw_posture_overlay(output, posture_cache["landmarks"])
    return output, f"<div style='color: #00ff00 !important;'>Posture Analysis: {posture_cache['text']}</div>"

def analyze_emotion_current(image):
    global emotion_cache
    emotion_cache["counter"] += 1
    current_frame = np.array(image)
    if emotion_cache["counter"] % SKIP_RATE == 0 or emotion_cache["text"] is None:
        text = compute_emotion_overlay(image)
        emotion_cache["text"] = text
    return current_frame, f"<div style='color: #00ff00 !important;'>Emotion Analysis: {emotion_cache['text']}</div>"

def analyze_faces_current(image):
    global faces_cache
    faces_cache["counter"] += 1
    current_frame = np.array(image)
    if faces_cache["counter"] % SKIP_RATE == 0 or faces_cache["boxes"] is None:
        boxes, text = compute_faces_overlay(image)
        faces_cache["boxes"] = boxes
        faces_cache["text"] = text
    output = current_frame.copy()
    if faces_cache["boxes"]:
        output = draw_boxes_overlay(output, faces_cache["boxes"], (0, 0, 255))
    return output, f"<div style='color: #00ff00 !important;'>Face Detection: {faces_cache['text']}</div>"

# -----------------------------
# Custom CSS (Revamped High-Contrast Neon Theme with Green Glows)
# -----------------------------
custom_css = """
@import url('https://fonts.googleapis.com/css2?family=Orbitron:wght@400;700&display=swap');
body {
    background-color: #121212;
    font-family: 'Orbitron', sans-serif;
    color: #00ff00;
}
.gradio-container {
    background: linear-gradient(135deg, #2d2d2d, #1a1a1a);
    border: 2px solid #00ff00;
    box-shadow: 0 0 15px #00ff00;
    border-radius: 10px;
    padding: 20px;
    max-width: 1200px;
    margin: auto;
}
.gradio-title, .gradio-description, .tab-item, .tab-item * {
    color: #00ff00 !important;
    text-shadow: 0 0 10px #00ff00;
}
input, button, .output {
    border: 1px solid #00ff00;
    box-shadow: 0 0 8px #00ff00;
    color: #00ff00;
    background-color: #1a1a1a;
}
"""

# -----------------------------
# Create Individual Interfaces for Image Processing
# -----------------------------
posture_interface = gr.Interface(
    fn=analyze_posture_current,
    inputs=gr.Image(label="Upload an Image for Posture Analysis"),
    outputs=[gr.Image(type="numpy", label="Annotated Output"), gr.HTML(label="Posture Analysis")],
    title="Posture",
    description="Detects your posture using MediaPipe with connector lines.",
    live=False
)

emotion_interface = gr.Interface(
    fn=analyze_emotion_current,
    inputs=gr.Image(label="Upload an Image for Emotion Analysis"),
    outputs=[gr.Image(type="numpy", label="Annotated Output"), gr.HTML(label="Emotion Analysis")],
    title="Emotion",
    description="Detects facial emotions using FER.",
    live=False
)

faces_interface = gr.Interface(
    fn=analyze_faces_current,
    inputs=gr.Image(label="Upload an Image for Face Detection"),
    outputs=[gr.Image(type="numpy", label="Annotated Output"), gr.HTML(label="Face Detection")],
    title="Faces",
    description="Detects faces using MediaPipe.",
    live=False
)

facemesh_interface = gr.Interface(
    fn=analyze_facemesh,
    inputs=gr.Image(label="Upload an Image for Facemesh"),
    outputs=[
        gr.Image(type="numpy", label="Annotated Output"),
        gr.Image(type="numpy", label="Mask Output"),
        gr.HTML(label="Facemesh Analysis")
    ],
    title="Facemesh",
    description="Detects facial landmarks using MediaPipe Face Mesh and outputs both an annotated image and a mask on a black background.",
    live=False
)

tabbed_interface = gr.TabbedInterface(
    interface_list=[
        posture_interface,
        emotion_interface,
        faces_interface,
        facemesh_interface
    ],
    tab_names=[
        "Posture",
        "Emotion",
        "Faces",
        "Facemesh"
    ]
)

# -----------------------------
# Wrap in a Blocks Layout and Launch
# -----------------------------
demo = gr.Blocks(css=custom_css)
with demo:
    gr.Markdown("<h1 class='gradio-title'>Multi-Analysis Image App</h1>")
    gr.Markdown("<p class='gradio-description'>Upload an image to run high-tech analysis for posture, emotions, faces, and facemesh landmarks.</p>")
    tabbed_interface.render()

if __name__ == "__main__":
    demo.launch()