|
|
import cv2 |
|
|
import torch |
|
|
from config import SCORE_THRESHOLD |
|
|
from services.model_loader import load_model |
|
|
import subprocess |
|
|
import os |
|
|
import numpy as np |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model = load_model("Model/epoch-199.pkl") |
|
|
model = model.to(device) |
|
|
model = model.eval() |
|
|
|
|
|
def get_scores(features): |
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
print("Features before model inference:", features.shape) |
|
|
scores, _ = model(features) |
|
|
scores = scores.squeeze().cpu().numpy() |
|
|
print("Features after model inference:", features.shape) |
|
|
return scores |
|
|
|
|
|
def get_selected_indices(scores, picks, threshold=SCORE_THRESHOLD): |
|
|
selected = [picks[i] for i, score in enumerate(scores) if score >= threshold] |
|
|
print("Threshold for selection:", threshold) |
|
|
print("Scores:", len(scores), scores) |
|
|
print("Picks:", len(picks), picks) |
|
|
print("Selected indices:", len(selected), selected) |
|
|
return selected |
|
|
|
|
|
def save_summary_video(video_path, selected_indices, output_path, fps=15): |
|
|
import cv2 |
|
|
|
|
|
cap = cv2.VideoCapture(video_path) |
|
|
selected = set(selected_indices) |
|
|
frame_id = 0 |
|
|
frames = [] |
|
|
|
|
|
while cap.isOpened(): |
|
|
ret, frame = cap.read() |
|
|
if not ret: |
|
|
break |
|
|
if frame_id in selected: |
|
|
frames.append(frame) |
|
|
frame_id += 1 |
|
|
cap.release() |
|
|
|
|
|
if len(frames) == 0: |
|
|
print("No frames selected.") |
|
|
return |
|
|
|
|
|
h, w, _ = frames[0].shape |
|
|
print("Video dimensions:", w, h) |
|
|
|
|
|
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (h, w)) |
|
|
for frame in frames: |
|
|
out.write(frame) |
|
|
out.release() |
|
|
|
|
|
print("Fixing the video with ffmpeg") |
|
|
|
|
|
|
|
|
def fix_video_with_ffmpeg(path): |
|
|
temp_path = path + ".fixed.mp4" |
|
|
subprocess.run([ |
|
|
"ffmpeg", "-y", "-i", path, |
|
|
"-vcodec", "libx264", "-pix_fmt", "yuv420p", temp_path |
|
|
]) |
|
|
os.replace(temp_path, path) |
|
|
|