Spaces:
Running
on
L4
Running
on
L4
""" | |
Real-time video classification using VJEPA2 model with streaming capabilities. | |
This module implements a real-time video classification system that: | |
1. Captures video frames from a webcam | |
2. Processes batches of frames using the V-JEPA 2 model | |
3. Displays predictions overlaid on the video stream | |
4. Maintains a history of recent predictions | |
The system uses FastRTC for video streaming and Gradio for the web interface. | |
""" | |
import os | |
import cv2 | |
import time | |
import torch | |
import random | |
import gradio as gr | |
import numpy as np | |
from loguru import logger | |
from gradio.utils import get_space | |
from fastrtc import ( | |
Stream, | |
VideoStreamHandler, | |
AdditionalOutputs, | |
get_cloudflare_turn_credentials_async, | |
get_cloudflare_turn_credentials, | |
) | |
from transformers import VJEPA2ForVideoClassification, AutoVideoProcessor | |
# Model configuration | |
CHECKPOINT = "facebook/vjepa2-vitl-fpc16-256-ssv2" # Pre-trained VJEPA2 model checkpoint | |
TORCH_DTYPE = torch.float16 # Use half precision for faster inference | |
TORCH_DEVICE = "cuda" # Use GPU for inference | |
UPDATE_EVERY_N_FRAMES = 64 # How often to update predictions (in frames) | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
model = VJEPA2ForVideoClassification.from_pretrained(CHECKPOINT, torch_dtype=torch.bfloat16) | |
model = model.to(TORCH_DEVICE) | |
video_processor = AutoVideoProcessor.from_pretrained(CHECKPOINT) | |
frames_per_clip = model.config.frames_per_clip | |
def add_text_on_image(image, text): | |
""" | |
Overlays text on an image with a black background bar at the top. | |
Args: | |
image (np.ndarray): Input image to add text to | |
text (str): Text to overlay on the image | |
Returns: | |
np.ndarray: Image with text overlaid | |
""" | |
# Add a black background to the text | |
image[:70] = 0 | |
line_spacing = 10 | |
top_margin = 20 | |
font = cv2.FONT_HERSHEY_SIMPLEX | |
font_scale = 0.5 | |
thickness = 1 | |
color = (255, 255, 255) # White | |
words = text.split() | |
lines = [] | |
current_line = "" | |
img_width = image.shape[1] | |
# Build lines that fit within the image width | |
for word in words: | |
test_line = current_line + (" " if current_line else "") + word | |
(test_width, _), _ = cv2.getTextSize(test_line, font, font_scale, thickness) | |
if test_width > img_width - 20: # 20 px margin | |
lines.append(current_line) | |
current_line = word | |
else: | |
current_line = test_line | |
if current_line: | |
lines.append(current_line) | |
# Draw each line, centered | |
y = top_margin | |
for line in lines: | |
(line_width, line_height), _ = cv2.getTextSize(line, font, font_scale, thickness) | |
x = (img_width - line_width) // 2 | |
cv2.putText(image, line, (x, y + line_height), font, font_scale, color, thickness, cv2.LINE_AA) | |
y += line_height + line_spacing | |
return image | |
class RunningFramesCache: | |
""" | |
Maintains a rolling buffer of video frames for model input. | |
This class manages a fixed-size queue of frames, keeping only the most recent | |
frames needed for model inference. It supports subsampling frames to reduce | |
memory usage and processing requirements. | |
Args: | |
save_every_k_frame (int): Only save every k-th frame (for subsampling) | |
max_frames (int): Maximum number of frames to keep in cache | |
""" | |
def __init__(self, save_every_k_frame: int = 1, max_frames: int = 16): | |
self.save_every_k_frame = save_every_k_frame | |
self.max_frames = max_frames | |
self._frames = [] | |
self.counter = 0 | |
def add_frame(self, frame: np.ndarray): | |
self.counter += 1 | |
self._frames.append(frame) | |
if len(self._frames) > self.max_frames: | |
self._frames.pop(0) | |
def get_last_n_frames(self, n: int) -> list[np.ndarray]: | |
return self._frames[-n:] | |
def __len__(self) -> int: | |
return len(self._frames) | |
class RunningResult: | |
""" | |
Maintains a history of recent model predictions with timestamps. | |
This class keeps track of the most recent predictions made by the model, | |
including timestamps for each prediction. It provides formatted output | |
for display in the UI. | |
Args: | |
max_predictions (int): Maximum number of predictions to keep in history | |
""" | |
def __init__(self, max_predictions: int = 4): | |
self.predictions = [] | |
self.max_predictions = max_predictions | |
def add_prediction(self, prediction: str): | |
# add time in a format of HH:MM:SS | |
current_time_formatted = time.strftime("%H:%M:%S", time.gmtime(time.time())) | |
self.predictions.append((current_time_formatted, prediction)) | |
if len(self.predictions) > self.max_predictions: | |
self.predictions.pop(0) | |
def get_formatted_predictions(self) -> str: | |
if not self.predictions: | |
return "Starting..." | |
current, *past = self.predictions[::-1] | |
text = f">>> {current[1]}\n\n" + "\n".join( | |
[f"[{time_formatted}] {prediction}" for time_formatted, prediction in past] | |
) | |
return text | |
def get_last_prediction(self) -> str: | |
return self.predictions[-1][1] if self.predictions else "Starting..." | |
def process_frames(image: np.ndarray, frames_state: list, result_state: list, session_cache: list): | |
if not session_cache: | |
session_id = random.randint(1, 1000) | |
session_cache.append(session_id) | |
else: | |
session_id = session_cache[0] | |
# Initialize frames cache if not exists (and put in gradio state) | |
if not frames_state: | |
logger.info(f"({session_id}) initialized frames cache") | |
running_frames_cache = RunningFramesCache( | |
save_every_k_frame=128 / frames_per_clip, | |
max_frames=frames_per_clip, | |
) | |
frames_state.append(running_frames_cache) | |
else: | |
running_frames_cache = frames_state[0] | |
# Initialize result cache if not exists (and put in gradio state) | |
if not result_state: | |
logger.info(f"({session_id}) initialized result cache") | |
running_result = RunningResult(4) | |
result_state.append(running_result) | |
else: | |
running_result = result_state[0] | |
# Add frame to frames cache | |
image = np.flip(image, axis=1).copy() | |
running_frames_cache.add_frame(image) | |
# Run model if enough frames are available | |
if ( | |
running_frames_cache.counter % UPDATE_EVERY_N_FRAMES == 0 | |
and len(running_frames_cache) >= model.config.frames_per_clip | |
): | |
# Prepare frames for model | |
frames = running_frames_cache.get_last_n_frames(model.config.frames_per_clip) | |
frames = np.array(frames) | |
inputs = video_processor(frames, device=TORCH_DEVICE, return_tensors="pt") | |
inputs = inputs.to(dtype=TORCH_DTYPE) | |
# Run model | |
with torch.no_grad(): | |
logits = model(**inputs).logits | |
# Get top prediction | |
top_index = logits.argmax(dim=-1).item() | |
class_name = model.config.id2label[top_index] | |
logger.info(f"({session_id}) action: '{class_name}'") | |
running_result.add_prediction(class_name) | |
# Get formatted predictions and last prediction | |
formatted_predictions = running_result.get_formatted_predictions() | |
last_prediction = running_result.get_last_prediction() | |
image = add_text_on_image(image, last_prediction) | |
return image, AdditionalOutputs(formatted_predictions) | |
async def get_credentials(): | |
return await get_cloudflare_turn_credentials_async(hf_token=HF_TOKEN) | |
frames_cache = gr.State([]) | |
result_cache = gr.State([]) | |
session_id = gr.State([]) | |
# Initialize the video stream with processing callback | |
stream = Stream( | |
handler=VideoStreamHandler(process_frames, skip_frames=True), | |
modality="video", | |
mode="send-receive", | |
additional_inputs=[frames_cache, result_cache, session_id], | |
additional_outputs=[gr.TextArea(label="Actions", value="", lines=5)], | |
additional_outputs_handler=lambda _, output: output, | |
rtc_configuration=get_credentials if get_space() else None, | |
server_rtc_configuration=get_cloudflare_turn_credentials(ttl=360_000) if get_space() else None, | |
concurrency_limit=3 if get_space() else None, | |
) | |
if __name__ == "__main__": | |
stream.ui.launch() | |