Samuel Stevens
commited on
Commit
·
7b4abf1
1
Parent(s):
484209d
Add zero-shot example.
Browse files- README.md +2 -0
- examples/README.md +18 -0
- examples/zero_shot.py +298 -0
README.md
CHANGED
|
@@ -36,6 +36,8 @@ It is trained on [TreeOfLife-10M](https://huggingface.co/datasets/imageomics/Tre
|
|
| 36 |
Through rigorous benchmarking on a diverse set of fine-grained biological classification tasks, BioCLIP consistently outperformed existing baselines by 17% to 20% absolute.
|
| 37 |
Through intrinsic evaluation, we found that BioCLIP learned a hierarchical representation aligned to the tree of life, which demonstrates its potential for robust generalizability.
|
| 38 |
|
|
|
|
|
|
|
| 39 |
## Model Details
|
| 40 |
|
| 41 |
### Model Description
|
|
|
|
| 36 |
Through rigorous benchmarking on a diverse set of fine-grained biological classification tasks, BioCLIP consistently outperformed existing baselines by 17% to 20% absolute.
|
| 37 |
Through intrinsic evaluation, we found that BioCLIP learned a hierarchical representation aligned to the tree of life, which demonstrates its potential for robust generalizability.
|
| 38 |
|
| 39 |
+
**See the `examples/` directory for examples of how to use BioCLIP in zero-shot and few-shot settings.**
|
| 40 |
+
|
| 41 |
## Model Details
|
| 42 |
|
| 43 |
### Model Description
|
examples/README.md
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Examples
|
| 2 |
+
|
| 3 |
+
## Zero-Shot Classification
|
| 4 |
+
|
| 5 |
+
```sh
|
| 6 |
+
pip install torch # whatever version you want
|
| 7 |
+
pip install open_clip_torch numpy tqdm torchvision
|
| 8 |
+
```
|
| 9 |
+
|
| 10 |
+
Suppose you want to evaluate BioCLIP on zero-shot classification on two tasks, `<DATASET-NAME>` and `<DATASET2-NAME>`.
|
| 11 |
+
You can use `examples/zero_shot.py` to get top1 and top5 accuracy assuming your tasks are arranged as `torchvision`'s [`ImageFolder`](https://pytorch.org/vision/stable/generated/torchvision.datasets.ImageFolder.html) wants.
|
| 12 |
+
|
| 13 |
+
```sh
|
| 14 |
+
python examples/zero_shot.py \
|
| 15 |
+
--datasets <DATASET-NAME>=<DATASET-FOLDER> <DATASET2-NAME>=<DATASET2-FOLDER>
|
| 16 |
+
```
|
| 17 |
+
|
| 18 |
+
This will write to `logs/bioclip-zero-shot/results.json` with your results.
|
examples/zero_shot.py
ADDED
|
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Do zero-shot image classification.
|
| 3 |
+
|
| 4 |
+
Writes the output to a plaintext and JSON format in the logs directory.
|
| 5 |
+
"""
|
| 6 |
+
import argparse
|
| 7 |
+
import ast
|
| 8 |
+
import contextlib
|
| 9 |
+
import json
|
| 10 |
+
import logging
|
| 11 |
+
import os
|
| 12 |
+
import random
|
| 13 |
+
import sys
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
import open_clip
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
+
from torchvision import datasets
|
| 20 |
+
from tqdm import tqdm
|
| 21 |
+
|
| 22 |
+
log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s"
|
| 23 |
+
logging.basicConfig(level=logging.INFO, format=log_format)
|
| 24 |
+
logger = logging.getLogger("main")
|
| 25 |
+
|
| 26 |
+
openai_templates = [
|
| 27 |
+
lambda c: f"a bad photo of a {c}.",
|
| 28 |
+
lambda c: f"a photo of many {c}.",
|
| 29 |
+
lambda c: f"a sculpture of a {c}.",
|
| 30 |
+
lambda c: f"a photo of the hard to see {c}.",
|
| 31 |
+
lambda c: f"a low resolution photo of the {c}.",
|
| 32 |
+
lambda c: f"a rendering of a {c}.",
|
| 33 |
+
lambda c: f"graffiti of a {c}.",
|
| 34 |
+
lambda c: f"a bad photo of the {c}.",
|
| 35 |
+
lambda c: f"a cropped photo of the {c}.",
|
| 36 |
+
lambda c: f"a tattoo of a {c}.",
|
| 37 |
+
lambda c: f"the embroidered {c}.",
|
| 38 |
+
lambda c: f"a photo of a hard to see {c}.",
|
| 39 |
+
lambda c: f"a bright photo of a {c}.",
|
| 40 |
+
lambda c: f"a photo of a clean {c}.",
|
| 41 |
+
lambda c: f"a photo of a dirty {c}.",
|
| 42 |
+
lambda c: f"a dark photo of the {c}.",
|
| 43 |
+
lambda c: f"a drawing of a {c}.",
|
| 44 |
+
lambda c: f"a photo of my {c}.",
|
| 45 |
+
lambda c: f"the plastic {c}.",
|
| 46 |
+
lambda c: f"a photo of the cool {c}.",
|
| 47 |
+
lambda c: f"a close-up photo of a {c}.",
|
| 48 |
+
lambda c: f"a black and white photo of the {c}.",
|
| 49 |
+
lambda c: f"a painting of the {c}.",
|
| 50 |
+
lambda c: f"a painting of a {c}.",
|
| 51 |
+
lambda c: f"a pixelated photo of the {c}.",
|
| 52 |
+
lambda c: f"a sculpture of the {c}.",
|
| 53 |
+
lambda c: f"a bright photo of the {c}.",
|
| 54 |
+
lambda c: f"a cropped photo of a {c}.",
|
| 55 |
+
lambda c: f"a plastic {c}.",
|
| 56 |
+
lambda c: f"a photo of the dirty {c}.",
|
| 57 |
+
lambda c: f"a jpeg corrupted photo of a {c}.",
|
| 58 |
+
lambda c: f"a blurry photo of the {c}.",
|
| 59 |
+
lambda c: f"a photo of the {c}.",
|
| 60 |
+
lambda c: f"a good photo of the {c}.",
|
| 61 |
+
lambda c: f"a rendering of the {c}.",
|
| 62 |
+
lambda c: f"a {c} in a video game.",
|
| 63 |
+
lambda c: f"a photo of one {c}.",
|
| 64 |
+
lambda c: f"a doodle of a {c}.",
|
| 65 |
+
lambda c: f"a close-up photo of the {c}.",
|
| 66 |
+
lambda c: f"a photo of a {c}.",
|
| 67 |
+
lambda c: f"the origami {c}.",
|
| 68 |
+
lambda c: f"the {c} in a video game.",
|
| 69 |
+
lambda c: f"a sketch of a {c}.",
|
| 70 |
+
lambda c: f"a doodle of the {c}.",
|
| 71 |
+
lambda c: f"a origami {c}.",
|
| 72 |
+
lambda c: f"a low resolution photo of a {c}.",
|
| 73 |
+
lambda c: f"the toy {c}.",
|
| 74 |
+
lambda c: f"a rendition of the {c}.",
|
| 75 |
+
lambda c: f"a photo of the clean {c}.",
|
| 76 |
+
lambda c: f"a photo of a large {c}.",
|
| 77 |
+
lambda c: f"a rendition of a {c}.",
|
| 78 |
+
lambda c: f"a photo of a nice {c}.",
|
| 79 |
+
lambda c: f"a photo of a weird {c}.",
|
| 80 |
+
lambda c: f"a blurry photo of a {c}.",
|
| 81 |
+
lambda c: f"a cartoon {c}.",
|
| 82 |
+
lambda c: f"art of a {c}.",
|
| 83 |
+
lambda c: f"a sketch of the {c}.",
|
| 84 |
+
lambda c: f"a embroidered {c}.",
|
| 85 |
+
lambda c: f"a pixelated photo of a {c}.",
|
| 86 |
+
lambda c: f"itap of the {c}.",
|
| 87 |
+
lambda c: f"a jpeg corrupted photo of the {c}.",
|
| 88 |
+
lambda c: f"a good photo of a {c}.",
|
| 89 |
+
lambda c: f"a plushie {c}.",
|
| 90 |
+
lambda c: f"a photo of the nice {c}.",
|
| 91 |
+
lambda c: f"a photo of the small {c}.",
|
| 92 |
+
lambda c: f"a photo of the weird {c}.",
|
| 93 |
+
lambda c: f"the cartoon {c}.",
|
| 94 |
+
lambda c: f"art of the {c}.",
|
| 95 |
+
lambda c: f"a drawing of the {c}.",
|
| 96 |
+
lambda c: f"a photo of the large {c}.",
|
| 97 |
+
lambda c: f"a black and white photo of a {c}.",
|
| 98 |
+
lambda c: f"the plushie {c}.",
|
| 99 |
+
lambda c: f"a dark photo of a {c}.",
|
| 100 |
+
lambda c: f"itap of a {c}.",
|
| 101 |
+
lambda c: f"graffiti of the {c}.",
|
| 102 |
+
lambda c: f"a toy {c}.",
|
| 103 |
+
lambda c: f"itap of my {c}.",
|
| 104 |
+
lambda c: f"a photo of a cool {c}.",
|
| 105 |
+
lambda c: f"a photo of a small {c}.",
|
| 106 |
+
lambda c: f"a tattoo of the {c}.",
|
| 107 |
+
]
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def parse_args(args):
|
| 111 |
+
class ParseKwargs(argparse.Action):
|
| 112 |
+
def __call__(self, parser, namespace, values, option_string=None):
|
| 113 |
+
kw = {}
|
| 114 |
+
for value in values:
|
| 115 |
+
key, value = value.split("=")
|
| 116 |
+
try:
|
| 117 |
+
kw[key] = ast.literal_eval(value)
|
| 118 |
+
except (ValueError, SyntaxError):
|
| 119 |
+
# fallback to string (avoid need to escape on command line)
|
| 120 |
+
kw[key] = str(value)
|
| 121 |
+
setattr(namespace, self.dest, kw)
|
| 122 |
+
|
| 123 |
+
parser = argparse.ArgumentParser()
|
| 124 |
+
parser.add_argument(
|
| 125 |
+
"--datasets",
|
| 126 |
+
type=str,
|
| 127 |
+
default=None,
|
| 128 |
+
nargs="+",
|
| 129 |
+
help="Path to dirs(s) with validation data. In the format NAME=PATH.",
|
| 130 |
+
action=ParseKwargs,
|
| 131 |
+
)
|
| 132 |
+
parser.add_argument(
|
| 133 |
+
"--logs", type=str, default="./logs", help="Where to write logs"
|
| 134 |
+
)
|
| 135 |
+
parser.add_argument(
|
| 136 |
+
"--exp", type=str, default="bioclip-zero-shot", help="Experiment name."
|
| 137 |
+
)
|
| 138 |
+
parser.add_argument(
|
| 139 |
+
"--workers", type=int, default=8, help="Number of dataloader workers per GPU."
|
| 140 |
+
)
|
| 141 |
+
parser.add_argument(
|
| 142 |
+
"--batch-size", type=int, default=64, help="Batch size per GPU."
|
| 143 |
+
)
|
| 144 |
+
parser.add_argument(
|
| 145 |
+
"--precision",
|
| 146 |
+
choices=["amp", "amp_bf16", "amp_bfloat16", "bf16", "fp32"],
|
| 147 |
+
default="amp",
|
| 148 |
+
help="Floating point precision.",
|
| 149 |
+
)
|
| 150 |
+
parser.add_argument("--seed", type=int, default=0, help="Default random seed.")
|
| 151 |
+
args = parser.parse_args(args)
|
| 152 |
+
os.makedirs(os.path.join(args.logs, args.exp), exist_ok=True)
|
| 153 |
+
|
| 154 |
+
return args
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def make_txt_features(model, classnames, templates, args):
|
| 158 |
+
tokenizer = open_clip.get_tokenizer("hf-hub:imageomics/bioclip")
|
| 159 |
+
with torch.no_grad():
|
| 160 |
+
txt_features = []
|
| 161 |
+
for classname in tqdm(classnames):
|
| 162 |
+
classname = " ".join(word for word in classname.split("_") if word)
|
| 163 |
+
texts = [template(classname) for template in templates] # format with class
|
| 164 |
+
texts = tokenizer(texts).to(args.device) # tokenize
|
| 165 |
+
class_embeddings = model.encode_text(texts)
|
| 166 |
+
class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0)
|
| 167 |
+
class_embedding /= class_embedding.norm()
|
| 168 |
+
txt_features.append(class_embedding)
|
| 169 |
+
txt_features = torch.stack(txt_features, dim=1).to(args.device)
|
| 170 |
+
return txt_features
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def accuracy(output, target, topk=(1,)):
|
| 174 |
+
pred = output.topk(max(topk), 1, True, True)[1].t()
|
| 175 |
+
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
| 176 |
+
return [correct[:k].reshape(-1).float().sum(0, keepdim=True).item() for k in topk]
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def get_autocast(precision):
|
| 180 |
+
if precision == "amp":
|
| 181 |
+
return torch.cuda.amp.autocast
|
| 182 |
+
elif precision == "amp_bfloat16" or precision == "amp_bf16":
|
| 183 |
+
# amp_bfloat16 is more stable than amp float16 for clip training
|
| 184 |
+
return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16)
|
| 185 |
+
else:
|
| 186 |
+
return contextlib.suppress
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def run(model, txt_features, dataloader, args):
|
| 190 |
+
autocast = get_autocast(args.precision)
|
| 191 |
+
cast_dtype = open_clip.get_cast_dtype(args.precision)
|
| 192 |
+
|
| 193 |
+
top1, top5, n = 0.0, 0.0, 0.0
|
| 194 |
+
|
| 195 |
+
with torch.no_grad():
|
| 196 |
+
for images, targets in tqdm(dataloader, unit_scale=args.batch_size):
|
| 197 |
+
images = images.to(args.device)
|
| 198 |
+
if cast_dtype is not None:
|
| 199 |
+
images = images.to(dtype=cast_dtype)
|
| 200 |
+
targets = targets.to(args.device)
|
| 201 |
+
|
| 202 |
+
with autocast():
|
| 203 |
+
image_features = model.encode_image(images)
|
| 204 |
+
image_features = F.normalize(image_features, dim=-1)
|
| 205 |
+
logits = model.logit_scale.exp() * image_features @ txt_features
|
| 206 |
+
|
| 207 |
+
# Measure accuracy
|
| 208 |
+
acc1, acc5 = accuracy(logits, targets, topk=(1, 5))
|
| 209 |
+
top1 += acc1
|
| 210 |
+
top5 += acc5
|
| 211 |
+
n += images.size(0)
|
| 212 |
+
|
| 213 |
+
top1 = top1 / n
|
| 214 |
+
top5 = top5 / n
|
| 215 |
+
return top1, top5
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def evaluate(model, data, args):
|
| 219 |
+
results = {}
|
| 220 |
+
|
| 221 |
+
logger.info("Starting zero-shot classification.")
|
| 222 |
+
|
| 223 |
+
for split in data:
|
| 224 |
+
logger.info("Building zero-shot %s classifier.", split)
|
| 225 |
+
|
| 226 |
+
classnames = data[split].dataset.classes
|
| 227 |
+
classnames = [name.replace("_", " ") for name in classnames]
|
| 228 |
+
|
| 229 |
+
txt_features = make_txt_features(model, classnames, openai_templates, args)
|
| 230 |
+
|
| 231 |
+
logger.info("Got text features.")
|
| 232 |
+
top1, top5 = run(model, txt_features, data[split], args)
|
| 233 |
+
|
| 234 |
+
logger.info("%s-top1: %.3f", split, top1 * 100)
|
| 235 |
+
logger.info("%s-top5: %.3f", split, top5 * 100)
|
| 236 |
+
|
| 237 |
+
results[f"{split}-top1"] = top1 * 100
|
| 238 |
+
results[f"{split}-top5"] = top5 * 100
|
| 239 |
+
|
| 240 |
+
logger.info("Finished zero-shot %s.", split)
|
| 241 |
+
|
| 242 |
+
logger.info("Finished zero-shot classification.")
|
| 243 |
+
|
| 244 |
+
return results
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
if __name__ == "__main__":
|
| 248 |
+
args = parse_args(sys.argv[1:])
|
| 249 |
+
|
| 250 |
+
if torch.cuda.is_available():
|
| 251 |
+
# This enables tf32 on Ampere GPUs which is only 8% slower than
|
| 252 |
+
# float16 and almost as accurate as float32
|
| 253 |
+
# This was a default in pytorch until 1.12
|
| 254 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 255 |
+
torch.backends.cudnn.benchmark = True
|
| 256 |
+
torch.backends.cudnn.deterministic = False
|
| 257 |
+
|
| 258 |
+
# Init torch device
|
| 259 |
+
if torch.cuda.is_available():
|
| 260 |
+
device = "cuda:0"
|
| 261 |
+
torch.cuda.set_device(device)
|
| 262 |
+
else:
|
| 263 |
+
device = "cpu"
|
| 264 |
+
args.device = device
|
| 265 |
+
|
| 266 |
+
# Random seeding
|
| 267 |
+
torch.manual_seed(args.seed)
|
| 268 |
+
np.random.seed(args.seed)
|
| 269 |
+
random.seed(args.seed)
|
| 270 |
+
|
| 271 |
+
# Load model.
|
| 272 |
+
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(
|
| 273 |
+
"hf-hub:imageomics/bioclip"
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
# Write datasets
|
| 277 |
+
params_file = os.path.join(args.logs, args.exp, "params.json")
|
| 278 |
+
with open(params_file, "w") as fd:
|
| 279 |
+
params = {name: getattr(args, name) for name in vars(args)}
|
| 280 |
+
json.dump(params, fd, sort_keys=True, indent=4)
|
| 281 |
+
|
| 282 |
+
# Initialize datasets.
|
| 283 |
+
data = {}
|
| 284 |
+
for split, path in args.datasets.items():
|
| 285 |
+
data[split] = torch.utils.data.DataLoader(
|
| 286 |
+
datasets.ImageFolder(path, transform=preprocess_val),
|
| 287 |
+
batch_size=args.batch_size,
|
| 288 |
+
num_workers=args.workers,
|
| 289 |
+
sampler=None,
|
| 290 |
+
shuffle=False,
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
model.eval()
|
| 294 |
+
results = evaluate(model, data, args)
|
| 295 |
+
|
| 296 |
+
results_file = os.path.join(args.logs, args.exp, "results.json")
|
| 297 |
+
with open(results_file, "w") as fd:
|
| 298 |
+
json.dump(results, fd, indent=4, sort_keys=True)
|