bornet commited on
Commit
9a6d3df
·
verified ·
1 Parent(s): 3785820

Refomatting by editor

Browse files
Files changed (1) hide show
  1. app.py +74 -38
app.py CHANGED
@@ -223,51 +223,66 @@ FULL_CSS = CSS + title_css(TEXT_DARK, PRIMARY, PRIMARY_DARK, TEXT_LIGHT)
223
  # ───────────────────────────────
224
  # Torch / transforms
225
  # ───────────────────────────────
226
- _tx = transforms.Compose([
227
- transforms.ToTensor(),
228
- transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5]),
229
- ])
230
- def get_edge_model(name:str)->torch.nn.Module:
 
 
 
 
231
  if name not in get_edge_model.cache:
232
- mdl=torch.hub.load("otroshi/edgeface",name,source="github",pretrained=True).eval()
 
 
233
  mdl.to("cuda" if torch.cuda.is_available() else "cpu")
234
- get_edge_model.cache[name]=mdl
235
  return get_edge_model.cache[name]
236
- get_edge_model.cache={}
 
 
 
237
 
238
  # ───────────────────────────────
239
  # Helpers
240
  # ───────────────────────────────
241
- def _as_rgb(path:Path)->np.ndarray:
242
- return cv2.cvtColor(cv2.imread(str(path)),cv2.COLOR_BGR2RGB)
243
 
244
- def badge(text:str,colour:str)->str:
 
245
  return f'<div class="match-badge" style="background:{colour}22;color:{colour}">{text}</div>'
246
 
 
247
  # ───────────────────────────────
248
  # Face comparison
249
  # ───────────────────────────────
250
- def compare(img_left,img_right,variant):
251
- crop_a,crop_b=align_crop(img_left),align_crop(img_right)
252
  if crop_a is None and crop_b is None:
253
- return None,None,badge("No face detected","#DC2626")
254
  if crop_a is None:
255
- return None,None,badge("No face in A","#DC2626")
256
  if crop_b is None:
257
- return None,None,badge("No face in B","#DC2626")
258
- mdl=get_edge_model(variant);dev=next(mdl.parameters()).device
 
259
  with torch.no_grad():
260
- ea=mdl(_tx(cv2.cvtColor(crop_a,cv2.COLOR_RGB2BGR))[None].to(dev))[0]
261
- eb=mdl(_tx(cv2.cvtColor(crop_b,cv2.COLOR_RGB2BGR))[None].to(dev))[0]
262
- pct=float(F.cosine_similarity(ea[None],eb[None]).item()*100)
263
- pct=max(0,min(100,pct))
264
- colour="#15803D" if pct>=80 else "#CA8A04" if pct>=50 else "#DC2626"
265
- return crop_a,crop_b,badge(f"{pct:.2f}% match",colour)
 
266
 
267
  # ───────────────────────────────
268
  # Static HTML
269
  # ───────────────────────────────
270
- TITLE_HTML = title_with_logo("""<span class="brand">EdgeFace:</span> Efficient Face Recognition Model for Edge Devices""")
 
 
271
 
272
  # <div id="hero-badge">
273
  # 🏆 Winner of IJCB 2023 Efficient Face Recognition Competition
@@ -308,12 +323,22 @@ with gr.Blocks(css=FULL_CSS, title="EdgeFace Demo") as demo:
308
  gr.HTML(HERO_HTML)
309
 
310
  with gr.Row():
311
- gal_a = gr.Gallery(PRELOADED, columns=[5], height=120,
312
- label="Image A", object_fit="contain",
313
- elem_classes="card")
314
- gal_b = gr.Gallery(PRELOADED, columns=[5], height=120,
315
- label="Image B", object_fit="contain",
316
- elem_classes="card")
 
 
 
 
 
 
 
 
 
 
317
 
318
  with gr.Row():
319
  # img_a = gr.Image(type="numpy", height=300, label="Image A",
@@ -321,18 +346,30 @@ with gr.Blocks(css=FULL_CSS, title="EdgeFace Demo") as demo:
321
  # img_b = gr.Image(type="numpy", height=300, label="Image B",
322
  # elem_classes="preview")
323
 
324
- img_a = gr.Image(type="numpy", height=300, label="Image A (click or drag-drop)",
325
- interactive=True, elem_classes="preview card")
326
- img_b = gr.Image(type="numpy", height=300, label="Image B (click or drag-drop)",
327
- interactive=True, elem_classes="preview card")
 
 
 
 
 
 
 
 
 
 
328
 
329
  def _fill(evt: gr.SelectData):
330
  return _as_rgb(PRELOADED[evt.index]) if evt.index is not None else None
 
331
  gal_a.select(_fill, outputs=img_a)
332
  gal_b.select(_fill, outputs=img_b)
333
 
334
- variant_dd = gr.Dropdown(EDGE_MODELS, value="edgeface_base",
335
- label="Model variant", elem_classes="card")
 
336
  btn = gr.Button("Compare", variant="primary")
337
 
338
  with gr.Row():
@@ -340,8 +377,7 @@ with gr.Blocks(css=FULL_CSS, title="EdgeFace Demo") as demo:
340
  out_b = gr.Image(label="Aligned B (112×112)", elem_classes="card")
341
  score_html = gr.HTML(elem_id="score-area")
