Spaces:
Sleeping
Sleeping
Update
Browse files- app_lib/config.py +159 -0
- app_lib/defaults.py +15 -10
- app_lib/multimodal.py +186 -0
- app_lib/test.py +28 -56
- app_lib/user_input.py +12 -12
- app_lib/utils.py +2 -10
- assets/results/bowl_ace.npy +1 -1
- assets/results/gardener_ace.npy +1 -1
- assets/results/gentleman_ace.npy +1 -1
- assets/results/mathematician_ace.npy +1 -1
- precompute_results.py +7 -7
app_lib/config.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SOURCE: https://github.com/Sulam-Group/IBYDMT/blob/main/ibydmt/utils/config.py
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from enum import Enum
|
| 6 |
+
from itertools import product
|
| 7 |
+
from typing import Any, Iterable, Mapping, Optional, Union
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from ml_collections import ConfigDict
|
| 11 |
+
from numpy import ndarray
|
| 12 |
+
|
| 13 |
+
Array = Union[ndarray, torch.Tensor]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class TestType(Enum):
|
| 17 |
+
GLOBAL = "global"
|
| 18 |
+
GLOBAL_COND = "global_cond"
|
| 19 |
+
LOCAL_COND = "local_cond"
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class ConceptType(Enum):
|
| 23 |
+
DATASET = "dataset"
|
| 24 |
+
CLASS = "class"
|
| 25 |
+
IMAGE = "image"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class Constants:
|
| 30 |
+
WORKDIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
| 31 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class DataConfig(ConfigDict):
|
| 35 |
+
def __init__(self, config_dict: Optional[Mapping[str, Any]] = None):
|
| 36 |
+
super().__init__()
|
| 37 |
+
if config_dict is None:
|
| 38 |
+
config_dict = {}
|
| 39 |
+
|
| 40 |
+
self.dataset: str = config_dict.get("dataset", None)
|
| 41 |
+
self.backbone: str = config_dict.get("backbone", None)
|
| 42 |
+
self.bottleneck: str = config_dict.get("bottleneck", None)
|
| 43 |
+
self.classifier: str = config_dict.get("classifier", None)
|
| 44 |
+
self.sampler: str = config_dict.get("sampler", None)
|
| 45 |
+
self.num_concepts: int = config_dict.get("num_concepts", None)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class SpliceConfig(ConfigDict):
|
| 49 |
+
def __init__(self, config_dict: Optional[Mapping[str, Any]] = None):
|
| 50 |
+
super().__init__()
|
| 51 |
+
if config_dict is None:
|
| 52 |
+
config_dict = {}
|
| 53 |
+
|
| 54 |
+
self.vocab: str = config_dict.get("vocab", None)
|
| 55 |
+
self.vocab_size: int = config_dict.get("vocab_size", None)
|
| 56 |
+
self.l1_penalty: float = config_dict.get("l1_penalty", None)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class PCBMConfig(ConfigDict):
|
| 60 |
+
def __init__(self, config_dict: Optional[Mapping[str, Any]] = None):
|
| 61 |
+
super().__init__()
|
| 62 |
+
if config_dict is None:
|
| 63 |
+
config_dict = {}
|
| 64 |
+
|
| 65 |
+
self.alpha: float = config_dict.get("alpha", None)
|
| 66 |
+
self.l1_ratio: float = config_dict.get("l1_ratio", None)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class cKDEConfig(ConfigDict):
|
| 70 |
+
def __init__(self, config_dict: Optional[Mapping[str, Any]] = None):
|
| 71 |
+
super().__init__()
|
| 72 |
+
if config_dict is None:
|
| 73 |
+
config_dict = {}
|
| 74 |
+
|
| 75 |
+
self.metric: str = config_dict.get("metric", None)
|
| 76 |
+
self.scale_method: str = config_dict.get("scale_method", None)
|
| 77 |
+
self.scale: float = config_dict.get("scale", None)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class TestingConfig(ConfigDict):
|
| 81 |
+
def __init__(self, config_dict: Optional[Mapping[str, Any]] = None):
|
| 82 |
+
super().__init__()
|
| 83 |
+
if config_dict is None:
|
| 84 |
+
config_dict = {}
|
| 85 |
+
|
| 86 |
+
self.significance_level: float = config_dict.get("significance_level", None)
|
| 87 |
+
self.wealth: str = config_dict.get("wealth", None)
|
| 88 |
+
self.bet: str = config_dict.get("bet", None)
|
| 89 |
+
self.kernel: str = config_dict.get("kernel", None)
|
| 90 |
+
self.kernel_scale_method: str = config_dict.get("kernel_scale_method", None)
|
| 91 |
+
self.kernel_scale: float = config_dict.get("kernel_scale", None)
|
| 92 |
+
self.tau_max: int = config_dict.get("tau_max", None)
|
| 93 |
+
self.images_per_class: int = config_dict.get("images_per_class", None)
|
| 94 |
+
self.r: int = config_dict.get("r", None)
|
| 95 |
+
self.cardinality: Iterable[int] = config_dict.get("cardinality", None)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class Config(ConfigDict):
|
| 99 |
+
def __init__(self, config_dict: Optional[Mapping[str, Any]] = None):
|
| 100 |
+
super().__init__()
|
| 101 |
+
if config_dict is None:
|
| 102 |
+
config_dict = {}
|
| 103 |
+
|
| 104 |
+
self.name: str = config_dict.get("name", None)
|
| 105 |
+
self.data = DataConfig(config_dict.get("data", None))
|
| 106 |
+
self.splice = SpliceConfig(config_dict.get("splice", None))
|
| 107 |
+
self.pcbm = PCBMConfig(config_dict.get("pcbm", None))
|
| 108 |
+
self.ckde = cKDEConfig(config_dict.get("ckde", None))
|
| 109 |
+
self.testing = TestingConfig(config_dict.get("testing", None))
|
| 110 |
+
|
| 111 |
+
def backbone_name(self):
|
| 112 |
+
backbone = self.data.backbone.strip().lower()
|
| 113 |
+
return backbone.replace("/", "_").replace(":", "_")
|
| 114 |
+
|
| 115 |
+
def sweep(self, keys: Iterable[str]):
|
| 116 |
+
def _get(dict, key):
|
| 117 |
+
keys = key.split(".")
|
| 118 |
+
if len(keys) == 1:
|
| 119 |
+
return dict[keys[0]]
|
| 120 |
+
else:
|
| 121 |
+
return _get(dict[keys[0]], ".".join(keys[1:]))
|
| 122 |
+
|
| 123 |
+
def _set(dict, key, value):
|
| 124 |
+
keys = key.split(".")
|
| 125 |
+
if len(keys) == 1:
|
| 126 |
+
dict[keys[0]] = value
|
| 127 |
+
else:
|
| 128 |
+
_set(dict[keys[0]], ".".join(keys[1:]), value)
|
| 129 |
+
|
| 130 |
+
to_iterable = lambda v: v if isinstance(v, list) else [v]
|
| 131 |
+
|
| 132 |
+
config_dict = self.to_dict()
|
| 133 |
+
sweep_values = [_get(config_dict, key) for key in keys]
|
| 134 |
+
sweep = list(product(*map(to_iterable, sweep_values)))
|
| 135 |
+
|
| 136 |
+
configs: Iterable[Config] = []
|
| 137 |
+
for _sweep in sweep:
|
| 138 |
+
_config_dict = config_dict.copy()
|
| 139 |
+
for key, value in zip(keys, _sweep):
|
| 140 |
+
_set(_config_dict, key, value)
|
| 141 |
+
|
| 142 |
+
configs.append(Config(_config_dict))
|
| 143 |
+
return configs
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def register_config(name: str):
|
| 147 |
+
def register(cls: Config):
|
| 148 |
+
if name in configs:
|
| 149 |
+
raise ValueError(f"Config {name} is already registered")
|
| 150 |
+
configs[name] = cls
|
| 151 |
+
|
| 152 |
+
return register
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def get_config(name: str) -> Config:
|
| 156 |
+
return configs[name]()
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
configs: Mapping[str, Config] = {}
|
app_lib/defaults.py
CHANGED
|
@@ -1,14 +1,19 @@
|
|
| 1 |
-
|
| 2 |
-
MODEL_NAME = "open_clip:ViT-B-32"
|
| 3 |
|
| 4 |
-
SIGNIFICANCE_LEVEL_VALUE = 0.05
|
| 5 |
-
SIGNIFICANCE_LEVEL_STEP = 0.01
|
| 6 |
|
| 7 |
-
|
| 8 |
-
|
|
|
|
|
|
|
| 9 |
|
| 10 |
-
|
| 11 |
-
|
| 12 |
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
|
|
|
| 2 |
|
|
|
|
|
|
|
| 3 |
|
| 4 |
+
@dataclass
|
| 5 |
+
class Defaults:
|
| 6 |
+
DATASET_NAME = "imagenette"
|
| 7 |
+
MODEL_NAME = "open_clip:ViT-B-32"
|
| 8 |
|
| 9 |
+
SIGNIFICANCE_LEVEL_VALUE = 0.05
|
| 10 |
+
SIGNIFICANCE_LEVEL_STEP = 0.01
|
| 11 |
|
| 12 |
+
TAU_MAX_VALUE = 200
|
| 13 |
+
TAU_MAX_STEP = 50
|
| 14 |
+
|
| 15 |
+
R_VALUE = 20
|
| 16 |
+
R_STEP = 5
|
| 17 |
+
|
| 18 |
+
CARDINALITY_VALUE = 1
|
| 19 |
+
CARDINALITY_STEP = 1
|
app_lib/multimodal.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SOURCE: https://github.com/Sulam-Group/IBYDMT/blob/main/ibydmt/multimodal.py
|
| 2 |
+
|
| 3 |
+
from abc import abstractmethod
|
| 4 |
+
from typing import Mapping, Optional
|
| 5 |
+
|
| 6 |
+
import clip
|
| 7 |
+
import open_clip
|
| 8 |
+
from transformers import (
|
| 9 |
+
AlignModel,
|
| 10 |
+
AlignProcessor,
|
| 11 |
+
BlipForImageTextRetrieval,
|
| 12 |
+
BlipProcessor,
|
| 13 |
+
FlavaModel,
|
| 14 |
+
FlavaProcessor,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
from app_lib.config import Config
|
| 18 |
+
from app_lib.config import Constants as c
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class VisionLanguageModel:
|
| 22 |
+
def __init__(self, backbone: Optional[str] = None, device=c.DEVICE):
|
| 23 |
+
pass
|
| 24 |
+
|
| 25 |
+
@abstractmethod
|
| 26 |
+
def encode_text(self, text):
|
| 27 |
+
pass
|
| 28 |
+
|
| 29 |
+
@abstractmethod
|
| 30 |
+
def encode_image(self, image):
|
| 31 |
+
pass
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
models: Mapping[str, VisionLanguageModel] = {}
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def register_model(name):
|
| 38 |
+
def register(cls: VisionLanguageModel):
|
| 39 |
+
if name in models:
|
| 40 |
+
raise ValueError(f"Model {name} is already registered")
|
| 41 |
+
models[name] = cls
|
| 42 |
+
|
| 43 |
+
return register
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def get_model_name_and_backbone(config: Config):
|
| 47 |
+
backbone = config.data.backbone.split(":")
|
| 48 |
+
if len(backbone) == 1:
|
| 49 |
+
backbone.append(None)
|
| 50 |
+
return backbone
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def get_model(config: Config, device=c.DEVICE) -> VisionLanguageModel:
|
| 54 |
+
model_name, backbone = get_model_name_and_backbone(config)
|
| 55 |
+
return models[model_name](backbone, device=device)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def get_text_encoder(config: Config, device=c.DEVICE):
|
| 59 |
+
model = get_model(config, device=device)
|
| 60 |
+
return model.encode_text
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def get_image_encoder(config: Config, device=c.DEVICE):
|
| 64 |
+
model = get_model(config, device=device)
|
| 65 |
+
return model.encode_image
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
@register_model(name="clip")
|
| 69 |
+
class CLIPModel(VisionLanguageModel):
|
| 70 |
+
def __init__(self, backbone: str, device=c.DEVICE):
|
| 71 |
+
self.model, self.preprocess = clip.load(backbone, device=device)
|
| 72 |
+
self.tokenize = clip.tokenize
|
| 73 |
+
|
| 74 |
+
self.device = device
|
| 75 |
+
|
| 76 |
+
def encode_text(self, text):
|
| 77 |
+
text = self.tokenize(text).to(self.device)
|
| 78 |
+
return self.model.encode_text(text)
|
| 79 |
+
|
| 80 |
+
def encode_image(self, image):
|
| 81 |
+
image = self.preprocess(image).unsqueeze(0).to(self.device)
|
| 82 |
+
return self.model.encode_image(image)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
@register_model(name="open_clip")
|
| 86 |
+
class OpenClipModel(VisionLanguageModel):
|
| 87 |
+
OPENCLIP_WEIGHTS = {
|
| 88 |
+
"ViT-B-32": "laion2b_s34b_b79k",
|
| 89 |
+
"ViT-L-14": "laion2b_s32b_b82k",
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
def __init__(self, backbone: str, device=c.DEVICE):
|
| 93 |
+
self.model, _, self.preprocess = open_clip.create_model_and_transforms(
|
| 94 |
+
backbone, pretrained=self.OPENCLIP_WEIGHTS[backbone], device=device
|
| 95 |
+
)
|
| 96 |
+
self.tokenize = open_clip.get_tokenizer(backbone)
|
| 97 |
+
|
| 98 |
+
self.device = device
|
| 99 |
+
|
| 100 |
+
def encode_text(self, text):
|
| 101 |
+
text = self.tokenize(text).to(self.device)
|
| 102 |
+
return self.model.encode_text(text)
|
| 103 |
+
|
| 104 |
+
def encode_image(self, image):
|
| 105 |
+
image = self.preprocess(image).unsqueeze(0).to(self.device)
|
| 106 |
+
return self.model.encode_image(image)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
@register_model(name="flava")
|
| 110 |
+
class FLAVAModel(VisionLanguageModel):
|
| 111 |
+
HF_MODEL = "facebook/flava-full"
|
| 112 |
+
|
| 113 |
+
def __init__(self, backbone: Optional[str] = None, device=c.DEVICE):
|
| 114 |
+
if backbone is None:
|
| 115 |
+
backbone = self.HF_MODEL
|
| 116 |
+
|
| 117 |
+
self.model = FlavaModel.from_pretrained(backbone).to(device)
|
| 118 |
+
self.processor = FlavaProcessor.from_pretrained(backbone)
|
| 119 |
+
|
| 120 |
+
self.device = device
|
| 121 |
+
|
| 122 |
+
def encode_text(self, text):
|
| 123 |
+
text_inputs = self.processor(
|
| 124 |
+
text=text, return_tensors="pt", padding="max_length", max_length=77
|
| 125 |
+
)
|
| 126 |
+
text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()}
|
| 127 |
+
return self.model.get_text_features(**text_inputs)[:, 0, :]
|
| 128 |
+
|
| 129 |
+
def encode_image(self, image):
|
| 130 |
+
image_inputs = self.processor(images=image, return_tensors="pt")
|
| 131 |
+
image_inputs = {k: v.to(self.device) for k, v in image_inputs.items()}
|
| 132 |
+
return self.model.get_image_features(**image_inputs)[:, 0, :]
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
@register_model(name="align")
|
| 136 |
+
class ALIGNModel(VisionLanguageModel):
|
| 137 |
+
HF_MODEL = "kakaobrain/align-base"
|
| 138 |
+
|
| 139 |
+
def __init__(self, backbone: Optional[str] = None, device=c.DEVICE):
|
| 140 |
+
if backbone is None:
|
| 141 |
+
backbone = self.HF_MODEL
|
| 142 |
+
|
| 143 |
+
self.model = AlignModel.from_pretrained(backbone).to(device)
|
| 144 |
+
self.processor = AlignProcessor.from_pretrained(backbone)
|
| 145 |
+
|
| 146 |
+
self.device = device
|
| 147 |
+
|
| 148 |
+
def encode_text(self, text):
|
| 149 |
+
text_inputs = self.processor(
|
| 150 |
+
text=text, return_tensors="pt", padding="max_length", max_length=77
|
| 151 |
+
)
|
| 152 |
+
text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()}
|
| 153 |
+
return self.model.get_text_features(**text_inputs)
|
| 154 |
+
|
| 155 |
+
def encode_image(self, image):
|
| 156 |
+
image_inputs = self.processor(images=image, return_tensors="pt")
|
| 157 |
+
image_inputs = {k: v.to(self.device) for k, v in image_inputs.items()}
|
| 158 |
+
return self.model.get_image_features(**image_inputs)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
@register_model(name="blip")
|
| 162 |
+
class BLIPModel(VisionLanguageModel):
|
| 163 |
+
HF_MODEL = "Salesforce/blip-itm-base-coco"
|
| 164 |
+
|
| 165 |
+
def __init__(self, backbone: Optional[str] = None, device=c.DEVICE):
|
| 166 |
+
if backbone is None:
|
| 167 |
+
backbone = self.HF_MODEL
|
| 168 |
+
|
| 169 |
+
self.model = BlipForImageTextRetrieval.from_pretrained(backbone).to(device)
|
| 170 |
+
self.processor = BlipProcessor.from_pretrained(backbone)
|
| 171 |
+
|
| 172 |
+
self.device = device
|
| 173 |
+
|
| 174 |
+
def encode_text(self, text):
|
| 175 |
+
text_inputs = self.processor(
|
| 176 |
+
text=text, return_tensors="pt", padding="max_length", max_length=77
|
| 177 |
+
)
|
| 178 |
+
text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()}
|
| 179 |
+
question_embeds = self.model.text_encoder(**text_inputs)[0]
|
| 180 |
+
return self.model.text_proj(question_embeds[:, 0, :])
|
| 181 |
+
|
| 182 |
+
def encode_image(self, image):
|
| 183 |
+
image_inputs = self.processor(images=image, return_tensors="pt")
|
| 184 |
+
image_inputs = {k: v.to(self.device) for k, v in image_inputs.items()}
|
| 185 |
+
image_embeds = self.model.vision_model(**image_inputs)[0]
|
| 186 |
+
return self.model.vision_proj(image_embeds[:, 0, :])
|
app_lib/test.py
CHANGED
|
@@ -1,65 +1,33 @@
|
|
| 1 |
import os
|
| 2 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 3 |
|
| 4 |
-
import clip
|
| 5 |
-
import h5py
|
| 6 |
import ml_collections
|
| 7 |
import numpy as np
|
| 8 |
-
import
|
| 9 |
import streamlit as st
|
| 10 |
import torch
|
| 11 |
from huggingface_hub import hf_hub_download
|
| 12 |
|
|
|
|
| 13 |
from app_lib.ckde import cKDE
|
| 14 |
-
from app_lib.
|
|
|
|
| 15 |
from ibydmt.test import xSKIT
|
| 16 |
|
| 17 |
rng = np.random.default_rng()
|
| 18 |
|
| 19 |
|
| 20 |
-
def _get_open_clip_model(model_name, device):
|
| 21 |
-
backbone = model_name.split(":")[-1]
|
| 22 |
-
|
| 23 |
-
model, _, preprocess = open_clip.create_model_and_transforms(
|
| 24 |
-
SUPPORTED_MODELS[model_name], device=device
|
| 25 |
-
)
|
| 26 |
-
model.eval()
|
| 27 |
-
tokenizer = open_clip.get_tokenizer(backbone)
|
| 28 |
-
return model, preprocess, tokenizer
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
def _get_clip_model(model_name, device):
|
| 32 |
-
backbone = model_name.split(":")[-1]
|
| 33 |
-
model, preprocess = clip.load(backbone, device=device)
|
| 34 |
-
tokenizer = clip.tokenize
|
| 35 |
-
return model, preprocess, tokenizer
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
def _load_model(model_name, device):
|
| 39 |
-
if "open_clip" in model_name:
|
| 40 |
-
model, preprocess, tokenizer = _get_open_clip_model(model_name, device)
|
| 41 |
-
elif "clip" in model_name:
|
| 42 |
-
model, preprocess, tokenizer = _get_clip_model(model_name, device)
|
| 43 |
-
return model, preprocess, tokenizer
|
| 44 |
-
|
| 45 |
-
|
| 46 |
@torch.no_grad()
|
| 47 |
@torch.cuda.amp.autocast()
|
| 48 |
-
def _encode_concepts(
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
concept_features = model.encode_text(concepts_text)
|
| 52 |
concept_features /= torch.linalg.norm(concept_features, dim=-1, keepdim=True)
|
| 53 |
return concept_features.cpu().numpy()
|
| 54 |
|
| 55 |
|
| 56 |
@torch.no_grad()
|
| 57 |
@torch.cuda.amp.autocast()
|
| 58 |
-
def _encode_image(model,
|
| 59 |
-
image = preprocess(image)
|
| 60 |
-
image = image.unsqueeze(0)
|
| 61 |
-
image = image.to(device)
|
| 62 |
-
|
| 63 |
image_features = model.encode_image(image)
|
| 64 |
image_features /= image_features.norm(dim=-1, keepdim=True)
|
| 65 |
return image_features.cpu().numpy()
|
|
@@ -67,24 +35,24 @@ def _encode_image(model, preprocess, image, device):
|
|
| 67 |
|
| 68 |
@torch.no_grad()
|
| 69 |
@torch.cuda.amp.autocast()
|
| 70 |
-
def _encode_class_name(
|
| 71 |
-
class_text =
|
| 72 |
-
|
| 73 |
class_features = model.encode_text(class_text)
|
| 74 |
class_features /= torch.linalg.norm(class_features, dim=-1, keepdim=True)
|
| 75 |
return class_features.cpu().numpy()
|
| 76 |
|
| 77 |
|
| 78 |
-
def
|
| 79 |
dataset_path = hf_hub_download(
|
| 80 |
repo_id="jacopoteneggi/IBYDMT",
|
| 81 |
-
filename=
|
|
|
|
|
|
|
| 82 |
repo_type="dataset",
|
| 83 |
)
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
return embedding
|
| 88 |
|
| 89 |
|
| 90 |
def _sample_random_subset(concept_idx, concepts, cardinality):
|
|
@@ -162,26 +130,30 @@ def test(
|
|
| 162 |
cardinality,
|
| 163 |
dataset_name,
|
| 164 |
model_name,
|
| 165 |
-
device=
|
| 166 |
with_streamlit=True,
|
| 167 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
if with_streamlit:
|
| 169 |
with st.spinner("Loading model"):
|
| 170 |
-
model
|
| 171 |
else:
|
| 172 |
-
model
|
| 173 |
|
| 174 |
if with_streamlit:
|
| 175 |
with st.spinner("Encoding concepts"):
|
| 176 |
-
cbm = _encode_concepts(
|
| 177 |
else:
|
| 178 |
-
cbm = _encode_concepts(
|
| 179 |
|
| 180 |
if with_streamlit:
|
| 181 |
with st.spinner("Encoding image"):
|
| 182 |
-
h = _encode_image(model,
|
| 183 |
else:
|
| 184 |
-
h = _encode_image(model,
|
| 185 |
z = h @ cbm.T
|
| 186 |
z = z.squeeze()
|
| 187 |
|
|
@@ -201,11 +173,11 @@ def test(
|
|
| 201 |
),
|
| 202 |
)
|
| 203 |
|
| 204 |
-
embedding =
|
| 205 |
semantics = embedding @ cbm.T
|
| 206 |
sampler = cKDE(embedding, semantics)
|
| 207 |
|
| 208 |
-
classifier = _encode_class_name(
|
| 209 |
|
| 210 |
with ThreadPoolExecutor() as executor:
|
| 211 |
futures = [
|
|
|
|
| 1 |
import os
|
| 2 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 3 |
|
|
|
|
|
|
|
| 4 |
import ml_collections
|
| 5 |
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
import streamlit as st
|
| 8 |
import torch
|
| 9 |
from huggingface_hub import hf_hub_download
|
| 10 |
|
| 11 |
+
import app_lib.multimodal as multimodal
|
| 12 |
from app_lib.ckde import cKDE
|
| 13 |
+
from app_lib.config import Config
|
| 14 |
+
from app_lib.config import Constants as c
|
| 15 |
from ibydmt.test import xSKIT
|
| 16 |
|
| 17 |
rng = np.random.default_rng()
|
| 18 |
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
@torch.no_grad()
|
| 21 |
@torch.cuda.amp.autocast()
|
| 22 |
+
def _encode_concepts(model, concepts):
|
| 23 |
+
concept_features = model.encode_text(concepts)
|
|
|
|
|
|
|
| 24 |
concept_features /= torch.linalg.norm(concept_features, dim=-1, keepdim=True)
|
| 25 |
return concept_features.cpu().numpy()
|
| 26 |
|
| 27 |
|
| 28 |
@torch.no_grad()
|
| 29 |
@torch.cuda.amp.autocast()
|
| 30 |
+
def _encode_image(model, image):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
image_features = model.encode_image(image)
|
| 32 |
image_features /= image_features.norm(dim=-1, keepdim=True)
|
| 33 |
return image_features.cpu().numpy()
|
|
|
|
| 35 |
|
| 36 |
@torch.no_grad()
|
| 37 |
@torch.cuda.amp.autocast()
|
| 38 |
+
def _encode_class_name(model, class_name):
|
| 39 |
+
class_text = [f"A photo of a {class_name}"]
|
|
|
|
| 40 |
class_features = model.encode_text(class_text)
|
| 41 |
class_features /= torch.linalg.norm(class_features, dim=-1, keepdim=True)
|
| 42 |
return class_features.cpu().numpy()
|
| 43 |
|
| 44 |
|
| 45 |
+
def _load_embedding(config):
|
| 46 |
dataset_path = hf_hub_download(
|
| 47 |
repo_id="jacopoteneggi/IBYDMT",
|
| 48 |
+
filename=(
|
| 49 |
+
f"{config.data.dataset.lower()}_train_{config.backbone_name()}.parquet"
|
| 50 |
+
),
|
| 51 |
repo_type="dataset",
|
| 52 |
)
|
| 53 |
|
| 54 |
+
dataset = pd.read_parquet(dataset_path)
|
| 55 |
+
return np.array(dataset["embedding"].values.tolist())
|
|
|
|
| 56 |
|
| 57 |
|
| 58 |
def _sample_random_subset(concept_idx, concepts, cardinality):
|
|
|
|
| 130 |
cardinality,
|
| 131 |
dataset_name,
|
| 132 |
model_name,
|
| 133 |
+
device=c.DEVICE,
|
| 134 |
with_streamlit=True,
|
| 135 |
):
|
| 136 |
+
config = Config()
|
| 137 |
+
config.data.dataset = dataset_name
|
| 138 |
+
config.data.backbone = model_name
|
| 139 |
+
|
| 140 |
if with_streamlit:
|
| 141 |
with st.spinner("Loading model"):
|
| 142 |
+
model = multimodal.get_model(config, device=device)
|
| 143 |
else:
|
| 144 |
+
model = multimodal.get_model(config, device=device)
|
| 145 |
|
| 146 |
if with_streamlit:
|
| 147 |
with st.spinner("Encoding concepts"):
|
| 148 |
+
cbm = _encode_concepts(model, concepts)
|
| 149 |
else:
|
| 150 |
+
cbm = _encode_concepts(model, concepts)
|
| 151 |
|
| 152 |
if with_streamlit:
|
| 153 |
with st.spinner("Encoding image"):
|
| 154 |
+
h = _encode_image(model, image)
|
| 155 |
else:
|
| 156 |
+
h = _encode_image(model, image)
|
| 157 |
z = h @ cbm.T
|
| 158 |
z = z.squeeze()
|
| 159 |
|
|
|
|
| 173 |
),
|
| 174 |
)
|
| 175 |
|
| 176 |
+
embedding = _load_embedding(config)
|
| 177 |
semantics = embedding @ cbm.T
|
| 178 |
sampler = cKDE(embedding, semantics)
|
| 179 |
|
| 180 |
+
classifier = _encode_class_name(model, class_name)
|
| 181 |
|
| 182 |
with ThreadPoolExecutor() as executor:
|
| 183 |
futures = [
|
app_lib/user_input.py
CHANGED
|
@@ -5,7 +5,7 @@ import streamlit as st
|
|
| 5 |
from PIL import Image
|
| 6 |
from streamlit_image_select import image_select
|
| 7 |
|
| 8 |
-
|
| 9 |
from app_lib.utils import SUPPORTED_DATASETS, SUPPORTED_MODELS
|
| 10 |
|
| 11 |
IMAGE_DIR = os.path.join("assets", "images")
|
|
@@ -31,8 +31,8 @@ def _validate_concepts(concepts):
|
|
| 31 |
|
| 32 |
|
| 33 |
def _get_significance_level():
|
| 34 |
-
default =
|
| 35 |
-
step =
|
| 36 |
return st.slider(
|
| 37 |
"Significance level",
|
| 38 |
help=f"The level of significance of the tests. Defaults to {default:.2F}.",
|
|
@@ -45,8 +45,8 @@ def _get_significance_level():
|
|
| 45 |
|
| 46 |
|
| 47 |
def _get_tau_max():
|
| 48 |
-
default =
|
| 49 |
-
step =
|
| 50 |
return int(
|
| 51 |
st.slider(
|
| 52 |
"Length of test",
|
|
@@ -61,8 +61,8 @@ def _get_tau_max():
|
|
| 61 |
|
| 62 |
|
| 63 |
def _get_number_of_tests():
|
| 64 |
-
default =
|
| 65 |
-
step =
|
| 66 |
return int(
|
| 67 |
st.slider(
|
| 68 |
"Number of tests per concept",
|
|
@@ -80,8 +80,8 @@ def _get_number_of_tests():
|
|
| 80 |
|
| 81 |
|
| 82 |
def _get_cardinality(concepts, concepts_ready):
|
| 83 |
-
default =
|
| 84 |
-
step =
|
| 85 |
return st.slider(
|
| 86 |
"Size of conditioning set",
|
| 87 |
help=(
|
|
@@ -98,7 +98,7 @@ def _get_cardinality(concepts, concepts_ready):
|
|
| 98 |
|
| 99 |
def _get_dataset_name():
|
| 100 |
options = SUPPORTED_DATASETS
|
| 101 |
-
default_idx = options.index(
|
| 102 |
return st.selectbox(
|
| 103 |
"Dataset",
|
| 104 |
options=options,
|
|
@@ -112,8 +112,8 @@ def _get_dataset_name():
|
|
| 112 |
|
| 113 |
|
| 114 |
def get_model_name():
|
| 115 |
-
options = list(SUPPORTED_MODELS
|
| 116 |
-
default_idx = options.index(
|
| 117 |
return st.selectbox(
|
| 118 |
"Model to test",
|
| 119 |
options=options,
|
|
|
|
| 5 |
from PIL import Image
|
| 6 |
from streamlit_image_select import image_select
|
| 7 |
|
| 8 |
+
from app_lib.defaults import Defaults as d
|
| 9 |
from app_lib.utils import SUPPORTED_DATASETS, SUPPORTED_MODELS
|
| 10 |
|
| 11 |
IMAGE_DIR = os.path.join("assets", "images")
|
|
|
|
| 31 |
|
| 32 |
|
| 33 |
def _get_significance_level():
|
| 34 |
+
default = d.SIGNIFICANCE_LEVEL_VALUE
|
| 35 |
+
step = d.SIGNIFICANCE_LEVEL_STEP
|
| 36 |
return st.slider(
|
| 37 |
"Significance level",
|
| 38 |
help=f"The level of significance of the tests. Defaults to {default:.2F}.",
|
|
|
|
| 45 |
|
| 46 |
|
| 47 |
def _get_tau_max():
|
| 48 |
+
default = d.TAU_MAX_VALUE
|
| 49 |
+
step = d.TAU_MAX_STEP
|
| 50 |
return int(
|
| 51 |
st.slider(
|
| 52 |
"Length of test",
|
|
|
|
| 61 |
|
| 62 |
|
| 63 |
def _get_number_of_tests():
|
| 64 |
+
default = d.R_VALUE
|
| 65 |
+
step = d.R_STEP
|
| 66 |
return int(
|
| 67 |
st.slider(
|
| 68 |
"Number of tests per concept",
|
|
|
|
| 80 |
|
| 81 |
|
| 82 |
def _get_cardinality(concepts, concepts_ready):
|
| 83 |
+
default = d.CARDINALITY_VALUE
|
| 84 |
+
step = d.CARDINALITY_STEP
|
| 85 |
return st.slider(
|
| 86 |
"Size of conditioning set",
|
| 87 |
help=(
|
|
|
|
| 98 |
|
| 99 |
def _get_dataset_name():
|
| 100 |
options = SUPPORTED_DATASETS
|
| 101 |
+
default_idx = options.index(d.DATASET_NAME)
|
| 102 |
return st.selectbox(
|
| 103 |
"Dataset",
|
| 104 |
options=options,
|
|
|
|
| 112 |
|
| 113 |
|
| 114 |
def get_model_name():
|
| 115 |
+
options = list(SUPPORTED_MODELS)
|
| 116 |
+
default_idx = options.index(d.MODEL_NAME)
|
| 117 |
return st.selectbox(
|
| 118 |
"Model to test",
|
| 119 |
options=options,
|
app_lib/utils.py
CHANGED
|
@@ -11,16 +11,8 @@ supported_datasets_path = hf_hub_download(
|
|
| 11 |
repo_type="dataset",
|
| 12 |
)
|
| 13 |
|
| 14 |
-
SUPPORTED_MODELS = {}
|
| 15 |
with open(supported_models_path, "r") as f:
|
| 16 |
-
|
| 17 |
-
line = line.strip()
|
| 18 |
-
model_name, model_url = line.split(",")
|
| 19 |
-
SUPPORTED_MODELS[model_name] = model_url
|
| 20 |
|
| 21 |
-
|
| 22 |
-
SUPPORTED_DATASETS = []
|
| 23 |
with open(supported_datasets_path, "r") as f:
|
| 24 |
-
|
| 25 |
-
dataset_name = line.strip()
|
| 26 |
-
SUPPORTED_DATASETS.append(dataset_name)
|
|
|
|
| 11 |
repo_type="dataset",
|
| 12 |
)
|
| 13 |
|
|
|
|
| 14 |
with open(supported_models_path, "r") as f:
|
| 15 |
+
SUPPORTED_MODELS = f.read().splitlines()
|
|
|
|
|
|
|
|
|
|
| 16 |
|
|
|
|
|
|
|
| 17 |
with open(supported_datasets_path, "r") as f:
|
| 18 |
+
SUPPORTED_DATASETS = f.read().splitlines()
|
|
|
|
|
|
assets/results/bowl_ace.npy
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 226871
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:adaeeda897451c5548b3119fb917214b82590becfe3138158cfcf1055bcb714d
|
| 3 |
size 226871
|
assets/results/gardener_ace.npy
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 226873
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6a7e6d291b8e7226da6af5990094501a741a99c363f04647688da6f8e71746c6
|
| 3 |
size 226873
|
assets/results/gentleman_ace.npy
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 226874
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fb55900dbb2870ffccc56aac8de6fde7960c3d817035bd9d461419ef0bee6b3b
|
| 3 |
size 226874
|
assets/results/mathematician_ace.npy
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 226873
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ab4807601def20a04c66a2ec993a90e929172aee047e66fb8d052d8eaee438b0
|
| 3 |
size 226873
|
precompute_results.py
CHANGED
|
@@ -4,7 +4,7 @@ import os
|
|
| 4 |
import numpy as np
|
| 5 |
from PIL import Image
|
| 6 |
|
| 7 |
-
|
| 8 |
from app_lib.test import get_testing_config, test
|
| 9 |
|
| 10 |
assets_dir = "assets"
|
|
@@ -13,9 +13,9 @@ results_dir = os.path.join(assets_dir, "results")
|
|
| 13 |
os.makedirs(results_dir, exist_ok=True)
|
| 14 |
|
| 15 |
testing_config = get_testing_config(
|
| 16 |
-
significance_level=
|
| 17 |
-
tau_max=
|
| 18 |
-
r=
|
| 19 |
)
|
| 20 |
|
| 21 |
image_presets = json.load(open(os.path.join(assets_dir, "image_presets.json")))
|
|
@@ -26,7 +26,7 @@ for _image_name, _image_presets in image_presets.items():
|
|
| 26 |
_image = Image.open(_image_path)
|
| 27 |
_class_name = _image_presets["class_name"]
|
| 28 |
_concepts = _image_presets["concepts"]
|
| 29 |
-
_cardinality =
|
| 30 |
|
| 31 |
_results = test(
|
| 32 |
testing_config,
|
|
@@ -34,8 +34,8 @@ for _image_name, _image_presets in image_presets.items():
|
|
| 34 |
_class_name,
|
| 35 |
_concepts,
|
| 36 |
_cardinality,
|
| 37 |
-
|
| 38 |
-
|
| 39 |
with_streamlit=False,
|
| 40 |
)
|
| 41 |
|
|
|
|
| 4 |
import numpy as np
|
| 5 |
from PIL import Image
|
| 6 |
|
| 7 |
+
from app_lib.defaults import Defaults as d
|
| 8 |
from app_lib.test import get_testing_config, test
|
| 9 |
|
| 10 |
assets_dir = "assets"
|
|
|
|
| 13 |
os.makedirs(results_dir, exist_ok=True)
|
| 14 |
|
| 15 |
testing_config = get_testing_config(
|
| 16 |
+
significance_level=d.SIGNIFICANCE_LEVEL_VALUE,
|
| 17 |
+
tau_max=d.TAU_MAX_VALUE,
|
| 18 |
+
r=d.R_VALUE,
|
| 19 |
)
|
| 20 |
|
| 21 |
image_presets = json.load(open(os.path.join(assets_dir, "image_presets.json")))
|
|
|
|
| 26 |
_image = Image.open(_image_path)
|
| 27 |
_class_name = _image_presets["class_name"]
|
| 28 |
_concepts = _image_presets["concepts"]
|
| 29 |
+
_cardinality = d.CARDINALITY_VALUE
|
| 30 |
|
| 31 |
_results = test(
|
| 32 |
testing_config,
|
|
|
|
| 34 |
_class_name,
|
| 35 |
_concepts,
|
| 36 |
_cardinality,
|
| 37 |
+
d.DATASET_NAME,
|
| 38 |
+
d.MODEL_NAME,
|
| 39 |
with_streamlit=False,
|
| 40 |
)
|
| 41 |
|