antonioloison's picture
Fix filtering crash (#9)
4de1a2b verified
raw
history blame
3.47 kB
from data.deprecated_model_handler import DeprecatedModelHandler
def make_clickable_model(model_name, link=None):
if link is None:
desanitized_model_name = model_name.replace("__", "/")
desanitized_model_name = desanitized_model_name.replace("_", "/")
desanitized_model_name = desanitized_model_name.replace("-thisisapoint-", ".")
if "/captioning" in desanitized_model_name:
desanitized_model_name = desanitized_model_name.replace("/captioning", "")
if "/ocr" in desanitized_model_name:
desanitized_model_name = desanitized_model_name.replace("/ocr", "")
link = "https://huggingface.co/" + desanitized_model_name
return f'<a target="_blank" style="text-decoration: underline" href="{link}">{desanitized_model_name}</a>'
def add_rank(df, benchmark_version=1, selected_columns=None):
df.fillna(0.0, inplace=True)
if selected_columns is None:
cols_to_rank = [
col
for col in df.columns
if col
not in [
"Model",
"Model Size (Million Parameters)",
"Memory Usage (GB, fp32)",
"Embedding Dimensions",
"Max Tokens",
]
]
else:
cols_to_rank = selected_columns
if len(cols_to_rank) == 1:
df.sort_values(cols_to_rank[0], ascending=False, inplace=True)
else:
df.insert(len(df.columns) - len(cols_to_rank), "Average", df[cols_to_rank].mean(axis=1, skipna=False))
df.sort_values("Average", ascending=False, inplace=True)
df.insert(0, "Rank", list(range(1, len(df) + 1)))
# multiply values by 100 if they are floats and round to 1 decimal place
for col in df.columns:
if df[col].dtype == "float64" and col != "Model Size (Million Parameters)":
df[col] = df[col].apply(lambda x: round(x * 100, 1))
return df
def add_rank_and_format(df, benchmark_version=1, selected_columns=None):
df = df.reset_index()
df = df.rename(columns={"index": "Model"})
df = add_rank(df, benchmark_version, selected_columns)
df["Model"] = df["Model"].apply(make_clickable_model)
# df = remove_duplicates(df)
return df
def remove_duplicates(df):
"""Remove duplicate models based on their name (after the last '/' if present)."""
df["model_name"] = df["Model"].str.replace("_", "/")
df = df.sort_values("Rank").drop_duplicates(subset=["model_name"], keep="first")
df = df.drop("model_name", axis=1)
return df
def get_refresh_function(model_handler, benchmark_version):
def _refresh(metric):
model_handler.get_vidore_data(metric)
data_task_category = model_handler.render_df(metric, benchmark_version)
df = add_rank_and_format(data_task_category, benchmark_version)
return df
return _refresh
def deprecated_get_refresh_function(model_handler, benchmark_version):
def _refresh(metric):
model_handler.get_vidore_data(metric)
data_task_category = model_handler.render_df(metric, benchmark_version)
df = add_rank_and_format(data_task_category, benchmark_version)
return df
return _refresh
def filter_models(data, search_term):
if search_term:
data = data[data["Model"].str.contains(search_term, case=False, na=False)]
return data