yehtutmaung commited on
Commit
b38122e
·
verified ·
1 Parent(s): 0237338

Upload 14 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ doc/demo.gif filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ env/
12
+ build/
13
+ develop-eggs/
14
+ dist/
15
+ downloads/
16
+ eggs/
17
+ .eggs/
18
+ lib/
19
+ lib64/
20
+ parts/
21
+ sdist/
22
+ var/
23
+ wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+
28
+ # PyInstaller
29
+ # Usually these files are written by a python script from a template
30
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
31
+ *.manifest
32
+ *.spec
33
+
34
+ # Installer logs
35
+ pip-log.txt
36
+ pip-delete-this-directory.txt
37
+
38
+ # Unit test / coverage reports
39
+ htmlcov/
40
+ .tox/
41
+ .coverage
42
+ .coverage.*
43
+ .cache
44
+ nosetests.xml
45
+ coverage.xml
46
+ *.cover
47
+ .hypothesis/
48
+
49
+ # Translations
50
+ *.mo
51
+ *.pot
52
+
53
+ # Django stuff:
54
+ *.log
55
+ local_settings.py
56
+
57
+ # Flask stuff:
58
+ instance/
59
+ .webassets-cache
60
+
61
+ # Scrapy stuff:
62
+ .scrapy
63
+
64
+ # Sphinx documentation
65
+ docs/_build/
66
+
67
+ # PyBuilder
68
+ target/
69
+
70
+ # Jupyter Notebook
71
+ .ipynb_checkpoints
72
+
73
+ # pyenv
74
+ .python-version
75
+
76
+ # celery beat schedule file
77
+ celerybeat-schedule
78
+
79
+ # SageMath parsed files
80
+ *.sage.py
81
+
82
+ # dotenv
83
+ .env
84
+
85
+ # virtualenv
86
+ .venv
87
+ venv/
88
+ ENV/
89
+
90
+ # Spyder project settings
91
+ .spyderproject
92
+ .spyproject
93
+
94
+ # Rope project settings
95
+ .ropeproject
96
+
97
+ # mkdocs documentation
98
+ /site
99
+
100
+ # mypy
101
+ .mypy_cache/
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import streamlit as st
3
+ from face_detection import FaceDetector
4
+ from mark_detection import MarkDetector
5
+ from pose_estimation import PoseEstimator
6
+ from utils import refine
7
+
8
+ def main():
9
+ # Streamlit Title and Sidebar for inputs
10
+ st.title("Distraction Detection App")
11
+ video_src = st.sidebar.selectbox("Select Video Source", ("Webcam", "Video File"))
12
+
13
+ # If a video file is chosen, provide file uploader
14
+ if video_src == "Video File":
15
+ video_file = st.sidebar.file_uploader("Upload a Video File", type=["mp4", "avi", "mov"])
16
+ if video_file is not None:
17
+ video_src = video_file
18
+ else:
19
+ st.warning("Please upload a video file.")
20
+ return
21
+ else:
22
+ video_src = 0 # Webcam index
23
+
24
+ # Setup the video capture and detector components
25
+ cap = cv2.VideoCapture(video_src if video_src == 0 else video_file)
26
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
27
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
28
+
29
+ face_detector = FaceDetector("assets/face_detector.onnx")
30
+ mark_detector = MarkDetector("assets/face_landmarks.onnx")
31
+ pose_estimator = PoseEstimator(frame_width, frame_height)
32
+
33
+ # Streamlit placeholders for images
34
+ frame_placeholder = st.empty()
35
+
36
+ while cap.isOpened():
37
+ # Capture a frame
38
+ frame_got, frame = cap.read()
39
+ if not frame_got:
40
+ break
41
+
42
+ # Flip the frame if from webcam
43
+ if video_src == 0:
44
+ frame = cv2.flip(frame, 2)
45
+
46
+ # Face detection and pose estimation
47
+ faces, _ = face_detector.detect(frame, 0.7)
48
+ if len(faces) > 0:
49
+ face = refine(faces, frame_width, frame_height, 0.15)[0]
50
+ x1, y1, x2, y2 = face[:4].astype(int)
51
+ patch = frame[y1:y2, x1:x2]
52
+ marks = mark_detector.detect([patch])[0].reshape([68, 2])
53
+ marks *= (x2 - x1)
54
+ marks[:, 0] += x1
55
+ marks[:, 1] += y1
56
+
57
+ distraction_status, pose_vectors = pose_estimator.detect_distraction(marks)
58
+ status_text = "Distracted" if distraction_status else "Focused"
59
+
60
+ # Overlay status text
61
+ cv2.putText(frame, f"Status: {status_text}", (10, 50),
62
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5,
63
+ (0, 255, 0) if not distraction_status else (0, 0, 255))
64
+
65
+ # Display the frame in Streamlit
66
+ frame_placeholder.image(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), channels="RGB")
67
+
68
+ cap.release()
69
+
70
+ if __name__ == "__main__":
71
+ main()
assets/.gitattributes ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ face_detector.onnx filter=lfs diff=lfs merge=lfs -text
2
+ face_landmarks.onnx filter=lfs diff=lfs merge=lfs -text
assets/face_detector.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:08bd3e3febd685ffb4fd7d9d16a101614cc7fc6ab08029d3cb6abe5fb12d3c64
3
+ size 3291589
assets/face_landmarks.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e848578c7ac2474b35e0c4b9a1498ff4145c525552b3d845bdb1f66c8a9d85c2
3
+ size 29402017
assets/model.txt ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -73.393523
2
+ -72.775014
3
+ -70.533638
4
+ -66.850058
5
+ -59.790187
6
+ -48.368973
7
+ -34.121101
8
+ -17.875411
9
+ 0.098749
10
+ 17.477031
11
+ 32.648966
12
+ 46.372358
13
+ 57.343480
14
+ 64.388482
15
+ 68.212038
16
+ 70.486405
17
+ 71.375822
18
+ -61.119406
19
+ -51.287588
20
+ -37.804800
21
+ -24.022754
22
+ -11.635713
23
+ 12.056636
24
+ 25.106256
25
+ 38.338588
26
+ 51.191007
27
+ 60.053851
28
+ 0.653940
29
+ 0.804809
30
+ 0.992204
31
+ 1.226783
32
+ -14.772472
33
+ -7.180239
34
+ 0.555920
35
+ 8.272499
36
+ 15.214351
37
+ -46.047290
38
+ -37.674688
39
+ -27.883856
40
+ -19.648268
41
+ -28.272965
42
+ -38.082418
43
+ 19.265868
44
+ 27.894191
45
+ 37.437529
46
+ 45.170805
47
+ 38.196454
48
+ 28.764989
49
+ -28.916267
50
+ -17.533194
51
+ -6.684590
52
+ 0.381001
53
+ 8.375443
54
+ 18.876618
55
+ 28.794412
56
+ 19.057574
57
+ 8.956375
58
+ 0.381549
59
+ -7.428895
60
+ -18.160634
61
+ -24.377490
62
+ -6.897633
63
+ 0.340663
64
+ 8.444722
65
+ 24.474473
66
+ 8.449166
67
+ 0.205322
68
+ -7.198266
69
+ -29.801432
70
+ -10.949766
71
+ 7.929818
72
+ 26.074280
73
+ 42.564390
74
+ 56.481080
75
+ 67.246992
76
+ 75.056892
77
+ 77.061286
78
+ 74.758448
79
+ 66.929021
80
+ 56.311389
81
+ 42.419126
82
+ 25.455880
83
+ 6.990805
84
+ -11.666193
85
+ -30.365191
86
+ -49.361602
87
+ -58.769795
88
+ -61.996155
89
+ -61.033399
90
+ -56.686759
91
+ -57.391033
92
+ -61.902186
93
+ -62.777713
94
+ -59.302347
95
+ -50.190255
96
+ -42.193790
97
+ -30.993721
98
+ -19.944596
99
+ -8.414541
100
+ 2.598255
101
+ 4.751589
102
+ 6.562900
103
+ 4.661005
104
+ 2.643046
105
+ -37.471411
106
+ -42.730510
107
+ -42.711517
108
+ -36.754742
109
+ -35.134493
110
+ -34.919043
111
+ -37.032306
112
+ -43.342445
113
+ -43.110822
114
+ -38.086515
115
+ -35.532024
116
+ -35.484289
117
+ 28.612716
118
+ 22.172187
119
+ 19.029051
120
+ 20.721118
121
+ 19.035460
122
+ 22.394109
123
+ 28.079924
124
+ 36.298248
125
+ 39.634575
126
+ 40.395647
127
+ 39.836405
128
+ 36.677899
129
+ 28.677771
130
+ 25.475976
131
+ 26.014269
132
+ 25.326198
133
+ 28.323008
134
+ 30.596216
135
+ 31.408738
136
+ 30.844876
137
+ 47.667532
138
+ 45.909403
139
+ 44.842580
140
+ 43.141114
141
+ 38.635298
142
+ 30.750622
143
+ 18.456453
144
+ 3.609035
145
+ -0.881698
146
+ 5.181201
147
+ 19.176563
148
+ 30.770570
149
+ 37.628629
150
+ 40.886309
151
+ 42.281449
152
+ 44.142567
153
+ 47.140426
154
+ 14.254422
155
+ 7.268147
156
+ 0.442051
157
+ -6.606501
158
+ -11.967398
159
+ -12.051204
160
+ -7.315098
161
+ -1.022953
162
+ 5.349435
163
+ 11.615746
164
+ -13.380835
165
+ -21.150853
166
+ -29.284036
167
+ -36.948060
168
+ -20.132003
169
+ -23.536684
170
+ -25.944448
171
+ -23.695741
172
+ -20.858157
173
+ 7.037989
174
+ 3.021217
175
+ 1.353629
176
+ -0.111088
177
+ -0.147273
178
+ 1.476612
179
+ -0.665746
180
+ 0.247660
181
+ 1.696435
182
+ 4.894163
183
+ 0.282961
184
+ -1.172675
185
+ -2.240310
186
+ -15.934335
187
+ -22.611355
188
+ -23.748437
189
+ -22.721995
190
+ -15.610679
191
+ -3.217393
192
+ -14.987997
193
+ -22.554245
194
+ -23.591626
195
+ -22.406106
196
+ -15.121907
197
+ -4.785684
198
+ -20.893742
199
+ -22.220479
200
+ -21.025520
201
+ -5.712776
202
+ -20.671489
203
+ -21.903670
204
+ -20.328022
doc/demo.gif ADDED

