yehtutmaung commited on
Commit
80d4d59
·
verified ·
1 Parent(s): c0e276c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -71
app.py CHANGED
@@ -1,71 +1,68 @@
1
- import cv2
2
- import streamlit as st
3
- from face_detection import FaceDetector
4
- from mark_detection import MarkDetector
5
- from pose_estimation import PoseEstimator
6
- from utils import refine
7
-
8
- def main():
9
- # Streamlit Title and Sidebar for inputs
10
- st.title("Distraction Detection App")
11
- video_src = st.sidebar.selectbox("Select Video Source", ("Webcam", "Video File"))
12
-
13
- # If a video file is chosen, provide file uploader
14
- if video_src == "Video File":
15
- video_file = st.sidebar.file_uploader("Upload a Video File", type=["mp4", "avi", "mov"])
16
- if video_file is not None:
17
- video_src = video_file
18
- else:
19
- st.warning("Please upload a video file.")
20
- return
21
- else:
22
- video_src = 0 # Webcam index
23
-
24
- # Setup the video capture and detector components
25
- cap = cv2.VideoCapture(video_src if video_src == 0 else video_file)
26
- frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
27
- frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
28
-
29
- face_detector = FaceDetector("assets/face_detector.onnx")
30
- mark_detector = MarkDetector("assets/face_landmarks.onnx")
31
- pose_estimator = PoseEstimator(frame_width, frame_height)
32
-
33
- # Streamlit placeholders for images
34
- frame_placeholder = st.empty()
35
-
36
- while cap.isOpened():
37
- # Capture a frame
38
- frame_got, frame = cap.read()
39
- if not frame_got:
40
- break
41
-
42
- # Flip the frame if from webcam
43
- if video_src == 0:
44
- frame = cv2.flip(frame, 2)
45
-
46
- # Face detection and pose estimation
47
- faces, _ = face_detector.detect(frame, 0.7)
48
- if len(faces) > 0:
49
- face = refine(faces, frame_width, frame_height, 0.15)[0]
50
- x1, y1, x2, y2 = face[:4].astype(int)
51
- patch = frame[y1:y2, x1:x2]
52
- marks = mark_detector.detect([patch])[0].reshape([68, 2])
53
- marks *= (x2 - x1)
54
- marks[:, 0] += x1
55
- marks[:, 1] += y1
56
-
57
- distraction_status, pose_vectors = pose_estimator.detect_distraction(marks)
58
- status_text = "Distracted" if distraction_status else "Focused"
59
-
60
- # Overlay status text
61
- cv2.putText(frame, f"Status: {status_text}", (10, 50),
62
- cv2.FONT_HERSHEY_SIMPLEX, 0.5,
63
- (0, 255, 0) if not distraction_status else (0, 0, 255))
64
-
65
- # Display the frame in Streamlit
66
- frame_placeholder.image(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), channels="RGB")
67
-
68
- cap.release()
69
-
70
- if __name__ == "__main__":
71
- main()
 
1
+ import cv2
2
+ import streamlit as st
3
+ from face_detection import FaceDetector
4
+ from mark_detection import MarkDetector
5
+ from pose_estimation import PoseEstimator
6
+ from utils import refine
7
+ import tempfile
8
+
9
+ def main():
10
+ # Streamlit Title and Sidebar for inputs
11
+ st.title("Distraction Detection App on Hugging Face Spaces")
12
+ st.sidebar.write("Please upload a video file for analysis.")
13
+
14
+ # File uploader for video
15
+ video_file = st.sidebar.file_uploader("Upload a Video File", type=["mp4", "avi", "mov"])
16
+
17
+ if video_file is not None:
18
+ # Create a temporary file to store the uploaded video
19
+ with tempfile.NamedTemporaryFile(delete=False) as temp_file:
20
+ temp_file.write(video_file.read())
21
+ video_path = temp_file.name
22
+
23
+ # Set up video capture and detectors
24
+ cap = cv2.VideoCapture(video_path)
25
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
26
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
27
+
28
+ face_detector = FaceDetector("assets/face_detector.onnx")
29
+ mark_detector = MarkDetector("assets/face_landmarks.onnx")
30
+ pose_estimator = PoseEstimator(frame_width, frame_height)
31
+
32
+ # Streamlit placeholder for processed frames
33
+ frame_placeholder = st.empty()
34
+
35
+ while cap.isOpened():
36
+ # Capture each frame
37
+ frame_got, frame = cap.read()
38
+ if not frame_got:
39
+ break
40
+
41
+ # Face detection and pose estimation
42
+ faces, _ = face_detector.detect(frame, 0.7)
43
+ if len(faces) > 0:
44
+ face = refine(faces, frame_width, frame_height, 0.15)[0]
45
+ x1, y1, x2, y2 = face[:4].astype(int)
46
+ patch = frame[y1:y2, x1:x2]
47
+ marks = mark_detector.detect([patch])[0].reshape([68, 2])
48
+ marks *= (x2 - x1)
49
+ marks[:, 0] += x1
50
+ marks[:, 1] += y1
51
+
52
+ distraction_status, pose_vectors = pose_estimator.detect_distraction(marks)
53
+ status_text = "Distracted" if distraction_status else "Focused"
54
+
55
+ # Overlay status text
56
+ cv2.putText(frame, f"Status: {status_text}", (10, 50),
57
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5,
58
+ (0, 255, 0) if not distraction_status else (0, 0, 255))
59
+
60
+ # Display the processed frame
61
+ frame_placeholder.image(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), channels="RGB")
62
+
63
+ cap.release()
64
+ else:
65
+ st.warning("Please upload a video file to proceed.")
66
+
67
+ if __name__ == "__main__":
68
+ main()