342
 
343
- btn.click(compare, [img_a, img_b, variant_dd],
344
- [out_a, out_b, score_html])
345
 
346
  gr.HTML(CITATION_HTML)
347
 
 
223
  # ───────────────────────────────
224
  # Torch / transforms
225
  # ───────────────────────────────
226
+ _tx = transforms.Compose(
227
+ [
228
+ transforms.ToTensor(),
229
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
230
+ ]
231
+ )
232
+
233
+
234
+ def get_edge_model(name: str) -> torch.nn.Module:
235
  if name not in get_edge_model.cache:
236
+ mdl = torch.hub.load(
237
+ "otroshi/edgeface", name, source="github", pretrained=True
238
+ ).eval()
239
  mdl.to("cuda" if torch.cuda.is_available() else "cpu")
240
+ get_edge_model.cache[name] = mdl
241
  return get_edge_model.cache[name]
242
+
243
+
244
+ get_edge_model.cache = {}
245
+
246
 
247
  # ───────────────────────────────
248
  # Helpers
249
  # ───────────────────────────────
250
+ def _as_rgb(path: Path) -> np.ndarray:
251
+ return cv2.cvtColor(cv2.imread(str(path)), cv2.COLOR_BGR2RGB)
252
 
253
+
254
+ def badge(text: str, colour: str) -> str:
255
  return f'<div class="match-badge" style="background:{colour}22;color:{colour}">{text}</div>'
256
 
257
+
258
  # ───────────────────────────────
259
  # Face comparison
260
  # ───────────────────────────────
261
+ def compare(img_left, img_right, variant):
262
+ crop_a, crop_b = align_crop(img_left), align_crop(img_right)
263
  if crop_a is None and crop_b is None:
264
+ return None, None, badge("No face detected", "#DC2626")
265
  if crop_a is None:
266
+ return None, None, badge("No face in A", "#DC2626")
267
  if crop_b is None:
268
+ return None, None, badge("No face in B", "#DC2626")
269
+ mdl = get_edge_model(variant)
270
+ dev = next(mdl.parameters()).device
271
  with torch.no_grad():
272
+ ea = mdl(_tx(cv2.cvtColor(crop_a, cv2.COLOR_RGB2BGR))[None].to(dev))[0]
273
+ eb = mdl(_tx(cv2.cvtColor(crop_b, cv2.COLOR_RGB2BGR))[None].to(dev))[0]
274
+ pct = float(F.cosine_similarity(ea[None], eb[None]).item() * 100)
275
+ pct = max(0, min(100, pct))
276
+ colour = "#15803D" if pct >= 80 else "#CA8A04" if pct >= 50 else "#DC2626"
277
+ return crop_a, crop_b, badge(f"{pct:.2f}% match", colour)
278
+
279
 
280
  # ───────────────────────────────
281
  # Static HTML
282
  # ───────────────────────────────
283
+ TITLE_HTML = title_with_logo(
284
+ """<span class="brand">EdgeFace:</span> Efficient Face Recognition Model for Edge Devices"""
285
+ )
286
 
287
  # <div id="hero-badge">
288
  # 🏆 Winner of IJCB 2023 Efficient Face Recognition Competition
 
323
  gr.HTML(HERO_HTML)
324
 
325
  with gr.Row():
326
+ gal_a = gr.Gallery(
327
+ PRELOADED,
328
+ columns=[5],
329
+ height=120,
330
+ label="Image A",
331
+ object_fit="contain",
332
+ elem_classes="card",
333
+ )
334
+ gal_b = gr.Gallery(
335
+ PRELOADED,
336
+ columns=[5],
337
+ height=120,
338
+ label="Image B",
339
+ object_fit="contain",
340
+ elem_classes="card",
341
+ )
342
 
343
  with gr.Row():
344
  # img_a = gr.Image(type="numpy", height=300, label="Image A",
 
346
  # img_b = gr.Image(type="numpy", height=300, label="Image B",
347
  # elem_classes="preview")
348
 
349
+ img_a = gr.Image(
350
+ type="numpy",
351
+ height=300,
352
+ label="Image A (click or drag-drop)",
353
+ interactive=True,
354
+ elem_classes="preview card",
355
+ )
356
+ img_b = gr.Image(
357
+ type="numpy",
358
+ height=300,
359
+ label="Image B (click or drag-drop)",
360
+ interactive=True,
361
+ elem_classes="preview card",
362
+ )
363
 
364
  def _fill(evt: gr.SelectData):
365
  return _as_rgb(PRELOADED[evt.index]) if evt.index is not None else None
366
+
367
  gal_a.select(_fill, outputs=img_a)
368
  gal_b.select(_fill, outputs=img_b)
369
 
370
+ variant_dd = gr.Dropdown(
371
+ EDGE_MODELS, value="edgeface_base", label="Model variant", elem_classes="card"
372
+ )
373
  btn = gr.Button("Compare", variant="primary")
374
 
375
  with gr.Row():
 
377
  out_b = gr.Image(label="Aligned B (112×112)", elem_classes="card")
378
  score_html = gr.HTML(elem_id="score-area")
379
 
380
+ btn.click(compare, [img_a, img_b, variant_dd], [out_a, out_b, score_html])
 
381
 
382
  gr.HTML(CITATION_HTML)
383