bornet commited on
Commit
528e972
·
verified ·
1 Parent(s): 9a6d3df

Refactor to use hf_hub_download instead of torch.hub.load

Browse files
Files changed (2) hide show
  1. app.py +15 -5
  2. timmfrv2.py +84 -0
app.py CHANGED
@@ -12,9 +12,12 @@ import numpy as np
12
  import torch
13
  import torch.nn.functional as F
14
  from torchvision import transforms
 
15
 
16
  from utils import align_crop
17
  from title import title_css, title_with_logo
 
 
18
 
19
  # ───────────────────────────────
20
  # Data & models
@@ -233,11 +236,18 @@ _tx = transforms.Compose(
233
 
234
  def get_edge_model(name: str) -> torch.nn.Module:
235
  if name not in get_edge_model.cache:
236
- mdl = torch.hub.load(
237
- "otroshi/edgeface", name, source="github", pretrained=True
238
- ).eval()
239
- mdl.to("cuda" if torch.cuda.is_available() else "cpu")
240
- get_edge_model.cache[name] = mdl
 
 
 
 
 
 
 
241
  return get_edge_model.cache[name]
242
 
243
 
 
12
  import torch
13
  import torch.nn.functional as F
14
  from torchvision import transforms
15
+ from huggingface_hub import hf_hub_download
16
 
17
  from utils import align_crop
18
  from title import title_css, title_with_logo
19
+ from timmfrv2 import TimmFRWrapperV2, model_configs
20
+
21
 
22
  # ───────────────────────────────
23
  # Data & models
 
236
 
237
  def get_edge_model(name: str) -> torch.nn.Module:
238
  if name not in get_edge_model.cache:
239
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
240
+ model_path = hf_hub_download(
241
+ repo_id=model_configs[name]["repo"],
242
+ filename=model_configs[name]["filename"],
243
+ local_dir="models",
244
+ )
245
+ model = TimmFRWrapperV2(model_configs[name]["timm_model"], batchnorm=False)
246
+ model = model_configs[name]["post_setup"](model)
247
+ model.load_state_dict(torch.load(model_path, map_location="cpu"))
248
+ model = model.eval()
249
+ model.to(device)
250
+ get_edge_model.cache[name] = model
251
  return get_edge_model.cache[name]
252
 
253
 
timmfrv2.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import timm
3
+
4
+
5
+ class TimmFRWrapperV2(nn.Module):
6
+ """
7
+ Wraps timm model
8
+ """
9
+
10
+ def __init__(self, model_name="edgenext_x_small", featdim=512, batchnorm=False):
11
+ super().__init__()
12
+ self.featdim = featdim
13
+ self.model_name = model_name
14
+
15
+ self.model = timm.create_model(self.model_name)
16
+ self.model.reset_classifier(self.featdim)
17
+
18
+ def forward(self, x):
19
+ x = self.model(x)
20
+ return x
21
+
22
+
23
+ class LoRaLin(nn.Module):
24
+ def __init__(self, in_features, out_features, rank, bias=True):
25
+ super(LoRaLin, self).__init__()
26
+ self.in_features = in_features
27
+ self.out_features = out_features
28
+ self.rank = rank
29
+ self.linear1 = nn.Linear(in_features, rank, bias=False)
30
+ self.linear2 = nn.Linear(rank, out_features, bias=bias)
31
+
32
+ def forward(self, input):
33
+ x = self.linear1(input)
34
+ x = self.linear2(x)
35
+ return x
36
+
37
+
38
+ def replace_linear_with_lowrank_recursive_2(model, rank_ratio=0.2):
39
+ for name, module in model.named_children():
40
+ if isinstance(module, nn.Linear) and "head" not in name:
41
+ in_features = module.in_features
42
+ out_features = module.out_features
43
+ rank = max(2, int(min(in_features, out_features) * rank_ratio))
44
+ bias = False
45
+ if module.bias is not None:
46
+ bias = True
47
+ lowrank_module = LoRaLin(in_features, out_features, rank, bias)
48
+
49
+ setattr(model, name, lowrank_module)
50
+ else:
51
+ replace_linear_with_lowrank_recursive_2(module, rank_ratio)
52
+
53
+
54
+ def replace_linear_with_lowrank_2(model, rank_ratio=0.2):
55
+ replace_linear_with_lowrank_recursive_2(model, rank_ratio)
56
+ return model
57
+
58
+
59
+ model_configs = {
60
+ "edgeface_base": {
61
+ "repo": "idiap/EdgeFace-Base",
62
+ "filename": "edgeface_base.pt",
63
+ "timm_model": "edgenext_base",
64
+ "post_setup": lambda x: x,
65
+ },
66
+ "edgeface_s_gamma_05": {
67
+ "repo": "idiap/EdgeFace-S-GAMMA",
68
+ "filename": "edgeface_s_gamma_05.pt",
69
+ "timm_model": "edgenext_small",
70
+ "post_setup": lambda x: replace_linear_with_lowrank_2(x, rank_ratio=0.5),
71
+ },
72
+ "edgeface_xs_gamma_06": {
73
+ "repo": "idiap/EdgeFace-XS-GAMMA",
74
+ "filename": "edgeface_xs_gamma_06.pt",
75
+ "timm_model": "edgenext_x_small",
76
+ "post_setup": lambda x: replace_linear_with_lowrank_2(x, rank_ratio=0.6),
77
+ },
78
+ "edgeface_xxs": {
79
+ "repo": "idiap/EdgeFace-XXS",
80
+ "filename": "edgeface_xxs.pt",
81
+ "timm_model": "edgenext_xx_small",
82
+ "post_setup": lambda x: x,
83
+ },
84
+ }