Spaces:
Sleeping
Sleeping
change small model
Browse files- app.py +28 -4
- clip_vitb_imagenet_zeroweights.pt +3 -0
app.py
CHANGED
@@ -17,6 +17,8 @@ from sklearn import metrics
|
|
17 |
import torch
|
18 |
from torchvision import transforms
|
19 |
|
|
|
|
|
20 |
from models.submodular_vit_efficient_plus import MultiModalSubModularExplanationEfficientPlus
|
21 |
|
22 |
data_transform = transforms.Compose(
|
@@ -42,7 +44,7 @@ class CLIPModel_Super(torch.nn.Module):
|
|
42 |
self.device = device
|
43 |
self.model, _ = clip.load(type, device=self.device, download_root=download_root)
|
44 |
|
45 |
-
self.model = self.model.
|
46 |
|
47 |
def forward(self, vision_inputs):
|
48 |
"""
|
@@ -70,18 +72,40 @@ def transform_vision_data(image):
|
|
70 |
image = data_transform(image)
|
71 |
return image
|
72 |
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
# Instantiate model
|
75 |
-
vis_model = CLIPModel_Super("ViT-
|
76 |
vis_model.eval()
|
77 |
vis_model.to(device)
|
78 |
print("load clip model")
|
79 |
|
80 |
-
semantic_path = "./
|
81 |
if os.path.exists(semantic_path):
|
82 |
semantic_feature = torch.load(semantic_path, map_location="cpu")
|
83 |
semantic_feature = semantic_feature.to(device)
|
84 |
semantic_feature = semantic_feature.type(torch.float32)
|
|
|
|
|
|
|
|
|
85 |
|
86 |
explainer = MultiModalSubModularExplanationEfficientPlus(
|
87 |
vis_model, semantic_feature, transform_vision_data, device=device,
|
|
|
17 |
import torch
|
18 |
from torchvision import transforms
|
19 |
|
20 |
+
from tqdm import tqdm
|
21 |
+
|
22 |
from models.submodular_vit_efficient_plus import MultiModalSubModularExplanationEfficientPlus
|
23 |
|
24 |
data_transform = transforms.Compose(
|
|
|
44 |
self.device = device
|
45 |
self.model, _ = clip.load(type, device=self.device, download_root=download_root)
|
46 |
|
47 |
+
self.model = self.model.type(torch.float32)
|
48 |
|
49 |
def forward(self, vision_inputs):
|
50 |
"""
|
|
|
72 |
image = data_transform(image)
|
73 |
return image
|
74 |
|
75 |
+
def zeroshot_classifier(model, classnames, templates, device):
|
76 |
+
with torch.no_grad():
|
77 |
+
zeroshot_weights = []
|
78 |
+
for classname in tqdm(classnames):
|
79 |
+
texts = [template.format(classname) for template in templates] #format with class
|
80 |
+
texts = clip.tokenize(texts).to(device) #tokenize
|
81 |
+
|
82 |
+
with torch.no_grad():
|
83 |
+
class_embeddings = model.model.encode_text(texts)
|
84 |
+
|
85 |
+
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
|
86 |
+
class_embedding = class_embeddings.mean(dim=0)
|
87 |
+
class_embedding /= class_embedding.norm()
|
88 |
+
zeroshot_weights.append(class_embedding)
|
89 |
+
zeroshot_weights = torch.stack(zeroshot_weights).cuda()
|
90 |
+
return zeroshot_weights*100
|
91 |
+
|
92 |
+
# device = "cuda" if torch.cuda.is_available() else "cpu"
|
93 |
+
device = "cuda"
|
94 |
# Instantiate model
|
95 |
+
vis_model = CLIPModel_Super("ViT-B/16", device=device, download_root="./ckpt")
|
96 |
vis_model.eval()
|
97 |
vis_model.to(device)
|
98 |
print("load clip model")
|
99 |
|
100 |
+
semantic_path = "./clip_vitb_imagenet_zeroweights.pt"
|
101 |
if os.path.exists(semantic_path):
|
102 |
semantic_feature = torch.load(semantic_path, map_location="cpu")
|
103 |
semantic_feature = semantic_feature.to(device)
|
104 |
semantic_feature = semantic_feature.type(torch.float32)
|
105 |
+
else:
|
106 |
+
semantic_feature = zeroshot_classifier(vis_model, imagenet_classes, imagenet_templates, device)
|
107 |
+
torch.save(semantic_feature, semantic_path)
|
108 |
+
|
109 |
|
110 |
explainer = MultiModalSubModularExplanationEfficientPlus(
|
111 |
vis_model, semantic_feature, transform_vision_data, device=device,
|
clip_vitb_imagenet_zeroweights.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c552bb4a3eebecf3162e53861a8368417a2d0b5c3af5454041369c89160ac34e
|
3 |
+
size 2048880
|