Git LFS Details

  • SHA256: acb666daeda4eb9e6ecce88a1e4561610ccf6a4c2e54d4158c02bba78377744b
  • Pointer size: 132 Bytes
  • Size of remote file: 1.02 MB
doc/demo1.gif ADDED
doc/wechat_logo.png ADDED
face_detection.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This module provides a face detection implementation backed by SCRFD.
2
+ https://github.com/deepinsight/insightface/tree/master/detection/scrfd
3
+ """
4
+ import os
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import onnxruntime
9
+
10
+
11
+ def distance2bbox(points, distance, max_shape=None):
12
+ """Decode distance prediction to bounding box.
13
+
14
+ Args:
15
+ points (Tensor): Shape (n, 2), [x, y].
16
+ distance (Tensor): Distance from the given point to 4
17
+ boundaries (left, top, right, bottom).
18
+ max_shape (tuple): Shape of the image.
19
+
20
+ Returns:
21
+ Tensor: Decoded bboxes.
22
+ """
23
+ x1 = points[:, 0] - distance[:, 0]
24
+ y1 = points[:, 1] - distance[:, 1]
25
+ x2 = points[:, 0] + distance[:, 2]
26
+ y2 = points[:, 1] + distance[:, 3]
27
+ if max_shape is not None:
28
+ x1 = x1.clamp(min=0, max=max_shape[1])
29
+ y1 = y1.clamp(min=0, max=max_shape[0])
30
+ x2 = x2.clamp(min=0, max=max_shape[1])
31
+ y2 = y2.clamp(min=0, max=max_shape[0])
32
+ return np.stack([x1, y1, x2, y2], axis=-1)
33
+
34
+
35
+ def distance2kps(points, distance, max_shape=None):
36
+ """Decode distance prediction to bounding box.
37
+
38
+ Args:
39
+ points (Tensor): Shape (n, 2), [x, y].
40
+ distance (Tensor): Distance from the given point to 4
41
+ boundaries (left, top, right, bottom).
42
+ max_shape (tuple): Shape of the image.
43
+
44
+ Returns:
45
+ Tensor: Decoded bboxes.
46
+ """
47
+ preds = []
48
+ for i in range(0, distance.shape[1], 2):
49
+ px = points[:, i % 2] + distance[:, i]
50
+ py = points[:, i % 2 + 1] + distance[:, i + 1]
51
+ if max_shape is not None:
52
+ px = px.clamp(min=0, max=max_shape[1])
53
+ py = py.clamp(min=0, max=max_shape[0])
54
+ preds.append(px)
55
+ preds.append(py)
56
+ return np.stack(preds, axis=-1)
57
+
58
+
59
+ class FaceDetector:
60
+
61
+ def __init__(self, model_file):
62
+ """Initialize a face detector.
63
+
64
+ Args:
65
+ model_file (str): ONNX model file path.
66
+ """
67
+ assert os.path.exists(model_file), f"File not found: {model_file}"
68
+
69
+ self.center_cache = {}
70
+ self.nms_threshold = 0.4
71
+ self.session = onnxruntime.InferenceSession(
72
+ model_file, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
73
+
74
+ # Get model configurations from the model file.
75
+ # What is the input like?
76
+ input_cfg = self.session.get_inputs()[0]
77
+ input_name = input_cfg.name
78
+ input_shape = input_cfg.shape
79
+ self.input_size = tuple(input_shape[2:4][::-1])
80
+
81
+ # How about the outputs?
82
+ outputs = self.session.get_outputs()
83
+ output_names = []
84
+ for o in outputs:
85
+ output_names.append(o.name)
86
+ self.input_name = input_name
87
+ self.output_names = output_names
88
+
89
+ # And any key points?
90
+ self._with_kps = False
91
+ self._anchor_ratio = 1.0
92
+ self._num_anchors = 1
93
+
94
+ if len(outputs) == 6:
95
+ self._offset = 3
96
+ self._strides = [8, 16, 32]
97
+ self._num_anchors = 2
98
+ elif len(outputs) == 9:
99
+ self._offset = 3
100
+ self._strides = [8, 16, 32]
101
+ self._num_anchors = 2
102
+ self._with_kps = True
103
+ elif len(outputs) == 10:
104
+ self._offset = 5
105
+ self._strides = [8, 16, 32, 64, 128]
106
+ self._num_anchors = 1
107
+ elif len(outputs) == 15:
108
+ self._offset = 5
109
+ self._strides = [8, 16, 32, 64, 128]
110
+ self._num_anchors = 1
111
+ self._with_kps = True
112
+
113
+ def _preprocess(self, image):
114
+ inputs = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
115
+ inputs = inputs - np.array([127.5, 127.5, 127.5])
116
+ inputs = inputs / 128
117
+ inputs = np.expand_dims(inputs, axis=0)
118
+ inputs = np.transpose(inputs, [0, 3, 1, 2])
119
+
120
+ return inputs.astype(np.float32)
121
+
122
+ def forward(self, img, threshold):
123
+ scores_list = []
124
+ bboxes_list = []
125
+ kpss_list = []
126
+
127
+ inputs = self._preprocess(img)
128
+ predictions = self.session.run(
129
+ self.output_names, {self.input_name: inputs})
130
+
131
+ input_height = inputs.shape[2]
132
+ input_width = inputs.shape[3]
133
+ offset = self._offset
134
+
135
+ for idx, stride in enumerate(self._strides):
136
+ scores_pred = predictions[idx]
137
+ bbox_preds = predictions[idx + offset] * stride
138
+ if self._with_kps:
139
+ kps_preds = predictions[idx + offset * 2] * stride
140
+
141
+ # Generate the anchors.
142
+ height = input_height // stride
143
+ width = input_width // stride
144
+ key = (height, width, stride)
145
+
146
+ if key in self.center_cache:
147
+ anchor_centers = self.center_cache[key]
148
+ else:
149
+ # solution-3:
150
+ anchor_centers = np.stack(
151
+ np.mgrid[:height, :width][::-1], axis=-1).astype(np.float32)
152
+ anchor_centers = (anchor_centers * stride).reshape((-1, 2))
153
+
154
+ if self._num_anchors > 1:
155
+ anchor_centers = np.stack(
156
+ [anchor_centers] * self._num_anchors, axis=1).reshape((-1, 2))
157
+
158
+ if len(self.center_cache) < 100:
159
+ self.center_cache[key] = anchor_centers
160
+
161
+ # solution-1, c style:
162
+ # anchor_centers = np.zeros( (height, width, 2), dtype=np.float32 )
163
+ # for i in range(height):
164
+ # anchor_centers[i, :, 1] = i
165
+ # for i in range(width):
166
+ # anchor_centers[:, i, 0] = i
167
+
168
+ # solution-2:
169
+ # ax = np.arange(width, dtype=np.float32)
170
+ # ay = np.arange(height, dtype=np.float32)
171
+ # xv, yv = np.meshgrid(np.arange(width), np.arange(height))
172
+ # anchor_centers = np.stack([xv, yv], axis=-1).astype(np.float32)
173
+
174
+ # Filter the results by scores and threshold.
175
+ pos_inds = np.where(scores_pred >= threshold)[0]
176
+ bboxes = distance2bbox(anchor_centers, bbox_preds)
177
+ pos_scores = scores_pred[pos_inds]
178
+ pos_bboxes = bboxes[pos_inds]
179
+ scores_list.append(pos_scores)
180
+ bboxes_list.append(pos_bboxes)
181
+
182
+ if self._with_kps:
183
+ kpss = distance2kps(anchor_centers, kps_preds)
184
+ kpss = kpss.reshape((kpss.shape[0], -1, 2))
185
+ pos_kpss = kpss[pos_inds]
186
+ kpss_list.append(pos_kpss)
187
+
188
+ return scores_list, bboxes_list, kpss_list
189
+
190
+ def _nms(self, detections):
191
+ """None max suppression."""
192
+ x1 = detections[:, 0]
193
+ y1 = detections[:, 1]
194
+ x2 = detections[:, 2]
195
+ y2 = detections[:, 3]
196
+ scores = detections[:, 4]
197
+
198
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
199
+ order = scores.argsort()[::-1]
200
+
201
+ keep = []
202
+ while order.size > 0:
203
+ i = order[0]
204
+ keep.append(i)
205
+
206
+ _x1 = np.maximum(x1[i], x1[order[1:]])
207
+ _y1 = np.maximum(y1[i], y1[order[1:]])
208
+ _x2 = np.minimum(x2[i], x2[order[1:]])
209
+ _y2 = np.minimum(y2[i], y2[order[1:]])
210
+
211
+ w = np.maximum(0.0, _x2 - _x1 + 1)
212
+ h = np.maximum(0.0, _y2 - _y1 + 1)
213
+ inter = w * h
214
+ ovr = inter / (areas[i] + areas[order[1:]] - inter)
215
+
216
+ inds = np.where(ovr <= self.nms_threshold)[0]
217
+ order = order[inds + 1]
218
+
219
+ return keep
220
+
221
+ def detect(self, img, threshold=0.5, input_size=None, max_num=0, metric='default'):
222
+ input_size = self.input_size if input_size is None else input_size
223
+
224
+ # Rescale the image?
225
+ img_height, img_width, _ = img.shape
226
+ ratio_img = float(img_height) / img_width
227
+
228
+ input_width, input_height = input_size
229
+ ratio_model = float(input_height) / input_width
230
+
231
+ if ratio_img > ratio_model:
232
+ new_height = input_height
233
+ new_width = int(new_height / ratio_img)
234
+ else:
235
+ new_width = input_width
236
+ new_height = int(new_width * ratio_img)
237
+
238
+ det_scale = float(new_height) / img_height
239
+ resized_img = cv2.resize(img, (new_width, new_height))
240
+
241
+ det_img = np.zeros((input_size[1], input_size[0], 3), dtype=np.uint8)
242
+ det_img[:new_height, :new_width, :] = resized_img
243
+
244
+ scores_list, bboxes_list, kpss_list = self.forward(det_img, threshold)
245
+ scores = np.vstack(scores_list)
246
+ scores_ravel = scores.ravel()
247
+ order = scores_ravel.argsort()[::-1]
248
+
249
+ bboxes = np.vstack(bboxes_list) / det_scale
250
+
251
+ if self._with_kps:
252
+ kpss = np.vstack(kpss_list) / det_scale
253
+ pre_det = np.hstack((bboxes, scores)).astype(np.float32, copy=False)
254
+ pre_det = pre_det[order, :]
255
+
256
+ keep = self._nms(pre_det)
257
+
258
+ det = pre_det[keep, :]
259
+
260
+ if self._with_kps:
261
+ kpss = kpss[order, :, :]
262
+ kpss = kpss[keep, :, :]
263
+ else:
264
+ kpss = None
265
+
266
+ if max_num > 0 and det.shape[0] > max_num:
267
+ area = (det[:, 2] - det[:, 0]) * (det[:, 3] - det[:, 1])
268
+ img_center = img.shape[0] // 2, img.shape[1] // 2
269
+ offsets = np.vstack([
270
+ (det[:, 0] + det[:, 2]) / 2 - img_center[1],
271
+ (det[:, 1] + det[:, 3]) / 2 - img_center[0]])
272
+
273
+ offset_dist_squared = np.sum(np.power(offsets, 2.0), 0)
274
+
275
+ if metric == 'max':
276
+ values = area
277
+ else:
278
+ # some extra weight on the centering
279
+ values = area - offset_dist_squared * 2.0
280
+
281
+ # some extra weight on the centering
282
+ bindex = np.argsort(values)[::-1]
283
+ bindex = bindex[0:max_num]
284
+ det = det[bindex, :]
285
+
286
+ if kpss is not None:
287
+ kpss = kpss[bindex, :]
288
+
289
+ return det, kpss
290
+
291
+ def visualize(self, image, results, box_color=(0, 255, 0), text_color=(0, 0, 0)):
292
+ """Visualize the detection results.
293
+
294
+ Args:
295
+ image (np.ndarray): image to draw marks on.
296
+ results (np.ndarray): face detection results.
297
+ box_color (tuple, optional): color of the face box. Defaults to (0, 255, 0).
298
+ text_color (tuple, optional): color of the face marks (5 points). Defaults to (0, 0, 255).
299
+ """
300
+ for det in results:
301
+ bbox = det[0:4].astype(np.int32)
302
+ conf = det[-1]
303
+ cv2.rectangle(image, (bbox[0], bbox[1]),
304
+ (bbox[2], bbox[3]), box_color)
305
+ label = f"face: {conf:.2f}"
306
+ label_size, base_line = cv2.getTextSize(
307
+ label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
308
+ cv2.rectangle(image, (bbox[0], bbox[1] - label_size[1]),
309
+ (bbox[2], bbox[1] + base_line), box_color, cv2.FILLED)
310
+ cv2.putText(image, label, (bbox[0], bbox[1]),
311
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, text_color)
main2.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import ArgumentParser
2
+ import cv2
3
+ from face_detection import FaceDetector
4
+ from mark_detection import MarkDetector
5
+ from pose_estimation import PoseEstimator
6
+ from utils import refine
7
+
8
+ # Parse arguments from user input.
9
+ parser = ArgumentParser()
10
+ parser.add_argument("--video", type=str, default=None,
11
+ help="Video file to be processed.")
12
+ parser.add_argument("--cam", type=int, default=0,
13
+ help="The webcam index.")
14
+ args = parser.parse_args()
15
+
16
+ print(__doc__)
17
+ print("OpenCV version: {}".format(cv2.__version__))
18
+
19
+ def run():
20
+ # Initialize the video source from webcam or video file.
21
+ video_src = args.cam if args.video is None else args.video
22
+ cap = cv2.VideoCapture(video_src)
23
+ print(f"Video source: {video_src}")
24
+
25
+ # Get the frame size. This will be used by the following detectors.
26
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
27
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
28
+
29
+ # Setup a face detector to detect human faces.
30
+ face_detector = FaceDetector("assets/face_detector.onnx")
31
+
32
+ # Setup a mark detector to detect landmarks.
33
+ mark_detector = MarkDetector("assets/face_landmarks.onnx")
34
+
35
+ # Setup a pose estimator to solve pose.
36
+ pose_estimator = PoseEstimator(frame_width, frame_height)
37
+
38
+ # Measure the performance with a tick meter.
39
+ tm = cv2.TickMeter()
40
+
41
+ while True:
42
+ # Read a frame.
43
+ frame_got, frame = cap.read()
44
+ if frame_got is False:
45
+ break
46
+
47
+ # If the frame comes from webcam, flip it so it looks like a mirror.
48
+ if video_src == 0:
49
+ frame = cv2.flip(frame, 2)
50
+
51
+ # Step 1: Get faces from current frame.
52
+ faces, _ = face_detector.detect(frame, 0.7)
53
+
54
+ if len(faces) > 0:
55
+ tm.start()
56
+
57
+ # Step 2: Detect landmarks.
58
+ face = refine(faces, frame_width, frame_height, 0.15)[0]
59
+ x1, y1, x2, y2 = face[:4].astype(int)
60
+ patch = frame[y1:y2, x1:x2]
61
+
62
+ # Run the mark detection.
63
+ marks = mark_detector.detect([patch])[0].reshape([68, 2])
64
+
65
+ # Convert to global image.
66
+ marks *= (x2 - x1)
67
+ marks[:, 0] += x1
68
+ marks[:, 1] += y1
69
+
70
+ # Step 3: Try pose estimation.
71
+ distraction_status, pose_vectors = pose_estimator.detect_distraction(marks)
72
+ rotation_vector, translation_vector = pose_vectors
73
+
74
+ # Check distraction
75
+ if distraction_status:
76
+ status_text = "Distracted"
77
+ else:
78
+ status_text = "Focused"
79
+
80
+ cv2.putText(frame, f"Status: {status_text}", (10, 50),
81
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0) if not distraction_status else (0, 0, 255))
82
+
83
+ tm.stop()
84
+
85
+ # Visualize the pose
86
+ pose_estimator.visualize(frame, pose_vectors, color=(0, 255, 0))
87
+
88
+ # Draw axes
89
+ pose_estimator.draw_axes(frame, pose_vectors)
90
+
91
+ # Draw the FPS on the screen
92
+ cv2.rectangle(frame, (0, 0), (90, 30), (0, 0, 0), cv2.FILLED)
93
+ cv2.putText(frame, f"FPS: {tm.getFPS():.0f}", (10, 20),
94
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255))
95
+
96
+ # Show preview
97
+ cv2.imshow("Preview", frame)
98
+ if cv2.waitKey(1) == 27:
99
+ break
100
+
101
+ if __name__ == '__main__':
102
+ run()
mark_detection.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Human facial landmark detector based on Convolutional Neural Network."""
2
+ import os
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import onnxruntime as ort
7
+
8
+
9
+ class MarkDetector:
10
+ """Facial landmark detector by Convolutional Neural Network"""
11
+
12
+ def __init__(self, model_file):
13
+ """Initialize a mark detector.
14
+
15
+ Args:
16
+ model_file (str): ONNX model path.
17
+ """
18
+ assert os.path.exists(model_file), f"File not found: {model_file}"
19
+ self._input_size = 128
20
+ self.model = ort.InferenceSession(
21
+ model_file, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
22
+
23
+ def _preprocess(self, bgrs):
24
+ """Preprocess the inputs to meet the model's needs.
25
+
26
+ Args:
27
+ bgrs (np.ndarray): a list of input images in BGR format.
28
+
29
+ Returns:
30
+ tf.Tensor: a tensor
31
+ """
32
+ rgbs = []
33
+ for img in bgrs:
34
+ img = cv2.resize(img, (self._input_size, self._input_size))
35
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
36
+ rgbs.append(img)
37
+
38
+ return rgbs
39
+
40
+ def detect(self, images):
41
+ """Detect facial marks from an face image.
42
+
43
+ Args:
44
+ images: a list of face images.
45
+
46
+ Returns:
47
+ marks: the facial marks as a numpy array of shape [Batch, 68*2].
48
+ """
49
+ inputs = self._preprocess(images)
50
+ marks = self.model.run(["dense_1"], {"image_input": inputs})
51
+ return np.array(marks)
52
+
53
+ def visualize(self, image, marks, color=(255, 255, 255)):
54
+ """Draw mark points on image"""
55
+ for mark in marks:
56
+ cv2.circle(image, (int(mark[0]), int(
57
+ mark[1])), 1, color, -1, cv2.LINE_AA)
pose_estimation.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Estimate head pose according to the facial landmarks"""
2
+ import cv2
3
+ import numpy as np
4
+
5
+
6
+ class PoseEstimator:
7
+ """Estimate head pose according to the facial landmarks"""
8
+
9
+ def __init__(self, image_width, image_height):
10
+ """Init a pose estimator.
11
+
12
+ Args:
13
+ image_width (int): input image width
14
+ image_height (int): input image height
15
+ """
16
+ self.size = (image_height, image_width)
17
+ self.model_points_68 = self._get_full_model_points()
18
+
19
+ # Camera internals
20
+ self.focal_length = self.size[1]
21
+ self.camera_center = (self.size[1] / 2, self.size[0] / 2)
22
+ self.camera_matrix = np.array(
23
+ [[self.focal_length, 0, self.camera_center[0]],
24
+ [0, self.focal_length, self.camera_center[1]],
25
+ [0, 0, 1]], dtype="double")
26
+
27
+ # Assuming no lens distortion
28
+ self.dist_coeefs = np.zeros((4, 1))
29
+
30
+ # Rotation vector and translation vector
31
+ self.r_vec = np.array([[0.01891013], [0.08560084], [-3.14392813]])
32
+ self.t_vec = np.array(
33
+ [[-14.97821226], [-10.62040383], [-2053.03596872]])
34
+
35
+ def _get_full_model_points(self, filename='assets/model.txt'):
36
+ """Get all 68 3D model points from file"""
37
+ raw_value = []
38
+ with open(filename) as file:
39
+ for line in file:
40
+ raw_value.append(line)
41
+ model_points = np.array(raw_value, dtype=np.float32)
42
+ model_points = np.reshape(model_points, (3, -1)).T
43
+
44
+ # Transform the model into a front view.
45
+ model_points[:, 2] *= -1
46
+
47
+ return model_points
48
+
49
+ def solve(self, points):
50
+ """Solve pose with all the 68 image points
51
+ Args:
52
+ points (np.ndarray): points on image.
53
+
54
+ Returns:
55
+ Tuple: (rotation_vector, translation_vector) as pose.
56
+ """
57
+
58
+ if self.r_vec is None:
59
+ (_, rotation_vector, translation_vector) = cv2.solvePnP(
60
+ self.model_points_68, points, self.camera_matrix, self.dist_coeefs)
61
+ self.r_vec = rotation_vector
62
+ self.t_vec = translation_vector
63
+
64
+ (_, rotation_vector, translation_vector) = cv2.solvePnP(
65
+ self.model_points_68,
66
+ points,
67
+ self.camera_matrix,
68
+ self.dist_coeefs,
69
+ rvec=self.r_vec,
70
+ tvec=self.t_vec,
71
+ useExtrinsicGuess=True)
72
+
73
+ return (rotation_vector, translation_vector)
74
+
75
+ def visualize(self, image, pose, color=(255, 255, 255), line_width=2):
76
+ """Draw a 3D box as annotation of pose"""
77
+ rotation_vector, translation_vector = pose
78
+ point_3d = []
79
+ rear_size = 75
80
+ rear_depth = 0
81
+ point_3d.append((-rear_size, -rear_size, rear_depth))
82
+ point_3d.append((-rear_size, rear_size, rear_depth))
83
+ point_3d.append((rear_size, rear_size, rear_depth))
84
+ point_3d.append((rear_size, -rear_size, rear_depth))
85
+ point_3d.append((-rear_size, -rear_size, rear_depth))
86
+
87
+ front_size = 100
88
+ front_depth = 100
89
+ point_3d.append((-front_size, -front_size, front_depth))
90
+ point_3d.append((-front_size, front_size, front_depth))
91
+ point_3d.append((front_size, front_size, front_depth))
92
+ point_3d.append((front_size, -front_size, front_depth))
93
+ point_3d.append((-front_size, -front_size, front_depth))
94
+ point_3d = np.array(point_3d, dtype=np.float32).reshape(-1, 3)
95
+
96
+ # Map to 2d image points
97
+ (point_2d, _) = cv2.projectPoints(point_3d,
98
+ rotation_vector,
99
+ translation_vector,
100
+ self.camera_matrix,
101
+ self.dist_coeefs)
102
+ point_2d = np.int32(point_2d.reshape(-1, 2))
103
+
104
+ # Draw all the lines
105
+ cv2.polylines(image, [point_2d], True, color, line_width, cv2.LINE_AA)
106
+ cv2.line(image, tuple(point_2d[1]), tuple(
107
+ point_2d[6]), color, line_width, cv2.LINE_AA)
108
+ cv2.line(image, tuple(point_2d[2]), tuple(
109
+ point_2d[7]), color, line_width, cv2.LINE_AA)
110
+ cv2.line(image, tuple(point_2d[3]), tuple(
111
+ point_2d[8]), color, line_width, cv2.LINE_AA)
112
+
113
+ def draw_axes(self, img, pose):
114
+ R, t = pose
115
+ img = cv2.drawFrameAxes(img, self.camera_matrix,
116
+ self.dist_coeefs, R, t, 30)
117
+
118
+ def show_3d_model(self):
119
+ from matplotlib import pyplot
120
+ from mpl_toolkits.mplot3d import Axes3D
121
+ fig = pyplot.figure()
122
+ ax = Axes3D(fig)
123
+
124
+ x = self.model_points_68[:, 0]
125
+ y = self.model_points_68[:, 1]
126
+ z = self.model_points_68[:, 2]
127
+
128
+ ax.scatter(x, y, z)
129
+ ax.axis('square')
130
+ pyplot.xlabel('x')
131
+ pyplot.ylabel('y')
132
+ pyplot.show()
133
+
134
+ ###
135
+ # yhm : from chat gpt to detect distraction
136
+ ###
137
+ def rotation_matrix_to_angles(self, rotation_vector):
138
+ """Convert rotation vector to pitch, yaw, and roll angles."""
139
+ rotation_matrix, _ = cv2.Rodrigues(rotation_vector)
140
+ sy = np.sqrt(rotation_matrix[0, 0]**2 + rotation_matrix[1, 0]**2)
141
+
142
+ singular = sy < 1e-6
143
+ if not singular:
144
+ pitch = np.arctan2(rotation_matrix[2, 1], rotation_matrix[2, 2])
145
+ yaw = np.arctan2(-rotation_matrix[2, 0], sy)
146
+ roll = np.arctan2(rotation_matrix[1, 0], rotation_matrix[0, 0])
147
+ else:
148
+ pitch = np.arctan2(-rotation_matrix[1, 2], rotation_matrix[1, 1])
149
+ yaw = np.arctan2(-rotation_matrix[2, 0], sy)
150
+ roll = 0
151
+
152
+ return np.degrees(pitch), np.degrees(yaw), np.degrees(roll)
153
+
154
+ def is_distracted(self, rotation_vector):
155
+ """Determine if the user is distracted based on head pose angles."""
156
+ pitch, yaw, roll = self.rotation_matrix_to_angles(rotation_vector)
157
+
158
+ # Define thresholds (adjust based on further testing)
159
+ pitch_threshold = (-15, 10) # Allow some variability in pitch
160
+ yaw_threshold = (-20, 16) # Reasonable range for yaw
161
+ roll_threshold = (-180, 180) # Centered around -180 degree roll
162
+ # print("pitch, yaw, roll", pitch, yaw, roll)
163
+ # Check if head is roughly considered 'facing forward'
164
+ focus_pitch = pitch_threshold[0] < pitch < pitch_threshold[1]
165
+ focus_yaw = yaw_threshold[0] < yaw < yaw_threshold[1]
166
+ focus_roll = roll_threshold[0] < roll < roll_threshold[1]
167
+
168
+ return not (focus_pitch and focus_yaw and focus_roll)
169
+
170
+ # """Determine if the user is distracted based on head pose angles."""
171
+ # pitch, yaw, roll = self.rotation_matrix_to_angles(rotation_vector)
172
+ # print("pitch, yaw, roll", pitch, yaw, roll)
173
+ # # Define thresholds (you may need to adjust these based on testing)
174
+ # pitch_threshold = 15 # Up/Down threshold
175
+ # yaw_threshold = 20 # Left/Right threshold
176
+ # roll_threshold = 10 # Tilt threshold
177
+
178
+ # # Check if head is facing roughly forward
179
+ # if abs(pitch) < pitch_threshold and abs(yaw) < yaw_threshold and abs(roll) < roll_threshold:
180
+ # return False # Focused
181
+ # else:
182
+ # return True # Distracted
183
+
184
+ def detect_distraction(self, points):
185
+ """Solve pose and detect distraction status based on pose."""
186
+ rotation_vector, translation_vector = self.solve(points)
187
+ distraction_status = self.is_distracted(rotation_vector)
188
+ return distraction_status, (rotation_vector, translation_vector)
189
+
190
+
191
+ # second part
192
+
193
+ # def rotation_matrix_to_angles(self, rotation_vector):
194
+ # """Convert rotation vector to pitch, yaw, and roll angles."""
195
+ # # Convert the rotation vector into a rotation matrix
196
+ # rotation_matrix, _ = cv2.Rodrigues(rotation_vector)
197
+
198
+ # # Ensure no division by zero
199
+ # sy = np.sqrt(rotation_matrix[0, 0]**2 + rotation_matrix[1, 0]**2)
200
+ # singular = sy < 1e-6
201
+
202
+ # if not singular:
203
+ # pitch = np.arctan2(rotation_matrix[2, 1], rotation_matrix[2, 2])
204
+ # yaw = np.arctan2(-rotation_matrix[2, 0], sy)
205
+ # roll = np.arctan2(rotation_matrix[1, 0], rotation_matrix[0, 0])
206
+ # else:
207
+ # pitch = np.arctan2(-rotation_matrix[1, 2], rotation_matrix[1, 1])
208
+ # yaw = np.arctan2(-rotation_matrix[2, 0], sy)
209
+ # roll = 0
210
+
211
+ # # Return converted angles in degrees
212
+ # return np.degrees(pitch), np.degrees(yaw), np.degrees(roll)
213
+
214
+ # def is_distracted(self, rotation_vector):
215
+ # """Determine if the user is distracted based on head pose angles."""
216
+ # pitch, yaw, roll = self.rotation_matrix_to_angles(rotation_vector)
217
+
218
+ # # Test different thresholds based on specific requirements
219
+ # pitch_threshold = 15 # Up/Down
220
+ # yaw_threshold = 20 # Left/Right
221
+ # roll_threshold = 10 # Tilt
222
+
223
+ # # Determine distraction status
224
+ # return not (abs(pitch) < pitch_threshold and abs(yaw) < yaw_threshold and abs(roll) < roll_threshold)
225
+
226
+ # def detect_distraction(self, points):
227
+ # """Solve pose and detect distraction status based on pose."""
228
+ # rotation_vector, translation_vector = self.solve(points)
229
+ # distraction_status = self.is_distracted(rotation_vector)
230
+ # return distraction_status, (rotation_vector, translation_vector)
utils.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A module provides a bunch of helper functions."""
2
+ import numpy as np
3
+
4
+
5
+ def refine(boxes, max_width, max_height, shift=0.1):
6
+ """Refine the face boxes to suit the face landmark detection's needs.
7
+
8
+ Args:
9
+ boxes: [[x1, y1, x2, y2], ...]
10
+ max_width: Value larger than this will be clipped.
11
+ max_height: Value larger than this will be clipped.
12
+ shift (float, optional): How much to shift the face box down. Defaults to 0.1.
13
+
14
+ Returns:
15
+ Refined results.
16
+ """
17
+ refined = boxes.copy()
18
+ width = refined[:, 2] - refined[:, 0]
19
+ height = refined[:, 3] - refined[:, 1]
20
+
21
+ # Move the boxes in Y direction
22
+ shift = height * shift
23
+ refined[:, 1] += shift
24
+ refined[:, 3] += shift
25
+ center_x = (refined[:, 0] + refined[:, 2]) / 2
26
+ center_y = (refined[:, 1] + refined[:, 3]) / 2
27
+
28
+ # Make the boxes squares
29
+ square_sizes = np.maximum(width, height)
30
+ refined[:, 0] = center_x - square_sizes / 2
31
+ refined[:, 1] = center_y - square_sizes / 2
32
+ refined[:, 2] = center_x + square_sizes / 2
33
+ refined[:, 3] = center_y + square_sizes / 2
34
+
35
+ # Clip the boxes for safety
36
+ refined[:, 0] = np.clip(refined[:, 0], 0, max_width)
37
+ refined[:, 1] = np.clip(refined[:, 1], 0, max_height)
38
+ refined[:, 2] = np.clip(refined[:, 2], 0, max_width)
39
+ refined[:, 3] = np.clip(refined[:, 3], 0, max_height)
40
+
41
+ return refined