EdgeFace / app.py
bornet's picture
Refactor to use hf_hub_download instead of torch.hub.load
528e972 verified
# SPDX-FileCopyrightText: 2025 Idiap Research Institute
# SPDX-FileContributor: Anjith George
# SPDX-License-Identifier: BSD-3-Clause
"""EdgeFace demo"""
from __future__ import annotations
from pathlib import Path
import cv2
import gradio as gr
import numpy as np
import torch
import torch.nn.functional as F
from torchvision import transforms
from huggingface_hub import hf_hub_download
from utils import align_crop
from title import title_css, title_with_logo
from timmfrv2 import TimmFRWrapperV2, model_configs
# ───────────────────────────────
# Data & models
# ───────────────────────────────
DATA_DIR = Path("data")
EXTS = (".jpg", ".jpeg", ".png", ".bmp", ".webp")
PRELOADED = sorted(p for p in DATA_DIR.iterdir() if p.suffix.lower() in EXTS)
EDGE_MODELS = [
"edgeface_base",
"edgeface_s_gamma_05",
"edgeface_xs_gamma_06",
"edgeface_xxs",
]
# ───────────────────────────────
# Styling (orange palette)
# ───────────────────────────────
PRIMARY = "#F97316"
PRIMARY_DARK = "#C2410C"
ACCENT_LIGHT = "#FFEAD2"
BG_LIGHT = "#FFFBF7"
CARD_BG_DARK = "#473f38"
BG_DARK = "#332a22"
TEXT_DARK = "#0F172A"
TEXT_LIGHT = "#f8fafc"
CSS = f"""
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700;800&display=swap');
/* ─── palette ───────────────────────────────────────────── */
body, .gradio-container {{
font-family: 'Inter', sans-serif;
background: {BG_LIGHT};
color: {TEXT_DARK};
}}
a {{
color: {PRIMARY};
text-decoration: none;
font-weight: 600;
}}
a:hover {{ color: {PRIMARY_DARK}; }}
/* ─── headline ──────────────────────────────────────────── */
#titlebar {{
text-align: center;
margin-top: 2.4rem;
margin-bottom: .9rem;
}}
/* ─── card look ─────────────────────────────────────────── */
.gr-block,
.gr-box,
.gr-row,
#cite-wrapper {{
border: 1px solid #F8C89B;
border-radius: 10px;
background: #fff;
box-shadow: 0 3px 6px rgba(0, 0, 0, .05);
}}
.gr-gallery-item {{ background: #fff; }}
/* ─── controls / inputs ─────────────────────────────────── */
.gr-button-primary,
#copy-btn {{
background: linear-gradient(90deg, {PRIMARY} 0%, {PRIMARY_DARK} 100%);
border: none;
color: #fff;
border-radius: 6px;
font-weight: 600;
transition: transform .12s ease, box-shadow .12s ease;
}}
.gr-button-primary:hover,
#copy-btn:hover {{
transform: translateY(-2px);
box-shadow: 0 4px 12px rgba(249, 115, 22, .35);
}}
.gr-dropdown input {{
border: 1px solid {PRIMARY}99;
}}
.preview img,
.preview canvas {{ object-fit: contain !important; }}
/* ─── hero section ─────────────────────────────────────── */
#hero-wrapper {{ text-align: center; }}
#hero-badge {{
display: inline-block;
padding: .85rem 1.2rem;
border-radius: 8px;
background: {ACCENT_LIGHT};
border: 1px solid {PRIMARY}55;
font-size: .95rem;
font-weight: 600;
margin-bottom: .5rem;
}}
#hero-links {{
font-size: .95rem;
font-weight: 600;
margin-bottom: 1.6rem;
}}
#hero-links img {{
height: 22px;
vertical-align: middle;
margin-left: .55rem;
}}
/* ─── score area ───────────────────────────────────────── */
#score-area {{
text-align: center;
}}
.title-container {{
display: flex;
align-items: center;
gap: 12px;
justify-content: center;
margin-bottom: 10px;
text-align: center;
}}
.match-badge {{
display: inline-block;
padding: .35rem .9rem;
border-radius: 9999px;
font-weight: 600;
font-size: 1.25rem;
}}
/* ─── citation card ────────────────────────────────────── */
#cite-wrapper {{
position: relative;
padding: .9rem 1rem;
margin-top: 2rem;
}}
#cite-wrapper code {{
font-family: SFMono-Regular, Consolas, monospace;
font-size: .84rem;
white-space: pre-wrap;
color: {TEXT_DARK};
}}
#copy-btn {{
position: absolute;
top: .55rem;
right: .6rem;
padding: .18rem .7rem;
font-size: .72rem;
line-height: 1;
}}
/* ─── dark mode ────────────────────────────────────── */
.dark body,
.dark .gradio-container {{
background-color: {BG_DARK};
color: #e5e7eb;
}}
.dark .gr-block,
.dark .gr-box,
.dark .gr-row {{
background-color: {BG_DARK};
border: 1px solid #4b5563;
}}
.dark .gr-dropdown input {{
background-color: {BG_DARK};
color: #f1f5f9;
border: 1px solid {PRIMARY}aa;
}}
.dark #hero-badge {{
background: #334155;
border: 1px solid {PRIMARY}55;
color: #fefefe;
}}
.dark #cite-wrapper {{
background-color: {CARD_BG_DARK};
}}
.dark #bibtex {{
color: {TEXT_LIGHT} !important;
}}
.dark .card {{
background-color: {CARD_BG_DARK};
}}
/* ─── switch logo for light/dark theme ─────────────── */
.logo-dark {{ display: none; }}
.dark .logo-light {{ display: none; }}
.dark .logo-dark {{ display: inline; }}
"""
FULL_CSS = CSS + title_css(TEXT_DARK, PRIMARY, PRIMARY_DARK, TEXT_LIGHT)
# ───────────────────────────────
# Torch / transforms
# ───────────────────────────────
_tx = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
]
)
def get_edge_model(name: str) -> torch.nn.Module:
if name not in get_edge_model.cache:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path = hf_hub_download(
repo_id=model_configs[name]["repo"],
filename=model_configs[name]["filename"],
local_dir="models",
)
model = TimmFRWrapperV2(model_configs[name]["timm_model"], batchnorm=False)
model = model_configs[name]["post_setup"](model)
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model = model.eval()
model.to(device)
get_edge_model.cache[name] = model
return get_edge_model.cache[name]
get_edge_model.cache = {}
# ───────────────────────────────
# Helpers
# ───────────────────────────────
def _as_rgb(path: Path) -> np.ndarray:
return cv2.cvtColor(cv2.imread(str(path)), cv2.COLOR_BGR2RGB)
def badge(text: str, colour: str) -> str:
return f'<div class="match-badge" style="background:{colour}22;color:{colour}">{text}</div>'
# ───────────────────────────────
# Face comparison
# ───────────────────────────────
def compare(img_left, img_right, variant):
crop_a, crop_b = align_crop(img_left), align_crop(img_right)
if crop_a is None and crop_b is None:
return None, None, badge("No face detected", "#DC2626")
if crop_a is None:
return None, None, badge("No face in A", "#DC2626")
if crop_b is None:
return None, None, badge("No face in B", "#DC2626")
mdl = get_edge_model(variant)
dev = next(mdl.parameters()).device
with torch.no_grad():
ea = mdl(_tx(cv2.cvtColor(crop_a, cv2.COLOR_RGB2BGR))[None].to(dev))[0]
eb = mdl(_tx(cv2.cvtColor(crop_b, cv2.COLOR_RGB2BGR))[None].to(dev))[0]
pct = float(F.cosine_similarity(ea[None], eb[None]).item() * 100)
pct = max(0, min(100, pct))
colour = "#15803D" if pct >= 80 else "#CA8A04" if pct >= 50 else "#DC2626"
return crop_a, crop_b, badge(f"{pct:.2f}% match", colour)
# ───────────────────────────────
# Static HTML
# ───────────────────────────────
TITLE_HTML = title_with_logo(
"""<span class="brand">EdgeFace:</span> Efficient Face Recognition Model for Edge Devices"""
)
# <div id="hero-badge">
# 🏆 Winner of IJCB 2023 Efficient Face Recognition Competition
# </div><br/>
HERO_HTML = f"""
<div id="hero-wrapper">
<div id="hero-links">
<a href="https://www.idiap.ch/paper/edgeface/">Project</a>&nbsp;•&nbsp;
<a href="https://publications.idiap.ch/attachments/papers/2024/George_IEEETBIOM_2024.pdf">Paper</a>&nbsp;•&nbsp;
<a href="https://arxiv.org/abs/2307.01838">arXiv</a>&nbsp;•&nbsp;
<a href="https://gitlab.idiap.ch/bob/bob.paper.tbiom2023_edgeface">Code</a>&nbsp;•&nbsp;
<img src="https://hitscounter.dev/api/hit?url=https%3A%2F%2Fhuggingface.co%2Fspaces%2idiap%2FEdgeFace&label=Visitors&icon=award-fill&color=%23dc3545" alt="Visitors">
</div>
</div>
"""
CITATION_HTML = """
<div id="cite-wrapper">
<button id="copy-btn" onclick="
navigator.clipboard.writeText(document.getElementById('bibtex').innerText)
.then(()=>{this.textContent='✔︎';setTimeout(()=>this.textContent='Copy',1500);});
">Copy</button>
<code id="bibtex">@article{edgeface,
title = {{EdgeFace: Efficient Face Recognition Model for Edge Devices}},
author = {{George, A. and Ecabert, C. and Otroshi, H. and Kotwal, K. and Marcel, S.}},
journal= {{IEEE Trans. Biometrics, Behavior, & Identity Science}},
year = {{2024}}
}</code>
</div>
"""
# ───────────────────────────────
# Gradio UI
# ───────────────────────────────
with gr.Blocks(css=FULL_CSS, title="EdgeFace Demo") as demo:
gr.HTML(TITLE_HTML, elem_id="titlebar")
gr.HTML(HERO_HTML)
with gr.Row():
gal_a = gr.Gallery(
PRELOADED,
columns=[5],
height=120,
label="Image A",
object_fit="contain",
elem_classes="card",
)
gal_b = gr.Gallery(
PRELOADED,
columns=[5],
height=120,
label="Image B",
object_fit="contain",
elem_classes="card",
)
with gr.Row():
# img_a = gr.Image(type="numpy", height=300, label="Image A",
# elem_classes="preview")
# img_b = gr.Image(type="numpy", height=300, label="Image B",
# elem_classes="preview")
img_a = gr.Image(
type="numpy",
height=300,
label="Image A (click or drag-drop)",
interactive=True,
elem_classes="preview card",
)
img_b = gr.Image(
type="numpy",
height=300,
label="Image B (click or drag-drop)",
interactive=True,
elem_classes="preview card",
)
def _fill(evt: gr.SelectData):
return _as_rgb(PRELOADED[evt.index]) if evt.index is not None else None
gal_a.select(_fill, outputs=img_a)
gal_b.select(_fill, outputs=img_b)
variant_dd = gr.Dropdown(
EDGE_MODELS, value="edgeface_base", label="Model variant", elem_classes="card"
)
btn = gr.Button("Compare", variant="primary")
with gr.Row():
out_a = gr.Image(label="Aligned A (112×112)", elem_classes="card")
out_b = gr.Image(label="Aligned B (112×112)", elem_classes="card")
score_html = gr.HTML(elem_id="score-area")
btn.click(compare, [img_a, img_b, variant_dd], [out_a, out_b, score_html])
gr.HTML(CITATION_HTML)
# ───────────────────────────────
if __name__ == "__main__":
demo.launch(share=True)