import os import random import io import av import cv2 import decord import imageio from decord import VideoReader import torch import numpy as np import math import torch.nn.functional as F decord.bridge.set_bridge("torch") from transformers import AutoConfig, AutoModel config = AutoConfig.from_pretrained("/fs-computility/video/heyinan/iv2hf/", trust_remote_code=True) model = AutoModel.from_pretrained("/fs-computility/video/heyinan/iv2hf/", trust_remote_code=True).to(config.device) def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1, start=None, end=None): start_frame, end_frame = 0, vlen if start is not None: start_frame = max(start_frame,int(start * input_fps)) if end is not None: end_frame = min(end_frame,int(end * input_fps)) # Ensure start_frame is less than end_frame if start_frame >= end_frame: raise ValueError("Start frame index must be less than end frame index") # Calculate the length of the clip in frames clip_length = end_frame - start_frame if sample in ["rand", "middle"]: # uniform sampling acc_samples = min(num_frames, clip_length) # split the clip into `acc_samples` intervals, and sample from each interval. intervals = np.linspace(start=start_frame, stop=end_frame, num=acc_samples + 1).astype(int) ranges = [] for idx, interv in enumerate(intervals[:-1]): ranges.append((interv, intervals[idx + 1] - 1)) if sample == 'rand': try: frame_indices = [random.choice(range(x[0], x[1] + 1)) for x in ranges] except: frame_indices = np.random.permutation(clip_length)[:acc_samples] + start_frame frame_indices.sort() frame_indices = list(frame_indices) elif fix_start is not None: frame_indices = [x[0] + fix_start for x in ranges] elif sample == 'middle': frame_indices = [(x[0] + x[1]) // 2 for x in ranges] else: raise NotImplementedError if len(frame_indices) < num_frames: # padded with last frame padded_frame_indices = [frame_indices[-1]] * num_frames padded_frame_indices[:len(frame_indices)] = frame_indices frame_indices = padded_frame_indices elif "fps" in sample: # fps0.5, sequentially sample frames at 0.5 fps output_fps = float(sample[3:]) duration = float(clip_length) / input_fps delta = 1 / output_fps # gap between frames, this is also the clip length each frame represents frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta) frame_indices = np.around(frame_seconds * input_fps).astype(int) + start_frame frame_indices = [e for e in frame_indices if e < end_frame] if max_num_frames > 0 and len(frame_indices) > max_num_frames: frame_indices = frame_indices[:max_num_frames] # frame_indices = np.linspace(0 + delta / 2, duration + delta / 2, endpoint=False, num=max_num_frames) else: raise ValueError return frame_indices def read_frames_decord( video_path, num_frames, sample='middle', fix_start=None, max_num_frames=-1, client=None, trimmed30=False, start=None, end=None ): num_threads = 1 if video_path.endswith('.webm') else 0 # make ssv2 happy video_reader = VideoReader(video_path, num_threads=num_threads) vlen = len(video_reader) fps = video_reader.get_avg_fps() duration = vlen / float(fps) frame_indices = get_frame_indices( num_frames, vlen, sample=sample, fix_start=fix_start, input_fps=fps, max_num_frames=max_num_frames, start=start, end=end ) frames = video_reader.get_batch(frame_indices) # (T, H, W, C), torch.uint8 frames = frames.permute(0, 3, 1, 2) # (T, C, H, W), torch.uint8 return frames, frame_indices, duration def get_text_feature(model, texts): text_input = model.tokenizer(texts).to(model.device) text_features = model.encode_text(text_input) return text_features def get_similarity(video_feature, text_feature): video_feature = F.normalize(video_feature, dim=-1) text_feature = F.normalize(text_feature, dim=-1) sim_matrix = text_feature @ video_feature.T return sim_matrix def get_top_videos(model, text_features, video_features, video_paths, texts): # text_features = get_text_feature(texts) video_features = F.normalize(video_features, dim=-1) text_features = F.normalize(text_features, dim=-1) # print(text_features.shape, video_features.shape) sim_matrix = text_features @ video_features.T # print(sim_matrix.shape) top_k = 5 sim_matrix_top_k = torch.topk(sim_matrix, top_k, dim=1)[1] softmax_sim_matrix = F.softmax(sim_matrix, dim=1) retrieval_infos = {} for i in range(len(sim_matrix_top_k)): print("\n",texts[i]) retrieval_infos[texts[i]] = [] for j in range(top_k): print("top", j+1, ":", video_paths[sim_matrix_top_k[i][j]], "~prob:", sim_matrix[i][sim_matrix_top_k[i][j]].item()) retrieval_infos[texts[i]].append({"video": video_paths[sim_matrix_top_k[i][j]], "prob": sim_matrix[i][sim_matrix_top_k[i][j]].item(), "rank": j+1}) return retrieval_infos if __name__=="__main__": video_features = [] demo_videos = ["video1.mp4","video2.mp4"] texts = ['a person talking', 'a logo', 'a building'] for video_path in demo_videos: frames, frame_indices, video_duration = read_frames_decord(video_path,8) frames = model.transform(frames).unsqueeze(0).to(model.device) with torch.no_grad(): video_feature = model.encode_vision(frames, test=True) video_features.append(video_feature) text_features = get_text_feature(model, texts) video_features = torch.cat(video_features, dim=0).to(text_features.dtype).to(config.device) results = get_top_videos(model, text_features, video_features, demo_videos, texts)