Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import polars as pl | |
from search import search | |
from table import df_orig | |
COLUMNS_MCP = [ | |
"title", | |
"authors", | |
"abstract", | |
"cvf_page_url", | |
"pdf_url", | |
"supp_url", | |
"arxiv_id", | |
"paper_page", | |
"bibtex", | |
"space_ids", | |
"model_ids", | |
"dataset_ids", | |
"upvotes", | |
"num_comments", | |
"project_page", | |
"github", | |
"row_index", | |
] | |
DEFAULT_COLUMNS_MCP = [ | |
"title", | |
"authors", | |
"abstract", | |
"cvf_page_url", | |
"pdf_url", | |
"arxiv_id", | |
"project_page", | |
"github", | |
"row_index", | |
] | |
df_mcp = df_orig.rename({"cvf": "cvf_page_url", "paper_id": "row_index"}).select(COLUMNS_MCP) | |
def search_papers( | |
search_query: str, | |
candidate_pool_size: int, | |
num_results: int, | |
columns: list[str], | |
) -> list[dict]: | |
"""Searches CVPR 2025 papers relevant to a user query in English. | |
This function performs a semantic search over CVPR 2025 papers. | |
It uses a dual-stage retrieval process: | |
- First, it retrieves `candidate_pool_size` papers using dense vector similarity. | |
- Then, it re-ranks them with a cross-encoder model to select the top `num_results` most relevant papers. | |
- The search results are returned as a list of dictionaries. | |
Note: | |
The search query must be written in English. Queries in other languages are not supported. | |
Args: | |
search_query (str): The natural language query input by the user. Must be in English. | |
candidate_pool_size (int): Number of candidate papers to retrieve using the dense vector model. | |
num_results (int): Final number of top-ranked papers to return after re-ranking. | |
columns (list[str]): The columns to select from the DataFrame. | |
Returns: | |
list[dict]: A list of dictionaries of the top-ranked papers matching the query, sorted by relevance. | |
""" | |
if not search_query: | |
raise ValueError("Search query cannot be empty") | |
if num_results > candidate_pool_size: | |
raise ValueError("Number of results must be less than or equal to candidate pool size") | |
df = df_mcp.clone() | |
results = search(search_query, candidate_pool_size, num_results) | |
df = pl.DataFrame(results).rename({"paper_id": "row_index"}).join(df, on="row_index", how="inner") | |
df = df.sort("ce_score", descending=True) | |
return df.select(columns).to_dicts() | |
def get_metadata(row_index: int) -> dict: | |
"""Returns a dictionary of metadata for a CVPR 2025 paper at the given table row index. | |
Args: | |
row_index (int): The index of the paper in the internal paper list table. | |
Returns: | |
dict: A dictionary containing metadata for the corresponding paper. | |
""" | |
return df_mcp.filter(pl.col("row_index") == row_index).to_dicts()[0] | |
def get_table(columns: list[str]) -> list[dict]: | |
"""Returns a list of dictionaries of all CVPR 2025 papers. | |
Args: | |
columns (list[str]): The columns to select from the DataFrame. | |
Returns: | |
list[dict]: A list of dictionaries of all CVPR 2025 papers. | |
""" | |
return df_mcp.select(columns).to_dicts() | |
with gr.Blocks() as demo: | |
search_query = gr.Textbox(label="Search", submit_btn=True) | |
candidate_pool_size = gr.Slider(label="Candidate Pool Size", minimum=1, maximum=500, step=1, value=200) | |
num_results = gr.Slider(label="Number of Results", minimum=1, maximum=400, step=1, value=100) | |
column_names = gr.CheckboxGroup(label="Columns", choices=COLUMNS_MCP, value=DEFAULT_COLUMNS_MCP) | |
row_index = gr.Slider(label="Row Index", minimum=0, maximum=len(df_mcp) - 1, step=1, value=0) | |
out = gr.JSON() | |
search_papers_btn = gr.Button("Search Papers") | |
get_metadata_btn = gr.Button("Get Metadata") | |
get_table_btn = gr.Button("Get Table") | |
search_papers_btn.click( | |
fn=search_papers, | |
inputs=[search_query, candidate_pool_size, num_results, column_names], | |
outputs=out, | |
) | |
get_metadata_btn.click( | |
fn=get_metadata, | |
inputs=row_index, | |
outputs=out, | |
) | |
get_table_btn.click( | |
fn=get_table, | |
inputs=column_names, | |
outputs=out, | |
) | |
if __name__ == "__main__": | |
demo.launch(mcp_server=True) | |