File size: 5,830 Bytes
288784f
 
 
 
9651249
288784f
9651249
 
288784f
32fe845
288784f
 
 
cd71716
288784f
32fe845
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288784f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9651249
32fe845
 
 
288784f
32fe845
288784f
 
 
9651249
32fe845
 
 
 
 
 
 
 
288784f
 
 
 
9651249
288784f
 
9651249
a39df26
288784f
 
 
 
 
 
 
 
9651249
 
288784f
 
 
 
 
 
 
32fe845
288784f
9651249
288784f
32fe845
9651249
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288784f
e5da3db
 
 
 
288784f
 
 
9651249
 
 
288784f
 
 
 
ee2c984
9651249
 
 
 
288784f
 
 
32fe845
 
288784f
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import spaces  # Import spaces immediately for HF ZeroGPU support.
import os
import cv2
import torch
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from io import BytesIO
from PIL import Image
import csv
from transformers import AutoFeatureExtractor, AutoModelForVideoClassification

# Specify the model checkpoint for TimeSformer.
MODEL_NAME = "facebook/timesformer-base-finetuned-k400"

def load_kinetics_labels(csv_path="kinetics-400-class-names.csv"):
    """
    Loads the Kinetics-400 labels from a CSV file.
    Expected CSV format:
        id,name
        0,abseiling
        1,air drumming
        ...
        399,zumba
    Returns a dictionary mapping string IDs to label names.
    """
    labels = {}
    try:
        with open(csv_path, "r", encoding="utf-8") as f:
            reader = csv.reader(f)
            # Skip header if present
            header = next(reader)
            if "id" not in header[0].lower():
                f.seek(0)
                reader = csv.reader(f)
            for row in reader:
                if len(row) >= 2:
                    labels[row[0].strip()] = row[1].strip()
    except Exception as e:
        print("Error reading CSV mapping:", e)
    return labels

def extract_frames(video_path, num_frames=16, target_size=(224, 224)):
    """
    Extract up to `num_frames` uniformly-sampled frames from the video.
    If the video has fewer frames, all frames are returned.
    """
    cap = cv2.VideoCapture(video_path)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    frames = []
    if total_frames <= 0:
        cap.release()
        return frames
    indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
    current_frame = 0
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        if current_frame in indices:
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame = cv2.resize(frame, target_size)
            frames.append(Image.fromarray(frame))
        current_frame += 1
    cap.release()
    return frames

@spaces.GPU
def classify_video(video_path):
    """
    Loads the TimeSformer model and feature extractor inside the GPU context,
    extracts frames from the video, runs inference, and returns:
      1. A text string of the top 5 predicted actions (with class ID and descriptive label)
         along with their probabilities.
      2. A bar chart (as a PIL Image) showing the prediction distribution.
    """
    # Load the feature extractor and model.
    feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)
    model = AutoModelForVideoClassification.from_pretrained(MODEL_NAME)
    model.eval()
    
    # Load the complete Kinetics-400 mapping from CSV.
    kinetics_id2label = load_kinetics_labels("kinetics-400-class-names.csv")
    if kinetics_id2label:
        print("Loaded complete Kinetics-400 mapping from CSV.")
    else:
        print("Warning: Could not load Kinetics-400 mapping; using default labels.")
    model.config.id2label = kinetics_id2label if kinetics_id2label else model.config.id2label

    # Determine the device.
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    # Extract frames from the video.
    frames = extract_frames(video_path, num_frames=16, target_size=(224, 224))
    if len(frames) == 0:
        return "No frames extracted from video.", None
    
    # Preprocess the frames.
    inputs = feature_extractor(frames, return_tensors="pt")
    inputs = {key: val.to(device) for key, val in inputs.items()}
    
    # Run inference.
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Get logits and compute probabilities.
    logits = outputs.logits  # shape: [batch_size, num_classes] with batch_size=1.
    probs = torch.nn.functional.softmax(logits, dim=-1)[0]
    
    # Get the top 5 predictions.
    top_probs, top_indices = torch.topk(probs, k=5)
    top_probs = top_probs.cpu().numpy()
    top_indices = top_indices.cpu().numpy()
    
    # Prepare textual results including both ID and label.
    results = []
    x_labels = []
    for idx, prob in zip(top_indices, top_probs):
        label = kinetics_id2label.get(str(idx), f"Class {idx}")
        results.append(f"ID {idx} - {label}: {prob:.3f}")
        x_labels.append(f"ID {idx}\n{label}")
    results_text = "\n".join(results)
    
    # Create a bar chart for the distribution.
    fig, ax = plt.subplots(figsize=(8, 4))
    ax.bar(x_labels, top_probs, color="skyblue")
    ax.set_ylabel("Probability")
    ax.set_title("Top 5 Prediction Distribution")
    plt.xticks(rotation=45, ha="right")
    plt.tight_layout()
    
    buf = BytesIO()
    plt.savefig(buf, format="png")
    buf.seek(0)
    plt.close(fig)
    
    # Convert the BytesIO plot to a PIL Image.
    chart_image = Image.open(buf)
    
    return results_text, chart_image

def process_video(video_file):
    if video_file is None:
        return "No video provided.", None
    result_text, plot_img = classify_video(video_file)
    return result_text, plot_img

# Gradio interface definition.
demo = gr.Interface(
    fn=process_video,
    inputs=gr.Video(sources=["upload"], label="Upload Video Clip"),
    outputs=[
        gr.Textbox(label="Predicted Actions"),
        gr.Image(label="Prediction Distribution")
    ],
    title="Video Human Detection Demo using TimeSformer",
    description=(
        "Upload a video clip to see the top predicted human action labels using the TimeSformer model "
        "(fine-tuned on Kinetics-400). The output displays each prediction's class ID and label, along with "
        "a bar chart distribution of the top 5 predictions. A complete Kinetics-400 mapping is loaded from a CSV file."
    )
)

if __name__ == "__main__":
    demo.launch()