ekabaruh's picture
Update app.py
3643479 verified
raw
history blame
22.7 kB
"""
Real-time People Detection Streamlit application.
This is the main entry point for the Hugging Face Space application.
"""
import os
import time
from pathlib import Path
from typing import Tuple, Dict, Any, Optional, List
import cv2
import numpy as np
import streamlit as st
from PIL import Image
import torch
from ultralytics import YOLO
# Constants
ASSETS_DIR = Path(__file__).parent / "assets"
DEMO_VIDEOS = {
"One Person": ASSETS_DIR / "one-by-one-person-detection.mp4",
"Store Aisle": ASSETS_DIR / "store-aisle-detection.mp4",
"People Detection": ASSETS_DIR / "people-detection.mp4"
}
FRAME_WIDTH = 640
FRAME_HEIGHT = 480
class PeopleDetector:
"""
A class for detecting people in images using a pre-trained YOLOv8n model.
Attributes:
model_name: Name or path of the YOLOv8 model to use
threshold: Confidence threshold for detection
device: Device to run inference on (cuda/cpu)
model: The detection model
"""
def __init__(
self,
model_name: str = "yolov8n.pt",
threshold: float = 0.5,
device: Optional[str] = None,
):
"""
Initialize the people detector with a pre-trained model.
Args:
model_name: YOLOv8 model name to use ('yolov8n.pt' is the smallest one)
threshold: Confidence threshold for detection (0.0 to 1.0)
device: Device to run inference on (cuda/cpu). If None, will use cuda if available.
"""
self.model_name = model_name
self.threshold = threshold
# Determine the device to use
if device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
else:
self.device = device
# Load the YOLOv8 model
self.model = YOLO(model_name)
# Person class ID is 0 in COCO (YOLOv8 uses COCO classes)
self.person_class_id = 0
def detect(self, image: np.ndarray) -> Tuple[List[Dict[str, Any]], float]:
"""
Detect people in an image.
Args:
image: Input image as numpy array (BGR format from OpenCV)
Returns:
Tuple containing:
- List of detection results with keys 'box', 'score', and 'label'
- Inference time in seconds
"""
# Start timing
start_time = time.time()
# Run inference with YOLOv8
results = self.model(image, conf=self.threshold, device=self.device)
# Extract detections of people only
detections = []
# Process the results
for result in results:
boxes = result.boxes
# Extract coordinates, confidence and class
for i, box in enumerate(boxes):
cls = int(box.cls.item())
conf = float(box.conf.item())
# Check if it's a person (class 0)
if cls == self.person_class_id:
# Get bounding box
x1, y1, x2, y2 = map(int, box.xyxy.tolist()[0])
detections.append({
'box': (x1, y1, x2, y2),
'score': conf,
'label': 'person'
})
# Calculate inference time
inference_time = time.time() - start_time
return detections, inference_time
def update_threshold(self, threshold: float) -> None:
"""
Update the detection confidence threshold.
Args:
threshold: New threshold value (0.0 to 1.0)
"""
self.threshold = threshold
class VideoSource:
"""
A class for handling video input from different sources (webcam or file).
Attributes:
source: Camera index (int) or video file path (str)
width: Frame width to set (if possible)
height: Frame height to set (if possible)
fps_buffer_size: Number of frames to average for FPS calculation
"""
def __init__(
self,
source: Any = 0,
width: int = 640,
height: int = 480,
fps_buffer_size: int = 30,
):
"""
Initialize the video source.
Args:
source: Camera index (int) or video file path (str)
width: Width to set for the captured frames
height: Height to set for the captured frames
fps_buffer_size: Number of frames to use for FPS averaging
"""
self.source = source
self.width = width
self.height = height
self.fps_buffer_size = fps_buffer_size
self.cap = None
self.frame_times = []
self.is_running = False
def start(self) -> bool:
"""
Start the video capture.
Returns:
bool: True if capture was started successfully, False otherwise
"""
if self.is_running:
return True
self.cap = cv2.VideoCapture(self.source)
if not self.cap.isOpened():
return False
# Try to set properties if it's a webcam
if isinstance(self.source, int):
self.cap.set(cv2.CAP_PROP_FRAME_WIDTH, self.width)
self.cap.set(cv2.CAP_PROP_FRAME_HEIGHT, self.height)
self.is_running = True
self.frame_times = []
return True
def stop(self) -> None:
"""Stop the video capture and release resources."""
if self.is_running and self.cap is not None:
self.cap.release()
self.is_running = False
def read_frame(self) -> Tuple[bool, Optional[np.ndarray]]:
"""
Read a single frame from the video source.
Returns:
Tuple containing:
- Boolean indicating if frame was successfully read
- Image as numpy array (or None if no frame was read)
"""
if not self.is_running or self.cap is None:
return False, None
# Record time for FPS calculation
current_time = time.time()
# Read frame
ret, frame = self.cap.read()
if ret:
# Update FPS buffer
self.frame_times.append(current_time)
if len(self.frame_times) > self.fps_buffer_size:
self.frame_times.pop(0)
return ret, frame
def get_fps(self) -> float:
"""
Calculate the current FPS based on actual frame timings.
Returns:
float: Current frames per second
"""
if len(self.frame_times) < 2:
return 0.0
# Calculate FPS from time differences
time_diff = self.frame_times[-1] - self.frame_times[0]
if time_diff > 0:
return (len(self.frame_times) - 1) / time_diff
return 0.0
def draw_detections(
image: np.ndarray,
detections: List[Dict[str, Any]],
color: Tuple[int, int, int] = (0, 255, 0),
thickness: int = 2,
font_scale: float = 0.5,
) -> np.ndarray:
"""
Draw bounding boxes and labels for detected people.
Args:
image: Input image to draw on
detections: List of detection results from PeopleDetector
color: BGR color tuple for bounding boxes
thickness: Line thickness for bounding boxes
font_scale: Font scale for text labels
Returns:
np.ndarray: Image with drawn detections
"""
annotated_image = image.copy()
for detection in detections:
# Extract bounding box coordinates
x_min, y_min, x_max, y_max = detection['box']
# Draw bounding box
cv2.rectangle(
annotated_image,
(x_min, y_min),
(x_max, y_max),
color,
thickness
)
# Create label with confidence score
label = f"Person: {detection['score']:.2f}"
# Calculate text size and position
(text_width, text_height), _ = cv2.getTextSize(
label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness
)
# Draw label background
cv2.rectangle(
annotated_image,
(x_min, y_min - text_height - 5),
(x_min + text_width, y_min),
color,
-1 # Filled rectangle
)
# Draw text
cv2.putText(
annotated_image,
label,
(x_min, y_min - 5),
cv2.FONT_HERSHEY_SIMPLEX,
font_scale,
(0, 0, 0), # Black text
thickness
)
return annotated_image
def add_performance_stats(
image: np.ndarray,
fps: float,
inference_time: float,
people_count: int,
inference_fps: float = 0.0,
bg_color: Tuple[int, int, int] = (0, 0, 0),
text_color: Tuple[int, int, int] = (255, 255, 255),
font_scale: float = 0.5,
thickness: int = 1,
) -> np.ndarray:
"""
Add performance statistics to the image.
Args:
image: Input image to add stats to
fps: Current FPS value
inference_time: Model inference time in seconds
people_count: Number of people detected
inference_fps: Inference FPS (model predictions per second)
bg_color: Background color for stats box
text_color: Text color for stats
font_scale: Font scale for text
thickness: Line thickness for text
Returns:
np.ndarray: Image with added performance stats
"""
stats_image = image.copy()
# Create stats text
fps_text = f"FPS: {fps:.1f}"
inference_text = f"Inference: {inference_time*1000:.1f}ms"
count_text = f"People: {people_count}"
inf_fps_text = f"Inference FPS: {inference_fps:.1f}"
# Get text sizes
(fps_width, fps_height), _ = cv2.getTextSize(
fps_text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness
)
(inf_width, inf_height), _ = cv2.getTextSize(
inference_text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness
)
(count_width, count_height), _ = cv2.getTextSize(
count_text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness
)
(inf_fps_width, inf_fps_height), _ = cv2.getTextSize(
inf_fps_text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness
)
# Calculate background box dimensions
box_width = max(fps_width, inf_width, count_width, inf_fps_width) + 20
box_height = fps_height + inf_height + count_height + inf_fps_height + 30
# Draw background box
cv2.rectangle(
stats_image,
(10, 10),
(10 + box_width, 10 + box_height),
bg_color,
-1 # Filled rectangle
)
# Draw text
y_offset = 10 + fps_height + 5
cv2.putText(
stats_image,
fps_text,
(20, y_offset),
cv2.FONT_HERSHEY_SIMPLEX,
font_scale,
text_color,
thickness
)
y_offset += inf_height + 5
cv2.putText(
stats_image,
inference_text,
(20, y_offset),
cv2.FONT_HERSHEY_SIMPLEX,
font_scale,
text_color,
thickness
)
y_offset += count_height + 5
cv2.putText(
stats_image,
count_text,
(20, y_offset),
cv2.FONT_HERSHEY_SIMPLEX,
font_scale,
text_color,
thickness
)
y_offset += inf_fps_height + 5
cv2.putText(
stats_image,
inf_fps_text,
(20, y_offset),
cv2.FONT_HERSHEY_SIMPLEX,
font_scale,
text_color,
thickness
)
return stats_image
class PeopleDetectionApp:
"""
Streamlit application for real-time people detection.
This class handles the Streamlit UI components and orchestrates
the video capture and detection processes.
"""
def __init__(self):
"""Initialize the Streamlit application components."""
# Set page config
st.set_page_config(
page_title="Real-time People Detection",
page_icon="👁️",
layout="wide",
)
# Initialize session state
if "video_source" not in st.session_state:
st.session_state.video_source = None
if "detector" not in st.session_state:
st.session_state.detector = None
if "is_running" not in st.session_state:
st.session_state.is_running = False
if "frame_placeholder" not in st.session_state:
st.session_state.frame_placeholder = None
if "last_inference_time" not in st.session_state:
st.session_state.last_inference_time = 0.0
if "last_inference_timestamp" not in st.session_state:
st.session_state.last_inference_timestamp = 0.0
if "frame_count" not in st.session_state:
st.session_state.frame_count = 0
if "last_frame" not in st.session_state:
st.session_state.last_frame = None
if "last_detections" not in st.session_state:
st.session_state.last_detections = []
def create_ui(self):
"""Create the Streamlit UI components."""
# Page header
st.title("Real-time People Detection")
st.markdown(
"This application detects people in video streams using YOLOv8."
)
# Sidebar for controls
with st.sidebar:
st.header("Settings")
# Model selection
model_name = st.selectbox(
"Select detection model",
options=[
"yolov8n.pt", # Nano model (smallest)
],
index=0,
)
# Detection threshold
detection_threshold = st.slider(
"Detection threshold",
min_value=0.1,
max_value=1.0,
value=0.5,
step=0.05,
)
# Target inference FPS
target_fps = st.slider(
"Target inference FPS",
min_value=1,
max_value=30,
value=10,
step=1,
help="Control how many frames per second are sent to the model for inference. Lower values use less resources but may appear less smooth."
)
# For Hugging Face Space, we only provide demo videos (no webcam)
source_type = "Demo Video"
# Let user select which demo video to use
demo_selection = st.selectbox(
"Select demo video",
options=list(DEMO_VIDEOS.keys()),
index=0,
)
video_path = str(DEMO_VIDEOS[demo_selection])
source = video_path
# Control buttons
col1, col2 = st.columns(2)
with col1:
start_button = st.button(
"Start" if not st.session_state.is_running else "Restart",
use_container_width=True,
)
with col2:
stop_button = st.button(
"Stop",
use_container_width=True,
disabled=not st.session_state.is_running,
)
# Main area for video display
video_column, stats_column = st.columns([3, 1])
with video_column:
st.subheader("Detection Feed")
# Create a placeholder for the video frame
frame_placeholder = st.empty()
st.session_state.frame_placeholder = frame_placeholder
with stats_column:
st.subheader("Performance Stats")
# Create placeholders for stats
fps_text = st.empty()
inference_text = st.empty()
people_count = st.empty()
inference_fps_text = st.empty()
# Handle button actions
if start_button:
self.start_detection(source, model_name, detection_threshold, target_fps)
if stop_button:
self.stop_detection()
# Return stats placeholders for updating
return fps_text, inference_text, people_count, inference_fps_text
def start_detection(self, source, model_name, threshold, target_fps):
"""
Start the detection process.
Args:
source: Video source (camera ID or file path)
model_name: YOLOv8 model to use
threshold: Detection confidence threshold
target_fps: Target frames per second for inference
"""
# Stop existing detection if running
self.stop_detection()
# Initialize video source
video_source = VideoSource(
source=source,
width=FRAME_WIDTH,
height=FRAME_HEIGHT,
)
# Initialize detector
detector = PeopleDetector(
model_name=model_name,
threshold=threshold,
)
# Start video capture
if not video_source.start():
st.error(f"Failed to open video source: {source}")
return
# Store objects in session state
st.session_state.video_source = video_source
st.session_state.detector = detector
st.session_state.is_running = True
st.session_state.target_fps = target_fps
st.session_state.last_inference_timestamp = time.time()
st.session_state.frame_count = 0
st.session_state.last_frame = None
st.session_state.last_detections = []
def stop_detection(self):
"""Stop the detection process and release resources."""
if st.session_state.video_source is not None:
st.session_state.video_source.stop()
st.session_state.video_source = None
st.session_state.detector = None
st.session_state.is_running = False
st.session_state.last_frame = None
st.session_state.last_detections = []
def update_frame(self, fps_text, inference_text, people_count, inference_fps_text):
"""
Update the video frame and stats.
Args:
fps_text: Streamlit element for FPS display
inference_text: Streamlit element for inference time display
people_count: Streamlit element for people count display
inference_fps_text: Streamlit element for inference FPS display
"""
if not st.session_state.is_running:
return
video_source = st.session_state.video_source
detector = st.session_state.detector
target_fps = st.session_state.target_fps
if video_source is None or detector is None:
return
# Read a new frame
ret, frame = video_source.read_frame()
if not ret:
# If we've reached the end of a video file, restart it
if not isinstance(video_source.source, int):
# Restart video
video_source.stop()
if video_source.start():
ret, frame = video_source.read_frame()
if not ret:
st.error("Failed to restart video")
self.stop_detection()
return
else:
st.error("Failed to restart video source")
self.stop_detection()
return
else:
st.error("Failed to read frame from camera")
self.stop_detection()
return
# Calculate current FPS
fps = video_source.get_fps()
# Determine if we should run inference on this frame
current_time = time.time()
time_since_last_inference = current_time - st.session_state.last_inference_timestamp
inference_interval = 1.0 / target_fps
# Use cached detections or run new detection
detections = []
inference_time = 0
# Run a new detection if enough time has passed
if time_since_last_inference >= inference_interval:
detections, inference_time = detector.detect(frame)
# Update cache
st.session_state.last_frame = frame.copy()
st.session_state.last_detections = detections
st.session_state.last_inference_time = inference_time
st.session_state.last_inference_timestamp = current_time
else:
# Use cached detections
detections = st.session_state.last_detections
inference_time = st.session_state.last_inference_time
# Draw detections on the frame
frame_with_detections = draw_detections(frame, detections)
# Calculate inference FPS
if time_since_last_inference > 0:
inference_fps = 1.0 / time_since_last_inference
else:
inference_fps = 0.0
# Add performance stats to the frame
frame_with_stats = add_performance_stats(
frame_with_detections,
fps,
inference_time,
len(detections),
inference_fps
)
# Display the frame
st.session_state.frame_placeholder.image(
frame_with_stats,
channels="BGR",
use_column_width=True
)
# Update stats
fps_text.metric("FPS", f"{fps:.1f}")
inference_text.metric("Inference Time", f"{inference_time*1000:.1f} ms")
people_count.metric("People Detected", len(detections))
inference_fps_text.metric("Inference FPS", f"{inference_fps:.1f}")
# Increment frame counter
st.session_state.frame_count += 1
def main():
"""Main entry point for the application."""
app = PeopleDetectionApp()
fps_text, inference_text, people_count, inference_fps_text = app.create_ui()
# Infinite loop for updating the video frame
while st.session_state.is_running:
app.update_frame(fps_text, inference_text, people_count, inference_fps_text)
time.sleep(0.01) # Small delay to prevent overloading the CPU
if __name__ == "__main__":
main()