import gradio as gr
import cv2
import numpy as np
import torch
from torchvision import models, transforms
from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights
from PIL import Image
import mediapipe as mp
from fer import FER  # Facial emotion recognition

# -----------------------------
# Initialize Models and Helpers
# -----------------------------

# MediaPipe Pose for posture analysis
mp_pose = mp.solutions.pose
pose = mp_pose.Pose()
mp_drawing = mp.solutions.drawing_utils

# MediaPipe Face Detection for face detection
mp_face_detection = mp.solutions.face_detection
face_detection = mp_face_detection.FaceDetection(min_detection_confidence=0.5)

# Object Detection Model: Faster R-CNN (pretrained on COCO)
object_detection_model = models.detection.fasterrcnn_resnet50_fpn(
    weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT
)
object_detection_model.eval()
obj_transform = transforms.Compose([transforms.ToTensor()])

# Facial Emotion Detection using FER (requires TensorFlow)
emotion_detector = FER(mtcnn=True)

# -----------------------------
# Define Analysis Functions
# -----------------------------

def analyze_posture(image):
    """
    Takes an image (captured via the webcam), processes it with MediaPipe Pose,
    and returns an annotated image and a text summary.
    """
    # Convert from PIL (RGB) to OpenCV BGR format
    frame = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
    output_frame = frame.copy()
    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    
    posture_result = "No posture detected"
    pose_results = pose.process(frame_rgb)
    if pose_results.pose_landmarks:
        posture_result = "Posture detected"
        mp_drawing.draw_landmarks(
            output_frame, pose_results.pose_landmarks, mp_pose.POSE_CONNECTIONS,
            mp_drawing.DrawingSpec(color=(0, 255, 0), thickness=2, circle_radius=2),
            mp_drawing.DrawingSpec(color=(0, 0, 255), thickness=2)
        )
    
    annotated_image = cv2.cvtColor(output_frame, cv2.COLOR_BGR2RGB)
    return annotated_image, f"Posture Analysis: {posture_result}"

def analyze_emotion(image):
    """
    Uses FER to detect facial emotions from the captured image.
    Returns the original image and a text summary.
    """
    frame = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
    # FER expects an RGB image
    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    emotions = emotion_detector.detect_emotions(frame_rgb)
    if emotions:
        top_emotion, score = max(emotions[0]["emotions"].items(), key=lambda x: x[1])
        emotion_text = f"{top_emotion} ({score:.2f})"
    else:
        emotion_text = "No face detected for emotion analysis"
    
    # For simplicity, we return the original image
    annotated_image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    return annotated_image, f"Emotion Analysis: {emotion_text}"

def analyze_objects(image):
    """
    Uses a pretrained Faster R-CNN to detect objects in the image.
    Returns an annotated image with bounding boxes and a text summary.
    """
    frame = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
    output_frame = frame.copy()
    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    image_pil = Image.fromarray(frame_rgb)
    img_tensor = obj_transform(image_pil)
    
    with torch.no_grad():
        detections = object_detection_model([img_tensor])[0]
    
    threshold = 0.8
    detected_boxes = detections["boxes"][detections["scores"] > threshold]
    for box in detected_boxes:
        box = box.int().cpu().numpy()
        cv2.rectangle(output_frame, (box[0], box[1]), (box[2], box[3]), (255, 255, 0), 2)
    
    object_result = f"Detected {len(detected_boxes)} object(s)" if len(detected_boxes) else "No objects detected"
    annotated_image = cv2.cvtColor(output_frame, cv2.COLOR_BGR2RGB)
    return annotated_image, f"Object Detection: {object_result}"

def analyze_faces(image):
    """
    Uses MediaPipe face detection to identify faces in the image.
    Returns an annotated image with face bounding boxes and a text summary.
    """
    frame = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
    output_frame = frame.copy()
    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    face_results = face_detection.process(frame_rgb)
    
    face_result = "No faces detected"
    if face_results.detections:
        face_result = f"Detected {len(face_results.detections)} face(s)"
        h, w, _ = output_frame.shape
        for detection in face_results.detections:
            bbox = detection.location_data.relative_bounding_box
            x = int(bbox.xmin * w)
            y = int(bbox.ymin * h)
            box_w = int(bbox.width * w)
            box_h = int(bbox.height * h)
            cv2.rectangle(output_frame, (x, y), (x + box_w, y + box_h), (0, 0, 255), 2)
    
    annotated_image = cv2.cvtColor(output_frame, cv2.COLOR_BGR2RGB)
    return annotated_image, f"Face Detection: {face_result}"

# -----------------------------
# Custom CSS for a High-Tech Look
# -----------------------------
custom_css = """
@import url('https://fonts.googleapis.com/css2?family=Orbitron:wght@400;700&display=swap');

body {
    background-color: #0e0e0e;
    color: #e0e0e0;
    font-family: 'Orbitron', sans-serif;
}
.gradio-container {
    background: linear-gradient(135deg, #1e1e2f, #3e3e55);
    border-radius: 10px;
    padding: 20px;
}
.gradio-title {
    font-size: 2.5em;
    color: #66fcf1;
    text-align: center;
}
.gradio-description {
    font-size: 1.2em;
    text-align: center;
    margin-bottom: 20px;
}
"""

# -----------------------------
# Create Individual Interfaces for Each Analysis
# -----------------------------

posture_interface = gr.Interface(
    fn=analyze_posture,
    inputs=gr.Camera(label="Capture Your Posture"),
    outputs=[gr.Image(type="numpy", label="Annotated Output"), gr.Textbox(label="Posture Analysis")],
    title="Posture Analysis",
    description="Detects your posture using MediaPipe."
)

emotion_interface = gr.Interface(
    fn=analyze_emotion,
    inputs=gr.Camera(label="Capture Your Face"),
    outputs=[gr.Image(type="numpy", label="Annotated Output"), gr.Textbox(label="Emotion Analysis")],
    title="Emotion Analysis",
    description="Detects facial emotions using FER."
)

objects_interface = gr.Interface(
    fn=analyze_objects,
    inputs=gr.Camera(label="Capture the Scene"),
    outputs=[gr.Image(type="numpy", label="Annotated Output"), gr.Textbox(label="Object Detection")],
    title="Object Detection",
    description="Detects objects using a pretrained Faster R-CNN."
)

faces_interface = gr.Interface(
    fn=analyze_faces,
    inputs=gr.Camera(label="Capture Your Face"),
    outputs=[gr.Image(type="numpy", label="Annotated Output"), gr.Textbox(label="Face Detection")],
    title="Face Detection",
    description="Detects faces using MediaPipe."
)

# -----------------------------
# Create a Tabbed Interface for All Analyses
# -----------------------------

tabbed_interface = gr.TabbedInterface(
    interface_list=[posture_interface, emotion_interface, objects_interface, faces_interface],
    tab_names=["Posture", "Emotion", "Objects", "Faces"]
)

# -----------------------------
# Wrap Everything in a Blocks Layout with Custom CSS
# -----------------------------
demo = gr.Blocks(css=custom_css)
with demo:
    gr.Markdown("<h1 class='gradio-title'>Real-Time Multi-Analysis App</h1>")
    gr.Markdown("<p class='gradio-description'>Experience a high-tech, cinematic interface for real-time analysis of your posture, emotions, objects, and faces using your webcam.</p>")
    demo_tab = tabbed_interface

if __name__ == "__main__":
    demo.launch()