Spaces:
Running
on
Zero
Running
on
Zero
#!/usr/bin/env python | |
import gradio as gr | |
import polars as pl | |
from app_mcp import demo as demo_mcp | |
from search import search | |
from table import df_orig | |
DESCRIPTION = "# CVPR 2025" | |
# TODO: remove this once https://github.com/gradio-app/gradio/issues/10916 https://github.com/gradio-app/gradio/issues/11001 https://github.com/gradio-app/gradio/issues/11002 are fixed # noqa: TD002, FIX002 | |
NOTE = """\ | |
Note: Sorting by upvotes or comments may not work correctly due to a known bug in Gradio. | |
""" | |
df_main = df_orig.select( | |
"title", | |
"authors_str", | |
"cvf_md", | |
"paper_page_md", | |
"upvotes", | |
"num_comments", | |
"project_page_md", | |
"github_md", | |
"Spaces", | |
"Models", | |
"Datasets", | |
"claimed", | |
"abstract", | |
"paper_id", | |
) | |
# TODO: Fix this once https://github.com/gradio-app/gradio/issues/10916 is fixed # noqa: FIX002, TD002 | |
# format numbers as strings | |
df_main = df_main.with_columns( | |
[pl.col(col).fill_null(0).cast(pl.Int64).alias(col) for col in ["upvotes", "num_comments"]] | |
) | |
df_main = df_main.rename( | |
{ | |
"title": "Title", | |
"authors_str": "Authors", | |
"cvf_md": "CVF", | |
"paper_page_md": "Paper page", | |
"upvotes": "👍", | |
"num_comments": "💬", | |
"project_page_md": "Project page", | |
"github_md": "GitHub", | |
} | |
) | |
COLUMN_INFO = { | |
"Title": ("str", "40%"), | |
"Authors": ("str", "20%"), | |
"Paper page": ("markdown", "135px"), | |
"👍": ("number", "50px"), | |
"💬": ("number", "50px"), | |
"CVF": ("markdown", None), | |
"Project page": ("markdown", None), | |
"GitHub": ("markdown", None), | |
"Spaces": ("markdown", None), | |
"Models": ("markdown", None), | |
"Datasets": ("markdown", None), | |
"claimed": ("markdown", None), | |
} | |
DEFAULT_COLUMNS = [ | |
"Title", | |
"Paper page", | |
"👍", | |
"💬", | |
"CVF", | |
"Project page", | |
"GitHub", | |
"Spaces", | |
"Models", | |
"Datasets", | |
] | |
def update_num_papers(df: pl.DataFrame) -> str: | |
if "claimed" in df.columns: | |
return f"{len(df)} / {len(df_main)} ({df.select(pl.col('claimed').str.contains('✅').sum()).item()} claimed)" | |
return f"{len(df)} / {len(df_main)}" | |
def update_df( | |
search_query: str, | |
candidate_pool_size: int, | |
num_results: int, | |
column_names: list[str], | |
) -> gr.Dataframe: | |
if num_results > candidate_pool_size: | |
raise gr.Error("Number of results must be less than or equal to candidate pool size", print_exception=False) | |
df = df_main.clone() | |
column_names = ["Title", *column_names] | |
if search_query: | |
results = search(search_query, candidate_pool_size, num_results) | |
if not results: | |
df = df.head(0) | |
else: | |
df = pl.DataFrame(results).join(df, on="paper_id", how="inner") | |
df = df.sort("ce_score", descending=True).drop("ce_score") | |
sorted_column_names = [col for col in COLUMN_INFO if col in column_names] | |
df = df.select(sorted_column_names) | |
return gr.Dataframe( | |
value=df, | |
datatype=[COLUMN_INFO[col][0] for col in sorted_column_names], | |
column_widths=[COLUMN_INFO[col][1] for col in sorted_column_names], | |
) | |
with gr.Blocks(css_paths="style.css") as demo: | |
gr.Markdown(DESCRIPTION) | |
search_query = gr.Textbox(label="Search", submit_btn=True, show_label=False, placeholder="Search...") | |
with gr.Accordion(label="Advanced Search Options", open=False) as advanced_search_options: | |
with gr.Row(): | |
candidate_pool_size = gr.Slider(label="Candidate Pool Size", minimum=1, maximum=600, step=1, value=200) | |
num_results = gr.Slider(label="Number of Results", minimum=1, maximum=400, step=1, value=100) | |
column_names = gr.CheckboxGroup( | |
label="Columns", | |
choices=[col for col in COLUMN_INFO if col != "Title"], | |
value=[col for col in DEFAULT_COLUMNS if col != "Title"], | |
) | |
num_papers = gr.Textbox(label="Number of papers", value=update_num_papers(df_orig), interactive=False) | |
gr.Markdown(NOTE) | |
df = gr.Dataframe( | |
value=df_main, | |
datatype=list(COLUMN_INFO.values()), | |
type="polars", | |
row_count=(0, "dynamic"), | |
show_row_numbers=True, | |
interactive=False, | |
max_height=1000, | |
elem_id="table", | |
column_widths=[COLUMN_INFO[col][1] for col in COLUMN_INFO], | |
) | |
inputs = [ | |
search_query, | |
candidate_pool_size, | |
num_results, | |
column_names, | |
] | |
gr.on( | |
triggers=[ | |
search_query.submit, | |
column_names.input, | |
], | |
fn=update_df, | |
inputs=inputs, | |
outputs=df, | |
api_name=False, | |
).then( | |
fn=update_num_papers, | |
inputs=df, | |
outputs=num_papers, | |
queue=False, | |
api_name=False, | |
) | |
demo.load( | |
fn=update_df, | |
inputs=inputs, | |
outputs=df, | |
api_name=False, | |
).then( | |
fn=update_num_papers, | |
inputs=df, | |
outputs=num_papers, | |
queue=False, | |
api_name=False, | |
) | |
with gr.Row(visible=False): | |
demo_mcp.render() | |
if __name__ == "__main__": | |
demo.launch(mcp_server=True) | |