|
""" |
|
Real-time People Detection Streamlit application. |
|
|
|
This is the main entry point for the Hugging Face Space application. |
|
""" |
|
|
|
import os |
|
import time |
|
from pathlib import Path |
|
from typing import Tuple, Dict, Any, Optional, List |
|
|
|
import cv2 |
|
import numpy as np |
|
import streamlit as st |
|
from PIL import Image |
|
import torch |
|
from ultralytics import YOLO |
|
|
|
|
|
|
|
ASSETS_DIR = Path(__file__).parent / "assets" |
|
DEMO_VIDEOS = { |
|
"One Person": ASSETS_DIR / "one-by-one-person-detection.mp4", |
|
"Store Aisle": ASSETS_DIR / "store-aisle-detection.mp4", |
|
"People Detection": ASSETS_DIR / "people-detection.mp4" |
|
} |
|
FRAME_WIDTH = 640 |
|
FRAME_HEIGHT = 480 |
|
|
|
|
|
class PeopleDetector: |
|
""" |
|
A class for detecting people in images using a pre-trained YOLOv8n model. |
|
|
|
Attributes: |
|
model_name: Name or path of the YOLOv8 model to use |
|
threshold: Confidence threshold for detection |
|
device: Device to run inference on (cuda/cpu) |
|
model: The detection model |
|
""" |
|
|
|
def __init__( |
|
self, |
|
model_name: str = "yolov8n.pt", |
|
threshold: float = 0.5, |
|
device: Optional[str] = None, |
|
): |
|
""" |
|
Initialize the people detector with a pre-trained model. |
|
|
|
Args: |
|
model_name: YOLOv8 model name to use ('yolov8n.pt' is the smallest one) |
|
threshold: Confidence threshold for detection (0.0 to 1.0) |
|
device: Device to run inference on (cuda/cpu). If None, will use cuda if available. |
|
""" |
|
self.model_name = model_name |
|
self.threshold = threshold |
|
|
|
|
|
if device is None: |
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
else: |
|
self.device = device |
|
|
|
|
|
self.model = YOLO(model_name) |
|
|
|
|
|
self.person_class_id = 0 |
|
|
|
def detect(self, image: np.ndarray) -> Tuple[List[Dict[str, Any]], float]: |
|
""" |
|
Detect people in an image. |
|
|
|
Args: |
|
image: Input image as numpy array (BGR format from OpenCV) |
|
|
|
Returns: |
|
Tuple containing: |
|
- List of detection results with keys 'box', 'score', and 'label' |
|
- Inference time in seconds |
|
""" |
|
|
|
start_time = time.time() |
|
|
|
|
|
results = self.model(image, conf=self.threshold, device=self.device) |
|
|
|
|
|
detections = [] |
|
|
|
|
|
for result in results: |
|
boxes = result.boxes |
|
|
|
|
|
for i, box in enumerate(boxes): |
|
cls = int(box.cls.item()) |
|
conf = float(box.conf.item()) |
|
|
|
|
|
if cls == self.person_class_id: |
|
|
|
x1, y1, x2, y2 = map(int, box.xyxy.tolist()[0]) |
|
|
|
detections.append({ |
|
'box': (x1, y1, x2, y2), |
|
'score': conf, |
|
'label': 'person' |
|
}) |
|
|
|
|
|
inference_time = time.time() - start_time |
|
|
|
return detections, inference_time |
|
|
|
def update_threshold(self, threshold: float) -> None: |
|
""" |
|
Update the detection confidence threshold. |
|
|
|
Args: |
|
threshold: New threshold value (0.0 to 1.0) |
|
""" |
|
self.threshold = threshold |
|
|
|
|
|
class VideoSource: |
|
""" |
|
A class for handling video input from different sources (webcam or file). |
|
|
|
Attributes: |
|
source: Camera index (int) or video file path (str) |
|
width: Frame width to set (if possible) |
|
height: Frame height to set (if possible) |
|
fps_buffer_size: Number of frames to average for FPS calculation |
|
""" |
|
|
|
def __init__( |
|
self, |
|
source: Any = 0, |
|
width: int = 640, |
|
height: int = 480, |
|
fps_buffer_size: int = 30, |
|
): |
|
""" |
|
Initialize the video source. |
|
|
|
Args: |
|
source: Camera index (int) or video file path (str) |
|
width: Width to set for the captured frames |
|
height: Height to set for the captured frames |
|
fps_buffer_size: Number of frames to use for FPS averaging |
|
""" |
|
self.source = source |
|
self.width = width |
|
self.height = height |
|
self.fps_buffer_size = fps_buffer_size |
|
|
|
self.cap = None |
|
self.frame_times = [] |
|
self.is_running = False |
|
|
|
def start(self) -> bool: |
|
""" |
|
Start the video capture. |
|
|
|
Returns: |
|
bool: True if capture was started successfully, False otherwise |
|
""" |
|
if self.is_running: |
|
return True |
|
|
|
self.cap = cv2.VideoCapture(self.source) |
|
|
|
if not self.cap.isOpened(): |
|
return False |
|
|
|
|
|
if isinstance(self.source, int): |
|
self.cap.set(cv2.CAP_PROP_FRAME_WIDTH, self.width) |
|
self.cap.set(cv2.CAP_PROP_FRAME_HEIGHT, self.height) |
|
|
|
self.is_running = True |
|
self.frame_times = [] |
|
return True |
|
|
|
def stop(self) -> None: |
|
"""Stop the video capture and release resources.""" |
|
if self.is_running and self.cap is not None: |
|
self.cap.release() |
|
self.is_running = False |
|
|
|
def read_frame(self) -> Tuple[bool, Optional[np.ndarray]]: |
|
""" |
|
Read a single frame from the video source. |
|
|
|
Returns: |
|
Tuple containing: |
|
- Boolean indicating if frame was successfully read |
|
- Image as numpy array (or None if no frame was read) |
|
""" |
|
if not self.is_running or self.cap is None: |
|
return False, None |
|
|
|
|
|
current_time = time.time() |
|
|
|
|
|
ret, frame = self.cap.read() |
|
|
|
if ret: |
|
|
|
self.frame_times.append(current_time) |
|
if len(self.frame_times) > self.fps_buffer_size: |
|
self.frame_times.pop(0) |
|
|
|
return ret, frame |
|
|
|
def get_fps(self) -> float: |
|
""" |
|
Calculate the current FPS based on actual frame timings. |
|
|
|
Returns: |
|
float: Current frames per second |
|
""" |
|
if len(self.frame_times) < 2: |
|
return 0.0 |
|
|
|
|
|
time_diff = self.frame_times[-1] - self.frame_times[0] |
|
if time_diff > 0: |
|
return (len(self.frame_times) - 1) / time_diff |
|
return 0.0 |
|
|
|
|
|
def draw_detections( |
|
image: np.ndarray, |
|
detections: List[Dict[str, Any]], |
|
color: Tuple[int, int, int] = (0, 255, 0), |
|
thickness: int = 2, |
|
font_scale: float = 0.5, |
|
) -> np.ndarray: |
|
""" |
|
Draw bounding boxes and labels for detected people. |
|
|
|
Args: |
|
image: Input image to draw on |
|
detections: List of detection results from PeopleDetector |
|
color: BGR color tuple for bounding boxes |
|
thickness: Line thickness for bounding boxes |
|
font_scale: Font scale for text labels |
|
|
|
Returns: |
|
np.ndarray: Image with drawn detections |
|
""" |
|
annotated_image = image.copy() |
|
|
|
for detection in detections: |
|
|
|
x_min, y_min, x_max, y_max = detection['box'] |
|
|
|
|
|
cv2.rectangle( |
|
annotated_image, |
|
(x_min, y_min), |
|
(x_max, y_max), |
|
color, |
|
thickness |
|
) |
|
|
|
|
|
label = f"Person: {detection['score']:.2f}" |
|
|
|
|
|
(text_width, text_height), _ = cv2.getTextSize( |
|
label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness |
|
) |
|
|
|
|
|
cv2.rectangle( |
|
annotated_image, |
|
(x_min, y_min - text_height - 5), |
|
(x_min + text_width, y_min), |
|
color, |
|
-1 |
|
) |
|
|
|
|
|
cv2.putText( |
|
annotated_image, |
|
label, |
|
(x_min, y_min - 5), |
|
cv2.FONT_HERSHEY_SIMPLEX, |
|
font_scale, |
|
(0, 0, 0), |
|
thickness |
|
) |
|
|
|
return annotated_image |
|
|
|
|
|
def add_performance_stats( |
|
image: np.ndarray, |
|
fps: float, |
|
inference_time: float, |
|
people_count: int, |
|
inference_fps: float = 0.0, |
|
bg_color: Tuple[int, int, int] = (0, 0, 0), |
|
text_color: Tuple[int, int, int] = (255, 255, 255), |
|
font_scale: float = 0.5, |
|
thickness: int = 1, |
|
) -> np.ndarray: |
|
""" |
|
Add performance statistics to the image. |
|
|
|
Args: |
|
image: Input image to add stats to |
|
fps: Current FPS value |
|
inference_time: Model inference time in seconds |
|
people_count: Number of people detected |
|
inference_fps: Inference FPS (model predictions per second) |
|
bg_color: Background color for stats box |
|
text_color: Text color for stats |
|
font_scale: Font scale for text |
|
thickness: Line thickness for text |
|
|
|
Returns: |
|
np.ndarray: Image with added performance stats |
|
""" |
|
stats_image = image.copy() |
|
|
|
|
|
fps_text = f"FPS: {fps:.1f}" |
|
inference_text = f"Inference: {inference_time*1000:.1f}ms" |
|
count_text = f"People: {people_count}" |
|
inf_fps_text = f"Inference FPS: {inference_fps:.1f}" |
|
|
|
|
|
(fps_width, fps_height), _ = cv2.getTextSize( |
|
fps_text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness |
|
) |
|
(inf_width, inf_height), _ = cv2.getTextSize( |
|
inference_text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness |
|
) |
|
(count_width, count_height), _ = cv2.getTextSize( |
|
count_text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness |
|
) |
|
(inf_fps_width, inf_fps_height), _ = cv2.getTextSize( |
|
inf_fps_text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness |
|
) |
|
|
|
|
|
box_width = max(fps_width, inf_width, count_width, inf_fps_width) + 20 |
|
box_height = fps_height + inf_height + count_height + inf_fps_height + 30 |
|
|
|
|
|
cv2.rectangle( |
|
stats_image, |
|
(10, 10), |
|
(10 + box_width, 10 + box_height), |
|
bg_color, |
|
-1 |
|
) |
|
|
|
|
|
y_offset = 10 + fps_height + 5 |
|
cv2.putText( |
|
stats_image, |
|
fps_text, |
|
(20, y_offset), |
|
cv2.FONT_HERSHEY_SIMPLEX, |
|
font_scale, |
|
text_color, |
|
thickness |
|
) |
|
|
|
y_offset += inf_height + 5 |
|
cv2.putText( |
|
stats_image, |
|
inference_text, |
|
(20, y_offset), |
|
cv2.FONT_HERSHEY_SIMPLEX, |
|
font_scale, |
|
text_color, |
|
thickness |
|
) |
|
|
|
y_offset += count_height + 5 |
|
cv2.putText( |
|
stats_image, |
|
count_text, |
|
(20, y_offset), |
|
cv2.FONT_HERSHEY_SIMPLEX, |
|
font_scale, |
|
text_color, |
|
thickness |
|
) |
|
|
|
y_offset += inf_fps_height + 5 |
|
cv2.putText( |
|
stats_image, |
|
inf_fps_text, |
|
(20, y_offset), |
|
cv2.FONT_HERSHEY_SIMPLEX, |
|
font_scale, |
|
text_color, |
|
thickness |
|
) |
|
|
|
return stats_image |
|
|
|
|
|
class PeopleDetectionApp: |
|
""" |
|
Streamlit application for real-time people detection. |
|
|
|
This class handles the Streamlit UI components and orchestrates |
|
the video capture and detection processes. |
|
""" |
|
|
|
def __init__(self): |
|
"""Initialize the Streamlit application components.""" |
|
|
|
st.set_page_config( |
|
page_title="Real-time People Detection", |
|
page_icon="👁️", |
|
layout="wide", |
|
) |
|
|
|
|
|
if "video_source" not in st.session_state: |
|
st.session_state.video_source = None |
|
if "detector" not in st.session_state: |
|
st.session_state.detector = None |
|
if "is_running" not in st.session_state: |
|
st.session_state.is_running = False |
|
if "frame_placeholder" not in st.session_state: |
|
st.session_state.frame_placeholder = None |
|
if "last_inference_time" not in st.session_state: |
|
st.session_state.last_inference_time = 0.0 |
|
if "last_inference_timestamp" not in st.session_state: |
|
st.session_state.last_inference_timestamp = 0.0 |
|
if "frame_count" not in st.session_state: |
|
st.session_state.frame_count = 0 |
|
if "last_frame" not in st.session_state: |
|
st.session_state.last_frame = None |
|
if "last_detections" not in st.session_state: |
|
st.session_state.last_detections = [] |
|
|
|
def create_ui(self): |
|
"""Create the Streamlit UI components.""" |
|
|
|
st.title("Real-time People Detection") |
|
st.markdown( |
|
"This application detects people in video streams using YOLOv8." |
|
) |
|
|
|
|
|
with st.sidebar: |
|
st.header("Settings") |
|
|
|
|
|
model_name = st.selectbox( |
|
"Select detection model", |
|
options=[ |
|
"yolov8n.pt", |
|
], |
|
index=0, |
|
) |
|
|
|
|
|
detection_threshold = st.slider( |
|
"Detection threshold", |
|
min_value=0.1, |
|
max_value=1.0, |
|
value=0.5, |
|
step=0.05, |
|
) |
|
|
|
|
|
target_fps = st.slider( |
|
"Target inference FPS", |
|
min_value=1, |
|
max_value=30, |
|
value=10, |
|
step=1, |
|
help="Control how many frames per second are sent to the model for inference. Lower values use less resources but may appear less smooth." |
|
) |
|
|
|
|
|
source_type = "Demo Video" |
|
|
|
|
|
demo_selection = st.selectbox( |
|
"Select demo video", |
|
options=list(DEMO_VIDEOS.keys()), |
|
index=0, |
|
) |
|
video_path = str(DEMO_VIDEOS[demo_selection]) |
|
source = video_path |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
|
with col1: |
|
start_button = st.button( |
|
"Start" if not st.session_state.is_running else "Restart", |
|
use_container_width=True, |
|
) |
|
|
|
with col2: |
|
stop_button = st.button( |
|
"Stop", |
|
use_container_width=True, |
|
disabled=not st.session_state.is_running, |
|
) |
|
|
|
|
|
video_column, stats_column = st.columns([3, 1]) |
|
|
|
with video_column: |
|
st.subheader("Detection Feed") |
|
|
|
frame_placeholder = st.empty() |
|
st.session_state.frame_placeholder = frame_placeholder |
|
|
|
with stats_column: |
|
st.subheader("Performance Stats") |
|
|
|
fps_text = st.empty() |
|
inference_text = st.empty() |
|
people_count = st.empty() |
|
inference_fps_text = st.empty() |
|
|
|
|
|
if start_button: |
|
self.start_detection(source, model_name, detection_threshold, target_fps) |
|
|
|
if stop_button: |
|
self.stop_detection() |
|
|
|
|
|
return fps_text, inference_text, people_count, inference_fps_text |
|
|
|
def start_detection(self, source, model_name, threshold, target_fps): |
|
""" |
|
Start the detection process. |
|
|
|
Args: |
|
source: Video source (camera ID or file path) |
|
model_name: YOLOv8 model to use |
|
threshold: Detection confidence threshold |
|
target_fps: Target frames per second for inference |
|
""" |
|
|
|
self.stop_detection() |
|
|
|
|
|
video_source = VideoSource( |
|
source=source, |
|
width=FRAME_WIDTH, |
|
height=FRAME_HEIGHT, |
|
) |
|
|
|
|
|
detector = PeopleDetector( |
|
model_name=model_name, |
|
threshold=threshold, |
|
) |
|
|
|
|
|
if not video_source.start(): |
|
st.error(f"Failed to open video source: {source}") |
|
return |
|
|
|
|
|
st.session_state.video_source = video_source |
|
st.session_state.detector = detector |
|
st.session_state.is_running = True |
|
st.session_state.target_fps = target_fps |
|
st.session_state.last_inference_timestamp = time.time() |
|
st.session_state.frame_count = 0 |
|
st.session_state.last_frame = None |
|
st.session_state.last_detections = [] |
|
|
|
def stop_detection(self): |
|
"""Stop the detection process and release resources.""" |
|
if st.session_state.video_source is not None: |
|
st.session_state.video_source.stop() |
|
st.session_state.video_source = None |
|
|
|
st.session_state.detector = None |
|
st.session_state.is_running = False |
|
st.session_state.last_frame = None |
|
st.session_state.last_detections = [] |
|
|
|
def update_frame(self, fps_text, inference_text, people_count, inference_fps_text): |
|
""" |
|
Update the video frame and stats. |
|
|
|
Args: |
|
fps_text: Streamlit element for FPS display |
|
inference_text: Streamlit element for inference time display |
|
people_count: Streamlit element for people count display |
|
inference_fps_text: Streamlit element for inference FPS display |
|
""" |
|
if not st.session_state.is_running: |
|
return |
|
|
|
video_source = st.session_state.video_source |
|
detector = st.session_state.detector |
|
target_fps = st.session_state.target_fps |
|
|
|
if video_source is None or detector is None: |
|
return |
|
|
|
|
|
ret, frame = video_source.read_frame() |
|
|
|
if not ret: |
|
|
|
if not isinstance(video_source.source, int): |
|
|
|
video_source.stop() |
|
if video_source.start(): |
|
ret, frame = video_source.read_frame() |
|
if not ret: |
|
st.error("Failed to restart video") |
|
self.stop_detection() |
|
return |
|
else: |
|
st.error("Failed to restart video source") |
|
self.stop_detection() |
|
return |
|
else: |
|
st.error("Failed to read frame from camera") |
|
self.stop_detection() |
|
return |
|
|
|
|
|
fps = video_source.get_fps() |
|
|
|
|
|
current_time = time.time() |
|
time_since_last_inference = current_time - st.session_state.last_inference_timestamp |
|
inference_interval = 1.0 / target_fps |
|
|
|
|
|
detections = [] |
|
inference_time = 0 |
|
|
|
|
|
if time_since_last_inference >= inference_interval: |
|
detections, inference_time = detector.detect(frame) |
|
|
|
|
|
st.session_state.last_frame = frame.copy() |
|
st.session_state.last_detections = detections |
|
st.session_state.last_inference_time = inference_time |
|
st.session_state.last_inference_timestamp = current_time |
|
else: |
|
|
|
detections = st.session_state.last_detections |
|
inference_time = st.session_state.last_inference_time |
|
|
|
|
|
frame_with_detections = draw_detections(frame, detections) |
|
|
|
|
|
if time_since_last_inference > 0: |
|
inference_fps = 1.0 / time_since_last_inference |
|
else: |
|
inference_fps = 0.0 |
|
|
|
|
|
frame_with_stats = add_performance_stats( |
|
frame_with_detections, |
|
fps, |
|
inference_time, |
|
len(detections), |
|
inference_fps |
|
) |
|
|
|
|
|
st.session_state.frame_placeholder.image( |
|
frame_with_stats, |
|
channels="BGR", |
|
use_column_width=True |
|
) |
|
|
|
|
|
fps_text.metric("FPS", f"{fps:.1f}") |
|
inference_text.metric("Inference Time", f"{inference_time*1000:.1f} ms") |
|
people_count.metric("People Detected", len(detections)) |
|
inference_fps_text.metric("Inference FPS", f"{inference_fps:.1f}") |
|
|
|
|
|
st.session_state.frame_count += 1 |
|
|
|
|
|
def main(): |
|
"""Main entry point for the application.""" |
|
app = PeopleDetectionApp() |
|
fps_text, inference_text, people_count, inference_fps_text = app.create_ui() |
|
|
|
|
|
while st.session_state.is_running: |
|
app.update_frame(fps_text, inference_text, people_count, inference_fps_text) |
|
time.sleep(0.01) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |