CVPR2025 / app.py
hysts's picture
hysts HF Staff
Update app.py (#1)
b3707ca verified
#!/usr/bin/env python
import gradio as gr
import polars as pl
from app_mcp import demo as demo_mcp
from search import search
from table import df_orig
DESCRIPTION = "# CVPR 2025"
# TODO: remove this once https://github.com/gradio-app/gradio/issues/10916 https://github.com/gradio-app/gradio/issues/11001 https://github.com/gradio-app/gradio/issues/11002 are fixed # noqa: TD002, FIX002
NOTE = """\
Note: Sorting by upvotes or comments may not work correctly due to a known bug in Gradio.
"""
df_main = df_orig.select(
"title",
"authors_str",
"cvf_md",
"paper_page_md",
"upvotes",
"num_comments",
"project_page_md",
"github_md",
"Spaces",
"Models",
"Datasets",
"claimed",
"abstract",
"paper_id",
)
# TODO: Fix this once https://github.com/gradio-app/gradio/issues/10916 is fixed # noqa: FIX002, TD002
# format numbers as strings
df_main = df_main.with_columns(
[pl.col(col).fill_null(0).cast(pl.Int64).alias(col) for col in ["upvotes", "num_comments"]]
)
df_main = df_main.rename(
{
"title": "Title",
"authors_str": "Authors",
"cvf_md": "CVF",
"paper_page_md": "Paper page",
"upvotes": "👍",
"num_comments": "💬",
"project_page_md": "Project page",
"github_md": "GitHub",
}
)
COLUMN_INFO = {
"Title": ("str", "40%"),
"Authors": ("str", "20%"),
"Paper page": ("markdown", "135px"),
"👍": ("number", "50px"),
"💬": ("number", "50px"),
"CVF": ("markdown", None),
"Project page": ("markdown", None),
"GitHub": ("markdown", None),
"Spaces": ("markdown", None),
"Models": ("markdown", None),
"Datasets": ("markdown", None),
"claimed": ("markdown", None),
}
DEFAULT_COLUMNS = [
"Title",
"Paper page",
"👍",
"💬",
"CVF",
"Project page",
"GitHub",
"Spaces",
"Models",
"Datasets",
]
def update_num_papers(df: pl.DataFrame) -> str:
if "claimed" in df.columns:
return f"{len(df)} / {len(df_main)} ({df.select(pl.col('claimed').str.contains('✅').sum()).item()} claimed)"
return f"{len(df)} / {len(df_main)}"
def update_df(
search_query: str,
candidate_pool_size: int,
num_results: int,
column_names: list[str],
) -> gr.Dataframe:
if num_results > candidate_pool_size:
raise gr.Error("Number of results must be less than or equal to candidate pool size", print_exception=False)
df = df_main.clone()
column_names = ["Title", *column_names]
if search_query:
results = search(search_query, candidate_pool_size, num_results)
if not results:
df = df.head(0)
else:
df = pl.DataFrame(results).join(df, on="paper_id", how="inner")
df = df.sort("ce_score", descending=True).drop("ce_score")
sorted_column_names = [col for col in COLUMN_INFO if col in column_names]
df = df.select(sorted_column_names)
return gr.Dataframe(
value=df,
datatype=[COLUMN_INFO[col][0] for col in sorted_column_names],
column_widths=[COLUMN_INFO[col][1] for col in sorted_column_names],
)
with gr.Blocks(css_paths="style.css") as demo:
gr.Markdown(DESCRIPTION)
search_query = gr.Textbox(label="Search", submit_btn=True, show_label=False, placeholder="Search...")
with gr.Accordion(label="Advanced Search Options", open=False) as advanced_search_options:
with gr.Row():
candidate_pool_size = gr.Slider(label="Candidate Pool Size", minimum=1, maximum=600, step=1, value=200)
num_results = gr.Slider(label="Number of Results", minimum=1, maximum=400, step=1, value=100)
column_names = gr.CheckboxGroup(
label="Columns",
choices=[col for col in COLUMN_INFO if col != "Title"],
value=[col for col in DEFAULT_COLUMNS if col != "Title"],
)
num_papers = gr.Textbox(label="Number of papers", value=update_num_papers(df_orig), interactive=False)
gr.Markdown(NOTE)
df = gr.Dataframe(
value=df_main,
datatype=list(COLUMN_INFO.values()),
type="polars",
row_count=(0, "dynamic"),
show_row_numbers=True,
interactive=False,
max_height=1000,
elem_id="table",
column_widths=[COLUMN_INFO[col][1] for col in COLUMN_INFO],
)
inputs = [
search_query,
candidate_pool_size,
num_results,
column_names,
]
gr.on(
triggers=[
search_query.submit,
column_names.input,
],
fn=update_df,
inputs=inputs,
outputs=df,
api_name=False,
).then(
fn=update_num_papers,
inputs=df,
outputs=num_papers,
queue=False,
api_name=False,
)
demo.load(
fn=update_df,
inputs=inputs,
outputs=df,
api_name=False,
).then(
fn=update_num_papers,
inputs=df,
outputs=num_papers,
queue=False,
api_name=False,
)
with gr.Row(visible=False):
demo_mcp.render()
if __name__ == "__main__":
demo.launch(mcp_server=True)