from functools import lru_cache import duckdb import gradio as gr import polars as pl from datasets import load_dataset from gradio_huggingfacehub_search import HuggingfaceHubSearch from model2vec import StaticModel global df # Load a model from the HuggingFace hub (in this case the potion-base-8M model) model_name = "minishlab/potion-base-8M" model = StaticModel.from_pretrained(model_name) def get_iframe(hub_repo_id): if not hub_repo_id: raise ValueError("Hub repo id is required") url = f"https://huggingface.co/datasets/{hub_repo_id}/embed/viewer" iframe = f""" """ return iframe def load_dataset_from_hub(hub_repo_id: str): gr.Info(message="Loading dataset...") ds = load_dataset(hub_repo_id) def get_columns(hub_repo_id: str, split: str): ds = load_dataset(hub_repo_id) ds_split = ds[split] return gr.Dropdown( choices=ds_split.column_names, value=ds_split.column_names[0], label="Select a column", visible=True, ) def get_splits(hub_repo_id: str): ds = load_dataset(hub_repo_id) splits = list(ds.keys()) return gr.Dropdown( choices=splits, value=splits[0], label="Select a split", visible=True ) @lru_cache def vectorize_dataset(hub_repo_id: str, split: str, column: str): gr.Info("Vectorizing dataset...") ds = load_dataset(hub_repo_id) df = ds[split].to_polars() embeddings = model.encode(df[column].cast(str), max_length=512) return embeddings def run_query(hub_repo_id: str, query: str, split: str, column: str): embeddings = vectorize_dataset(hub_repo_id, split, column) ds = load_dataset(hub_repo_id) df = ds[split].to_polars() df = df.with_columns(pl.Series(embeddings).alias("embeddings")) try: vector = model.encode(query) df_results = duckdb.sql( query=f""" SELECT * FROM df ORDER BY array_cosine_distance(embeddings, {vector.tolist()}::FLOAT[256]) LIMIT 5 """ ).to_df() return gr.Dataframe(df_results, visible=True) except Exception as e: raise gr.Error(f"Error running query: {e}") def hide_components(): return [ gr.Dropdown(visible=False), gr.Dropdown(visible=False), gr.Textbox(visible=False), gr.Button(visible=False), gr.Dataframe(visible=False), ] def partial_hide_components(): return [ gr.Textbox(visible=False), gr.Button(visible=False), gr.Dataframe(visible=False), ] def show_components(): return [ gr.Textbox(visible=True, label="Query"), gr.Button(visible=True, value="Search"), ] with gr.Blocks() as demo: gr.HTML( """
This app allows you to vector search any Hugging Face dataset. You can search for the nearest neighbors of a query vector, or perform a similarity search on a dataframe.
""" ) with gr.Row(): with gr.Column(): search_in = HuggingfaceHubSearch( label="Search Huggingface Hub", placeholder="Search for models on Huggingface", search_type="dataset", sumbit_on_select=True, ) with gr.Row(): search_out = gr.HTML(label="Search Results") with gr.Row(): split_dropdown = gr.Dropdown(label="Select a split", visible=False) column_dropdown = gr.Dropdown(label="Select a column", visible=False) with gr.Row(): query_input = gr.Textbox(label="Query", visible=False) btn_run = gr.Button("Search", visible=False) results_output = gr.Dataframe(label="Results", visible=False) search_in.submit(get_iframe, inputs=search_in, outputs=search_out).then( fn=load_dataset_from_hub, inputs=search_in, show_progress=True, ).then( fn=hide_components, outputs=[split_dropdown, column_dropdown, query_input, btn_run, results_output], ).then(fn=get_splits, inputs=search_in, outputs=split_dropdown).then( fn=get_columns, inputs=[search_in, split_dropdown], outputs=column_dropdown ) split_dropdown.change( fn=get_columns, inputs=[search_in, split_dropdown], outputs=column_dropdown ) column_dropdown.change( fn=partial_hide_components, outputs=[query_input, btn_run, results_output], ).then(fn=show_components, outputs=[query_input, btn_run]) btn_run.click( fn=run_query, inputs=[search_in, query_input, split_dropdown, column_dropdown], outputs=results_output, ) demo.launch()