Spaces:
Runtime error
Runtime error
| import base64 | |
| import hashlib | |
| import os | |
| import subprocess | |
| from dataclasses import dataclass | |
| from typing import Final | |
| import faiss | |
| import numpy as np | |
| import pandas as pd | |
| import streamlit as st | |
| from streamlit.logger import get_logger | |
| from pipeline import clip_wrapper | |
| from pipeline.process_videos import DATAFRAME_PATH | |
| NUM_FRAMES_TO_RETURN = 21 | |
| logger = get_logger(__name__) | |
| class SemanticSearcher: | |
| def __init__(self, dataset: pd.DataFrame): | |
| dim_columns = dataset.filter(regex="^dim_").columns | |
| self.embedder = clip_wrapper.ClipWrapper().texts2vec | |
| self.metadata = dataset.drop(columns=dim_columns) | |
| self.index = faiss.IndexFlatIP(len(dim_columns)) | |
| self.index.add(np.ascontiguousarray(dataset[dim_columns].to_numpy(np.float32))) | |
| def search(self, query: str) -> list["SearchResult"]: | |
| v = self.embedder([query]).detach().numpy() | |
| D, I = self.index.search(v, NUM_FRAMES_TO_RETURN) | |
| return [ | |
| SearchResult( | |
| video_id=row["video_id"], | |
| frame_idx=row["frame_idx"], | |
| timestamp=row["timestamp"], | |
| base64_image=row["base64_image"], | |
| score=score, | |
| ) | |
| for score, (_, row) in zip(D[0], self.metadata.iloc[I[0]].iterrows()) | |
| ] | |
| def get_semantic_searcher(): | |
| return SemanticSearcher(pd.read_parquet(DATAFRAME_PATH)) | |
| def get_git_hash() -> str: | |
| return subprocess.check_output(["git", "rev-parse", "HEAD"]).decode().strip() | |
| class SearchResult: | |
| video_id: str | |
| frame_idx: int | |
| timestamp: float | |
| base64_image: str | |
| score: float | |
| def get_video_url(video_id: str, timestamp: float) -> str: | |
| timestamp = max(0, timestamp - 1) | |
| return f"https://www.youtube.com/watch?v={video_id}&t={int(timestamp)}" | |
| def display_search_results(results: list[SearchResult]) -> None: | |
| col_count = 3 # Number of videos per row | |
| col_num = 0 # Counter to keep track of the current column | |
| row = st.empty() # Placeholder for the current row | |
| for i, result in enumerate(results): | |
| if col_num == 0: | |
| row = st.columns(col_count) # Create a new row of columns | |
| with row[col_num]: | |
| # Apply CSS styling to the video container | |
| st.markdown( | |
| """ | |
| <style> | |
| .video-container { | |
| position: relative; | |
| padding-bottom: 56.25%; | |
| padding-top: 30px; | |
| height: 0; | |
| overflow: hidden; | |
| } | |
| .video-container iframe, | |
| .video-container object, | |
| .video-container embed { | |
| position: absolute; | |
| top: 0; | |
| left: 0; | |
| width: 100%; | |
| height: 100%; | |
| } | |
| </style> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| st.markdown( | |
| f""" | |
| <a href="{get_video_url(result.video_id, result.timestamp)}"> | |
| <img src="data:image/jpeg;base64,{result.base64_image.decode()}" alt="frame {result.frame_idx} timestamp {int(result.timestamp)}" width="100%"> | |
| </a> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| col_num += 1 | |
| if col_num >= col_count: | |
| col_num = 0 | |
| def main(): | |
| st.set_page_config(page_title="video-semantic-search", layout="wide") | |
| st.header("Visual content search over music videos") | |
| st.markdown("_App by Ben Tenmann and Sidney Radcliffe_") | |
| searcher = get_semantic_searcher() | |
| num_videos = len(searcher.metadata.video_id.unique()) | |
| st.text_input( | |
| f"What are you looking for? Search over {num_videos} music videos.", key="query" | |
| ) | |
| query = st.session_state["query"] | |
| if query: | |
| logger.info(f"Recieved query... {hashlib.md5(query.encode()).hexdigest()}") | |
| st.text("Click image to open video") | |
| display_search_results(searcher.search(query)) | |
| st.text(f"Build: {get_git_hash()[0:7]}") | |
| if __name__ == "__main__": | |
| main() | |