|
import spaces |
|
import os |
|
import cv2 |
|
import torch |
|
import gradio as gr |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from io import BytesIO |
|
from PIL import Image |
|
import csv |
|
from transformers import AutoFeatureExtractor, AutoModelForVideoClassification |
|
|
|
|
|
MODEL_NAME = "facebook/timesformer-base-finetuned-k400" |
|
|
|
def load_kinetics_labels(csv_path="kinetics-400-class-names.csv"): |
|
""" |
|
Loads the Kinetics-400 labels from a CSV file. |
|
Expected CSV format: |
|
id,name |
|
0,abseiling |
|
1,air drumming |
|
... |
|
399,zumba |
|
Returns a dictionary mapping string IDs to label names. |
|
""" |
|
labels = {} |
|
try: |
|
with open(csv_path, "r", encoding="utf-8") as f: |
|
reader = csv.reader(f) |
|
|
|
header = next(reader) |
|
if "id" not in header[0].lower(): |
|
f.seek(0) |
|
reader = csv.reader(f) |
|
for row in reader: |
|
if len(row) >= 2: |
|
labels[row[0].strip()] = row[1].strip() |
|
except Exception as e: |
|
print("Error reading CSV mapping:", e) |
|
return labels |
|
|
|
def extract_frames(video_path, num_frames=16, target_size=(224, 224)): |
|
""" |
|
Extract up to `num_frames` uniformly-sampled frames from the video. |
|
If the video has fewer frames, all frames are returned. |
|
""" |
|
cap = cv2.VideoCapture(video_path) |
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
frames = [] |
|
if total_frames <= 0: |
|
cap.release() |
|
return frames |
|
indices = np.linspace(0, total_frames - 1, num_frames, dtype=int) |
|
current_frame = 0 |
|
while True: |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
if current_frame in indices: |
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
frame = cv2.resize(frame, target_size) |
|
frames.append(Image.fromarray(frame)) |
|
current_frame += 1 |
|
cap.release() |
|
return frames |
|
|
|
@spaces.GPU |
|
def classify_video(video_path): |
|
""" |
|
Loads the TimeSformer model and feature extractor inside the GPU context, |
|
extracts frames from the video, runs inference, and returns: |
|
1. A text string of the top 5 predicted actions (with class ID and descriptive label) |
|
along with their probabilities. |
|
2. A bar chart (as a PIL Image) showing the prediction distribution. |
|
""" |
|
|
|
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME) |
|
model = AutoModelForVideoClassification.from_pretrained(MODEL_NAME) |
|
model.eval() |
|
|
|
|
|
kinetics_id2label = load_kinetics_labels("kinetics-400-class-names.csv") |
|
if kinetics_id2label: |
|
print("Loaded complete Kinetics-400 mapping from CSV.") |
|
else: |
|
print("Warning: Could not load Kinetics-400 mapping; using default labels.") |
|
model.config.id2label = kinetics_id2label if kinetics_id2label else model.config.id2label |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model.to(device) |
|
|
|
|
|
frames = extract_frames(video_path, num_frames=16, target_size=(224, 224)) |
|
if len(frames) == 0: |
|
return "No frames extracted from video.", None |
|
|
|
|
|
inputs = feature_extractor(frames, return_tensors="pt") |
|
inputs = {key: val.to(device) for key, val in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
|
|
logits = outputs.logits |
|
probs = torch.nn.functional.softmax(logits, dim=-1)[0] |
|
|
|
|
|
top_probs, top_indices = torch.topk(probs, k=5) |
|
top_probs = top_probs.cpu().numpy() |
|
top_indices = top_indices.cpu().numpy() |
|
|
|
|
|
results = [] |
|
x_labels = [] |
|
for idx, prob in zip(top_indices, top_probs): |
|
label = kinetics_id2label.get(str(idx), f"Class {idx}") |
|
results.append(f"ID {idx} - {label}: {prob:.3f}") |
|
x_labels.append(f"ID {idx}\n{label}") |
|
results_text = "\n".join(results) |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(8, 4)) |
|
ax.bar(x_labels, top_probs, color="skyblue") |
|
ax.set_ylabel("Probability") |
|
ax.set_title("Top 5 Prediction Distribution") |
|
plt.xticks(rotation=45, ha="right") |
|
plt.tight_layout() |
|
|
|
buf = BytesIO() |
|
plt.savefig(buf, format="png") |
|
buf.seek(0) |
|
plt.close(fig) |
|
|
|
|
|
chart_image = Image.open(buf) |
|
|
|
return results_text, chart_image |
|
|
|
def process_video(video_file): |
|
if video_file is None: |
|
return "No video provided.", None |
|
result_text, plot_img = classify_video(video_file) |
|
return result_text, plot_img |
|
|
|
|
|
demo = gr.Interface( |
|
fn=process_video, |
|
inputs=gr.Video(sources=["upload"], label="Upload Video Clip"), |
|
outputs=[ |
|
gr.Textbox(label="Predicted Actions"), |
|
gr.Image(label="Prediction Distribution") |
|
], |
|
title="Video Human Detection Demo using TimeSformer", |
|
description=( |
|
"Upload a video clip to see the top predicted human action labels using the TimeSformer model " |
|
"(fine-tuned on Kinetics-400). The output displays each prediction's class ID and label, along with " |
|
"a bar chart distribution of the top 5 predictions. A complete Kinetics-400 mapping is loaded from a CSV file." |
|
) |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|