""" Real-time video classification using VJEPA2 model with streaming capabilities. This module implements a real-time video classification system that: 1. Captures video frames from a webcam 2. Processes batches of frames using the V-JEPA 2 model 3. Displays predictions overlaid on the video stream 4. Maintains a history of recent predictions The system uses FastRTC for video streaming and Gradio for the web interface. """ import os import cv2 import time import torch import random import gradio as gr import numpy as np from loguru import logger from gradio.utils import get_space from fastrtc import ( Stream, VideoStreamHandler, AdditionalOutputs, get_cloudflare_turn_credentials_async, get_cloudflare_turn_credentials, ) from transformers import VJEPA2ForVideoClassification, AutoVideoProcessor # Model configuration CHECKPOINT = "facebook/vjepa2-vitl-fpc16-256-ssv2" # Pre-trained VJEPA2 model checkpoint TORCH_DTYPE = torch.float16 # Use half precision for faster inference TORCH_DEVICE = "cuda" # Use GPU for inference UPDATE_EVERY_N_FRAMES = 64 # How often to update predictions (in frames) HF_TOKEN = os.getenv("HF_TOKEN") model = VJEPA2ForVideoClassification.from_pretrained(CHECKPOINT, torch_dtype=torch.bfloat16) model = model.to(TORCH_DEVICE) video_processor = AutoVideoProcessor.from_pretrained(CHECKPOINT) frames_per_clip = model.config.frames_per_clip def add_text_on_image(image, text): """ Overlays text on an image with a black background bar at the top. Args: image (np.ndarray): Input image to add text to text (str): Text to overlay on the image Returns: np.ndarray: Image with text overlaid """ # Add a black background to the text image[:70] = 0 line_spacing = 10 top_margin = 20 font = cv2.FONT_HERSHEY_SIMPLEX font_scale = 0.5 thickness = 1 color = (255, 255, 255) # White words = text.split() lines = [] current_line = "" img_width = image.shape[1] # Build lines that fit within the image width for word in words: test_line = current_line + (" " if current_line else "") + word (test_width, _), _ = cv2.getTextSize(test_line, font, font_scale, thickness) if test_width > img_width - 20: # 20 px margin lines.append(current_line) current_line = word else: current_line = test_line if current_line: lines.append(current_line) # Draw each line, centered y = top_margin for line in lines: (line_width, line_height), _ = cv2.getTextSize(line, font, font_scale, thickness) x = (img_width - line_width) // 2 cv2.putText(image, line, (x, y + line_height), font, font_scale, color, thickness, cv2.LINE_AA) y += line_height + line_spacing return image class RunningFramesCache: """ Maintains a rolling buffer of video frames for model input. This class manages a fixed-size queue of frames, keeping only the most recent frames needed for model inference. It supports subsampling frames to reduce memory usage and processing requirements. Args: save_every_k_frame (int): Only save every k-th frame (for subsampling) max_frames (int): Maximum number of frames to keep in cache """ def __init__(self, save_every_k_frame: int = 1, max_frames: int = 16): self.save_every_k_frame = save_every_k_frame self.max_frames = max_frames self._frames = [] self.counter = 0 def add_frame(self, frame: np.ndarray): self.counter += 1 self._frames.append(frame) if len(self._frames) > self.max_frames: self._frames.pop(0) def get_last_n_frames(self, n: int) -> list[np.ndarray]: return self._frames[-n:] def __len__(self) -> int: return len(self._frames) class RunningResult: """ Maintains a history of recent model predictions with timestamps. This class keeps track of the most recent predictions made by the model, including timestamps for each prediction. It provides formatted output for display in the UI. Args: max_predictions (int): Maximum number of predictions to keep in history """ def __init__(self, max_predictions: int = 4): self.predictions = [] self.max_predictions = max_predictions def add_prediction(self, prediction: str): # add time in a format of HH:MM:SS current_time_formatted = time.strftime("%H:%M:%S", time.gmtime(time.time())) self.predictions.append((current_time_formatted, prediction)) if len(self.predictions) > self.max_predictions: self.predictions.pop(0) def get_formatted_predictions(self) -> str: if not self.predictions: return "Starting..." current, *past = self.predictions[::-1] text = f">>> {current[1]}\n\n" + "\n".join( [f"[{time_formatted}] {prediction}" for time_formatted, prediction in past] ) return text def get_last_prediction(self) -> str: return self.predictions[-1][1] if self.predictions else "Starting..." def process_frames(image: np.ndarray, frames_state: list, result_state: list, session_cache: list): if not session_cache: session_id = random.randint(1, 1000) session_cache.append(session_id) else: session_id = session_cache[0] # Initialize frames cache if not exists (and put in gradio state) if not frames_state: logger.info(f"({session_id}) initialized frames cache") running_frames_cache = RunningFramesCache( save_every_k_frame=128 / frames_per_clip, max_frames=frames_per_clip, ) frames_state.append(running_frames_cache) else: running_frames_cache = frames_state[0] # Initialize result cache if not exists (and put in gradio state) if not result_state: logger.info(f"({session_id}) initialized result cache") running_result = RunningResult(4) result_state.append(running_result) else: running_result = result_state[0] # Add frame to frames cache image = np.flip(image, axis=1).copy() running_frames_cache.add_frame(image) # Run model if enough frames are available if ( running_frames_cache.counter % UPDATE_EVERY_N_FRAMES == 0 and len(running_frames_cache) >= model.config.frames_per_clip ): # Prepare frames for model frames = running_frames_cache.get_last_n_frames(model.config.frames_per_clip) frames = np.array(frames) inputs = video_processor(frames, device=TORCH_DEVICE, return_tensors="pt") inputs = inputs.to(dtype=TORCH_DTYPE) # Run model with torch.no_grad(): logits = model(**inputs).logits # Get top prediction top_index = logits.argmax(dim=-1).item() class_name = model.config.id2label[top_index] logger.info(f"({session_id}) action: '{class_name}'") running_result.add_prediction(class_name) # Get formatted predictions and last prediction formatted_predictions = running_result.get_formatted_predictions() last_prediction = running_result.get_last_prediction() image = add_text_on_image(image, last_prediction) return image, AdditionalOutputs(formatted_predictions) async def get_credentials(): return await get_cloudflare_turn_credentials_async(hf_token=HF_TOKEN) frames_cache = gr.State([]) result_cache = gr.State([]) session_id = gr.State([]) # Initialize the video stream with processing callback stream = Stream( handler=VideoStreamHandler(process_frames, skip_frames=True), modality="video", mode="send-receive", additional_inputs=[frames_cache, result_cache, session_id], additional_outputs=[gr.TextArea(label="Actions", value="", lines=5)], additional_outputs_handler=lambda _, output: output, rtc_configuration=get_credentials if get_space() else None, server_rtc_configuration=get_cloudflare_turn_credentials(ttl=360_000) if get_space() else None, concurrency_limit=3 if get_space() else None, ) if __name__ == "__main__": stream.ui.launch()