File size: 2,691 Bytes
80d4d59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import cv2
import streamlit as st
from face_detection import FaceDetector
from mark_detection import MarkDetector
from pose_estimation import PoseEstimator
from utils import refine
import tempfile

def main():
    # Streamlit Title and Sidebar for inputs
    st.title("Distraction Detection App on Hugging Face Spaces")
    st.sidebar.write("Please upload a video file for analysis.")
    
    # File uploader for video
    video_file = st.sidebar.file_uploader("Upload a Video File", type=["mp4", "avi", "mov"])
    
    if video_file is not None:
        # Create a temporary file to store the uploaded video
        with tempfile.NamedTemporaryFile(delete=False) as temp_file:
            temp_file.write(video_file.read())
            video_path = temp_file.name
        
        # Set up video capture and detectors
        cap = cv2.VideoCapture(video_path)
        frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        
        face_detector = FaceDetector("assets/face_detector.onnx")
        mark_detector = MarkDetector("assets/face_landmarks.onnx")
        pose_estimator = PoseEstimator(frame_width, frame_height)

        # Streamlit placeholder for processed frames
        frame_placeholder = st.empty()

        while cap.isOpened():
            # Capture each frame
            frame_got, frame = cap.read()
            if not frame_got:
                break

            # Face detection and pose estimation
            faces, _ = face_detector.detect(frame, 0.7)
            if len(faces) > 0:
                face = refine(faces, frame_width, frame_height, 0.15)[0]
                x1, y1, x2, y2 = face[:4].astype(int)
                patch = frame[y1:y2, x1:x2]
                marks = mark_detector.detect([patch])[0].reshape([68, 2])
                marks *= (x2 - x1)
                marks[:, 0] += x1
                marks[:, 1] += y1
                
                distraction_status, pose_vectors = pose_estimator.detect_distraction(marks)
                status_text = "Distracted" if distraction_status else "Focused"
                
                # Overlay status text
                cv2.putText(frame, f"Status: {status_text}", (10, 50),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.5, 
                            (0, 255, 0) if not distraction_status else (0, 0, 255))
                
                # Display the processed frame
                frame_placeholder.image(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), channels="RGB")
        
        cap.release()
    else:
        st.warning("Please upload a video file to proceed.")

if __name__ == "__main__":
    main()