RuoyuChen commited on
Commit
00d944c
·
1 Parent(s): 9821382

change small model

Browse files
Files changed (2) hide show
  1. app.py +28 -4
  2. clip_vitb_imagenet_zeroweights.pt +3 -0
app.py CHANGED
@@ -17,6 +17,8 @@ from sklearn import metrics
17
  import torch
18
  from torchvision import transforms
19
 
 
 
20
  from models.submodular_vit_efficient_plus import MultiModalSubModularExplanationEfficientPlus
21
 
22
  data_transform = transforms.Compose(
@@ -42,7 +44,7 @@ class CLIPModel_Super(torch.nn.Module):
42
  self.device = device
43
  self.model, _ = clip.load(type, device=self.device, download_root=download_root)
44
 
45
- self.model = self.model.float()
46
 
47
  def forward(self, vision_inputs):
48
  """
@@ -70,18 +72,40 @@ def transform_vision_data(image):
70
  image = data_transform(image)
71
  return image
72
 
73
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  # Instantiate model
75
- vis_model = CLIPModel_Super("ViT-L/14", device=device, download_root="./ckpt")
76
  vis_model.eval()
77
  vis_model.to(device)
78
  print("load clip model")
79
 
80
- semantic_path = "./clip_vitl_imagenet_zeroweights.pt"
81
  if os.path.exists(semantic_path):
82
  semantic_feature = torch.load(semantic_path, map_location="cpu")
83
  semantic_feature = semantic_feature.to(device)
84
  semantic_feature = semantic_feature.type(torch.float32)
 
 
 
 
85
 
86
  explainer = MultiModalSubModularExplanationEfficientPlus(
87
  vis_model, semantic_feature, transform_vision_data, device=device,
 
17
  import torch
18
  from torchvision import transforms
19
 
20
+ from tqdm import tqdm
21
+
22
  from models.submodular_vit_efficient_plus import MultiModalSubModularExplanationEfficientPlus
23
 
24
  data_transform = transforms.Compose(
 
44
  self.device = device
45
  self.model, _ = clip.load(type, device=self.device, download_root=download_root)
46
 
47
+ self.model = self.model.type(torch.float32)
48
 
49
  def forward(self, vision_inputs):
50
  """
 
72
  image = data_transform(image)
73
  return image
74
 
75
+ def zeroshot_classifier(model, classnames, templates, device):
76
+ with torch.no_grad():
77
+ zeroshot_weights = []
78
+ for classname in tqdm(classnames):
79
+ texts = [template.format(classname) for template in templates] #format with class
80
+ texts = clip.tokenize(texts).to(device) #tokenize
81
+
82
+ with torch.no_grad():
83
+ class_embeddings = model.model.encode_text(texts)
84
+
85
+ class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
86
+ class_embedding = class_embeddings.mean(dim=0)
87
+ class_embedding /= class_embedding.norm()
88
+ zeroshot_weights.append(class_embedding)
89
+ zeroshot_weights = torch.stack(zeroshot_weights).cuda()
90
+ return zeroshot_weights*100
91
+
92
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
93
+ device = "cuda"
94
  # Instantiate model
95
+ vis_model = CLIPModel_Super("ViT-B/16", device=device, download_root="./ckpt")
96
  vis_model.eval()
97
  vis_model.to(device)
98
  print("load clip model")
99
 
100
+ semantic_path = "./clip_vitb_imagenet_zeroweights.pt"
101
  if os.path.exists(semantic_path):
102
  semantic_feature = torch.load(semantic_path, map_location="cpu")
103
  semantic_feature = semantic_feature.to(device)
104
  semantic_feature = semantic_feature.type(torch.float32)
105
+ else:
106
+ semantic_feature = zeroshot_classifier(vis_model, imagenet_classes, imagenet_templates, device)
107
+ torch.save(semantic_feature, semantic_path)
108
+
109
 
110
  explainer = MultiModalSubModularExplanationEfficientPlus(
111
  vis_model, semantic_feature, transform_vision_data, device=device,
clip_vitb_imagenet_zeroweights.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c552bb4a3eebecf3162e53861a8368417a2d0b5c3af5454041369c89160ac34e
3
+ size 2048880