Add files using upload-large-folder tool
Browse files- .gitattributes +2 -0
- README.md +163 -0
- ViT-L-14-336_register.json +116 -0
- __init__.py +8 -0
- __pycache__/imagenet_classes.cpython-310.pyc +0 -0
- __pycache__/misc.cpython-310.pyc +0 -0
- __pycache__/model.cpython-310.pyc +0 -0
- __pycache__/modified_resnet.cpython-310.pyc +0 -0
- __pycache__/shared.cpython-310.pyc +0 -0
- __pycache__/timm_model.cpython-310.pyc +0 -0
- __pycache__/tokenizer.cpython-310.pyc +0 -0
- __pycache__/transformer.cpython-310.pyc +0 -0
- config.json +29 -0
- config_TTR_bak.json +145 -0
- config_bak.json +110 -0
- constants.py +2 -0
- factory.py +390 -0
- imagenet_classes.py +85 -0
- merges.txt +0 -0
- misc.py +114 -0
- model.py +431 -0
- model_sanity_check.ipynb +285 -0
- modeling_custom_clip.py +357 -0
- modified_resnet.py +181 -0
- neuron_indices.json +1 -0
- openai_models.py +91 -0
- openai_templates.py +84 -0
- preprocess.py +42 -0
- preprocessor_config.json +19 -0
- preprocessor_config_bak.json +16 -0
- pretrained.py +426 -0
- pytorch_model.bin +3 -0
- requirements.txt +5 -0
- shared.py +616 -0
- special_tokens_map.json +1 -0
- timm_model.py +149 -0
- tokenizer.py +214 -0
- tokenizer_config.json +34 -0
- tokenizer_config_bak.json +10 -0
- transform.py +133 -0
- transformer.py +872 -0
- utils.py +34 -0
- utils/utils.py +34 -0
- vitl14_attention.png +3 -0
- vitl14_patchnorms.png +3 -0
- vocab.json +0 -0
- vocab/bpe_simple_vocab_16e6.txt.gz +3 -0
- zeroshot_classifier.pt +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
vitl14_attention.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
vitl14_patchnorms.png filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
library_name: transformers
|
| 3 |
+
license: mit
|
| 4 |
+
pipeline_tag: image-feature-extraction
|
| 5 |
+
tags:
|
| 6 |
+
- clip
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
# OpenCLIP ViT-L/14 with Test-Time Register
|
| 10 |
+
|
| 11 |
+
Register tokens in ViTs were introduced as learnable tokens in [Vision Transformers Need Registers](https://arxiv.org/abs/2309.16588) to mitigate artifacts in intermediate feature maps.
|
| 12 |
+
In [Vision Transformers Don't Need *Trained* Registers](https://arxiv.org/abs/2506.08010), we introduced a training-free method to create registers. These *test-time registers* serve a similar purpose
|
| 13 |
+
as the original trained registers, but can be added post-hoc to any ViT to mitigate artifacts, enhance model interpretability, and modestly improve downstream performance in tasks such as segmentation, depth estimation, etc.
|
| 14 |
+
|
| 15 |
+
## Model description
|
| 16 |
+
|
| 17 |
+
The base model is [OpenCLIP-ViT-L-14-laion2B-s32B-b82K](https://huggingface.co/laion/CLIP-ViT-L-14-laion2B-s32B-b82K). With test-time registers, the model's internal representations
|
| 18 |
+
are cleaner (see below). Using the environment from [here](https://github.com/nickjiang2378/test-time-registers/blob/main/environment.yml) and evaluating using bfloat16 leads to IN-1k zeroshot performance of 76.4 for both the original model and the variant with test-time registers.
|
| 19 |
+
This model is intended to be used with this [repo](https://github.com/nickjiang2378/test-time-registers). Use transformers==4.45.1. The model can also be used for fine-tuning or other downstream tasks.
|
| 20 |
+
|
| 21 |
+
<img src="https://huggingface.co/amildravid4292/clip-vitl14-test-time-registers/resolve/main/vitl14_attention.png" alt="drawing" width="600"/>
|
| 22 |
+
<img src="https://huggingface.co/amildravid4292/clip-vitl14-test-time-registers/resolve/main/vitl14_patchnorms.png" alt="drawing" width="600"/>
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
## Quick Start
|
| 28 |
+
|
| 29 |
+
```python
|
| 30 |
+
from transformers import AutoModel
|
| 31 |
+
from PIL import Image
|
| 32 |
+
import torch
|
| 33 |
+
|
| 34 |
+
# Load the complete model with all components
|
| 35 |
+
model = AutoModel.from_pretrained(
|
| 36 |
+
"amildravid4292/clip-vitl14-test-time-registers",
|
| 37 |
+
trust_remote_code=True
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
# Check what was loaded
|
| 41 |
+
print(f"Register tokens: {model.num_register_tokens}")
|
| 42 |
+
print(f"Neuron dict: {model.neuron_dict}")
|
| 43 |
+
print(f"Tokenizer available: {model.tokenizer is not None}")
|
| 44 |
+
print(f"Preprocessor available: {model.preprocessor is not None}")
|
| 45 |
+
print(f"Zero-shot classifier available: {model.zeroshot_classifier is not None}")
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
## Usage Examples
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
### Image Processing
|
| 53 |
+
```python
|
| 54 |
+
from PIL import Image
|
| 55 |
+
|
| 56 |
+
# Load and preprocess image
|
| 57 |
+
image = Image.open("your_image.jpg")
|
| 58 |
+
image_tensor = model.preprocess_image(image).unsqueeze(0)
|
| 59 |
+
|
| 60 |
+
image_features = model.encode_image(
|
| 61 |
+
image_tensor
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# to run inference with the original model without test-time registers
|
| 65 |
+
image_features = model.encode_image(
|
| 66 |
+
image_tensor,
|
| 67 |
+
neuron_dict=None,
|
| 68 |
+
num_register_tokens=0
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
### Text Processing
|
| 74 |
+
```python
|
| 75 |
+
# Tokenize text
|
| 76 |
+
text = ["a photo of a cat", "a photo of a dog"]
|
| 77 |
+
text_tokens = model.tokenize(text)
|
| 78 |
+
|
| 79 |
+
# Encode text
|
| 80 |
+
text_features = model.encode_text(text_tokens)
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
### Complete Pipeline
|
| 86 |
+
```python
|
| 87 |
+
|
| 88 |
+
# load model
|
| 89 |
+
model = AutoModel.from_pretrained('amildravid4292/clip-vitl14-test-time-registers', trust_remote_code=True)
|
| 90 |
+
model = model.to(device).bfloat16()
|
| 91 |
+
classifier = model.zeroshot_classifier.to(device).bfloat16()
|
| 92 |
+
|
| 93 |
+
# load data
|
| 94 |
+
imagenet_dataset = ImageNet(root='/datasets/ilsvrc/current', split='val', transform=model.preprocessor)
|
| 95 |
+
ground_truth_labels = [imagenet_dataset.targets[i] for i in range(len(imagenet_dataset))]
|
| 96 |
+
loader = torch.utils.data.DataLoader(imagenet_dataset, batch_size=100, num_workers=4, pin_memory=True, shuffle=False)
|
| 97 |
+
|
| 98 |
+
# run zero-shot classification
|
| 99 |
+
with torch.no_grad():
|
| 100 |
+
correct = [0, 0]
|
| 101 |
+
for i, (images, target) in enumerate(tqdm(loader)):
|
| 102 |
+
images = images.to(device).bfloat16()
|
| 103 |
+
|
| 104 |
+
target = target.to(device).bfloat16()
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
# predict
|
| 108 |
+
image_features = model.encode_image(images)
|
| 109 |
+
|
| 110 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
| 111 |
+
logits = 100. * image_features @ classifier
|
| 112 |
+
|
| 113 |
+
pred = logits.argmax(dim=-1)
|
| 114 |
+
correct[0] += (pred == target).sum().item()
|
| 115 |
+
correct[1] += target.size(0)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
print(correct[0]/correct[1])
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
## Advanced Usage
|
| 123 |
+
|
| 124 |
+
### Custom Neuron Modifications
|
| 125 |
+
```python
|
| 126 |
+
# Override the saved neuron configuration
|
| 127 |
+
custom_neuron_dict = {0: [10, 20, 30]} # Modify neurons 10,20,30 in layer 0
|
| 128 |
+
|
| 129 |
+
image_features = model.encode_image(
|
| 130 |
+
image_tensor,
|
| 131 |
+
num_register_tokens=4,
|
| 132 |
+
neuron_dict=custom_neuron_dict
|
| 133 |
+
)
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
### Different Register Token Counts
|
| 137 |
+
```python
|
| 138 |
+
# Use different number of register tokens
|
| 139 |
+
image_features = model.encode_image(
|
| 140 |
+
image_tensor,
|
| 141 |
+
num_register_tokens=8 # Override the default
|
| 142 |
+
)
|
| 143 |
+
```
|
| 144 |
+
|
| 145 |
+
## Model Details
|
| 146 |
+
|
| 147 |
+
- **Base Architecture**: ViT-L/14
|
| 148 |
+
- **Training Data**: LAION-2B subset
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
### BibTeX entry and citation info
|
| 152 |
+
|
| 153 |
+
```bibtex
|
| 154 |
+
@misc{jiang2025visiontransformersdontneed,
|
| 155 |
+
title={Vision Transformers Don't Need Trained Registers},
|
| 156 |
+
author={Nick Jiang and Amil Dravid and Alexei Efros and Yossi Gandelsman},
|
| 157 |
+
year={2025},
|
| 158 |
+
eprint={2506.08010},
|
| 159 |
+
archivePrefix={arXiv},
|
| 160 |
+
primaryClass={cs.CV},
|
| 161 |
+
url={https://arxiv.org/abs/2506.08010},
|
| 162 |
+
}
|
| 163 |
+
```
|
ViT-L-14-336_register.json
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"9": [
|
| 3 |
+
815,
|
| 4 |
+
4078,
|
| 5 |
+
3618,
|
| 6 |
+
2693,
|
| 7 |
+
3973,
|
| 8 |
+
1744,
|
| 9 |
+
1983,
|
| 10 |
+
1157,
|
| 11 |
+
1309,
|
| 12 |
+
1335,
|
| 13 |
+
2607,
|
| 14 |
+
2396,
|
| 15 |
+
3049,
|
| 16 |
+
1610,
|
| 17 |
+
2621,
|
| 18 |
+
2867,
|
| 19 |
+
2012,
|
| 20 |
+
1924,
|
| 21 |
+
2394,
|
| 22 |
+
3097,
|
| 23 |
+
3125,
|
| 24 |
+
3959,
|
| 25 |
+
3210,
|
| 26 |
+
2855,
|
| 27 |
+
3609,
|
| 28 |
+
526,
|
| 29 |
+
3362,
|
| 30 |
+
3395,
|
| 31 |
+
2626,
|
| 32 |
+
503,
|
| 33 |
+
2941,
|
| 34 |
+
3696,
|
| 35 |
+
1823,
|
| 36 |
+
2000,
|
| 37 |
+
129,
|
| 38 |
+
3667,
|
| 39 |
+
1372,
|
| 40 |
+
147,
|
| 41 |
+
1150,
|
| 42 |
+
852,
|
| 43 |
+
3222
|
| 44 |
+
],
|
| 45 |
+
"8": [
|
| 46 |
+
745,
|
| 47 |
+
3249,
|
| 48 |
+
2585,
|
| 49 |
+
1537,
|
| 50 |
+
200,
|
| 51 |
+
1603,
|
| 52 |
+
1851,
|
| 53 |
+
3523,
|
| 54 |
+
3697,
|
| 55 |
+
3137,
|
| 56 |
+
2563,
|
| 57 |
+
2293,
|
| 58 |
+
730,
|
| 59 |
+
906,
|
| 60 |
+
1528,
|
| 61 |
+
3348,
|
| 62 |
+
2438,
|
| 63 |
+
1564,
|
| 64 |
+
1540,
|
| 65 |
+
3238,
|
| 66 |
+
3606
|
| 67 |
+
],
|
| 68 |
+
"10": [
|
| 69 |
+
357,
|
| 70 |
+
1654,
|
| 71 |
+
3940,
|
| 72 |
+
2319,
|
| 73 |
+
2560,
|
| 74 |
+
2559,
|
| 75 |
+
4009,
|
| 76 |
+
3029,
|
| 77 |
+
951,
|
| 78 |
+
1903,
|
| 79 |
+
738,
|
| 80 |
+
1602,
|
| 81 |
+
1807,
|
| 82 |
+
2018,
|
| 83 |
+
1281,
|
| 84 |
+
267,
|
| 85 |
+
3539,
|
| 86 |
+
1015,
|
| 87 |
+
496,
|
| 88 |
+
693,
|
| 89 |
+
2278,
|
| 90 |
+
7,
|
| 91 |
+
856,
|
| 92 |
+
2785,
|
| 93 |
+
2690,
|
| 94 |
+
1367
|
| 95 |
+
],
|
| 96 |
+
"7": [
|
| 97 |
+
3228,
|
| 98 |
+
2550,
|
| 99 |
+
2977,
|
| 100 |
+
3716,
|
| 101 |
+
2467
|
| 102 |
+
],
|
| 103 |
+
"0": [
|
| 104 |
+
2890,
|
| 105 |
+
1779,
|
| 106 |
+
3761
|
| 107 |
+
],
|
| 108 |
+
"6": [
|
| 109 |
+
1042,
|
| 110 |
+
2315,
|
| 111 |
+
1674
|
| 112 |
+
],
|
| 113 |
+
"3": [
|
| 114 |
+
410
|
| 115 |
+
]
|
| 116 |
+
}
|
__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
|
| 2 |
+
from factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss
|
| 3 |
+
from factory import list_models, add_model_config, get_model_config, load_checkpoint
|
| 4 |
+
from pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \
|
| 5 |
+
get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
|
| 6 |
+
from tokenizer import SimpleTokenizer, tokenize, decode
|
| 7 |
+
from transform import image_transform, AugmentationCfg
|
| 8 |
+
from openai_templates import OPENAI_IMAGENET_TEMPLATES
|
__pycache__/imagenet_classes.cpython-310.pyc
ADDED
|
Binary file (21.7 kB). View file
|
|
|
__pycache__/misc.cpython-310.pyc
ADDED
|
Binary file (4.05 kB). View file
|
|
|
__pycache__/model.cpython-310.pyc
ADDED
|
Binary file (12.1 kB). View file
|
|
|
__pycache__/modified_resnet.cpython-310.pyc
ADDED
|
Binary file (6.39 kB). View file
|
|
|
__pycache__/shared.cpython-310.pyc
ADDED
|
Binary file (15.9 kB). View file
|
|
|
__pycache__/timm_model.cpython-310.pyc
ADDED
|
Binary file (4.04 kB). View file
|
|
|
__pycache__/tokenizer.cpython-310.pyc
ADDED
|
Binary file (8.63 kB). View file
|
|
|
__pycache__/transformer.cpython-310.pyc
ADDED
|
Binary file (23.1 kB). View file
|
|
|
config.json
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "custom_clip_with_registers",
|
| 3 |
+
|
| 4 |
+
"processor_class": "CLIPProcessor",
|
| 5 |
+
"tokenizer_class": "CLIPTokenizerFast",
|
| 6 |
+
|
| 7 |
+
"architectures": ["CustomCLIPModel"],
|
| 8 |
+
"auto_map": {
|
| 9 |
+
"AutoConfig": "modeling_custom_clip.CustomCLIPConfig",
|
| 10 |
+
"AutoModel": "modeling_custom_clip.CustomCLIPModel"
|
| 11 |
+
},
|
| 12 |
+
"vision_config": {
|
| 13 |
+
"hidden_size": 1024,
|
| 14 |
+
"num_hidden_layers": 24,
|
| 15 |
+
"num_attention_heads": 16,
|
| 16 |
+
"image_size": 336,
|
| 17 |
+
"patch_size": 14
|
| 18 |
+
},
|
| 19 |
+
"text_config": {
|
| 20 |
+
"vocab_size": 49408,
|
| 21 |
+
"hidden_size": 768,
|
| 22 |
+
"num_hidden_layers": 12,
|
| 23 |
+
"max_position_embeddings": 77
|
| 24 |
+
},
|
| 25 |
+
"neuron_dict": {},
|
| 26 |
+
"projection_dim": 768,
|
| 27 |
+
"torch_dtype": "float32",
|
| 28 |
+
"transformers_version": "4.21.0"
|
| 29 |
+
}
|
config_TTR_bak.json
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "custom_clip_with_registers",
|
| 3 |
+
|
| 4 |
+
"processor_class": "CLIPProcessor",
|
| 5 |
+
"tokenizer_class": "CLIPTokenizerFast",
|
| 6 |
+
|
| 7 |
+
"architectures": ["CustomCLIPModel"],
|
| 8 |
+
"auto_map": {
|
| 9 |
+
"AutoConfig": "modeling_custom_clip.CustomCLIPConfig",
|
| 10 |
+
"AutoModel": "modeling_custom_clip.CustomCLIPModel"
|
| 11 |
+
},
|
| 12 |
+
"vision_config": {
|
| 13 |
+
"hidden_size": 1024,
|
| 14 |
+
"num_hidden_layers": 24,
|
| 15 |
+
"num_attention_heads": 16,
|
| 16 |
+
"image_size": 336,
|
| 17 |
+
"patch_size": 14
|
| 18 |
+
},
|
| 19 |
+
"text_config": {
|
| 20 |
+
"vocab_size": 49408,
|
| 21 |
+
"hidden_size": 768,
|
| 22 |
+
"num_hidden_layers": 12,
|
| 23 |
+
"max_position_embeddings": 77
|
| 24 |
+
},
|
| 25 |
+
"num_register_tokens": 1,
|
| 26 |
+
"neuron_dict": {
|
| 27 |
+
"9": [
|
| 28 |
+
815,
|
| 29 |
+
4078,
|
| 30 |
+
3618,
|
| 31 |
+
2693,
|
| 32 |
+
3973,
|
| 33 |
+
1744,
|
| 34 |
+
1983,
|
| 35 |
+
1157,
|
| 36 |
+
1309,
|
| 37 |
+
1335,
|
| 38 |
+
2607,
|
| 39 |
+
2396,
|
| 40 |
+
3049,
|
| 41 |
+
1610,
|
| 42 |
+
2621,
|
| 43 |
+
2867,
|
| 44 |
+
2012,
|
| 45 |
+
1924,
|
| 46 |
+
2394,
|
| 47 |
+
3097,
|
| 48 |
+
3125,
|
| 49 |
+
3959,
|
| 50 |
+
3210,
|
| 51 |
+
2855,
|
| 52 |
+
3609,
|
| 53 |
+
526,
|
| 54 |
+
3362,
|
| 55 |
+
3395,
|
| 56 |
+
2626,
|
| 57 |
+
503,
|
| 58 |
+
2941,
|
| 59 |
+
3696,
|
| 60 |
+
1823,
|
| 61 |
+
2000,
|
| 62 |
+
129,
|
| 63 |
+
3667,
|
| 64 |
+
1372,
|
| 65 |
+
147,
|
| 66 |
+
1150,
|
| 67 |
+
852,
|
| 68 |
+
3222
|
| 69 |
+
],
|
| 70 |
+
"8": [
|
| 71 |
+
745,
|
| 72 |
+
3249,
|
| 73 |
+
2585,
|
| 74 |
+
1537,
|
| 75 |
+
200,
|
| 76 |
+
1603,
|
| 77 |
+
1851,
|
| 78 |
+
3523,
|
| 79 |
+
3697,
|
| 80 |
+
3137,
|
| 81 |
+
2563,
|
| 82 |
+
2293,
|
| 83 |
+
730,
|
| 84 |
+
906,
|
| 85 |
+
1528,
|
| 86 |
+
3348,
|
| 87 |
+
2438,
|
| 88 |
+
1564,
|
| 89 |
+
1540,
|
| 90 |
+
3238,
|
| 91 |
+
3606
|
| 92 |
+
],
|
| 93 |
+
"10": [
|
| 94 |
+
357,
|
| 95 |
+
1654,
|
| 96 |
+
3940,
|
| 97 |
+
2319,
|
| 98 |
+
2560,
|
| 99 |
+
2559,
|
| 100 |
+
4009,
|
| 101 |
+
3029,
|
| 102 |
+
951,
|
| 103 |
+
1903,
|
| 104 |
+
738,
|
| 105 |
+
1602,
|
| 106 |
+
1807,
|
| 107 |
+
2018,
|
| 108 |
+
1281,
|
| 109 |
+
267,
|
| 110 |
+
3539,
|
| 111 |
+
1015,
|
| 112 |
+
496,
|
| 113 |
+
693,
|
| 114 |
+
2278,
|
| 115 |
+
7,
|
| 116 |
+
856,
|
| 117 |
+
2785,
|
| 118 |
+
2690,
|
| 119 |
+
1367
|
| 120 |
+
],
|
| 121 |
+
"7": [
|
| 122 |
+
3228,
|
| 123 |
+
2550,
|
| 124 |
+
2977,
|
| 125 |
+
3716,
|
| 126 |
+
2467
|
| 127 |
+
],
|
| 128 |
+
"0": [
|
| 129 |
+
2890,
|
| 130 |
+
1779,
|
| 131 |
+
3761
|
| 132 |
+
],
|
| 133 |
+
"6": [
|
| 134 |
+
1042,
|
| 135 |
+
2315,
|
| 136 |
+
1674
|
| 137 |
+
],
|
| 138 |
+
"3": [
|
| 139 |
+
410
|
| 140 |
+
]
|
| 141 |
+
},
|
| 142 |
+
"projection_dim": 768,
|
| 143 |
+
"torch_dtype": "float32",
|
| 144 |
+
"transformers_version": "4.21.0"
|
| 145 |
+
}
|
config_bak.json
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "custom_clip_with_registers",
|
| 3 |
+
"architectures": ["CustomCLIPModel"],
|
| 4 |
+
"auto_map": {
|
| 5 |
+
"AutoConfig": "modeling_custom_clip.CustomCLIPConfig",
|
| 6 |
+
"AutoModel": "modeling_custom_clip.CustomCLIPModel"
|
| 7 |
+
},
|
| 8 |
+
"vision_config": {
|
| 9 |
+
"hidden_size": 1024,
|
| 10 |
+
"num_hidden_layers": 24,
|
| 11 |
+
"num_attention_heads": 16,
|
| 12 |
+
"image_size": 336,
|
| 13 |
+
"patch_size": 14
|
| 14 |
+
},
|
| 15 |
+
"text_config": {
|
| 16 |
+
"vocab_size": 49408,
|
| 17 |
+
"hidden_size": 768,
|
| 18 |
+
"num_hidden_layers": 12,
|
| 19 |
+
"max_position_embeddings": 77
|
| 20 |
+
},
|
| 21 |
+
"num_register_tokens": 1,
|
| 22 |
+
"neuron_dict": {"10": [2924,
|
| 23 |
+
2520,
|
| 24 |
+
2936,
|
| 25 |
+
675,
|
| 26 |
+
517,
|
| 27 |
+
1610,
|
| 28 |
+
88,
|
| 29 |
+
1950,
|
| 30 |
+
3098,
|
| 31 |
+
4082,
|
| 32 |
+
1237,
|
| 33 |
+
857,
|
| 34 |
+
3020,
|
| 35 |
+
1321,
|
| 36 |
+
1128,
|
| 37 |
+
3561,
|
| 38 |
+
4091,
|
| 39 |
+
69,
|
| 40 |
+
3378,
|
| 41 |
+
2304,
|
| 42 |
+
977,
|
| 43 |
+
1762,
|
| 44 |
+
3598,
|
| 45 |
+
371,
|
| 46 |
+
1097],
|
| 47 |
+
"9": [1253,
|
| 48 |
+
3658,
|
| 49 |
+
1827,
|
| 50 |
+
2600,
|
| 51 |
+
4000,
|
| 52 |
+
711,
|
| 53 |
+
2726,
|
| 54 |
+
615,
|
| 55 |
+
2654,
|
| 56 |
+
831,
|
| 57 |
+
1,
|
| 58 |
+
1387,
|
| 59 |
+
2178,
|
| 60 |
+
1967,
|
| 61 |
+
2413,
|
| 62 |
+
901,
|
| 63 |
+
481,
|
| 64 |
+
1514,
|
| 65 |
+
292,
|
| 66 |
+
692,
|
| 67 |
+
3094,
|
| 68 |
+
3470,
|
| 69 |
+
932,
|
| 70 |
+
2129],
|
| 71 |
+
"8": [3189,
|
| 72 |
+
1491,
|
| 73 |
+
2159,
|
| 74 |
+
1196,
|
| 75 |
+
1913,
|
| 76 |
+
1340,
|
| 77 |
+
2515,
|
| 78 |
+
2163,
|
| 79 |
+
955,
|
| 80 |
+
1496,
|
| 81 |
+
1891,
|
| 82 |
+
1410,
|
| 83 |
+
3725,
|
| 84 |
+
632,
|
| 85 |
+
188,
|
| 86 |
+
726,
|
| 87 |
+
1592,
|
| 88 |
+
1017,
|
| 89 |
+
1267,
|
| 90 |
+
995,
|
| 91 |
+
3465,
|
| 92 |
+
3510,
|
| 93 |
+
1494,
|
| 94 |
+
3467,
|
| 95 |
+
1896,
|
| 96 |
+
2779,
|
| 97 |
+
2309,
|
| 98 |
+
3389,
|
| 99 |
+
3682,
|
| 100 |
+
1968,
|
| 101 |
+
2904],
|
| 102 |
+
"7": [2226, 2565],
|
| 103 |
+
"6": [1450, 1551, 1024],
|
| 104 |
+
"5": [151, 1282],
|
| 105 |
+
"4": [2207],
|
| 106 |
+
"3": [2298, 2841]},
|
| 107 |
+
"projection_dim": 768,
|
| 108 |
+
"torch_dtype": "float32",
|
| 109 |
+
"transformers_version": "4.21.0"
|
| 110 |
+
}
|
constants.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
| 2 |
+
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
|
factory.py
ADDED
|
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import pathlib
|
| 5 |
+
import re
|
| 6 |
+
from copy import deepcopy
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
|
| 13 |
+
from model import CLIP, convert_to_custom_text_state_dict,\
|
| 14 |
+
resize_pos_embed, get_cast_dtype
|
| 15 |
+
from openai_models import load_openai_model
|
| 16 |
+
from pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained,\
|
| 17 |
+
list_pretrained_tags_by_model, download_pretrained_from_hf
|
| 18 |
+
from transform import image_transform, AugmentationCfg
|
| 19 |
+
from tokenizer import HFTokenizer, tokenize
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
HF_HUB_PREFIX = 'hf-hub:'
|
| 23 |
+
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
|
| 24 |
+
_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _natural_key(string_):
|
| 28 |
+
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _rescan_model_configs():
|
| 32 |
+
global _MODEL_CONFIGS
|
| 33 |
+
|
| 34 |
+
config_ext = ('.json',)
|
| 35 |
+
config_files = []
|
| 36 |
+
for config_path in _MODEL_CONFIG_PATHS:
|
| 37 |
+
if config_path.is_file() and config_path.suffix in config_ext:
|
| 38 |
+
config_files.append(config_path)
|
| 39 |
+
elif config_path.is_dir():
|
| 40 |
+
for ext in config_ext:
|
| 41 |
+
config_files.extend(config_path.glob(f'*{ext}'))
|
| 42 |
+
|
| 43 |
+
for cf in config_files:
|
| 44 |
+
with open(cf, 'r') as f:
|
| 45 |
+
model_cfg = json.load(f)
|
| 46 |
+
if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
|
| 47 |
+
_MODEL_CONFIGS[cf.stem] = model_cfg
|
| 48 |
+
|
| 49 |
+
_MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
_rescan_model_configs() # initial populate of model config registry
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def list_models():
|
| 56 |
+
""" enumerate available model architectures based on config files """
|
| 57 |
+
return list(_MODEL_CONFIGS.keys())
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def add_model_config(path):
|
| 61 |
+
""" add model config path or file and update registry """
|
| 62 |
+
if not isinstance(path, Path):
|
| 63 |
+
path = Path(path)
|
| 64 |
+
_MODEL_CONFIG_PATHS.append(path)
|
| 65 |
+
_rescan_model_configs()
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def get_model_config(model_name):
|
| 69 |
+
if model_name in _MODEL_CONFIGS:
|
| 70 |
+
return deepcopy(_MODEL_CONFIGS[model_name])
|
| 71 |
+
else:
|
| 72 |
+
return None
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def get_tokenizer(model_name):
|
| 76 |
+
if model_name.startswith(HF_HUB_PREFIX):
|
| 77 |
+
tokenizer = HFTokenizer(model_name[len(HF_HUB_PREFIX):])
|
| 78 |
+
else:
|
| 79 |
+
config = get_model_config(model_name)
|
| 80 |
+
tokenizer = HFTokenizer(
|
| 81 |
+
config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize
|
| 82 |
+
return tokenizer
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def load_state_dict(checkpoint_path: str, map_location='cpu'):
|
| 86 |
+
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
| 87 |
+
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
| 88 |
+
state_dict = checkpoint['state_dict']
|
| 89 |
+
else:
|
| 90 |
+
state_dict = checkpoint
|
| 91 |
+
if next(iter(state_dict.items()))[0].startswith('module'):
|
| 92 |
+
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
| 93 |
+
return state_dict
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def load_checkpoint(model, checkpoint_path, strict=False):
|
| 97 |
+
state_dict = load_state_dict(checkpoint_path)
|
| 98 |
+
# detect old format and make compatible with new format
|
| 99 |
+
if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
|
| 100 |
+
state_dict = convert_to_custom_text_state_dict(state_dict)
|
| 101 |
+
resize_pos_embed(state_dict, model)
|
| 102 |
+
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
|
| 103 |
+
|
| 104 |
+
model.num_register_tokens=state_dict["num_register_tokens"]
|
| 105 |
+
model.neuron_dict=state_dict["neuron_dict"]
|
| 106 |
+
model.visual.num_register_tokens=state_dict["num_register_tokens"]
|
| 107 |
+
model.visual.neuron_dict=state_dict["neuron_dict"]
|
| 108 |
+
|
| 109 |
+
return incompatible_keys
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def create_model(
|
| 113 |
+
model_name: str,
|
| 114 |
+
pretrained: Optional[str] = None,
|
| 115 |
+
precision: str = 'fp32',
|
| 116 |
+
device: Union[str, torch.device] = 'cpu',
|
| 117 |
+
jit: bool = False,
|
| 118 |
+
force_quick_gelu: bool = False,
|
| 119 |
+
force_custom_text: bool = False,
|
| 120 |
+
force_patch_dropout: Optional[float] = None,
|
| 121 |
+
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
|
| 122 |
+
pretrained_image: bool = False,
|
| 123 |
+
pretrained_hf: bool = True,
|
| 124 |
+
cache_dir: Optional[str] = None,
|
| 125 |
+
output_dict: Optional[bool] = None,
|
| 126 |
+
require_pretrained: bool = False,
|
| 127 |
+
):
|
| 128 |
+
has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)
|
| 129 |
+
if has_hf_hub_prefix:
|
| 130 |
+
model_id = model_name[len(HF_HUB_PREFIX):]
|
| 131 |
+
checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
|
| 132 |
+
config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir)
|
| 133 |
+
|
| 134 |
+
with open(config_path, 'r', encoding='utf-8') as f:
|
| 135 |
+
config = json.load(f)
|
| 136 |
+
pretrained_cfg = config['preprocess_cfg']
|
| 137 |
+
model_cfg = config['model_cfg']
|
| 138 |
+
else:
|
| 139 |
+
model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
|
| 140 |
+
checkpoint_path = None
|
| 141 |
+
pretrained_cfg = {}
|
| 142 |
+
model_cfg = None
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
if isinstance(device, str):
|
| 147 |
+
device = torch.device(device)
|
| 148 |
+
|
| 149 |
+
if pretrained and pretrained.lower() == 'openai':
|
| 150 |
+
logging.info(f'Loading pretrained {model_name} from OpenAI.')
|
| 151 |
+
model = load_openai_model(
|
| 152 |
+
model_name,
|
| 153 |
+
precision=precision,
|
| 154 |
+
device=device,
|
| 155 |
+
cache_dir=cache_dir,
|
| 156 |
+
quick_gelu=force_quick_gelu,
|
| 157 |
+
)
|
| 158 |
+
else:
|
| 159 |
+
model_cfg = model_cfg or get_model_config(model_name)
|
| 160 |
+
|
| 161 |
+
if model_cfg is not None:
|
| 162 |
+
logging.info(f'Loaded {model_name} model config.')
|
| 163 |
+
else:
|
| 164 |
+
logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
|
| 165 |
+
raise RuntimeError(f'Model config for {model_name} not found.')
|
| 166 |
+
|
| 167 |
+
if force_patch_dropout is not None:
|
| 168 |
+
# override the default patch dropout value
|
| 169 |
+
model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout
|
| 170 |
+
|
| 171 |
+
if force_image_size is not None:
|
| 172 |
+
# override model config's image size
|
| 173 |
+
model_cfg["vision_cfg"]["image_size"] = force_image_size
|
| 174 |
+
|
| 175 |
+
is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {})
|
| 176 |
+
if pretrained_image:
|
| 177 |
+
if is_timm_model:
|
| 178 |
+
# pretrained weight loading for timm models set via vision_cfg
|
| 179 |
+
model_cfg['vision_cfg']['timm_model_pretrained'] = True
|
| 180 |
+
else:
|
| 181 |
+
assert False, 'pretrained image towers currently only supported for timm models'
|
| 182 |
+
|
| 183 |
+
# cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes
|
| 184 |
+
cast_dtype = get_cast_dtype(precision)
|
| 185 |
+
is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {})
|
| 186 |
+
custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model
|
| 187 |
+
|
| 188 |
+
if custom_text:
|
| 189 |
+
if is_hf_model:
|
| 190 |
+
model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf
|
| 191 |
+
if "coca" in model_name:
|
| 192 |
+
raise ValueError('Coca is not implemented')
|
| 193 |
+
model = CoCa(**model_cfg, cast_dtype=cast_dtype)
|
| 194 |
+
else:
|
| 195 |
+
raise ValueError('CustomTextCLIP is not implemented')
|
| 196 |
+
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
|
| 197 |
+
else:
|
| 198 |
+
model = CLIP(**model_cfg, cast_dtype=cast_dtype)
|
| 199 |
+
|
| 200 |
+
if precision in ("fp16", "bf16"):
|
| 201 |
+
dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
|
| 202 |
+
# manual mixed precision that matches original OpenAI behaviour
|
| 203 |
+
if is_timm_model:
|
| 204 |
+
# FIXME this is a bit janky, create timm based model in low-precision and
|
| 205 |
+
# then cast only LayerNormFp32 instances back to float32 so they don't break.
|
| 206 |
+
# Why? The convert_weights_to_lp fn only works with native models.
|
| 207 |
+
model.to(device=device, dtype=dtype)
|
| 208 |
+
from transformer import LayerNormFp32
|
| 209 |
+
def _convert_ln(m):
|
| 210 |
+
if isinstance(m, LayerNormFp32):
|
| 211 |
+
m.weight.data = m.weight.data.to(torch.float32)
|
| 212 |
+
m.bias.data = m.bias.data.to(torch.float32)
|
| 213 |
+
model.apply(_convert_ln)
|
| 214 |
+
else:
|
| 215 |
+
model.to(device=device)
|
| 216 |
+
convert_weights_to_lp(model, dtype=dtype)
|
| 217 |
+
elif precision in ("pure_fp16", "pure_bf16"):
|
| 218 |
+
dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
|
| 219 |
+
model.to(device=device, dtype=dtype)
|
| 220 |
+
else:
|
| 221 |
+
model.to(device=device)
|
| 222 |
+
|
| 223 |
+
pretrained_loaded = False
|
| 224 |
+
if pretrained:
|
| 225 |
+
checkpoint_path = ''
|
| 226 |
+
pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
|
| 227 |
+
if pretrained_cfg:
|
| 228 |
+
checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
|
| 229 |
+
elif os.path.exists(pretrained):
|
| 230 |
+
checkpoint_path = pretrained
|
| 231 |
+
|
| 232 |
+
if checkpoint_path:
|
| 233 |
+
logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
|
| 234 |
+
load_checkpoint(model, checkpoint_path)
|
| 235 |
+
else:
|
| 236 |
+
error_str = (
|
| 237 |
+
f'Pretrained weights ({pretrained}) not found for model {model_name}.'
|
| 238 |
+
f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
|
| 239 |
+
logging.warning(error_str)
|
| 240 |
+
raise RuntimeError(error_str)
|
| 241 |
+
pretrained_loaded = True
|
| 242 |
+
elif has_hf_hub_prefix:
|
| 243 |
+
logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
|
| 244 |
+
load_checkpoint(model, checkpoint_path)
|
| 245 |
+
pretrained_loaded = True
|
| 246 |
+
|
| 247 |
+
if require_pretrained and not pretrained_loaded:
|
| 248 |
+
# callers of create_model_from_pretrained always expect pretrained weights
|
| 249 |
+
raise RuntimeError(
|
| 250 |
+
f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.')
|
| 251 |
+
|
| 252 |
+
# set image / mean metadata from pretrained_cfg if available, or use default
|
| 253 |
+
model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
|
| 254 |
+
model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD
|
| 255 |
+
|
| 256 |
+
if output_dict and hasattr(model, "output_dict"):
|
| 257 |
+
model.output_dict = True
|
| 258 |
+
|
| 259 |
+
if jit:
|
| 260 |
+
model = torch.jit.script(model)
|
| 261 |
+
|
| 262 |
+
return model
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def create_loss(args):
|
| 266 |
+
if args.distill:
|
| 267 |
+
return DistillClipLoss(
|
| 268 |
+
local_loss=args.local_loss,
|
| 269 |
+
gather_with_grad=args.gather_with_grad,
|
| 270 |
+
cache_labels=True,
|
| 271 |
+
rank=args.rank,
|
| 272 |
+
world_size=args.world_size,
|
| 273 |
+
use_horovod=args.horovod,
|
| 274 |
+
)
|
| 275 |
+
elif "coca" in args.model.lower():
|
| 276 |
+
return CoCaLoss(
|
| 277 |
+
caption_loss_weight=args.coca_caption_loss_weight,
|
| 278 |
+
clip_loss_weight=args.coca_contrastive_loss_weight,
|
| 279 |
+
local_loss=args.local_loss,
|
| 280 |
+
gather_with_grad=args.gather_with_grad,
|
| 281 |
+
cache_labels=True,
|
| 282 |
+
rank=args.rank,
|
| 283 |
+
world_size=args.world_size,
|
| 284 |
+
use_horovod=args.horovod,
|
| 285 |
+
)
|
| 286 |
+
return ClipLoss(
|
| 287 |
+
local_loss=args.local_loss,
|
| 288 |
+
gather_with_grad=args.gather_with_grad,
|
| 289 |
+
cache_labels=True,
|
| 290 |
+
rank=args.rank,
|
| 291 |
+
world_size=args.world_size,
|
| 292 |
+
use_horovod=args.horovod,
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def create_model_and_transforms(
|
| 297 |
+
model_name: str,
|
| 298 |
+
pretrained: Optional[str] = None,
|
| 299 |
+
precision: str = 'fp32',
|
| 300 |
+
device: Union[str, torch.device] = 'cpu',
|
| 301 |
+
jit: bool = False,
|
| 302 |
+
force_quick_gelu: bool = False,
|
| 303 |
+
force_custom_text: bool = False,
|
| 304 |
+
force_patch_dropout: Optional[float] = None,
|
| 305 |
+
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
|
| 306 |
+
pretrained_image: bool = False,
|
| 307 |
+
pretrained_hf: bool = True,
|
| 308 |
+
image_mean: Optional[Tuple[float, ...]] = None,
|
| 309 |
+
image_std: Optional[Tuple[float, ...]] = None,
|
| 310 |
+
aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
|
| 311 |
+
cache_dir: Optional[str] = None,
|
| 312 |
+
output_dict: Optional[bool] = None,
|
| 313 |
+
):
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
model = create_model(
|
| 317 |
+
model_name,
|
| 318 |
+
pretrained,
|
| 319 |
+
precision=precision,
|
| 320 |
+
device=device,
|
| 321 |
+
jit=jit,
|
| 322 |
+
force_quick_gelu=force_quick_gelu,
|
| 323 |
+
force_custom_text=force_custom_text,
|
| 324 |
+
force_patch_dropout=force_patch_dropout,
|
| 325 |
+
force_image_size=force_image_size,
|
| 326 |
+
pretrained_image=pretrained_image,
|
| 327 |
+
pretrained_hf=pretrained_hf,
|
| 328 |
+
cache_dir=cache_dir,
|
| 329 |
+
output_dict=output_dict,
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
image_mean = image_mean or getattr(model.visual, 'image_mean', None)
|
| 333 |
+
image_std = image_std or getattr(model.visual, 'image_std', None)
|
| 334 |
+
preprocess_train = image_transform(
|
| 335 |
+
model.visual.image_size,
|
| 336 |
+
is_train=True,
|
| 337 |
+
mean=image_mean,
|
| 338 |
+
std=image_std,
|
| 339 |
+
aug_cfg=aug_cfg,
|
| 340 |
+
)
|
| 341 |
+
preprocess_val = image_transform(
|
| 342 |
+
model.visual.image_size,
|
| 343 |
+
is_train=False,
|
| 344 |
+
mean=image_mean,
|
| 345 |
+
std=image_std,
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
return model, preprocess_train, preprocess_val
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def create_model_from_pretrained(
|
| 352 |
+
model_name: str,
|
| 353 |
+
pretrained: Optional[str] = None,
|
| 354 |
+
precision: str = 'fp32',
|
| 355 |
+
device: Union[str, torch.device] = 'cpu',
|
| 356 |
+
jit: bool = False,
|
| 357 |
+
force_quick_gelu: bool = False,
|
| 358 |
+
force_custom_text: bool = False,
|
| 359 |
+
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
|
| 360 |
+
return_transform: bool = True,
|
| 361 |
+
image_mean: Optional[Tuple[float, ...]] = None,
|
| 362 |
+
image_std: Optional[Tuple[float, ...]] = None,
|
| 363 |
+
cache_dir: Optional[str] = None,
|
| 364 |
+
):
|
| 365 |
+
model = create_model(
|
| 366 |
+
model_name,
|
| 367 |
+
pretrained,
|
| 368 |
+
precision=precision,
|
| 369 |
+
device=device,
|
| 370 |
+
jit=jit,
|
| 371 |
+
force_quick_gelu=force_quick_gelu,
|
| 372 |
+
force_custom_text=force_custom_text,
|
| 373 |
+
force_image_size=force_image_size,
|
| 374 |
+
cache_dir=cache_dir,
|
| 375 |
+
require_pretrained=True,
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
if not return_transform:
|
| 379 |
+
return model
|
| 380 |
+
|
| 381 |
+
image_mean = image_mean or getattr(model.visual, 'image_mean', None)
|
| 382 |
+
image_std = image_std or getattr(model.visual, 'image_std', None)
|
| 383 |
+
preprocess = image_transform(
|
| 384 |
+
model.visual.image_size,
|
| 385 |
+
is_train=False,
|
| 386 |
+
mean=image_mean,
|
| 387 |
+
std=image_std,
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
return model, preprocess
|
imagenet_classes.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
OPENAI_IMAGENET_TEMPLATES = (
|
| 2 |
+
lambda c: f'a bad photo of a {c}.',
|
| 3 |
+
lambda c: f'a photo of many {c}.',
|
| 4 |
+
lambda c: f'a sculpture of a {c}.',
|
| 5 |
+
lambda c: f'a photo of the hard to see {c}.',
|
| 6 |
+
lambda c: f'a low resolution photo of the {c}.',
|
| 7 |
+
lambda c: f'a rendering of a {c}.',
|
| 8 |
+
lambda c: f'graffiti of a {c}.',
|
| 9 |
+
lambda c: f'a bad photo of the {c}.',
|
| 10 |
+
lambda c: f'a cropped photo of the {c}.',
|
| 11 |
+
lambda c: f'a tattoo of a {c}.',
|
| 12 |
+
lambda c: f'the embroidered {c}.',
|
| 13 |
+
lambda c: f'a photo of a hard to see {c}.',
|
| 14 |
+
lambda c: f'a bright photo of a {c}.',
|
| 15 |
+
lambda c: f'a photo of a clean {c}.',
|
| 16 |
+
lambda c: f'a photo of a dirty {c}.',
|
| 17 |
+
lambda c: f'a dark photo of the {c}.',
|
| 18 |
+
lambda c: f'a drawing of a {c}.',
|
| 19 |
+
lambda c: f'a photo of my {c}.',
|
| 20 |
+
lambda c: f'the plastic {c}.',
|
| 21 |
+
lambda c: f'a photo of the cool {c}.',
|
| 22 |
+
lambda c: f'a close-up photo of a {c}.',
|
| 23 |
+
lambda c: f'a black and white photo of the {c}.',
|
| 24 |
+
lambda c: f'a painting of the {c}.',
|
| 25 |
+
lambda c: f'a painting of a {c}.',
|
| 26 |
+
lambda c: f'a pixelated photo of the {c}.',
|
| 27 |
+
lambda c: f'a sculpture of the {c}.',
|
| 28 |
+
lambda c: f'a bright photo of the {c}.',
|
| 29 |
+
lambda c: f'a cropped photo of a {c}.',
|
| 30 |
+
lambda c: f'a plastic {c}.',
|
| 31 |
+
lambda c: f'a photo of the dirty {c}.',
|
| 32 |
+
lambda c: f'a jpeg corrupted photo of a {c}.',
|
| 33 |
+
lambda c: f'a blurry photo of the {c}.',
|
| 34 |
+
lambda c: f'a photo of the {c}.',
|
| 35 |
+
lambda c: f'a good photo of the {c}.',
|
| 36 |
+
lambda c: f'a rendering of the {c}.',
|
| 37 |
+
lambda c: f'a {c} in a video game.',
|
| 38 |
+
lambda c: f'a photo of one {c}.',
|
| 39 |
+
lambda c: f'a doodle of a {c}.',
|
| 40 |
+
lambda c: f'a close-up photo of the {c}.',
|
| 41 |
+
lambda c: f'a photo of a {c}.',
|
| 42 |
+
lambda c: f'the origami {c}.',
|
| 43 |
+
lambda c: f'the {c} in a video game.',
|
| 44 |
+
lambda c: f'a sketch of a {c}.',
|
| 45 |
+
lambda c: f'a doodle of the {c}.',
|
| 46 |
+
lambda c: f'a origami {c}.',
|
| 47 |
+
lambda c: f'a low resolution photo of a {c}.',
|
| 48 |
+
lambda c: f'the toy {c}.',
|
| 49 |
+
lambda c: f'a rendition of the {c}.',
|
| 50 |
+
lambda c: f'a photo of the clean {c}.',
|
| 51 |
+
lambda c: f'a photo of a large {c}.',
|
| 52 |
+
lambda c: f'a rendition of a {c}.',
|
| 53 |
+
lambda c: f'a photo of a nice {c}.',
|
| 54 |
+
lambda c: f'a photo of a weird {c}.',
|
| 55 |
+
lambda c: f'a blurry photo of a {c}.',
|
| 56 |
+
lambda c: f'a cartoon {c}.',
|
| 57 |
+
lambda c: f'art of a {c}.',
|
| 58 |
+
lambda c: f'a sketch of the {c}.',
|
| 59 |
+
lambda c: f'a embroidered {c}.',
|
| 60 |
+
lambda c: f'a pixelated photo of a {c}.',
|
| 61 |
+
lambda c: f'itap of the {c}.',
|
| 62 |
+
lambda c: f'a jpeg corrupted photo of the {c}.',
|
| 63 |
+
lambda c: f'a good photo of a {c}.',
|
| 64 |
+
lambda c: f'a plushie {c}.',
|
| 65 |
+
lambda c: f'a photo of the nice {c}.',
|
| 66 |
+
lambda c: f'a photo of the small {c}.',
|
| 67 |
+
lambda c: f'a photo of the weird {c}.',
|
| 68 |
+
lambda c: f'the cartoon {c}.',
|
| 69 |
+
lambda c: f'art of the {c}.',
|
| 70 |
+
lambda c: f'a drawing of the {c}.',
|
| 71 |
+
lambda c: f'a photo of the large {c}.',
|
| 72 |
+
lambda c: f'a black and white photo of a {c}.',
|
| 73 |
+
lambda c: f'the plushie {c}.',
|
| 74 |
+
lambda c: f'a dark photo of a {c}.',
|
| 75 |
+
lambda c: f'itap of a {c}.',
|
| 76 |
+
lambda c: f'graffiti of the {c}.',
|
| 77 |
+
lambda c: f'a toy {c}.',
|
| 78 |
+
lambda c: f'itap of my {c}.',
|
| 79 |
+
lambda c: f'a photo of a cool {c}.',
|
| 80 |
+
lambda c: f'a photo of a small {c}.',
|
| 81 |
+
lambda c: f'a tattoo of the {c}.',
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
IMAGENET_CLASSNAMES = ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper", "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander", "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog", "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin", "box turtle", "banded gecko", "green iguana", "Carolina anole", "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard", "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile", "American alligator", "triceratops", "worm snake", "ring-necked snake", "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake", "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra", "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake", "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider", "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider", "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl", "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet", "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck", "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot", "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher", "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel", "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound", "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound", "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier", "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier", "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier", "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer", "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier", "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever", "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla", "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel", "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard", "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie", "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann", "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff", "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky", "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog", "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon", "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle", "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf", "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox", "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat", "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose", "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle", "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper", "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly", "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly", "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit", "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse", "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)", "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat", "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque", "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin", "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey", "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda", "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish", "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown", "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle", "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo", "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel", "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)", "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini", "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet", "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra", "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe", "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton", "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran", "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw", "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking", "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker", "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard", "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed", "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table", "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig", "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder", "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute", "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed", "freight car", "French horn", "frying pan", "fur coat", "garbage truck", "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola", "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine", "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer", "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet", "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar", "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep", "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat", "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library", "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion", "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag", "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask", "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone", "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor", "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa", "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail", "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina", "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart", "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush", "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench", "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case", "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube", "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball", "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag", "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho", "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug", "printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill", "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel", "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator", "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser", "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal", "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard", "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store", "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap", "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door", "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater", "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight", "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf", "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa", "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge", "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe", "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball", "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof", "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store", "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard", "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling", "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball", "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink", "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing", "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website", "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu", "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette", "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli", "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange", "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate", "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito", "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef", "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player", "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom", "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"]
|
merges.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
misc.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from itertools import repeat
|
| 2 |
+
import collections.abc
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn as nn
|
| 6 |
+
from torchvision.ops.misc import FrozenBatchNorm2d
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def freeze_batch_norm_2d(module, module_match={}, name=''):
|
| 10 |
+
"""
|
| 11 |
+
Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
|
| 12 |
+
itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
|
| 13 |
+
returned. Otherwise, the module is walked recursively and submodules are converted in place.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
module (torch.nn.Module): Any PyTorch module.
|
| 17 |
+
module_match (dict): Dictionary of full module names to freeze (all if empty)
|
| 18 |
+
name (str): Full module name (prefix)
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
torch.nn.Module: Resulting module
|
| 22 |
+
|
| 23 |
+
Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
|
| 24 |
+
"""
|
| 25 |
+
res = module
|
| 26 |
+
is_match = True
|
| 27 |
+
if module_match:
|
| 28 |
+
is_match = name in module_match
|
| 29 |
+
if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):
|
| 30 |
+
res = FrozenBatchNorm2d(module.num_features)
|
| 31 |
+
res.num_features = module.num_features
|
| 32 |
+
res.affine = module.affine
|
| 33 |
+
if module.affine:
|
| 34 |
+
res.weight.data = module.weight.data.clone().detach()
|
| 35 |
+
res.bias.data = module.bias.data.clone().detach()
|
| 36 |
+
res.running_mean.data = module.running_mean.data
|
| 37 |
+
res.running_var.data = module.running_var.data
|
| 38 |
+
res.eps = module.eps
|
| 39 |
+
else:
|
| 40 |
+
for child_name, child in module.named_children():
|
| 41 |
+
full_child_name = '.'.join([name, child_name]) if name else child_name
|
| 42 |
+
new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
|
| 43 |
+
if new_child is not child:
|
| 44 |
+
res.add_module(child_name, new_child)
|
| 45 |
+
return res
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# From PyTorch internals
|
| 49 |
+
def _ntuple(n):
|
| 50 |
+
def parse(x):
|
| 51 |
+
if isinstance(x, collections.abc.Iterable):
|
| 52 |
+
return x
|
| 53 |
+
return tuple(repeat(x, n))
|
| 54 |
+
return parse
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
to_1tuple = _ntuple(1)
|
| 58 |
+
to_2tuple = _ntuple(2)
|
| 59 |
+
to_3tuple = _ntuple(3)
|
| 60 |
+
to_4tuple = _ntuple(4)
|
| 61 |
+
to_ntuple = lambda n, x: _ntuple(n)(x)
|
| 62 |
+
|
| 63 |
+
# Replaces all linear layers with linear_replacement
|
| 64 |
+
# TODO: add int8 support for other linear layers including attn and convnets
|
| 65 |
+
def replace_linear(model, linear_replacement, include_modules=['c_fc', 'c_proj'], copy_weights=True):
|
| 66 |
+
for name, module in model.named_children():
|
| 67 |
+
if len(list(module.children())) > 0:
|
| 68 |
+
replace_linear(module, linear_replacement, include_modules, copy_weights)
|
| 69 |
+
|
| 70 |
+
if isinstance(module, torch.nn.Linear) and name in include_modules:
|
| 71 |
+
old_module = model._modules[name]
|
| 72 |
+
model._modules[name] = linear_replacement(
|
| 73 |
+
module.in_features,
|
| 74 |
+
module.out_features,
|
| 75 |
+
module.bias is not None,
|
| 76 |
+
)
|
| 77 |
+
if copy_weights:
|
| 78 |
+
model._modules[name].weight.data.copy_(old_module.weight.data)
|
| 79 |
+
if model._modules[name].bias is not None:
|
| 80 |
+
model._modules[name].bias.data.copy_(old_module.bias)
|
| 81 |
+
|
| 82 |
+
return model
|
| 83 |
+
|
| 84 |
+
def convert_int8_model_to_inference_mode(model):
|
| 85 |
+
for m in model.modules():
|
| 86 |
+
if hasattr(m, 'prepare_for_eval'):
|
| 87 |
+
int8_original_dtype = m.weight.dtype
|
| 88 |
+
m.prepare_for_eval()
|
| 89 |
+
m.int8_original_dtype = int8_original_dtype
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def accuracy(output, target, topk=(1,)):
|
| 93 |
+
"""
|
| 94 |
+
Compute top-k accuracy
|
| 95 |
+
|
| 96 |
+
output: torch.Tensor
|
| 97 |
+
shape (N, C) where N is the number of examples, C the number of classes.
|
| 98 |
+
these are the logits.
|
| 99 |
+
|
| 100 |
+
target: torch.Tensor
|
| 101 |
+
shape (N,) where N is the number of examples. Groundtruth class id of each example.
|
| 102 |
+
|
| 103 |
+
topk: tuple
|
| 104 |
+
which topk to compute, e.g., topk=(1,5) will compute top-1 and top-5 accuracies
|
| 105 |
+
|
| 106 |
+
Returns
|
| 107 |
+
-------
|
| 108 |
+
|
| 109 |
+
list of top-k accuracies in the same order as `topk`
|
| 110 |
+
"""
|
| 111 |
+
pred = output.topk(max(topk), 1, True, True)[1].t()
|
| 112 |
+
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
| 113 |
+
n = len(target)
|
| 114 |
+
return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) / n for k in topk]
|
model.py
ADDED
|
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" CLIP Model
|
| 2 |
+
|
| 3 |
+
Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
|
| 4 |
+
"""
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
import logging
|
| 7 |
+
import math
|
| 8 |
+
from typing import Optional, Tuple, Union, Text
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from torch import nn
|
| 14 |
+
from torch.utils.checkpoint import checkpoint
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
from modified_resnet import ModifiedResNet
|
| 18 |
+
from timm_model import TimmModel
|
| 19 |
+
from transformer import LayerNorm, QuickGELU, VisionTransformer, TextTransformer, Attention
|
| 20 |
+
from misc import to_2tuple
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class CLIPVisionCfg:
|
| 26 |
+
layers: Union[Tuple[int, int, int, int], int] = 12
|
| 27 |
+
width: int = 768
|
| 28 |
+
head_width: int = 64
|
| 29 |
+
mlp_ratio: float = 4.0
|
| 30 |
+
patch_size: int = 16
|
| 31 |
+
image_size: Union[Tuple[int, int], int] = 224
|
| 32 |
+
|
| 33 |
+
ls_init_value: Optional[float] = None # layer scale initial value
|
| 34 |
+
patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
|
| 35 |
+
input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design
|
| 36 |
+
global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
|
| 37 |
+
attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer
|
| 38 |
+
n_queries: int = 256 # n_queries for attentional pooler
|
| 39 |
+
attn_pooler_heads: int = 8 # n heads for attentional_pooling
|
| 40 |
+
output_tokens: bool = False
|
| 41 |
+
|
| 42 |
+
timm_model_name: str = None # a valid model name overrides layers, width, patch_size
|
| 43 |
+
timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
|
| 44 |
+
timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
|
| 45 |
+
timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
|
| 46 |
+
timm_proj_bias: bool = False # enable bias final projection
|
| 47 |
+
timm_drop: float = 0. # head dropout
|
| 48 |
+
timm_drop_path: Optional[float] = None # backbone stochastic depth
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
|
| 54 |
+
"""Convert applicable model parameters to low-precision (bf16 or fp16)"""
|
| 55 |
+
|
| 56 |
+
def _convert_weights(l):
|
| 57 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
| 58 |
+
l.weight.data = l.weight.data.to(dtype)
|
| 59 |
+
if l.bias is not None:
|
| 60 |
+
l.bias.data = l.bias.data.to(dtype)
|
| 61 |
+
|
| 62 |
+
if isinstance(l, (nn.MultiheadAttention, Attention)):
|
| 63 |
+
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
| 64 |
+
tensor = getattr(l, attr)
|
| 65 |
+
if tensor is not None:
|
| 66 |
+
tensor.data = tensor.data.to(dtype)
|
| 67 |
+
|
| 68 |
+
if isinstance(l, (CLIP, TextTransformer)):
|
| 69 |
+
# convert text nn.Parameter projections
|
| 70 |
+
attr = getattr(l, "text_projection", None)
|
| 71 |
+
if attr is not None:
|
| 72 |
+
attr.data = attr.data.to(dtype)
|
| 73 |
+
|
| 74 |
+
if isinstance(l, VisionTransformer):
|
| 75 |
+
# convert vision nn.Parameter projections
|
| 76 |
+
attr = getattr(l, "proj", None)
|
| 77 |
+
if attr is not None:
|
| 78 |
+
attr.data = attr.data.to(dtype)
|
| 79 |
+
|
| 80 |
+
model.apply(_convert_weights)
|
| 81 |
+
|
| 82 |
+
convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
@dataclass
|
| 86 |
+
class CLIPTextCfg:
|
| 87 |
+
context_length: int = 77
|
| 88 |
+
vocab_size: int = 49408
|
| 89 |
+
width: int = 512
|
| 90 |
+
heads: int = 8
|
| 91 |
+
layers: int = 12
|
| 92 |
+
ls_init_value: Optional[float] = None # layer scale initial value
|
| 93 |
+
hf_model_name: str = None
|
| 94 |
+
hf_tokenizer_name: str = None
|
| 95 |
+
hf_model_pretrained: bool = True
|
| 96 |
+
proj: str = 'mlp'
|
| 97 |
+
pooler_type: str = 'mean_pooler'
|
| 98 |
+
embed_cls: bool = False
|
| 99 |
+
pad_id: int = 0
|
| 100 |
+
output_tokens: bool = False
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def get_cast_dtype(precision: str):
|
| 104 |
+
cast_dtype = None
|
| 105 |
+
if precision == 'bf16':
|
| 106 |
+
cast_dtype = torch.bfloat16
|
| 107 |
+
elif precision == 'fp16':
|
| 108 |
+
cast_dtype = torch.float16
|
| 109 |
+
return cast_dtype
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def get_input_dtype(precision: str):
|
| 113 |
+
input_dtype = None
|
| 114 |
+
if precision in ('bf16', 'pure_bf16'):
|
| 115 |
+
input_dtype = torch.bfloat16
|
| 116 |
+
elif precision in ('fp16', 'pure_fp16'):
|
| 117 |
+
input_dtype = torch.float16
|
| 118 |
+
return input_dtype
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def _build_vision_tower(
|
| 122 |
+
embed_dim: int,
|
| 123 |
+
vision_cfg: CLIPVisionCfg,
|
| 124 |
+
quick_gelu: bool = False,
|
| 125 |
+
cast_dtype: Optional[torch.dtype] = None,
|
| 126 |
+
):
|
| 127 |
+
if isinstance(vision_cfg, dict):
|
| 128 |
+
vision_cfg = CLIPVisionCfg(**vision_cfg)
|
| 129 |
+
|
| 130 |
+
# OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
|
| 131 |
+
# memory efficient in recent PyTorch releases (>= 1.10).
|
| 132 |
+
# NOTE: timm models always use native GELU regardless of quick_gelu flag.
|
| 133 |
+
act_layer = QuickGELU if quick_gelu else nn.GELU
|
| 134 |
+
|
| 135 |
+
if vision_cfg.timm_model_name:
|
| 136 |
+
visual = TimmModel(
|
| 137 |
+
vision_cfg.timm_model_name,
|
| 138 |
+
pretrained=vision_cfg.timm_model_pretrained,
|
| 139 |
+
pool=vision_cfg.timm_pool,
|
| 140 |
+
proj=vision_cfg.timm_proj,
|
| 141 |
+
proj_bias=vision_cfg.timm_proj_bias,
|
| 142 |
+
drop=vision_cfg.timm_drop,
|
| 143 |
+
drop_path=vision_cfg.timm_drop_path,
|
| 144 |
+
patch_drop=vision_cfg.patch_dropout if vision_cfg.patch_dropout > 0 else None,
|
| 145 |
+
embed_dim=embed_dim,
|
| 146 |
+
image_size=vision_cfg.image_size,
|
| 147 |
+
)
|
| 148 |
+
elif isinstance(vision_cfg.layers, (tuple, list)):
|
| 149 |
+
vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
|
| 150 |
+
visual = ModifiedResNet(
|
| 151 |
+
layers=vision_cfg.layers,
|
| 152 |
+
output_dim=embed_dim,
|
| 153 |
+
heads=vision_heads,
|
| 154 |
+
image_size=vision_cfg.image_size,
|
| 155 |
+
width=vision_cfg.width,
|
| 156 |
+
)
|
| 157 |
+
else:
|
| 158 |
+
vision_heads = vision_cfg.width // vision_cfg.head_width
|
| 159 |
+
norm_layer = LayerNorm
|
| 160 |
+
visual = VisionTransformer(
|
| 161 |
+
image_size=vision_cfg.image_size,
|
| 162 |
+
patch_size=vision_cfg.patch_size,
|
| 163 |
+
width=vision_cfg.width,
|
| 164 |
+
layers=vision_cfg.layers,
|
| 165 |
+
heads=vision_heads,
|
| 166 |
+
mlp_ratio=vision_cfg.mlp_ratio,
|
| 167 |
+
ls_init_value=vision_cfg.ls_init_value,
|
| 168 |
+
patch_dropout=vision_cfg.patch_dropout,
|
| 169 |
+
input_patchnorm=vision_cfg.input_patchnorm,
|
| 170 |
+
global_average_pool=vision_cfg.global_average_pool,
|
| 171 |
+
attentional_pool=vision_cfg.attentional_pool,
|
| 172 |
+
n_queries=vision_cfg.n_queries,
|
| 173 |
+
attn_pooler_heads=vision_cfg.attn_pooler_heads,
|
| 174 |
+
output_tokens=vision_cfg.output_tokens,
|
| 175 |
+
output_dim=embed_dim,
|
| 176 |
+
act_layer=act_layer,
|
| 177 |
+
norm_layer=norm_layer,
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
return visual
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def _build_text_tower(
|
| 184 |
+
embed_dim: int,
|
| 185 |
+
text_cfg: CLIPTextCfg,
|
| 186 |
+
quick_gelu: bool = False,
|
| 187 |
+
cast_dtype: Optional[torch.dtype] = None,
|
| 188 |
+
):
|
| 189 |
+
if isinstance(text_cfg, dict):
|
| 190 |
+
text_cfg = CLIPTextCfg(**text_cfg)
|
| 191 |
+
|
| 192 |
+
if text_cfg.hf_model_name:
|
| 193 |
+
from hf_model import HFTextEncoder
|
| 194 |
+
text = HFTextEncoder(
|
| 195 |
+
text_cfg.hf_model_name,
|
| 196 |
+
output_dim=embed_dim,
|
| 197 |
+
proj=text_cfg.proj,
|
| 198 |
+
pooler_type=text_cfg.pooler_type,
|
| 199 |
+
pretrained=text_cfg.hf_model_pretrained,
|
| 200 |
+
output_tokens=text_cfg.output_tokens,
|
| 201 |
+
)
|
| 202 |
+
else:
|
| 203 |
+
act_layer = QuickGELU if quick_gelu else nn.GELU
|
| 204 |
+
norm_layer = LayerNorm
|
| 205 |
+
|
| 206 |
+
text = TextTransformer(
|
| 207 |
+
context_length=text_cfg.context_length,
|
| 208 |
+
vocab_size=text_cfg.vocab_size,
|
| 209 |
+
width=text_cfg.width,
|
| 210 |
+
heads=text_cfg.heads,
|
| 211 |
+
layers=text_cfg.layers,
|
| 212 |
+
ls_init_value=text_cfg.ls_init_value,
|
| 213 |
+
output_dim=embed_dim,
|
| 214 |
+
embed_cls=text_cfg.embed_cls,
|
| 215 |
+
output_tokens=text_cfg.output_tokens,
|
| 216 |
+
pad_id=text_cfg.pad_id,
|
| 217 |
+
act_layer=act_layer,
|
| 218 |
+
norm_layer=norm_layer,
|
| 219 |
+
)
|
| 220 |
+
return text
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class CLIP(nn.Module):
|
| 224 |
+
"""
|
| 225 |
+
_VITL14_336 = dict(
|
| 226 |
+
openai=_pcfg(
|
| 227 |
+
url="https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
|
| 228 |
+
hf_hub="timm/vit_large_patch14_clip_336.openai/",
|
| 229 |
+
quick_gelu=True,
|
| 230 |
+
),
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
"""
|
| 234 |
+
output_dict: torch.jit.Final[bool]
|
| 235 |
+
|
| 236 |
+
def __init__(
|
| 237 |
+
self,
|
| 238 |
+
embed_dim: int,
|
| 239 |
+
vision_cfg: CLIPVisionCfg,
|
| 240 |
+
text_cfg: CLIPTextCfg,
|
| 241 |
+
quick_gelu: bool = False,
|
| 242 |
+
cast_dtype: Optional[torch.dtype] = None,
|
| 243 |
+
output_dict: bool = False,
|
| 244 |
+
):
|
| 245 |
+
super().__init__()
|
| 246 |
+
self.output_dict = output_dict
|
| 247 |
+
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
|
| 248 |
+
print(f"Building vision tower with config: {vision_cfg}")
|
| 249 |
+
|
| 250 |
+
print(f"Currently text tower is removed, using only image encoder for feature extraction")
|
| 251 |
+
do_use = False
|
| 252 |
+
if do_use:
|
| 253 |
+
text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
|
| 254 |
+
self.transformer = text.transformer
|
| 255 |
+
self.context_length = text.context_length
|
| 256 |
+
self.vocab_size = text.vocab_size
|
| 257 |
+
self.token_embedding = text.token_embedding
|
| 258 |
+
self.positional_embedding = text.positional_embedding
|
| 259 |
+
self.ln_final = text.ln_final
|
| 260 |
+
self.text_projection = text.text_projection
|
| 261 |
+
self.register_buffer('attn_mask', text.attn_mask, persistent=False)
|
| 262 |
+
|
| 263 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
| 264 |
+
|
| 265 |
+
self.num_register_tokens = 0
|
| 266 |
+
self.neuron_dict = None
|
| 267 |
+
|
| 268 |
+
@torch.jit.ignore
|
| 269 |
+
def set_grad_checkpointing(self, enable=True):
|
| 270 |
+
self.visual.set_grad_checkpointing(enable)
|
| 271 |
+
self.transformer.grad_checkpointing = enable
|
| 272 |
+
|
| 273 |
+
def encode_image(self, image, normalize: bool = False, attn_method: Text = 'direct', num_register_tokens = None, neuron_dict=None):
|
| 274 |
+
if num_register_tokens is None and neuron_dict is None:
|
| 275 |
+
num_register_tokens = self.num_register_tokens
|
| 276 |
+
neuron_dict = self.neuron_dict
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
features = self.visual(image, attn_method=attn_method, num_register_tokens=num_register_tokens, neuron_dict=neuron_dict)
|
| 280 |
+
return F.normalize(features, dim=-1) if normalize else features
|
| 281 |
+
|
| 282 |
+
def encode_text(self, text, normalize: bool = False):
|
| 283 |
+
cast_dtype = self.transformer.get_cast_dtype()
|
| 284 |
+
|
| 285 |
+
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
|
| 286 |
+
|
| 287 |
+
x = x + self.positional_embedding.to(cast_dtype)
|
| 288 |
+
# x = x.permute(1, 0, 2) # NLD -> LND
|
| 289 |
+
x = self.transformer(x, attn_mask=self.attn_mask)
|
| 290 |
+
# x = x.permute(1, 0, 2) # LND -> NLD
|
| 291 |
+
x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
|
| 292 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
| 293 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
| 294 |
+
return F.normalize(x, dim=-1) if normalize else x
|
| 295 |
+
|
| 296 |
+
def forward(
|
| 297 |
+
self,
|
| 298 |
+
image: Optional[torch.Tensor] = None,
|
| 299 |
+
text: Optional[torch.Tensor] = None,
|
| 300 |
+
num_register_tokens = None,
|
| 301 |
+
neuron_dict=None
|
| 302 |
+
|
| 303 |
+
):
|
| 304 |
+
|
| 305 |
+
if num_register_tokens is None and neuron_dict is None:
|
| 306 |
+
num_register_tokens = self.num_register_tokens
|
| 307 |
+
neuron_dict = self.neuron_dict
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
image_features = self.encode_image(image, num_register_tokens=num_register_tokens, neuron_dict=neuron_dict, normalize=True) if image is not None else None
|
| 311 |
+
text_features = self.encode_text(text, normalize=True) if text is not None else None
|
| 312 |
+
if self.output_dict:
|
| 313 |
+
return {
|
| 314 |
+
"image_features": image_features,
|
| 315 |
+
"text_features": text_features,
|
| 316 |
+
"logit_scale": self.logit_scale.exp()
|
| 317 |
+
}
|
| 318 |
+
return image_features, text_features, self.logit_scale.exp()
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
# used to maintain checkpoint compatibility
|
| 322 |
+
def convert_to_custom_text_state_dict(state_dict: dict):
|
| 323 |
+
if 'text_projection' in state_dict:
|
| 324 |
+
# old format state_dict, move text tower -> .text
|
| 325 |
+
new_state_dict = {}
|
| 326 |
+
for k, v in state_dict.items():
|
| 327 |
+
if any(k.startswith(p) for p in (
|
| 328 |
+
'text_projection',
|
| 329 |
+
'positional_embedding',
|
| 330 |
+
'token_embedding',
|
| 331 |
+
'transformer',
|
| 332 |
+
'ln_final',
|
| 333 |
+
)):
|
| 334 |
+
k = 'text.' + k
|
| 335 |
+
new_state_dict[k] = v
|
| 336 |
+
return new_state_dict
|
| 337 |
+
return state_dict
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def build_model_from_openai_state_dict(
|
| 341 |
+
state_dict: dict,
|
| 342 |
+
quick_gelu=True,
|
| 343 |
+
cast_dtype=torch.float16,
|
| 344 |
+
):
|
| 345 |
+
vit = "visual.proj" in state_dict
|
| 346 |
+
|
| 347 |
+
if vit:
|
| 348 |
+
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
| 349 |
+
vision_layers = len(
|
| 350 |
+
[k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
| 351 |
+
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
| 352 |
+
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
| 353 |
+
image_size = vision_patch_size * grid_size
|
| 354 |
+
else:
|
| 355 |
+
counts: list = [
|
| 356 |
+
len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
| 357 |
+
vision_layers = tuple(counts)
|
| 358 |
+
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
| 359 |
+
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
| 360 |
+
vision_patch_size = None
|
| 361 |
+
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
| 362 |
+
image_size = output_width * 32
|
| 363 |
+
|
| 364 |
+
embed_dim = state_dict["text_projection"].shape[1]
|
| 365 |
+
context_length = state_dict["positional_embedding"].shape[0]
|
| 366 |
+
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
| 367 |
+
transformer_width = state_dict["ln_final.weight"].shape[0]
|
| 368 |
+
transformer_heads = transformer_width // 64
|
| 369 |
+
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
|
| 370 |
+
|
| 371 |
+
vision_cfg = CLIPVisionCfg(
|
| 372 |
+
layers=vision_layers,
|
| 373 |
+
width=vision_width,
|
| 374 |
+
patch_size=vision_patch_size,
|
| 375 |
+
image_size=image_size,
|
| 376 |
+
)
|
| 377 |
+
text_cfg = CLIPTextCfg(
|
| 378 |
+
context_length=context_length,
|
| 379 |
+
vocab_size=vocab_size,
|
| 380 |
+
width=transformer_width,
|
| 381 |
+
heads=transformer_heads,
|
| 382 |
+
layers=transformer_layers,
|
| 383 |
+
)
|
| 384 |
+
model = CLIP(
|
| 385 |
+
embed_dim,
|
| 386 |
+
vision_cfg=vision_cfg,
|
| 387 |
+
text_cfg=text_cfg,
|
| 388 |
+
quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
|
| 389 |
+
cast_dtype=cast_dtype,
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
for key in ["input_resolution", "context_length", "vocab_size"]:
|
| 393 |
+
state_dict.pop(key, None)
|
| 394 |
+
|
| 395 |
+
convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
|
| 396 |
+
model.load_state_dict(state_dict)
|
| 397 |
+
return model.eval()
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True):
|
| 401 |
+
# Rescale the grid of position embeddings when loading from state_dict
|
| 402 |
+
old_pos_embed = state_dict.get('visual.positional_embedding', None)
|
| 403 |
+
if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
|
| 404 |
+
return
|
| 405 |
+
grid_size = to_2tuple(model.visual.grid_size)
|
| 406 |
+
extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
|
| 407 |
+
new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
|
| 408 |
+
if new_seq_len == old_pos_embed.shape[0]:
|
| 409 |
+
return
|
| 410 |
+
|
| 411 |
+
if extra_tokens:
|
| 412 |
+
pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
|
| 413 |
+
else:
|
| 414 |
+
pos_emb_tok, pos_emb_img = None, old_pos_embed
|
| 415 |
+
old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
|
| 416 |
+
|
| 417 |
+
logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
|
| 418 |
+
pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
|
| 419 |
+
pos_emb_img = F.interpolate(
|
| 420 |
+
pos_emb_img,
|
| 421 |
+
size=grid_size,
|
| 422 |
+
mode=interpolation,
|
| 423 |
+
antialias=antialias,
|
| 424 |
+
align_corners=False,
|
| 425 |
+
)
|
| 426 |
+
pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
|
| 427 |
+
if pos_emb_tok is not None:
|
| 428 |
+
new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
|
| 429 |
+
else:
|
| 430 |
+
new_pos_embed = pos_emb_img
|
| 431 |
+
state_dict['visual.positional_embedding'] = new_pos_embed
|
model_sanity_check.ipynb
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"id": "ba945813",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"%load_ext autoreload\n",
|
| 11 |
+
"%autoreload 2"
|
| 12 |
+
]
|
| 13 |
+
},
|
| 14 |
+
{
|
| 15 |
+
"cell_type": "code",
|
| 16 |
+
"execution_count": 2,
|
| 17 |
+
"id": "e7cec94e",
|
| 18 |
+
"metadata": {},
|
| 19 |
+
"outputs": [],
|
| 20 |
+
"source": [
|
| 21 |
+
"\n",
|
| 22 |
+
"import os, json, math, torch, tqdm\n",
|
| 23 |
+
"from pathlib import Path\n",
|
| 24 |
+
"from torchvision import transforms\n",
|
| 25 |
+
"from torchvision.datasets import ImageFolder\n",
|
| 26 |
+
"from torch.utils.data import DataLoader\n",
|
| 27 |
+
"from transformers import CLIPProcessor, CLIPModel\n",
|
| 28 |
+
"import os\n",
|
| 29 |
+
"import itertools\n",
|
| 30 |
+
"\n",
|
| 31 |
+
"import torch\n",
|
| 32 |
+
"import numpy as np\n",
|
| 33 |
+
"\n",
|
| 34 |
+
"import transformers\n",
|
| 35 |
+
"from transformers import AutoModel, AutoProcessor, CLIPForImageClassification, AutoConfig, AutoTokenizer\n",
|
| 36 |
+
"from torchvision import transforms\n",
|
| 37 |
+
"from torchvision.datasets import ImageNet\n",
|
| 38 |
+
"from torch.utils.data import Subset\n",
|
| 39 |
+
"from tqdm import tqdm\n",
|
| 40 |
+
"from PIL import Image\n",
|
| 41 |
+
"import matplotlib.ticker as mticker\n",
|
| 42 |
+
"import matplotlib.pyplot as plt\n",
|
| 43 |
+
"from mpl_toolkits.mplot3d import Axes3D # noqa: F401 – 3D 기능 활성화\n",
|
| 44 |
+
"import inspect\n",
|
| 45 |
+
"import torch.nn.functional as F\n",
|
| 46 |
+
"import torchvision.transforms.functional as VF\n",
|
| 47 |
+
"import tqdm\n",
|
| 48 |
+
"\n",
|
| 49 |
+
"\n",
|
| 50 |
+
"from functools import partial\n",
|
| 51 |
+
"\n",
|
| 52 |
+
"from torchvision import transforms\n",
|
| 53 |
+
"from torchvision.transforms import InterpolationMode\n",
|
| 54 |
+
"\n",
|
| 55 |
+
"from tqdm import tqdm\n",
|
| 56 |
+
"\n",
|
| 57 |
+
"import yaml\n",
|
| 58 |
+
"from pathlib import Path\n",
|
| 59 |
+
"\n",
|
| 60 |
+
"import sys\n",
|
| 61 |
+
"import os\n",
|
| 62 |
+
"\n",
|
| 63 |
+
"from imagenet_classes import *\n"
|
| 64 |
+
]
|
| 65 |
+
},
|
| 66 |
+
{
|
| 67 |
+
"cell_type": "code",
|
| 68 |
+
"execution_count": null,
|
| 69 |
+
"id": "b4c7a750",
|
| 70 |
+
"metadata": {},
|
| 71 |
+
"outputs": [
|
| 72 |
+
{
|
| 73 |
+
"name": "stdout",
|
| 74 |
+
"output_type": "stream",
|
| 75 |
+
"text": [
|
| 76 |
+
"Pretrained path from config: /workspace/code/clipL336_TTR\n",
|
| 77 |
+
"✓ Added '/workspace/code/clipL336_TTR' to Python path.\n",
|
| 78 |
+
"✓ Successfully imported 'model' from '/workspace/code/clipL336_TTR'\n",
|
| 79 |
+
"Building vision tower with config: CLIPVisionCfg(layers=24, width=1024, head_width=64, mlp_ratio=4.0, patch_size=14, image_size=336, ls_init_value=None, patch_dropout=0.0, input_patchnorm=False, global_average_pool=False, attentional_pool=False, n_queries=256, attn_pooler_heads=8, output_tokens=False, timm_model_name=None, timm_model_pretrained=False, timm_pool='avg', timm_proj='linear', timm_proj_bias=False, timm_drop=0.0, timm_drop_path=None)\n",
|
| 80 |
+
"✓ Added '/workspace/data/cache/huggingface/modules/transformers_modules/clipL336_TTR' to Python path.\n",
|
| 81 |
+
"✓ Successfully imported 'tokenizer' from '/workspace/data/cache/huggingface/modules/transformers_modules/clipL336_TTR'\n",
|
| 82 |
+
"Custom CLIP model loaded successfully!\n"
|
| 83 |
+
]
|
| 84 |
+
}
|
| 85 |
+
],
|
| 86 |
+
"source": [
|
| 87 |
+
"# 문제의 원인이 text encoder가 고장이 나 있었다..\n",
|
| 88 |
+
"device = \"cuda:7\"\n",
|
| 89 |
+
"model_path = \"/workspace/code/clipL336_TTR\"\n",
|
| 90 |
+
"\n",
|
| 91 |
+
"exp_cfg = AutoConfig.from_pretrained(\"/workspace/code/clipL336_TTR\", trust_remote_code=True)\n",
|
| 92 |
+
"exp_cfg.pretrained_path = model_path \n",
|
| 93 |
+
"# model = AutoModel.from_pretrained(model_path, trust_remote_code=True, local_files_only=True)\n",
|
| 94 |
+
"model = AutoModel.from_pretrained(pretrained_model_name_or_path=model_path, config=exp_cfg, trust_remote_code=True, local_files_only=True)\n",
|
| 95 |
+
"# 여기 load 되었는 지 확인할 필요 있음\n",
|
| 96 |
+
"model = model.to(device)\n",
|
| 97 |
+
"preprocessor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True, local_files_only=True)\n",
|
| 98 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, local_files_only=True)\n",
|
| 99 |
+
"# tokenizer랑 preprocessor 가져오기\n",
|
| 100 |
+
"\n",
|
| 101 |
+
"clip_transform = lambda image: preprocessor.image_processor(image, return_tensors=\"pt\")['pixel_values'].squeeze(0) # 와 이렇게 활용할 방법은 생각도 못했네\n",
|
| 102 |
+
"model_clip = AutoModel.from_pretrained(\"openai/clip-vit-large-patch14-336\").to(device).half()"
|
| 103 |
+
]
|
| 104 |
+
},
|
| 105 |
+
{
|
| 106 |
+
"cell_type": "code",
|
| 107 |
+
"execution_count": 4,
|
| 108 |
+
"id": "ed3cbfdc",
|
| 109 |
+
"metadata": {},
|
| 110 |
+
"outputs": [
|
| 111 |
+
{
|
| 112 |
+
"name": "stderr",
|
| 113 |
+
"output_type": "stream",
|
| 114 |
+
"text": [
|
| 115 |
+
"100%|██████████| 1000/1000 [00:23<00:00, 41.71it/s]"
|
| 116 |
+
]
|
| 117 |
+
},
|
| 118 |
+
{
|
| 119 |
+
"name": "stdout",
|
| 120 |
+
"output_type": "stream",
|
| 121 |
+
"text": [
|
| 122 |
+
"Built text features: torch.Size([768, 1000])\n"
|
| 123 |
+
]
|
| 124 |
+
},
|
| 125 |
+
{
|
| 126 |
+
"name": "stderr",
|
| 127 |
+
"output_type": "stream",
|
| 128 |
+
"text": [
|
| 129 |
+
"\n"
|
| 130 |
+
]
|
| 131 |
+
}
|
| 132 |
+
],
|
| 133 |
+
"source": [
|
| 134 |
+
"# langauge head\n",
|
| 135 |
+
"### zeroshot head construction (text encoding) ###\n",
|
| 136 |
+
"with torch.no_grad():\n",
|
| 137 |
+
" zeroshot_weight = []\n",
|
| 138 |
+
" for classname in tqdm(IMAGENET_CLASSNAMES):\n",
|
| 139 |
+
" texts = [template(classname) for template in OPENAI_IMAGENET_TEMPLATES]\n",
|
| 140 |
+
" text_inputs = preprocessor(text=texts, return_tensors=\"pt\", padding=\"max_length\").to(device)\n",
|
| 141 |
+
" # text_inputs = model.tokenize(texts).to(device)\n",
|
| 142 |
+
" # text_features = model.encode_text(text_inputs.input_ids)\n",
|
| 143 |
+
" text_features = model_clip.get_text_features(**text_inputs)\n",
|
| 144 |
+
" text_feature = F.normalize(text_features, dim=-1).mean(dim=0)\n",
|
| 145 |
+
" # text_feature = text_features.mean(dim=0)\n",
|
| 146 |
+
" text_feature = text_feature / text_feature.norm()\n",
|
| 147 |
+
" zeroshot_weight.append(text_feature)\n",
|
| 148 |
+
" \n",
|
| 149 |
+
" text_features = torch.stack(zeroshot_weight, dim=1).to(device)\n",
|
| 150 |
+
"print(\"Built text features:\", text_features.shape)"
|
| 151 |
+
]
|
| 152 |
+
},
|
| 153 |
+
{
|
| 154 |
+
"cell_type": "code",
|
| 155 |
+
"execution_count": 10,
|
| 156 |
+
"id": "e1bd37d1",
|
| 157 |
+
"metadata": {},
|
| 158 |
+
"outputs": [],
|
| 159 |
+
"source": [
|
| 160 |
+
"torch.save(text_features, \"./zeroshot_classifier.pt\")"
|
| 161 |
+
]
|
| 162 |
+
},
|
| 163 |
+
{
|
| 164 |
+
"cell_type": "code",
|
| 165 |
+
"execution_count": 5,
|
| 166 |
+
"id": "dbfeaedf",
|
| 167 |
+
"metadata": {},
|
| 168 |
+
"outputs": [],
|
| 169 |
+
"source": [
|
| 170 |
+
"imagenet_dataset = ImageNet(root='/workspace/data/imagenet', split='val', transform=clip_transform)\n",
|
| 171 |
+
"eval_loader = torch.utils.data.DataLoader(imagenet_dataset, batch_size=128, num_workers=16, pin_memory=False, shuffle=False)"
|
| 172 |
+
]
|
| 173 |
+
},
|
| 174 |
+
{
|
| 175 |
+
"cell_type": "code",
|
| 176 |
+
"execution_count": 6,
|
| 177 |
+
"id": "b0000195",
|
| 178 |
+
"metadata": {},
|
| 179 |
+
"outputs": [],
|
| 180 |
+
"source": [
|
| 181 |
+
"import numpy as np\n",
|
| 182 |
+
"\n",
|
| 183 |
+
"half = \"torch.bfloat16\"\n",
|
| 184 |
+
"def evaluate(model, loader, text_feats, max_samples: int | None = None):\n",
|
| 185 |
+
" model.eval()\n",
|
| 186 |
+
" top1 = top5 = n = 0\n",
|
| 187 |
+
" pbar = tqdm(loader, desc=\"Evaluating\", unit=\"batch\")\n",
|
| 188 |
+
" with torch.no_grad():\n",
|
| 189 |
+
" for images, labels in pbar:\n",
|
| 190 |
+
" if max_samples and n >= max_samples:\n",
|
| 191 |
+
" break\n",
|
| 192 |
+
" images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)\n",
|
| 193 |
+
" with torch.autocast(device_type=\"cuda\"):\n",
|
| 194 |
+
" # 여기 test-time 가공 함수 구현 필요\n",
|
| 195 |
+
" feats = model.encode_image(images)\n",
|
| 196 |
+
"\n",
|
| 197 |
+
" feats = feats / feats.norm(dim=-1, keepdim=True)\n",
|
| 198 |
+
" logits = model.model.logit_scale.exp() * feats @ text_feats \n",
|
| 199 |
+
" _, pred = logits.topk(5, dim=-1)\n",
|
| 200 |
+
" top1 += (pred[:, :1] == labels.unsqueeze(1)).sum().item()\n",
|
| 201 |
+
" top5 += (pred == labels.unsqueeze(1)).sum().item()\n",
|
| 202 |
+
" n += images.size(0)\n",
|
| 203 |
+
" pbar.set_postfix(samples=n, top1=top1 / n * 100, top5=top5 / n * 100)\n",
|
| 204 |
+
" return top1 / n * 100, top5 / n * 100\n"
|
| 205 |
+
]
|
| 206 |
+
},
|
| 207 |
+
{
|
| 208 |
+
"cell_type": "code",
|
| 209 |
+
"execution_count": 8,
|
| 210 |
+
"id": "8795b394",
|
| 211 |
+
"metadata": {},
|
| 212 |
+
"outputs": [
|
| 213 |
+
{
|
| 214 |
+
"name": "stderr",
|
| 215 |
+
"output_type": "stream",
|
| 216 |
+
"text": [
|
| 217 |
+
"Evaluating: 0%| | 0/391 [00:00<?, ?batch/s]"
|
| 218 |
+
]
|
| 219 |
+
},
|
| 220 |
+
{
|
| 221 |
+
"name": "stderr",
|
| 222 |
+
"output_type": "stream",
|
| 223 |
+
"text": [
|
| 224 |
+
"Evaluating: 100%|██████████| 391/391 [10:38<00:00, 1.63s/batch, samples=5e+4, top1=74.9, top5=94.4] "
|
| 225 |
+
]
|
| 226 |
+
},
|
| 227 |
+
{
|
| 228 |
+
"name": "stdout",
|
| 229 |
+
"output_type": "stream",
|
| 230 |
+
"text": [
|
| 231 |
+
"Baseline (Top‑1 / Top‑5) on 50,000 imgs: 74.87% / 94.37%\n"
|
| 232 |
+
]
|
| 233 |
+
},
|
| 234 |
+
{
|
| 235 |
+
"name": "stderr",
|
| 236 |
+
"output_type": "stream",
|
| 237 |
+
"text": [
|
| 238 |
+
"\n"
|
| 239 |
+
]
|
| 240 |
+
}
|
| 241 |
+
],
|
| 242 |
+
"source": [
|
| 243 |
+
"\n",
|
| 244 |
+
"### baseline evaluator ###\n",
|
| 245 |
+
"### 이거는 지금 당장은 못 써먹는다... 미친 너무 느리다 어디서 문제지 ###\n",
|
| 246 |
+
"# 씨발 이번에 뭐지\n",
|
| 247 |
+
"# architecture define이 어딘가에서 손상 된 것으로 보인다\n",
|
| 248 |
+
"# 성능 reproduce...\n",
|
| 249 |
+
"\n",
|
| 250 |
+
"BASELINE_SAMPLES = 50000 # set to None for full 50 k\n",
|
| 251 |
+
"acc1, acc5 = evaluate(model, eval_loader, text_features, max_samples=BASELINE_SAMPLES)\n",
|
| 252 |
+
"print(f\"Baseline (Top‑1 / Top‑5) on {BASELINE_SAMPLES or len(imagenet_dataset):,} imgs: {acc1:.2f}% / {acc5:.2f}%\")"
|
| 253 |
+
]
|
| 254 |
+
},
|
| 255 |
+
{
|
| 256 |
+
"cell_type": "code",
|
| 257 |
+
"execution_count": null,
|
| 258 |
+
"id": "4aa82bb4",
|
| 259 |
+
"metadata": {},
|
| 260 |
+
"outputs": [],
|
| 261 |
+
"source": []
|
| 262 |
+
}
|
| 263 |
+
],
|
| 264 |
+
"metadata": {
|
| 265 |
+
"kernelspec": {
|
| 266 |
+
"display_name": "base",
|
| 267 |
+
"language": "python",
|
| 268 |
+
"name": "python3"
|
| 269 |
+
},
|
| 270 |
+
"language_info": {
|
| 271 |
+
"codemirror_mode": {
|
| 272 |
+
"name": "ipython",
|
| 273 |
+
"version": 3
|
| 274 |
+
},
|
| 275 |
+
"file_extension": ".py",
|
| 276 |
+
"mimetype": "text/x-python",
|
| 277 |
+
"name": "python",
|
| 278 |
+
"nbconvert_exporter": "python",
|
| 279 |
+
"pygments_lexer": "ipython3",
|
| 280 |
+
"version": "3.10.14"
|
| 281 |
+
}
|
| 282 |
+
},
|
| 283 |
+
"nbformat": 4,
|
| 284 |
+
"nbformat_minor": 5
|
| 285 |
+
}
|
modeling_custom_clip.py
ADDED
|
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Custom CLIP Model with Register Tokens - Import Safe Version with Complete File Download
|
| 3 |
+
"""
|
| 4 |
+
import transformers
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from transformers import PreTrainedModel, PretrainedConfig
|
| 8 |
+
from transformers.utils import logging
|
| 9 |
+
from typing import Optional, Union, Tuple
|
| 10 |
+
import json
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
import warnings
|
| 13 |
+
import os
|
| 14 |
+
import sys
|
| 15 |
+
import importlib.util
|
| 16 |
+
|
| 17 |
+
# Suppress all warnings during import
|
| 18 |
+
warnings.filterwarnings("ignore")
|
| 19 |
+
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
|
| 20 |
+
|
| 21 |
+
logger = logging.get_logger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def safe_import_from_repo(module_name: str, repo_path: str):
|
| 25 |
+
"""
|
| 26 |
+
지정된 로컬 경로(repo_path)에서 파이썬 모듈을 안전하게 임포트합니다.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
module_name (str): 임포트할 모듈의 이름 (예: 'modeling_clip').
|
| 30 |
+
repo_path (str): 모듈이 포함된 디렉토리의 경로.
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
The imported module object.
|
| 34 |
+
|
| 35 |
+
Raises:
|
| 36 |
+
ValueError: repo_path가 None이거나 유효한 디렉토리가 아닐 경우.
|
| 37 |
+
ImportError: 지정된 경로에서 모듈을 찾을 수 없을 경우.
|
| 38 |
+
"""
|
| 39 |
+
# 1. repo_path가 유효한지 검사합니다.
|
| 40 |
+
if repo_path is None:
|
| 41 |
+
raise ValueError("The 'repo_path' argument cannot be None.")
|
| 42 |
+
|
| 43 |
+
# pathlib.Path 객체로 변환하여 경로를 쉽게 다룰 수 있도록 합니다.
|
| 44 |
+
repo_path_obj = Path(repo_path)
|
| 45 |
+
|
| 46 |
+
if not repo_path_obj.is_dir():
|
| 47 |
+
raise ValueError(
|
| 48 |
+
f"The specified repo_path does not exist or is not a directory: '{repo_path}'")
|
| 49 |
+
|
| 50 |
+
# 2. 파이썬이 모듈을 찾을 수 있도록 해당 경로를 sys.path에 추가합니다.
|
| 51 |
+
# resolve()를 통해 절대 경로를 사용하고, 문자열로 변환합니다.
|
| 52 |
+
absolute_repo_path = str(repo_path_obj.resolve())
|
| 53 |
+
|
| 54 |
+
if absolute_repo_path not in sys.path:
|
| 55 |
+
# sys.path의 맨 앞에 추가하여 다른 경로보다 우선적으로 탐색되도록 합니다.
|
| 56 |
+
sys.path.insert(0, absolute_repo_path)
|
| 57 |
+
print(f"✓ Added '{absolute_repo_path}' to Python path.")
|
| 58 |
+
|
| 59 |
+
# 3. `importlib`을 사용하여 모듈을 동적으로 임포트합니다.
|
| 60 |
+
try:
|
| 61 |
+
module = importlib.import_module(module_name)
|
| 62 |
+
print(
|
| 63 |
+
f"✓ Successfully imported '{module_name}' from '{absolute_repo_path}'")
|
| 64 |
+
return module
|
| 65 |
+
except ImportError:
|
| 66 |
+
# sys.path에 경로를 추가했음에도 임포트에 실패한 경우,
|
| 67 |
+
# 해당 경로에 모듈 파일(.py)이 없다는 의미입니다.
|
| 68 |
+
raise ImportError(
|
| 69 |
+
f"Module '{module_name}' not found inside the specified path: '{absolute_repo_path}'")
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class CustomCLIPConfig(PretrainedConfig):
|
| 73 |
+
model_type = "custom_clip_with_registers"
|
| 74 |
+
|
| 75 |
+
def __init__(
|
| 76 |
+
self,
|
| 77 |
+
vision_config=None,
|
| 78 |
+
text_config=None,
|
| 79 |
+
num_register_tokens=0,
|
| 80 |
+
neuron_dict=None,
|
| 81 |
+
projection_dim=512,
|
| 82 |
+
logit_scale_init_value=2.6592,
|
| 83 |
+
**kwargs
|
| 84 |
+
):
|
| 85 |
+
super().__init__(**kwargs)
|
| 86 |
+
|
| 87 |
+
self.vision_config = vision_config or {}
|
| 88 |
+
self.text_config = text_config or {}
|
| 89 |
+
self.num_register_tokens = num_register_tokens
|
| 90 |
+
self.neuron_dict = neuron_dict
|
| 91 |
+
self.projection_dim = projection_dim
|
| 92 |
+
self.logit_scale_init_value = logit_scale_init_value
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class CustomCLIPModel(PreTrainedModel):
|
| 96 |
+
config_class = CustomCLIPConfig
|
| 97 |
+
|
| 98 |
+
def __init__(self, config):
|
| 99 |
+
super().__init__(config)
|
| 100 |
+
|
| 101 |
+
# Safe import of custom modules
|
| 102 |
+
try:
|
| 103 |
+
# to strictly load from the local library
|
| 104 |
+
pretrained_path: str | None = getattr(
|
| 105 |
+
config, "pretrained_path", None)
|
| 106 |
+
if pretrained_path is None:
|
| 107 |
+
raise ValueError(
|
| 108 |
+
"The config must have a 'pretrained_path' attribute pointing to the local repository path.")
|
| 109 |
+
else:
|
| 110 |
+
print(f"Pretrained path from config: {pretrained_path}")
|
| 111 |
+
|
| 112 |
+
model_module = safe_import_from_repo('model', pretrained_path)
|
| 113 |
+
self.CLIP = model_module.CLIP
|
| 114 |
+
self.CLIPVisionCfg = model_module.CLIPVisionCfg
|
| 115 |
+
self.CLIPTextCfg = model_module.CLIPTextCfg
|
| 116 |
+
except ImportError as e:
|
| 117 |
+
raise ImportError(
|
| 118 |
+
f"Could not import model components: {e}. Make sure all model files are in the repository.")
|
| 119 |
+
|
| 120 |
+
# Create vision and text configs
|
| 121 |
+
vision_cfg = self.CLIPVisionCfg(
|
| 122 |
+
layers=config.vision_config.get("num_hidden_layers", 12),
|
| 123 |
+
width=config.vision_config.get("hidden_size", 768),
|
| 124 |
+
patch_size=config.vision_config.get("patch_size", 16),
|
| 125 |
+
image_size=config.vision_config.get("image_size", 224),
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
text_cfg = self.CLIPTextCfg(
|
| 129 |
+
context_length=config.text_config.get(
|
| 130 |
+
"max_position_embeddings", 77),
|
| 131 |
+
vocab_size=config.text_config.get("vocab_size", 49408),
|
| 132 |
+
width=config.text_config.get("hidden_size", 512),
|
| 133 |
+
layers=config.text_config.get("num_hidden_layers", 12),
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
# Initialize your custom CLIP model
|
| 137 |
+
self.model = self.CLIP(
|
| 138 |
+
embed_dim=config.projection_dim,
|
| 139 |
+
vision_cfg=vision_cfg,
|
| 140 |
+
text_cfg=text_cfg,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
# These will be set when loading the state dict
|
| 144 |
+
# 여기 statedict에서 load하면 않된다. configuration에 떡하니 있으면서 무슨 짓거리냐
|
| 145 |
+
self.neuron_dict = config.neuron_dict
|
| 146 |
+
if self.neuron_dict is None:
|
| 147 |
+
raise ValueError("neuron_dict must be provided in the config.")
|
| 148 |
+
self.num_register_tokens = config.num_register_tokens
|
| 149 |
+
|
| 150 |
+
# These will be loaded separately
|
| 151 |
+
self._tokenizer = None
|
| 152 |
+
self._preprocessor = None
|
| 153 |
+
self._zeroshot_classifier = None
|
| 154 |
+
|
| 155 |
+
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
| 156 |
+
"""Override to handle custom parameters and load weights properly"""
|
| 157 |
+
|
| 158 |
+
# Extract custom parameters first
|
| 159 |
+
if 'neuron_dict' in state_dict:
|
| 160 |
+
self.neuron_dict = state_dict.pop('neuron_dict')
|
| 161 |
+
|
| 162 |
+
if 'num_register_tokens' in state_dict:
|
| 163 |
+
self.num_register_tokens = state_dict.pop('num_register_tokens')
|
| 164 |
+
|
| 165 |
+
# Set these values in the model
|
| 166 |
+
if hasattr(self.model, 'visual'):
|
| 167 |
+
self.model.visual.num_register_tokens = self.num_register_tokens
|
| 168 |
+
self.model.visual.neuron_dict = self.neuron_dict
|
| 169 |
+
self.model.num_register_tokens = self.num_register_tokens
|
| 170 |
+
self.model.neuron_dict = self.neuron_dict
|
| 171 |
+
|
| 172 |
+
# Load the weights properly - suppress ALL warnings and errors
|
| 173 |
+
with warnings.catch_warnings():
|
| 174 |
+
warnings.simplefilter("ignore")
|
| 175 |
+
|
| 176 |
+
# Temporarily set logging to critical only
|
| 177 |
+
original_level = logging.get_verbosity()
|
| 178 |
+
logging.set_verbosity_error()
|
| 179 |
+
|
| 180 |
+
try:
|
| 181 |
+
# Load weights directly into self.model
|
| 182 |
+
missing, unexpected = self.model.load_state_dict(
|
| 183 |
+
state_dict, strict=False)
|
| 184 |
+
|
| 185 |
+
# Don't report any missing/unexpected keys to avoid warnings
|
| 186 |
+
|
| 187 |
+
except Exception as e:
|
| 188 |
+
# If direct loading fails, try the parent method silently
|
| 189 |
+
super()._load_from_state_dict(state_dict, prefix, local_metadata, False, [], [], [])
|
| 190 |
+
finally:
|
| 191 |
+
# Restore logging level
|
| 192 |
+
logging.set_verbosity(original_level)
|
| 193 |
+
|
| 194 |
+
@classmethod
|
| 195 |
+
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
| 196 |
+
"""Override to load cleanly and suppress warnings"""
|
| 197 |
+
|
| 198 |
+
# Suppress warnings during loading
|
| 199 |
+
with warnings.catch_warnings():
|
| 200 |
+
warnings.simplefilter("ignore")
|
| 201 |
+
|
| 202 |
+
# Temporarily suppress transformers logging
|
| 203 |
+
original_level = logging.get_verbosity()
|
| 204 |
+
logging.set_verbosity_error()
|
| 205 |
+
|
| 206 |
+
try:
|
| 207 |
+
# Load the model
|
| 208 |
+
model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
| 209 |
+
finally:
|
| 210 |
+
# Restore logging
|
| 211 |
+
logging.set_verbosity(original_level)
|
| 212 |
+
|
| 213 |
+
# Load additional components
|
| 214 |
+
model._load_additional_components(pretrained_model_name_or_path)
|
| 215 |
+
|
| 216 |
+
# Print clean success message
|
| 217 |
+
print("Custom CLIP model loaded successfully!")
|
| 218 |
+
|
| 219 |
+
return model
|
| 220 |
+
|
| 221 |
+
def _load_additional_components(self, pretrained_model_name_or_path):
|
| 222 |
+
"""Load tokenizer, preprocessor, and zero-shot classifier silently"""
|
| 223 |
+
|
| 224 |
+
try:
|
| 225 |
+
from huggingface_hub import hf_hub_download
|
| 226 |
+
|
| 227 |
+
# Load tokenizer
|
| 228 |
+
try:
|
| 229 |
+
# Safe import of tokenizer
|
| 230 |
+
tokenizer_module = safe_import_from_repo(
|
| 231 |
+
'tokenizer', Path(__file__).parent)
|
| 232 |
+
self._tokenizer = tokenizer_module.SimpleTokenizer()
|
| 233 |
+
except ImportError:
|
| 234 |
+
# If tokenizer import fails, create a dummy tokenizer message
|
| 235 |
+
pass
|
| 236 |
+
|
| 237 |
+
# Load preprocessor
|
| 238 |
+
try:
|
| 239 |
+
preprocess_config_file = hf_hub_download(
|
| 240 |
+
repo_id=pretrained_model_name_or_path,
|
| 241 |
+
filename="preprocessor_config.json"
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
with open(preprocess_config_file, 'r') as f:
|
| 245 |
+
preprocess_config = json.load(f)
|
| 246 |
+
|
| 247 |
+
self._create_preprocessor(preprocess_config)
|
| 248 |
+
except:
|
| 249 |
+
pass
|
| 250 |
+
|
| 251 |
+
# Load zero-shot classifier
|
| 252 |
+
try:
|
| 253 |
+
classifier_file = hf_hub_download(
|
| 254 |
+
repo_id=pretrained_model_name_or_path,
|
| 255 |
+
filename="zeroshot_classifier.pt"
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
# Suppress the torch.load warning
|
| 259 |
+
with warnings.catch_warnings():
|
| 260 |
+
warnings.simplefilter("ignore")
|
| 261 |
+
self._zeroshot_classifier = torch.load(
|
| 262 |
+
classifier_file, map_location='cpu', weights_only=False)
|
| 263 |
+
except:
|
| 264 |
+
pass
|
| 265 |
+
|
| 266 |
+
except:
|
| 267 |
+
pass
|
| 268 |
+
|
| 269 |
+
def _create_preprocessor(self, config):
|
| 270 |
+
"""Create image preprocessor from config"""
|
| 271 |
+
try:
|
| 272 |
+
from torchvision import transforms
|
| 273 |
+
|
| 274 |
+
self._preprocessor = transforms.Compose([
|
| 275 |
+
transforms.Resize(
|
| 276 |
+
config["image_size"], interpolation=transforms.InterpolationMode.BICUBIC),
|
| 277 |
+
transforms.CenterCrop(config["image_size"]),
|
| 278 |
+
transforms.ToTensor(),
|
| 279 |
+
transforms.Normalize(
|
| 280 |
+
mean=config["image_mean"], std=config["image_std"]),
|
| 281 |
+
])
|
| 282 |
+
except:
|
| 283 |
+
pass
|
| 284 |
+
|
| 285 |
+
@property
|
| 286 |
+
def tokenizer(self):
|
| 287 |
+
"""Access the tokenizer"""
|
| 288 |
+
return self._tokenizer
|
| 289 |
+
|
| 290 |
+
@property
|
| 291 |
+
def preprocessor(self):
|
| 292 |
+
"""Access the image preprocessor"""
|
| 293 |
+
return self._preprocessor
|
| 294 |
+
|
| 295 |
+
@property
|
| 296 |
+
def zeroshot_classifier(self):
|
| 297 |
+
"""Access the zero-shot classifier"""
|
| 298 |
+
return self._zeroshot_classifier
|
| 299 |
+
|
| 300 |
+
def tokenize(self, texts, context_length=77):
|
| 301 |
+
"""Tokenize text using the loaded tokenizer"""
|
| 302 |
+
if self._tokenizer is None:
|
| 303 |
+
raise ValueError(
|
| 304 |
+
"Tokenizer not available. Make sure tokenizer.py is in the repository.")
|
| 305 |
+
|
| 306 |
+
# Safe import of tokenize function
|
| 307 |
+
try:
|
| 308 |
+
tokenizer_module = safe_import_from_repo(
|
| 309 |
+
'tokenizer', Path(__file__).parent)
|
| 310 |
+
return tokenizer_module.tokenize(texts, context_length)
|
| 311 |
+
except ImportError:
|
| 312 |
+
raise ValueError("Could not import tokenize function.")
|
| 313 |
+
|
| 314 |
+
def preprocess_image(self, image):
|
| 315 |
+
"""Preprocess image using the loaded preprocessor"""
|
| 316 |
+
if self._preprocessor is None:
|
| 317 |
+
raise ValueError(
|
| 318 |
+
"Preprocessor not loaded. Make sure preprocessor_config.json is in the repository.")
|
| 319 |
+
|
| 320 |
+
return self._preprocessor(image)
|
| 321 |
+
|
| 322 |
+
def forward(self, input_ids=None, pixel_values=None, num_register_tokens=None, neuron_dict=None, **kwargs):
|
| 323 |
+
"""Forward pass supporting your custom functionality"""
|
| 324 |
+
|
| 325 |
+
if num_register_tokens is None:
|
| 326 |
+
num_register_tokens = self.num_register_tokens
|
| 327 |
+
if neuron_dict is None:
|
| 328 |
+
neuron_dict = self.neuron_dict
|
| 329 |
+
|
| 330 |
+
return self.model(
|
| 331 |
+
image=pixel_values,
|
| 332 |
+
text=input_ids,
|
| 333 |
+
num_register_tokens=num_register_tokens,
|
| 334 |
+
neuron_dict=neuron_dict
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
def encode_image(self, pixel_values, num_register_tokens=None, neuron_dict=None, **kwargs):
|
| 338 |
+
"""Encode images with register token support"""
|
| 339 |
+
if num_register_tokens is None:
|
| 340 |
+
num_register_tokens = self.num_register_tokens
|
| 341 |
+
if neuron_dict is None:
|
| 342 |
+
neuron_dict = self.neuron_dict
|
| 343 |
+
|
| 344 |
+
return self.model.encode_image(
|
| 345 |
+
pixel_values,
|
| 346 |
+
num_register_tokens=num_register_tokens,
|
| 347 |
+
neuron_dict=neuron_dict,
|
| 348 |
+
**kwargs
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
def encode_text(self, input_ids, **kwargs):
|
| 352 |
+
"""Encode text"""
|
| 353 |
+
return self.model.encode_text(input_ids, **kwargs)
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
# Auto-suppress warnings at module level
|
| 357 |
+
transformers.logging.set_verbosity_error()
|
modified_resnet.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
|
| 7 |
+
from misc import freeze_batch_norm_2d
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Bottleneck(nn.Module):
|
| 11 |
+
expansion = 4
|
| 12 |
+
|
| 13 |
+
def __init__(self, inplanes, planes, stride=1):
|
| 14 |
+
super().__init__()
|
| 15 |
+
|
| 16 |
+
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
| 17 |
+
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
| 18 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 19 |
+
self.act1 = nn.ReLU(inplace=True)
|
| 20 |
+
|
| 21 |
+
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
| 22 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 23 |
+
self.act2 = nn.ReLU(inplace=True)
|
| 24 |
+
|
| 25 |
+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
| 26 |
+
|
| 27 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
| 28 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
| 29 |
+
self.act3 = nn.ReLU(inplace=True)
|
| 30 |
+
|
| 31 |
+
self.downsample = None
|
| 32 |
+
self.stride = stride
|
| 33 |
+
|
| 34 |
+
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
| 35 |
+
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
| 36 |
+
self.downsample = nn.Sequential(OrderedDict([
|
| 37 |
+
("-1", nn.AvgPool2d(stride)),
|
| 38 |
+
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
| 39 |
+
("1", nn.BatchNorm2d(planes * self.expansion))
|
| 40 |
+
]))
|
| 41 |
+
|
| 42 |
+
def forward(self, x: torch.Tensor):
|
| 43 |
+
identity = x
|
| 44 |
+
|
| 45 |
+
out = self.act1(self.bn1(self.conv1(x)))
|
| 46 |
+
out = self.act2(self.bn2(self.conv2(out)))
|
| 47 |
+
out = self.avgpool(out)
|
| 48 |
+
out = self.bn3(self.conv3(out))
|
| 49 |
+
|
| 50 |
+
if self.downsample is not None:
|
| 51 |
+
identity = self.downsample(x)
|
| 52 |
+
|
| 53 |
+
out += identity
|
| 54 |
+
out = self.act3(out)
|
| 55 |
+
return out
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class AttentionPool2d(nn.Module):
|
| 59 |
+
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
| 62 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
| 63 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
| 64 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
| 65 |
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
| 66 |
+
self.num_heads = num_heads
|
| 67 |
+
|
| 68 |
+
def forward(self, x):
|
| 69 |
+
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
|
| 70 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
| 71 |
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
| 72 |
+
x, _ = F.multi_head_attention_forward(
|
| 73 |
+
query=x, key=x, value=x,
|
| 74 |
+
embed_dim_to_check=x.shape[-1],
|
| 75 |
+
num_heads=self.num_heads,
|
| 76 |
+
q_proj_weight=self.q_proj.weight,
|
| 77 |
+
k_proj_weight=self.k_proj.weight,
|
| 78 |
+
v_proj_weight=self.v_proj.weight,
|
| 79 |
+
in_proj_weight=None,
|
| 80 |
+
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
| 81 |
+
bias_k=None,
|
| 82 |
+
bias_v=None,
|
| 83 |
+
add_zero_attn=False,
|
| 84 |
+
dropout_p=0.,
|
| 85 |
+
out_proj_weight=self.c_proj.weight,
|
| 86 |
+
out_proj_bias=self.c_proj.bias,
|
| 87 |
+
use_separate_proj_weight=True,
|
| 88 |
+
training=self.training,
|
| 89 |
+
need_weights=False
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
return x[0]
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class ModifiedResNet(nn.Module):
|
| 96 |
+
"""
|
| 97 |
+
A ResNet class that is similar to torchvision's but contains the following changes:
|
| 98 |
+
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
| 99 |
+
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
| 100 |
+
- The final pooling layer is a QKV attention instead of an average pool
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
def __init__(self, layers, output_dim, heads, image_size=224, width=64):
|
| 104 |
+
super().__init__()
|
| 105 |
+
self.output_dim = output_dim
|
| 106 |
+
self.image_size = image_size
|
| 107 |
+
|
| 108 |
+
# the 3-layer stem
|
| 109 |
+
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
| 110 |
+
self.bn1 = nn.BatchNorm2d(width // 2)
|
| 111 |
+
self.act1 = nn.ReLU(inplace=True)
|
| 112 |
+
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
| 113 |
+
self.bn2 = nn.BatchNorm2d(width // 2)
|
| 114 |
+
self.act2 = nn.ReLU(inplace=True)
|
| 115 |
+
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
| 116 |
+
self.bn3 = nn.BatchNorm2d(width)
|
| 117 |
+
self.act3 = nn.ReLU(inplace=True)
|
| 118 |
+
self.avgpool = nn.AvgPool2d(2)
|
| 119 |
+
|
| 120 |
+
# residual layers
|
| 121 |
+
self._inplanes = width # this is a *mutable* variable used during construction
|
| 122 |
+
self.layer1 = self._make_layer(width, layers[0])
|
| 123 |
+
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
| 124 |
+
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
| 125 |
+
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
| 126 |
+
|
| 127 |
+
embed_dim = width * 32 # the ResNet feature dimension
|
| 128 |
+
self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
|
| 129 |
+
|
| 130 |
+
self.init_parameters()
|
| 131 |
+
|
| 132 |
+
def _make_layer(self, planes, blocks, stride=1):
|
| 133 |
+
layers = [Bottleneck(self._inplanes, planes, stride)]
|
| 134 |
+
|
| 135 |
+
self._inplanes = planes * Bottleneck.expansion
|
| 136 |
+
for _ in range(1, blocks):
|
| 137 |
+
layers.append(Bottleneck(self._inplanes, planes))
|
| 138 |
+
|
| 139 |
+
return nn.Sequential(*layers)
|
| 140 |
+
|
| 141 |
+
def init_parameters(self):
|
| 142 |
+
if self.attnpool is not None:
|
| 143 |
+
std = self.attnpool.c_proj.in_features ** -0.5
|
| 144 |
+
nn.init.normal_(self.attnpool.q_proj.weight, std=std)
|
| 145 |
+
nn.init.normal_(self.attnpool.k_proj.weight, std=std)
|
| 146 |
+
nn.init.normal_(self.attnpool.v_proj.weight, std=std)
|
| 147 |
+
nn.init.normal_(self.attnpool.c_proj.weight, std=std)
|
| 148 |
+
|
| 149 |
+
for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
|
| 150 |
+
for name, param in resnet_block.named_parameters():
|
| 151 |
+
if name.endswith("bn3.weight"):
|
| 152 |
+
nn.init.zeros_(param)
|
| 153 |
+
|
| 154 |
+
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
|
| 155 |
+
assert unlocked_groups == 0, 'partial locking not currently supported for this model'
|
| 156 |
+
for param in self.parameters():
|
| 157 |
+
param.requires_grad = False
|
| 158 |
+
if freeze_bn_stats:
|
| 159 |
+
freeze_batch_norm_2d(self)
|
| 160 |
+
|
| 161 |
+
@torch.jit.ignore
|
| 162 |
+
def set_grad_checkpointing(self, enable=True):
|
| 163 |
+
# FIXME support for non-transformer
|
| 164 |
+
pass
|
| 165 |
+
|
| 166 |
+
def stem(self, x):
|
| 167 |
+
x = self.act1(self.bn1(self.conv1(x)))
|
| 168 |
+
x = self.act2(self.bn2(self.conv2(x)))
|
| 169 |
+
x = self.act3(self.bn3(self.conv3(x)))
|
| 170 |
+
x = self.avgpool(x)
|
| 171 |
+
return x
|
| 172 |
+
|
| 173 |
+
def forward(self, x):
|
| 174 |
+
x = self.stem(x)
|
| 175 |
+
x = self.layer1(x)
|
| 176 |
+
x = self.layer2(x)
|
| 177 |
+
x = self.layer3(x)
|
| 178 |
+
x = self.layer4(x)
|
| 179 |
+
x = self.attnpool(x)
|
| 180 |
+
|
| 181 |
+
return x
|
neuron_indices.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
[[12, 42, 39.99140930175781], [12, 983, 34.50058364868164], [12, 3868, 23.993741989135742], [12, 2687, 23.192779541015625], [11, 3784, 14.847213745117188], [11, 987, 14.675474166870117], [11, 3661, 14.301347732543945], [12, 3008, 12.25265884399414], [11, 1967, 11.993508338928223], [12, 3002, 10.681584358215332], [11, 9, 9.61478042602539], [21, 1801, 8.448626518249512], [11, 2555, 6.903197288513184], [11, 1100, 6.859874725341797], [12, 1571, 4.70828104019165], [22, 901, 3.453416109085083], [21, 1550, 3.4134912490844727], [12, 1816, 3.37734055519104], [12, 183, 3.1418349742889404], [8, 745, 3.1221530437469482], [9, 4078, 3.0656824111938477], [9, 815, 3.0607407093048096], [10, 357, 2.7818374633789062], [9, 3618, 2.690423011779785], [10, 1654, 2.6796107292175293], [22, 2184, 2.6561291217803955], [10, 3940, 2.652881383895874], [7, 3228, 2.46209979057312], [10, 2319, 2.308473825454712], [9, 2693, 2.1979129314422607], [21, 1779, 2.1429498195648193], [20, 3077, 2.1137425899505615], [20, 2634, 2.04282808303833], [9, 3973, 2.031193733215332], [21, 3137, 2.026745080947876], [8, 3249, 1.9856672286987305], [8, 2585, 1.9620095491409302], [9, 1983, 1.9459211826324463], [9, 1744, 1.9378128051757812], [9, 1157, 1.749971866607666], [21, 2412, 1.7358660697937012], [10, 2560, 1.6931447982788086], [7, 2550, 1.6547895669937134], [21, 1381, 1.5941085815429688], [22, 1317, 1.560852289199829], [8, 1537, 1.5494486093521118], [8, 200, 1.4573794603347778], [19, 1881, 1.4518368244171143], [8, 1603, 1.416003704071045], [8, 1851, 1.3301061391830444], [8, 3523, 1.321004867553711], [12, 2780, 1.2789242267608643], [13, 1109, 1.2571412324905396], [10, 2559, 1.2549676895141602], [9, 1309, 1.238487958908081], [21, 2193, 1.2044764757156372], [17, 1868, 1.1777989864349365], [21, 1796, 1.1429805755615234], [10, 4009, 1.0898690223693848], [9, 1335, 1.0648274421691895], [22, 2889, 1.0604228973388672], [11, 888, 1.0271515846252441], [15, 415, 1.0207806825637817], [21, 68, 1.0149273872375488], [9, 3049, 0.9941853880882263], [9, 2607, 0.9631124138832092], [9, 2621, 0.954177737236023], [18, 1283, 0.9397207498550415], [9, 2396, 0.9153910875320435], [22, 797, 0.8976885676383972], [12, 2370, 0.8916781544685364], [22, 3026, 0.8911128044128418], [10, 3029, 0.864679217338562], [19, 2881, 0.8607441782951355], [9, 1610, 0.8600460886955261], [22, 3143, 0.8545671105384827], [19, 1149, 0.8446468114852905], [22, 806, 0.8359670042991638], [20, 676, 0.8346958756446838], [18, 3018, 0.8332728147506714], [22, 2714, 0.8295024037361145], [9, 2867, 0.813927948474884], [22, 3888, 0.8113453388214111], [8, 3697, 0.8057175278663635], [22, 1832, 0.7937105894088745], [22, 985, 0.7906701564788818], [22, 3361, 0.783061683177948], [9, 2394, 0.7818043231964111], [22, 3049, 0.7765958309173584], [8, 3137, 0.772114098072052], [10, 951, 0.7676676511764526], [11, 3568, 0.7665989398956299], [8, 2563, 0.7626394629478455], [23, 1137, 0.7513239979743958], [17, 604, 0.7489021420478821], [9, 1924, 0.7423470616340637], [19, 2106, 0.7369621992111206], [9, 2012, 0.7241123914718628], [10, 1903, 0.7238287329673767], [12, 3574, 0.7192695736885071]]
|
openai_models.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" OpenAI pretrained model functions
|
| 2 |
+
|
| 3 |
+
Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import warnings
|
| 8 |
+
from typing import List, Optional, Union
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
|
| 13 |
+
from model import build_model_from_openai_state_dict, get_cast_dtype
|
| 14 |
+
from pretrained import *
|
| 15 |
+
|
| 16 |
+
__all__ = ["list_openai_models", "load_openai_model"]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def list_openai_models() -> List[str]:
|
| 20 |
+
"""Returns the names of available CLIP models"""
|
| 21 |
+
return list_pretrained_models_by_tag('openai')
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def load_openai_model(
|
| 25 |
+
name: str,
|
| 26 |
+
precision: Optional[str] = None,
|
| 27 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 28 |
+
cache_dir: Optional[str] = None,
|
| 29 |
+
quick_gelu: Optional[bool] = True
|
| 30 |
+
):
|
| 31 |
+
"""Load a CLIP model
|
| 32 |
+
|
| 33 |
+
Parameters
|
| 34 |
+
----------
|
| 35 |
+
name : str
|
| 36 |
+
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
|
| 37 |
+
precision: str
|
| 38 |
+
Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.
|
| 39 |
+
device : Union[str, torch.device]
|
| 40 |
+
The device to put the loaded model
|
| 41 |
+
cache_dir : Optional[str]
|
| 42 |
+
The directory to cache the downloaded model weights
|
| 43 |
+
|
| 44 |
+
Returns
|
| 45 |
+
-------
|
| 46 |
+
model : torch.nn.Module
|
| 47 |
+
The CLIP model
|
| 48 |
+
preprocess : Callable[[PIL.Image], torch.Tensor]
|
| 49 |
+
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
| 50 |
+
"""
|
| 51 |
+
if device is None:
|
| 52 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 53 |
+
if precision is None:
|
| 54 |
+
precision = 'fp32' if device == 'cpu' else 'fp16'
|
| 55 |
+
|
| 56 |
+
if get_pretrained_url(name, 'openai'):
|
| 57 |
+
model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir)
|
| 58 |
+
elif os.path.isfile(name):
|
| 59 |
+
model_path = name
|
| 60 |
+
else:
|
| 61 |
+
raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}")
|
| 62 |
+
|
| 63 |
+
try:
|
| 64 |
+
# loading JIT archive
|
| 65 |
+
model = torch.jit.load(model_path, map_location="cpu").eval()
|
| 66 |
+
state_dict = None
|
| 67 |
+
except RuntimeError:
|
| 68 |
+
# loading saved state dict
|
| 69 |
+
state_dict = torch.load(model_path, map_location="cpu")
|
| 70 |
+
|
| 71 |
+
# Build a non-jit model from the OpenAI jitted model state dict
|
| 72 |
+
cast_dtype = get_cast_dtype(precision)
|
| 73 |
+
try:
|
| 74 |
+
model = build_model_from_openai_state_dict(state_dict or model.state_dict(), quick_gelu=quick_gelu, cast_dtype=cast_dtype)
|
| 75 |
+
except KeyError:
|
| 76 |
+
sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
|
| 77 |
+
model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype, quick_gelu=quick_gelu)
|
| 78 |
+
|
| 79 |
+
# model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use
|
| 80 |
+
model = model.to(device)
|
| 81 |
+
# FIXME support pure fp16/bf16 precision modes
|
| 82 |
+
if precision != 'fp16':
|
| 83 |
+
model.float()
|
| 84 |
+
if precision == 'bf16':
|
| 85 |
+
# for bf16, convert back to low-precision
|
| 86 |
+
convert_weights_to_lp(model, dtype=torch.bfloat16)
|
| 87 |
+
|
| 88 |
+
# add mean / std attributes for consistency with OpenCLIP models
|
| 89 |
+
model.visual.image_mean = OPENAI_DATASET_MEAN
|
| 90 |
+
model.visual.image_std = OPENAI_DATASET_STD
|
| 91 |
+
return model
|
openai_templates.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
OPENAI_IMAGENET_TEMPLATES = (
|
| 3 |
+
lambda c: f'a bad photo of a {c}.',
|
| 4 |
+
lambda c: f'a photo of many {c}.',
|
| 5 |
+
lambda c: f'a sculpture of a {c}.',
|
| 6 |
+
lambda c: f'a photo of the hard to see {c}.',
|
| 7 |
+
lambda c: f'a low resolution photo of the {c}.',
|
| 8 |
+
lambda c: f'a rendering of a {c}.',
|
| 9 |
+
lambda c: f'graffiti of a {c}.',
|
| 10 |
+
lambda c: f'a bad photo of the {c}.',
|
| 11 |
+
lambda c: f'a cropped photo of the {c}.',
|
| 12 |
+
lambda c: f'a tattoo of a {c}.',
|
| 13 |
+
lambda c: f'the embroidered {c}.',
|
| 14 |
+
lambda c: f'a photo of a hard to see {c}.',
|
| 15 |
+
lambda c: f'a bright photo of a {c}.',
|
| 16 |
+
lambda c: f'a photo of a clean {c}.',
|
| 17 |
+
lambda c: f'a photo of a dirty {c}.',
|
| 18 |
+
lambda c: f'a dark photo of the {c}.',
|
| 19 |
+
lambda c: f'a drawing of a {c}.',
|
| 20 |
+
lambda c: f'a photo of my {c}.',
|
| 21 |
+
lambda c: f'the plastic {c}.',
|
| 22 |
+
lambda c: f'a photo of the cool {c}.',
|
| 23 |
+
lambda c: f'a close-up photo of a {c}.',
|
| 24 |
+
lambda c: f'a black and white photo of the {c}.',
|
| 25 |
+
lambda c: f'a painting of the {c}.',
|
| 26 |
+
lambda c: f'a painting of a {c}.',
|
| 27 |
+
lambda c: f'a pixelated photo of the {c}.',
|
| 28 |
+
lambda c: f'a sculpture of the {c}.',
|
| 29 |
+
lambda c: f'a bright photo of the {c}.',
|
| 30 |
+
lambda c: f'a cropped photo of a {c}.',
|
| 31 |
+
lambda c: f'a plastic {c}.',
|
| 32 |
+
lambda c: f'a photo of the dirty {c}.',
|
| 33 |
+
lambda c: f'a jpeg corrupted photo of a {c}.',
|
| 34 |
+
lambda c: f'a blurry photo of the {c}.',
|
| 35 |
+
lambda c: f'a photo of the {c}.',
|
| 36 |
+
lambda c: f'a good photo of the {c}.',
|
| 37 |
+
lambda c: f'a rendering of the {c}.',
|
| 38 |
+
lambda c: f'a {c} in a video game.',
|
| 39 |
+
lambda c: f'a photo of one {c}.',
|
| 40 |
+
lambda c: f'a doodle of a {c}.',
|
| 41 |
+
lambda c: f'a close-up photo of the {c}.',
|
| 42 |
+
lambda c: f'a photo of a {c}.',
|
| 43 |
+
lambda c: f'the origami {c}.',
|
| 44 |
+
lambda c: f'the {c} in a video game.',
|
| 45 |
+
lambda c: f'a sketch of a {c}.',
|
| 46 |
+
lambda c: f'a doodle of the {c}.',
|
| 47 |
+
lambda c: f'a origami {c}.',
|
| 48 |
+
lambda c: f'a low resolution photo of a {c}.',
|
| 49 |
+
lambda c: f'the toy {c}.',
|
| 50 |
+
lambda c: f'a rendition of the {c}.',
|
| 51 |
+
lambda c: f'a photo of the clean {c}.',
|
| 52 |
+
lambda c: f'a photo of a large {c}.',
|
| 53 |
+
lambda c: f'a rendition of a {c}.',
|
| 54 |
+
lambda c: f'a photo of a nice {c}.',
|
| 55 |
+
lambda c: f'a photo of a weird {c}.',
|
| 56 |
+
lambda c: f'a blurry photo of a {c}.',
|
| 57 |
+
lambda c: f'a cartoon {c}.',
|
| 58 |
+
lambda c: f'art of a {c}.',
|
| 59 |
+
lambda c: f'a sketch of the {c}.',
|
| 60 |
+
lambda c: f'a embroidered {c}.',
|
| 61 |
+
lambda c: f'a pixelated photo of a {c}.',
|
| 62 |
+
lambda c: f'itap of the {c}.',
|
| 63 |
+
lambda c: f'a jpeg corrupted photo of the {c}.',
|
| 64 |
+
lambda c: f'a good photo of a {c}.',
|
| 65 |
+
lambda c: f'a plushie {c}.',
|
| 66 |
+
lambda c: f'a photo of the nice {c}.',
|
| 67 |
+
lambda c: f'a photo of the small {c}.',
|
| 68 |
+
lambda c: f'a photo of the weird {c}.',
|
| 69 |
+
lambda c: f'the cartoon {c}.',
|
| 70 |
+
lambda c: f'art of the {c}.',
|
| 71 |
+
lambda c: f'a drawing of the {c}.',
|
| 72 |
+
lambda c: f'a photo of the large {c}.',
|
| 73 |
+
lambda c: f'a black and white photo of a {c}.',
|
| 74 |
+
lambda c: f'the plushie {c}.',
|
| 75 |
+
lambda c: f'a dark photo of a {c}.',
|
| 76 |
+
lambda c: f'itap of a {c}.',
|
| 77 |
+
lambda c: f'graffiti of the {c}.',
|
| 78 |
+
lambda c: f'a toy {c}.',
|
| 79 |
+
lambda c: f'itap of my {c}.',
|
| 80 |
+
lambda c: f'a photo of a cool {c}.',
|
| 81 |
+
lambda c: f'a photo of a small {c}.',
|
| 82 |
+
lambda c: f'a tattoo of the {c}.',
|
| 83 |
+
)
|
| 84 |
+
|
preprocess.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Imports
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import os.path
|
| 6 |
+
import argparse
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
import cv2
|
| 9 |
+
from torch.nn import functional as F
|
| 10 |
+
from torch.utils.data import DataLoader
|
| 11 |
+
import tqdm
|
| 12 |
+
import einops
|
| 13 |
+
import plotly.express as px
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
import tqdm
|
| 16 |
+
import json
|
| 17 |
+
import albumentations
|
| 18 |
+
import glob
|
| 19 |
+
from torchvision import transforms
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _convert_to_rgb(image):
|
| 23 |
+
return image.convert('RGB')
|
| 24 |
+
|
| 25 |
+
def _resize(image):
|
| 26 |
+
image = np.array(image)
|
| 27 |
+
image = albumentations.augmentations.geometric.resize.LongestMaxSize(interpolation=Image.BICUBIC,
|
| 28 |
+
max_size=224)(image=image)
|
| 29 |
+
return Image.fromarray(image['image'])
|
| 30 |
+
|
| 31 |
+
preprocess = transforms.Compose([
|
| 32 |
+
_resize,
|
| 33 |
+
transforms.CenterCrop(size=(224, 224)),
|
| 34 |
+
_convert_to_rgb,
|
| 35 |
+
])
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
both_preprocess = transforms.Compose([
|
| 39 |
+
transforms.ToTensor(),
|
| 40 |
+
transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
|
| 41 |
+
std=(0.26862954, 0.26130258, 0.27577711)),
|
| 42 |
+
])
|
preprocessor_config.json
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"crop_size": 336,
|
| 3 |
+
"do_center_crop": true,
|
| 4 |
+
"do_normalize": true,
|
| 5 |
+
"do_resize": true,
|
| 6 |
+
"feature_extractor_type": "CLIPFeatureExtractor",
|
| 7 |
+
"image_mean": [
|
| 8 |
+
0.48145466,
|
| 9 |
+
0.4578275,
|
| 10 |
+
0.40821073
|
| 11 |
+
],
|
| 12 |
+
"image_std": [
|
| 13 |
+
0.26862954,
|
| 14 |
+
0.26130258,
|
| 15 |
+
0.27577711
|
| 16 |
+
],
|
| 17 |
+
"resample": 3,
|
| 18 |
+
"size": 336
|
| 19 |
+
}
|
preprocessor_config_bak.json
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"image_size": 224,
|
| 3 |
+
"image_mean": [
|
| 4 |
+
0.48145466,
|
| 5 |
+
0.4578275,
|
| 6 |
+
0.40821073
|
| 7 |
+
],
|
| 8 |
+
"image_std": [
|
| 9 |
+
0.26862954,
|
| 10 |
+
0.26130258,
|
| 11 |
+
0.27577711
|
| 12 |
+
],
|
| 13 |
+
"interpolation": "bicubic",
|
| 14 |
+
"resize_mode": "center_crop"
|
| 15 |
+
}
|
| 16 |
+
|
pretrained.py
ADDED
|
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import hashlib
|
| 2 |
+
import os
|
| 3 |
+
import urllib
|
| 4 |
+
import warnings
|
| 5 |
+
from functools import partial
|
| 6 |
+
from typing import Dict, Union
|
| 7 |
+
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
from huggingface_hub import hf_hub_download
|
| 13 |
+
hf_hub_download = partial(hf_hub_download, library_name="open_clip", library_version='2.20.0')
|
| 14 |
+
_has_hf_hub = True
|
| 15 |
+
except ImportError:
|
| 16 |
+
hf_hub_download = None
|
| 17 |
+
_has_hf_hub = False
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _pcfg(url='', hf_hub='', mean=None, std=None):
|
| 21 |
+
return dict(
|
| 22 |
+
url=url,
|
| 23 |
+
hf_hub=hf_hub,
|
| 24 |
+
mean=mean,
|
| 25 |
+
std=std,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
_RN50 = dict(
|
| 30 |
+
openai=_pcfg(
|
| 31 |
+
"https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"),
|
| 32 |
+
yfcc15m=_pcfg(
|
| 33 |
+
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"),
|
| 34 |
+
cc12m=_pcfg(
|
| 35 |
+
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"),
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
_RN50_quickgelu = dict(
|
| 39 |
+
openai=_pcfg(
|
| 40 |
+
"https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"),
|
| 41 |
+
yfcc15m=_pcfg(
|
| 42 |
+
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"),
|
| 43 |
+
cc12m=_pcfg(
|
| 44 |
+
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"),
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
_RN101 = dict(
|
| 48 |
+
openai=_pcfg(
|
| 49 |
+
"https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"),
|
| 50 |
+
yfcc15m=_pcfg(
|
| 51 |
+
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"),
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
_RN101_quickgelu = dict(
|
| 55 |
+
openai=_pcfg(
|
| 56 |
+
"https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"),
|
| 57 |
+
yfcc15m=_pcfg(
|
| 58 |
+
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"),
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
_RN50x4 = dict(
|
| 62 |
+
openai=_pcfg(
|
| 63 |
+
"https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt"),
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
_RN50x16 = dict(
|
| 67 |
+
openai=_pcfg(
|
| 68 |
+
"https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt"),
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
_RN50x64 = dict(
|
| 72 |
+
openai=_pcfg(
|
| 73 |
+
"https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt"),
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
_VITB32 = dict(
|
| 77 |
+
openai=_pcfg(
|
| 78 |
+
"https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
|
| 79 |
+
laion400m_e31=_pcfg(
|
| 80 |
+
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
|
| 81 |
+
laion400m_e32=_pcfg(
|
| 82 |
+
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
|
| 83 |
+
laion2b_e16=_pcfg(
|
| 84 |
+
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"),
|
| 85 |
+
laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/'),
|
| 86 |
+
# DataComp-M models
|
| 87 |
+
datacomp_m_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.M-s128M-b4K/'),
|
| 88 |
+
commonpool_m_clip_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.clip-s128M-b4K/'),
|
| 89 |
+
commonpool_m_laion_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.laion-s128M-b4K/'),
|
| 90 |
+
commonpool_m_image_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.image-s128M-b4K/'),
|
| 91 |
+
commonpool_m_text_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.text-s128M-b4K/'),
|
| 92 |
+
commonpool_m_basic_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.basic-s128M-b4K/'),
|
| 93 |
+
commonpool_m_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M-s128M-b4K/'),
|
| 94 |
+
# DataComp-S models
|
| 95 |
+
datacomp_s_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.S-s13M-b4K/'),
|
| 96 |
+
commonpool_s_clip_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.clip-s13M-b4K/'),
|
| 97 |
+
commonpool_s_laion_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.laion-s13M-b4K/'),
|
| 98 |
+
commonpool_s_image_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.image-s13M-b4K/'),
|
| 99 |
+
commonpool_s_text_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.text-s13M-b4K/'),
|
| 100 |
+
commonpool_s_basic_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.basic-s13M-b4K/'),
|
| 101 |
+
commonpool_s_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S-s13M-b4K/'),
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
_VITB32_quickgelu = dict(
|
| 105 |
+
openai=_pcfg(
|
| 106 |
+
"https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
|
| 107 |
+
laion400m_e31=_pcfg(
|
| 108 |
+
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
|
| 109 |
+
laion400m_e32=_pcfg(
|
| 110 |
+
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
_VITB16 = dict(
|
| 114 |
+
openai=_pcfg(
|
| 115 |
+
"https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"),
|
| 116 |
+
laion400m_e31=_pcfg(
|
| 117 |
+
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"),
|
| 118 |
+
laion400m_e32=_pcfg(
|
| 119 |
+
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"),
|
| 120 |
+
laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'),
|
| 121 |
+
# DataComp-L models
|
| 122 |
+
datacomp_l_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-DataComp.L-s1B-b8K/'),
|
| 123 |
+
commonpool_l_clip_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.clip-s1B-b8K/'),
|
| 124 |
+
commonpool_l_laion_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.laion-s1B-b8K/'),
|
| 125 |
+
commonpool_l_image_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.image-s1B-b8K/'),
|
| 126 |
+
commonpool_l_text_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.text-s1B-b8K/'),
|
| 127 |
+
commonpool_l_basic_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.basic-s1B-b8K/'),
|
| 128 |
+
commonpool_l_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L-s1B-b8K/'),
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
_VITB16_PLUS_240 = dict(
|
| 132 |
+
laion400m_e31=_pcfg(
|
| 133 |
+
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"),
|
| 134 |
+
laion400m_e32=_pcfg(
|
| 135 |
+
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"),
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
_VITL14 = dict(
|
| 139 |
+
openai=_pcfg(
|
| 140 |
+
"https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"),
|
| 141 |
+
laion400m_e31=_pcfg(
|
| 142 |
+
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"),
|
| 143 |
+
laion400m_e32=_pcfg(
|
| 144 |
+
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"),
|
| 145 |
+
laion2b_s32b_b82k=_pcfg(
|
| 146 |
+
hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/',
|
| 147 |
+
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
| 148 |
+
# DataComp-XL models
|
| 149 |
+
datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K/'),
|
| 150 |
+
commonpool_xl_clip_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL.clip-s13B-b90K/'),
|
| 151 |
+
commonpool_xl_laion_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL.laion-s13B-b90K/'),
|
| 152 |
+
commonpool_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL-s13B-b90K/'),
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
_VITL14_336 = dict(
|
| 156 |
+
openai=_pcfg(
|
| 157 |
+
"https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"),
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
_VITH14 = dict(
|
| 161 |
+
laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'),
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
_VITg14 = dict(
|
| 165 |
+
laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'),
|
| 166 |
+
laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'),
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
_VITbigG14 = dict(
|
| 170 |
+
laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'),
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
_robertaViTB32 = dict(
|
| 174 |
+
laion2b_s12b_b32k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-roberta-base-laion2B-s12B-b32k/'),
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
_xlmRobertaBaseViTB32 = dict(
|
| 178 |
+
laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k/'),
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
_xlmRobertaLargeFrozenViTH14 = dict(
|
| 182 |
+
frozen_laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/'),
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
_convnext_base = dict(
|
| 186 |
+
laion400m_s13b_b51k=_pcfg(hf_hub='laion/CLIP-convnext_base-laion400M-s13B-b51K/'),
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
_convnext_base_w = dict(
|
| 190 |
+
laion2b_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K/'),
|
| 191 |
+
laion2b_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg/'),
|
| 192 |
+
laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K/'),
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
_convnext_base_w_320 = dict(
|
| 196 |
+
laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K/'),
|
| 197 |
+
laion_aesthetic_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg/'),
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
_convnext_large_d = dict(
|
| 201 |
+
laion2b_s26b_b102k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg/'),
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
_convnext_large_d_320 = dict(
|
| 205 |
+
laion2b_s29b_b131k_ft=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft/'),
|
| 206 |
+
laion2b_s29b_b131k_ft_soup=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup/'),
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
_convnext_xxlarge = dict(
|
| 210 |
+
laion2b_s34b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg/'),
|
| 211 |
+
laion2b_s34b_b82k_augreg_rewind=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind/'),
|
| 212 |
+
laion2b_s34b_b82k_augreg_soup=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup/'),
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
_coca_VITB32 = dict(
|
| 216 |
+
laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-B-32-laion2B-s13B-b90k/'),
|
| 217 |
+
mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-B-32-laion2B-s13B-b90k/')
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
_coca_VITL14 = dict(
|
| 221 |
+
laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-L-14-laion2B-s13B-b90k/'),
|
| 222 |
+
mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-L-14-laion2B-s13B-b90k/')
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
_PRETRAINED = {
|
| 227 |
+
"RN50": _RN50,
|
| 228 |
+
"RN50-quickgelu": _RN50_quickgelu,
|
| 229 |
+
"RN101": _RN101,
|
| 230 |
+
"RN101-quickgelu": _RN101_quickgelu,
|
| 231 |
+
"RN50x4": _RN50x4,
|
| 232 |
+
"RN50x16": _RN50x16,
|
| 233 |
+
"RN50x64": _RN50x64,
|
| 234 |
+
"ViT-B-32": _VITB32,
|
| 235 |
+
"ViT-B-32-quickgelu": _VITB32_quickgelu,
|
| 236 |
+
"ViT-B-16": _VITB16,
|
| 237 |
+
"ViT-B-16-plus-240": _VITB16_PLUS_240,
|
| 238 |
+
"ViT-L-14": _VITL14,
|
| 239 |
+
"ViT-L-14-336": _VITL14_336,
|
| 240 |
+
"ViT-H-14": _VITH14,
|
| 241 |
+
"ViT-g-14": _VITg14,
|
| 242 |
+
"ViT-bigG-14": _VITbigG14,
|
| 243 |
+
"roberta-ViT-B-32": _robertaViTB32,
|
| 244 |
+
"xlm-roberta-base-ViT-B-32": _xlmRobertaBaseViTB32,
|
| 245 |
+
"xlm-roberta-large-ViT-H-14": _xlmRobertaLargeFrozenViTH14,
|
| 246 |
+
"convnext_base": _convnext_base,
|
| 247 |
+
"convnext_base_w": _convnext_base_w,
|
| 248 |
+
"convnext_base_w_320": _convnext_base_w_320,
|
| 249 |
+
"convnext_large_d": _convnext_large_d,
|
| 250 |
+
"convnext_large_d_320": _convnext_large_d_320,
|
| 251 |
+
"convnext_xxlarge": _convnext_xxlarge,
|
| 252 |
+
"coca_ViT-B-32": _coca_VITB32,
|
| 253 |
+
"coca_ViT-L-14": _coca_VITL14,
|
| 254 |
+
"EVA01-g-14": dict(
|
| 255 |
+
# from QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt
|
| 256 |
+
laion400m_s11b_b41k=_pcfg(hf_hub='timm/eva_giant_patch14_clip_224.laion400m_s11b_b41k/'),
|
| 257 |
+
),
|
| 258 |
+
"EVA01-g-14-plus": dict(
|
| 259 |
+
# from QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt
|
| 260 |
+
merged2b_s11b_b114k=_pcfg(hf_hub='timm/eva_giant_patch14_plus_clip_224.merged2b_s11b_b114k/'),
|
| 261 |
+
),
|
| 262 |
+
"EVA02-B-16": dict(
|
| 263 |
+
# from QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt
|
| 264 |
+
merged2b_s8b_b131k=_pcfg(hf_hub='timm/eva02_base_patch16_clip_224.merged2b_s8b_b131k/'),
|
| 265 |
+
),
|
| 266 |
+
"EVA02-L-14": dict(
|
| 267 |
+
# from QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt
|
| 268 |
+
merged2b_s4b_b131k=_pcfg(hf_hub='timm/eva02_large_patch14_clip_224.merged2b_s4b_b131k/'),
|
| 269 |
+
),
|
| 270 |
+
"EVA02-L-14-336": dict(
|
| 271 |
+
# from QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt
|
| 272 |
+
merged2b_s6b_b61k=_pcfg(hf_hub='timm/eva02_large_patch14_clip_336.merged2b_s6b_b61k/'),
|
| 273 |
+
),
|
| 274 |
+
"EVA02-E-14": dict(
|
| 275 |
+
# from QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt
|
| 276 |
+
laion2b_s4b_b115k=_pcfg(hf_hub='timm/eva02_enormous_patch14_clip_224.laion2b_s4b_b115k/'),
|
| 277 |
+
),
|
| 278 |
+
"EVA02-E-14-plus": dict(
|
| 279 |
+
# from QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt
|
| 280 |
+
laion2b_s9b_b144k=_pcfg(hf_hub='timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k/'),
|
| 281 |
+
)
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def _clean_tag(tag: str):
|
| 286 |
+
# normalize pretrained tags
|
| 287 |
+
return tag.lower().replace('-', '_')
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def list_pretrained(as_str: bool = False):
|
| 291 |
+
""" returns list of pretrained models
|
| 292 |
+
Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True
|
| 293 |
+
"""
|
| 294 |
+
return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()]
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def list_pretrained_models_by_tag(tag: str):
|
| 298 |
+
""" return all models having the specified pretrain tag """
|
| 299 |
+
models = []
|
| 300 |
+
tag = _clean_tag(tag)
|
| 301 |
+
for k in _PRETRAINED.keys():
|
| 302 |
+
if tag in _PRETRAINED[k]:
|
| 303 |
+
models.append(k)
|
| 304 |
+
return models
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def list_pretrained_tags_by_model(model: str):
|
| 308 |
+
""" return all pretrain tags for the specified model architecture """
|
| 309 |
+
tags = []
|
| 310 |
+
if model in _PRETRAINED:
|
| 311 |
+
tags.extend(_PRETRAINED[model].keys())
|
| 312 |
+
return tags
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def is_pretrained_cfg(model: str, tag: str):
|
| 316 |
+
if model not in _PRETRAINED:
|
| 317 |
+
return False
|
| 318 |
+
return _clean_tag(tag) in _PRETRAINED[model]
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def get_pretrained_cfg(model: str, tag: str):
|
| 322 |
+
if model not in _PRETRAINED:
|
| 323 |
+
return {}
|
| 324 |
+
model_pretrained = _PRETRAINED[model]
|
| 325 |
+
return model_pretrained.get(_clean_tag(tag), {})
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
def get_pretrained_url(model: str, tag: str):
|
| 329 |
+
cfg = get_pretrained_cfg(model, _clean_tag(tag))
|
| 330 |
+
return cfg.get('url', '')
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def download_pretrained_from_url(
|
| 334 |
+
url: str,
|
| 335 |
+
cache_dir: Union[str, None] = None,
|
| 336 |
+
):
|
| 337 |
+
if not cache_dir:
|
| 338 |
+
cache_dir = os.path.expanduser("~/.cache/clip")
|
| 339 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 340 |
+
filename = os.path.basename(url)
|
| 341 |
+
|
| 342 |
+
if 'openaipublic' in url:
|
| 343 |
+
expected_sha256 = url.split("/")[-2]
|
| 344 |
+
elif 'mlfoundations' in url:
|
| 345 |
+
expected_sha256 = os.path.splitext(filename)[0].split("-")[-1]
|
| 346 |
+
else:
|
| 347 |
+
expected_sha256 = ''
|
| 348 |
+
|
| 349 |
+
download_target = os.path.join(cache_dir, filename)
|
| 350 |
+
|
| 351 |
+
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
| 352 |
+
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
| 353 |
+
|
| 354 |
+
if os.path.isfile(download_target):
|
| 355 |
+
if expected_sha256:
|
| 356 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
|
| 357 |
+
return download_target
|
| 358 |
+
else:
|
| 359 |
+
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
| 360 |
+
else:
|
| 361 |
+
return download_target
|
| 362 |
+
|
| 363 |
+
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
| 364 |
+
with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
|
| 365 |
+
while True:
|
| 366 |
+
buffer = source.read(8192)
|
| 367 |
+
if not buffer:
|
| 368 |
+
break
|
| 369 |
+
|
| 370 |
+
output.write(buffer)
|
| 371 |
+
loop.update(len(buffer))
|
| 372 |
+
|
| 373 |
+
if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
|
| 374 |
+
raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
|
| 375 |
+
|
| 376 |
+
return download_target
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
def has_hf_hub(necessary=False):
|
| 380 |
+
if not _has_hf_hub and necessary:
|
| 381 |
+
# if no HF Hub module installed, and it is necessary to continue, raise error
|
| 382 |
+
raise RuntimeError(
|
| 383 |
+
'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
|
| 384 |
+
return _has_hf_hub
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
def download_pretrained_from_hf(
|
| 388 |
+
model_id: str,
|
| 389 |
+
filename: str = 'open_clip_pytorch_model.bin',
|
| 390 |
+
revision=None,
|
| 391 |
+
cache_dir: Union[str, None] = None,
|
| 392 |
+
):
|
| 393 |
+
has_hf_hub(True)
|
| 394 |
+
cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir)
|
| 395 |
+
return cached_file
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
def download_pretrained(
|
| 399 |
+
cfg: Dict,
|
| 400 |
+
force_hf_hub: bool = False,
|
| 401 |
+
cache_dir: Union[str, None] = None,
|
| 402 |
+
):
|
| 403 |
+
target = ''
|
| 404 |
+
if not cfg:
|
| 405 |
+
return target
|
| 406 |
+
|
| 407 |
+
download_url = cfg.get('url', '')
|
| 408 |
+
download_hf_hub = cfg.get('hf_hub', '')
|
| 409 |
+
if download_hf_hub and force_hf_hub:
|
| 410 |
+
# use HF hub even if url exists
|
| 411 |
+
download_url = ''
|
| 412 |
+
|
| 413 |
+
if download_url:
|
| 414 |
+
target = download_pretrained_from_url(download_url, cache_dir=cache_dir)
|
| 415 |
+
elif download_hf_hub:
|
| 416 |
+
has_hf_hub(True)
|
| 417 |
+
# we assume the hf_hub entries in pretrained config combine model_id + filename in
|
| 418 |
+
# 'org/model_name/filename.pt' form. To specify just the model id w/o filename and
|
| 419 |
+
# use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'.
|
| 420 |
+
model_id, filename = os.path.split(download_hf_hub)
|
| 421 |
+
if filename:
|
| 422 |
+
target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir)
|
| 423 |
+
else:
|
| 424 |
+
target = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
|
| 425 |
+
|
| 426 |
+
return target
|
pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c7a812f61be88b4148e6c910ea245178ff3663263d54680cdb99dd6bcaed9b32
|
| 3 |
+
size 1711950230
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=1.9.0
|
| 2 |
+
transformers>=4.21.0
|
| 3 |
+
torchvision>=0.10.0
|
| 4 |
+
Pillow
|
| 5 |
+
numpy
|
shared.py
ADDED
|
@@ -0,0 +1,616 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import subprocess
|
| 3 |
+
import re
|
| 4 |
+
import random
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
import json
|
| 7 |
+
def get_gpu_memory_usage():
|
| 8 |
+
"""Returns a list of GPU memory usage in MB."""
|
| 9 |
+
try:
|
| 10 |
+
# Run nvidia-smi command and capture the output
|
| 11 |
+
result = subprocess.run(['nvidia-smi', '--query-gpu=memory.used', '--format=csv,nounits,noheader'],
|
| 12 |
+
stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
| 13 |
+
|
| 14 |
+
# Check if the command was successful
|
| 15 |
+
if result.returncode != 0:
|
| 16 |
+
raise RuntimeError(f"nvidia-smi command failed with error: {result.stderr}")
|
| 17 |
+
|
| 18 |
+
# Parse the output to get memory usage
|
| 19 |
+
memory_usages = [int(x) for x in result.stdout.strip().split('\n')]
|
| 20 |
+
return memory_usages
|
| 21 |
+
except Exception as e:
|
| 22 |
+
print(f"Error querying GPU memory usage: {e}")
|
| 23 |
+
return []
|
| 24 |
+
|
| 25 |
+
def set_cuda_visible_device():
|
| 26 |
+
"""Sets the CUDA_VISIBLE_DEVICES environment variable to the GPU with the smallest memory usage."""
|
| 27 |
+
memory_usages = get_gpu_memory_usage()
|
| 28 |
+
|
| 29 |
+
if not memory_usages:
|
| 30 |
+
print("No GPU memory usage data available.")
|
| 31 |
+
return
|
| 32 |
+
|
| 33 |
+
# Find the index of the GPU with the smallest memory usage
|
| 34 |
+
min_memory_index = memory_usages.index(min(memory_usages))
|
| 35 |
+
|
| 36 |
+
# Set the CUDA_VISIBLE_DEVICES environment variable
|
| 37 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = str(min_memory_index)
|
| 38 |
+
print(f"Set CUDA_VISIBLE_DEVICES to GPU {min_memory_index} with {memory_usages[min_memory_index]} MB used.")
|
| 39 |
+
|
| 40 |
+
return str(min_memory_index)
|
| 41 |
+
|
| 42 |
+
os.environ["ASN_ROOT_DIR"] = "/home/nickj/asn/second_order_lens"
|
| 43 |
+
os.chdir(os.environ["ASN_ROOT_DIR"])
|
| 44 |
+
|
| 45 |
+
import numpy as np
|
| 46 |
+
import torch
|
| 47 |
+
from PIL import Image
|
| 48 |
+
import os.path
|
| 49 |
+
import argparse
|
| 50 |
+
from pathlib import Path
|
| 51 |
+
|
| 52 |
+
from tqdm import tqdm
|
| 53 |
+
from utils.factory import create_model_and_transforms, get_tokenizer
|
| 54 |
+
from PIL import Image, ImageDraw
|
| 55 |
+
|
| 56 |
+
def get_model(model_name = "ViT-B/16", pretrained = "openai", device = "cuda:0"):
|
| 57 |
+
torch.multiprocessing.set_sharing_strategy("file_system")
|
| 58 |
+
model, _, preprocess = create_model_and_transforms(
|
| 59 |
+
model_name, pretrained=pretrained, force_quick_gelu=True,
|
| 60 |
+
)
|
| 61 |
+
model.to(device)
|
| 62 |
+
model.eval()
|
| 63 |
+
context_length = model.context_length
|
| 64 |
+
vocab_size = model.vocab_size
|
| 65 |
+
|
| 66 |
+
return {
|
| 67 |
+
"model": model,
|
| 68 |
+
"model_name": model_name,
|
| 69 |
+
"pretrained": pretrained,
|
| 70 |
+
"preprocess": preprocess,
|
| 71 |
+
"context_length": context_length,
|
| 72 |
+
"vocab_size": vocab_size
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
img_path = "/datasets/ilsvrc_2024-01-04_1913/val/n04398044/ILSVRC2012_val_00042447.JPEG"
|
| 76 |
+
# img_path = "./sample.jpeg"
|
| 77 |
+
def load_images(preprocess, image_folder = "/datasets/ilsvrc/current/val", count = 100, images_only = True):
|
| 78 |
+
file_list = []
|
| 79 |
+
|
| 80 |
+
for root, dirs, files in os.walk(image_folder):
|
| 81 |
+
for file in files:
|
| 82 |
+
file_list.append(os.path.join(root, file))
|
| 83 |
+
|
| 84 |
+
if count > len(file_list):
|
| 85 |
+
sampled_files = file_list
|
| 86 |
+
else:
|
| 87 |
+
sampled_files = random.sample(file_list, count)
|
| 88 |
+
|
| 89 |
+
image_files = []
|
| 90 |
+
|
| 91 |
+
for filename in sampled_files:
|
| 92 |
+
image_files.append(preprocess(Image.open(filename)))
|
| 93 |
+
if images_only:
|
| 94 |
+
return image_files
|
| 95 |
+
else:
|
| 96 |
+
return image_files, sampled_files
|
| 97 |
+
|
| 98 |
+
def calc_neuron_potentials(model, attn_layers = (1, 2), include_layernorm = True):
|
| 99 |
+
# Calculates the attention-shifting potential scores for every neuron to the attention heads defined by the given layers (relative to the MLP layer)
|
| 100 |
+
|
| 101 |
+
embed_dim = model.visual.transformer.resblocks[0].attn.out_proj.in_features
|
| 102 |
+
num_heads = model.visual.transformer.resblocks[0].attn.num_heads
|
| 103 |
+
head_dim = embed_dim // num_heads
|
| 104 |
+
layers = len(model.visual.transformer.resblocks)
|
| 105 |
+
|
| 106 |
+
results = dict()
|
| 107 |
+
|
| 108 |
+
for neuron_layer in tqdm(range(layers), desc = "Calculating attention shifting potentials"):
|
| 109 |
+
neuron_projection = model.visual.transformer.resblocks[neuron_layer].state_dict()["mlp.c_proj.weight"]
|
| 110 |
+
for l_attn in range(min(layers, neuron_layer + attn_layers[0]), min(layers, neuron_layer + attn_layers[1])):
|
| 111 |
+
ln_vector = model.visual.transformer.resblocks[l_attn].ln_1.state_dict()["weight"]
|
| 112 |
+
attn_matrix = model.visual.transformer.resblocks[l_attn].state_dict()["attn.in_proj_weight"]
|
| 113 |
+
W_Q, W_K, W_V = (attn_matrix[:embed_dim].reshape(num_heads, head_dim, -1),
|
| 114 |
+
attn_matrix[embed_dim:2*embed_dim].reshape(num_heads, head_dim, -1),
|
| 115 |
+
attn_matrix[2*embed_dim:].reshape(num_heads, head_dim, -1))
|
| 116 |
+
|
| 117 |
+
for head_idx in range(num_heads):
|
| 118 |
+
W_Q_h, W_K_h = W_Q[head_idx], W_K[head_idx]
|
| 119 |
+
effects = []
|
| 120 |
+
for i in range(neuron_projection.shape[1]):
|
| 121 |
+
if include_layernorm:
|
| 122 |
+
neuron_attn_effect = torch.norm(W_Q_h.T @ W_K_h @ (neuron_projection[:, i] * ln_vector))
|
| 123 |
+
else:
|
| 124 |
+
neuron_attn_effect = torch.norm(W_Q_h.T @ W_K_h @ neuron_projection[:, i])
|
| 125 |
+
effects.append(neuron_attn_effect)
|
| 126 |
+
|
| 127 |
+
results[(neuron_layer, l_attn, head_idx)] = torch.tensor(effects)
|
| 128 |
+
return results
|
| 129 |
+
|
| 130 |
+
def calc_top_asns(shift_potentials, top_k = 10, per = "layer", layers_away = 1):
|
| 131 |
+
num_layers = max([key[1] for key in shift_potentials.keys()]) + 1 # the last layer has no ASNs by definition
|
| 132 |
+
num_heads = max([key[2] for key in shift_potentials.keys()])
|
| 133 |
+
|
| 134 |
+
top_asns = []
|
| 135 |
+
for layer in range(num_layers - layers_away):
|
| 136 |
+
if per == "layer":
|
| 137 |
+
potentials = []
|
| 138 |
+
for head_idx in range(num_heads):
|
| 139 |
+
potentials.append(shift_potentials[(layer, layer + layers_away, head_idx)])
|
| 140 |
+
potentials = torch.max(torch.stack(potentials, dim = 0), dim = 0).values
|
| 141 |
+
_, sorted_indices = torch.sort(potentials, descending = True)
|
| 142 |
+
top_asns.append(sorted_indices[:top_k].tolist())
|
| 143 |
+
elif per == "head":
|
| 144 |
+
top_layer_asns = []
|
| 145 |
+
for head_idx in range(num_heads):
|
| 146 |
+
_, sorted_indices = torch.sort(shift_potentials[(layer, layer + layers_away, head_idx)], descending = True)
|
| 147 |
+
top_layer_asns.append(sorted_indices[:top_k].tolist())
|
| 148 |
+
top_asns.append(top_layer_asns)
|
| 149 |
+
else:
|
| 150 |
+
raise ValueError(f"Invalid per value: {per}")
|
| 151 |
+
return top_asns
|
| 152 |
+
|
| 153 |
+
def aggregate_attn_map(attn_map, layer, head):
|
| 154 |
+
num_tokens = attn_map.shape[-1]
|
| 155 |
+
assert (num_tokens - 1) ** 0.5 % 1 == 0, "num_tokens - 1 is not a perfect square"
|
| 156 |
+
|
| 157 |
+
num_patches = int((num_tokens - 1) ** 0.5)
|
| 158 |
+
aggregate_scores = torch.sum(attn_map[:, layer, head, 1:, 1:], dim = 1).reshape((1, num_patches, num_patches))
|
| 159 |
+
return aggregate_scores
|
| 160 |
+
|
| 161 |
+
def attn_map_cls_token(attn_map, layer, head):
|
| 162 |
+
# Gets the attention map for the CLS token
|
| 163 |
+
num_tokens = attn_map.shape[-1]
|
| 164 |
+
assert (num_tokens - 1) ** 0.5 % 1 == 0, "num_tokens - 1 is not a perfect square"
|
| 165 |
+
|
| 166 |
+
num_patches = int((num_tokens - 1) ** 0.5)
|
| 167 |
+
attn_map_reshaped = attn_map[:, layer, head, 0, 1:].reshape((1, num_patches, num_patches))
|
| 168 |
+
return attn_map_reshaped
|
| 169 |
+
|
| 170 |
+
def visualize_attn_shift(attn_map1, attn_map2, image, display=True, out=None, min_diff=None, max_diff=None):
|
| 171 |
+
import matplotlib.pyplot as plt
|
| 172 |
+
import numpy as np
|
| 173 |
+
|
| 174 |
+
# Subtract attn_map1 from attn_map2
|
| 175 |
+
diff_map = attn_map2 - attn_map1
|
| 176 |
+
|
| 177 |
+
# Convert the image to RGBA
|
| 178 |
+
image = image.convert("RGBA")
|
| 179 |
+
overlay = Image.new("RGBA", image.size, (0, 0, 0, 0))
|
| 180 |
+
draw = ImageDraw.Draw(overlay)
|
| 181 |
+
|
| 182 |
+
# Calculate the size of each attention block
|
| 183 |
+
block_size_x = image.size[0] / diff_map.shape[0]
|
| 184 |
+
block_size_y = image.size[1] / diff_map.shape[1]
|
| 185 |
+
|
| 186 |
+
# Create a colormap
|
| 187 |
+
cmap = plt.get_cmap('coolwarm_r') # 'cool' colormap for lighter to darker
|
| 188 |
+
|
| 189 |
+
# Get the min and max values for scaling the colormap
|
| 190 |
+
if max_diff is None:
|
| 191 |
+
max_diff = diff_map.max()
|
| 192 |
+
if min_diff is None:
|
| 193 |
+
min_diff = diff_map.min()
|
| 194 |
+
|
| 195 |
+
for i in range(diff_map.shape[0]):
|
| 196 |
+
for j in range(diff_map.shape[1]):
|
| 197 |
+
# Get the color from the colormap
|
| 198 |
+
intensity = diff_map[i, j]
|
| 199 |
+
normalized_intensity = (intensity - min_diff) / (max_diff - min_diff) # Scale to [0, 1]
|
| 200 |
+
rgba_color = cmap(1 - normalized_intensity) # Invert the normalized intensity
|
| 201 |
+
color = tuple(int(c * 255) for c in rgba_color[:3]) + (int(rgba_color[3] * 128),)
|
| 202 |
+
|
| 203 |
+
# Draw the rectangle on the overlay with transparency
|
| 204 |
+
draw.rectangle(
|
| 205 |
+
[j * block_size_x, i * block_size_y, (j + 1) * block_size_x, (i + 1) * block_size_y],
|
| 206 |
+
fill=color # Add transparency to the color
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
# Composite the overlay with the original image
|
| 210 |
+
combined = Image.alpha_composite(image, overlay)
|
| 211 |
+
|
| 212 |
+
if display:
|
| 213 |
+
# Display the result
|
| 214 |
+
combined.show()
|
| 215 |
+
|
| 216 |
+
# Show the color scale
|
| 217 |
+
plt.figure(figsize=(6, 1))
|
| 218 |
+
plt.imshow([np.linspace(min_diff, max_diff, 256)], cmap='coolwarm_r', aspect='auto')
|
| 219 |
+
plt.gca().set_visible(False)
|
| 220 |
+
plt.colorbar(orientation="horizontal")
|
| 221 |
+
plt.show()
|
| 222 |
+
|
| 223 |
+
if out is not None:
|
| 224 |
+
combined.save(out)
|
| 225 |
+
|
| 226 |
+
return combined
|
| 227 |
+
|
| 228 |
+
def visualize_attn_shift_binary(attn_map1, attn_map2, image, display=True, out=None):
|
| 229 |
+
# Creates a visualization of the attention shift where green = positive, red = negative values.
|
| 230 |
+
# This is useful when there are outliers in the difference map causing the middle values around 0 to be messed into one color
|
| 231 |
+
# Subtract attn_map1 from attn_map2
|
| 232 |
+
diff_map = attn_map2 - attn_map1
|
| 233 |
+
|
| 234 |
+
# Normalize the difference map to range [0, 1] for visualization
|
| 235 |
+
diff_map_normalized = (diff_map - diff_map.min()) / (diff_map.max() - diff_map.min())
|
| 236 |
+
# Convert the image to RGBA
|
| 237 |
+
image = image.convert("RGBA")
|
| 238 |
+
overlay = Image.new("RGBA", image.size, (0, 0, 0, 0))
|
| 239 |
+
draw = ImageDraw.Draw(overlay)
|
| 240 |
+
|
| 241 |
+
# Calculate the size of each attention block
|
| 242 |
+
block_size_x = image.size[0] / diff_map.shape[0]
|
| 243 |
+
block_size_y = image.size[1] / diff_map.shape[1]
|
| 244 |
+
|
| 245 |
+
for i in range(diff_map.shape[0]):
|
| 246 |
+
for j in range(diff_map.shape[1]):
|
| 247 |
+
# Calculate the color intensity based on the difference
|
| 248 |
+
intensity = diff_map_normalized[i, j]
|
| 249 |
+
alpha = int(255 * 0.5) # Tone down the alpha to 50%
|
| 250 |
+
if diff_map[i, j] > 0:
|
| 251 |
+
color = (0, int(255 * intensity), 0, alpha) # Green for positive
|
| 252 |
+
else:
|
| 253 |
+
color = (int(255 * (1 - intensity)), 0, 0, alpha) # Red for negative
|
| 254 |
+
|
| 255 |
+
# Draw the rectangle on the overlay
|
| 256 |
+
draw.rectangle(
|
| 257 |
+
[j * block_size_x, i * block_size_y, (j + 1) * block_size_x, (i + 1) * block_size_y],
|
| 258 |
+
fill=color
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
# Composite the overlay with the original image
|
| 262 |
+
combined = Image.alpha_composite(image, overlay)
|
| 263 |
+
|
| 264 |
+
if display:
|
| 265 |
+
# Display the result
|
| 266 |
+
combined.show()
|
| 267 |
+
|
| 268 |
+
if out is not None:
|
| 269 |
+
combined.save(out)
|
| 270 |
+
|
| 271 |
+
return combined
|
| 272 |
+
|
| 273 |
+
def is_outlier(mean, std, value):
|
| 274 |
+
return value < mean - 2 * std or value > mean + 2 * std
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def get_neuron_activations(images, prs_group, model, device = "cuda:0"):
|
| 278 |
+
# Returns neuron activations in shape (num_images, num_layers, num_patches, num_neurons)
|
| 279 |
+
random_neuron_acts = []
|
| 280 |
+
for image in tqdm(images, desc="Processing images"):
|
| 281 |
+
prs_group.reinit()
|
| 282 |
+
image_input = image.unsqueeze(0).to(device)
|
| 283 |
+
representation = model.encode_image(
|
| 284 |
+
image_input, attn_method="head", normalize=False
|
| 285 |
+
)
|
| 286 |
+
prs_group.finalize()
|
| 287 |
+
gelu_outs = prs_group.post_gelu_outputs()
|
| 288 |
+
random_neuron_acts.append(gelu_outs)
|
| 289 |
+
random_neuron_acts = torch.stack(random_neuron_acts, dim = 0)
|
| 290 |
+
return random_neuron_acts
|
| 291 |
+
|
| 292 |
+
def normalize_array(arr):
|
| 293 |
+
min_val = np.min(arr)
|
| 294 |
+
max_val = np.max(arr)
|
| 295 |
+
# Avoid division by zero if all values are the same
|
| 296 |
+
if max_val - min_val == 0:
|
| 297 |
+
return np.zeros_like(arr)
|
| 298 |
+
normalized_arr = (arr - min_val) / (max_val - min_val)
|
| 299 |
+
return normalized_arr
|
| 300 |
+
|
| 301 |
+
def np_l2(arr1, arr2):
|
| 302 |
+
return np.linalg.norm(arr1 - arr2)
|
| 303 |
+
|
| 304 |
+
def best_class(classifier, representation):
|
| 305 |
+
cs = torch.cosine_similarity(classifier, representation.permute(1, 0), dim = 0)
|
| 306 |
+
return torch.argmax(cs).item(), cs[torch.argmax(cs).item()].item()
|
| 307 |
+
|
| 308 |
+
def load_group_attn_shifts(timestamp):
|
| 309 |
+
# Load from Supp1B
|
| 310 |
+
results_dir = "./results/supp1B"
|
| 311 |
+
# dirs = [os.path.join(results_dir, d) for d in os.listdir(results_dir)
|
| 312 |
+
# if os.path.isdir(os.path.join(results_dir, d))]
|
| 313 |
+
# latest_dir = max(dirs, key=os.path.getmtime)
|
| 314 |
+
|
| 315 |
+
latest_dir = os.path.join(results_dir, timestamp)
|
| 316 |
+
print(f"Using latest results directory: {latest_dir}")
|
| 317 |
+
|
| 318 |
+
# Load metadata
|
| 319 |
+
with open(os.path.join(latest_dir, "metadata.json"), "r") as f:
|
| 320 |
+
metadata = json.load(f)
|
| 321 |
+
|
| 322 |
+
# Load memory-mapped files
|
| 323 |
+
attn_maps = np.memmap(os.path.join(latest_dir, "attention_maps.mmap"),
|
| 324 |
+
dtype=np.float32,
|
| 325 |
+
mode='r',
|
| 326 |
+
shape=tuple(metadata["attention_maps_shape"]))
|
| 327 |
+
|
| 328 |
+
resblocks = np.memmap(os.path.join(latest_dir, "resblocks.mmap"),
|
| 329 |
+
dtype=np.float32,
|
| 330 |
+
mode='r',
|
| 331 |
+
shape=tuple(metadata["resblocks_shape"]))
|
| 332 |
+
|
| 333 |
+
# Get file list from metadata
|
| 334 |
+
file_list = metadata.get("file_list", [])
|
| 335 |
+
|
| 336 |
+
# Get top_k values
|
| 337 |
+
top_k_values = metadata.get("top_k_values", [0])
|
| 338 |
+
|
| 339 |
+
return {
|
| 340 |
+
"attn_maps": attn_maps,
|
| 341 |
+
"resblocks": resblocks,
|
| 342 |
+
"metadata": metadata,
|
| 343 |
+
"file_list": file_list,
|
| 344 |
+
"top_k_values": top_k_values,
|
| 345 |
+
"num_layers": metadata.get("num_layers", 0),
|
| 346 |
+
"num_images": metadata.get("num_images", 0),
|
| 347 |
+
"num_heads": metadata.get("num_heads", 0)
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
+
def load_individual_attn_shifts(timestamp, supp = "supp1D"):
|
| 351 |
+
results_dir = f"./results/{supp}"
|
| 352 |
+
# dirs = [os.path.join(results_dir, d) for d in os.listdir(results_dir)
|
| 353 |
+
# if os.path.isdir(os.path.join(results_dir, d))]
|
| 354 |
+
# latest_dir = max(dirs, key=os.path.getmtime)
|
| 355 |
+
|
| 356 |
+
latest_dir = os.path.join(results_dir, timestamp)
|
| 357 |
+
print(f"Using latest results directory: {latest_dir}")
|
| 358 |
+
|
| 359 |
+
# Load metadata
|
| 360 |
+
with open(os.path.join(latest_dir, "metadata.json"), "r") as f:
|
| 361 |
+
metadata = json.load(f)
|
| 362 |
+
|
| 363 |
+
# Load memory-mapped files
|
| 364 |
+
attn_maps = np.memmap(os.path.join(latest_dir, "attention_maps.mmap"),
|
| 365 |
+
dtype=np.float32,
|
| 366 |
+
mode='r',
|
| 367 |
+
shape=tuple(metadata["attention_maps_shape"]))
|
| 368 |
+
|
| 369 |
+
baseline_attn_maps = np.memmap(os.path.join(latest_dir, "baseline_attention_maps.mmap"),
|
| 370 |
+
dtype=np.float32,
|
| 371 |
+
mode='r',
|
| 372 |
+
shape=tuple(metadata["baseline_attention_maps_shape"]))
|
| 373 |
+
|
| 374 |
+
neuron_activations = np.memmap(os.path.join(latest_dir, "neuron_activations.mmap"),
|
| 375 |
+
dtype=np.float32,
|
| 376 |
+
mode='r',
|
| 377 |
+
shape=tuple(metadata["neuron_activations_shape"]))
|
| 378 |
+
|
| 379 |
+
baseline_neuron_activations = np.memmap(os.path.join(latest_dir, "baseline_neuron_activations.mmap"),
|
| 380 |
+
dtype=np.float32,
|
| 381 |
+
mode='r',
|
| 382 |
+
shape=tuple(metadata["baseline_neuron_activations_shape"]))
|
| 383 |
+
|
| 384 |
+
ablated_neurons = np.memmap(os.path.join(latest_dir, "ablated_neurons.mmap"),
|
| 385 |
+
dtype=np.float32,
|
| 386 |
+
mode='r',
|
| 387 |
+
shape=tuple(metadata["ablated_neurons_shape"]))
|
| 388 |
+
|
| 389 |
+
# Get file list from metadata
|
| 390 |
+
file_list = metadata.get("file_list", [])
|
| 391 |
+
|
| 392 |
+
# Get k value
|
| 393 |
+
k = metadata.get("k", 25)
|
| 394 |
+
|
| 395 |
+
return {
|
| 396 |
+
"attn_maps": attn_maps,
|
| 397 |
+
"baseline_attn_maps": baseline_attn_maps,
|
| 398 |
+
"neuron_activations": neuron_activations,
|
| 399 |
+
"baseline_neuron_activations": baseline_neuron_activations,
|
| 400 |
+
"ablated_neurons": ablated_neurons,
|
| 401 |
+
"metadata": metadata,
|
| 402 |
+
"file_list": file_list,
|
| 403 |
+
"k": k,
|
| 404 |
+
"num_layers": metadata.get("num_layers", 12),
|
| 405 |
+
"num_images": metadata.get("num_images", 100),
|
| 406 |
+
"model_name": metadata.get("model_name", "ViT-B-16"),
|
| 407 |
+
"pretrained": metadata.get("pretrained", "openai")
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
def find_register_neurons_cuda(model, preprocess, prs_group, register_norm_threshold = 30, highest_layer = -1, device = "cuda:0", processed_image_cnt = 500):
|
| 411 |
+
num_layers = len(model.visual.transformer.resblocks)
|
| 412 |
+
highest_layer = num_layers - 1 if highest_layer == -1 else highest_layer
|
| 413 |
+
num_neurons = model.visual.transformer.resblocks[0].mlp.state_dict()["c_proj.weight"].shape[1]
|
| 414 |
+
random_images = load_images(preprocess, count=processed_image_cnt)
|
| 415 |
+
neuron_scores = torch.zeros((len(random_images), num_layers, num_neurons), device=device)
|
| 416 |
+
alignment_scores = torch.zeros((len(random_images), num_layers, num_neurons), device=device)
|
| 417 |
+
image_count = 0
|
| 418 |
+
|
| 419 |
+
for i in tqdm(range(len(random_images)), desc="Processing random images"):
|
| 420 |
+
image = random_images[i].unsqueeze(0).to(device)
|
| 421 |
+
prs_group.reinit()
|
| 422 |
+
|
| 423 |
+
with torch.inference_mode():
|
| 424 |
+
representation = model.encode_image(
|
| 425 |
+
image, attn_method="head", normalize=False
|
| 426 |
+
)
|
| 427 |
+
prs_group.finalize()
|
| 428 |
+
|
| 429 |
+
baseline_neuron_acts = prs_group.post_gelu_outputs().to(device)
|
| 430 |
+
baseline_resblock_outputs = prs_group.resblock_outputs().to(device)
|
| 431 |
+
|
| 432 |
+
# Calculate norm map using torch
|
| 433 |
+
norm_map = torch.norm(baseline_resblock_outputs[-1], dim=1)
|
| 434 |
+
filtered_norms = norm_map.clone()
|
| 435 |
+
filtered_norms[filtered_norms < register_norm_threshold] = 0
|
| 436 |
+
|
| 437 |
+
# Get register locations as a tensor
|
| 438 |
+
register_locations = torch.where(filtered_norms > register_norm_threshold)[0]
|
| 439 |
+
|
| 440 |
+
if len(register_locations) == 0:
|
| 441 |
+
continue
|
| 442 |
+
|
| 443 |
+
image_count += 1
|
| 444 |
+
|
| 445 |
+
# Process all layers vectorized
|
| 446 |
+
for layer in range(num_layers):
|
| 447 |
+
# Get absolute activations for all neurons in this layer
|
| 448 |
+
act_layer = torch.abs(baseline_neuron_acts[layer]) # Shape: [seq_len, num_neurons]
|
| 449 |
+
|
| 450 |
+
# Check sparsity condition for all neurons at once
|
| 451 |
+
sparse_neurons = torch.sum(act_layer < 0.5, dim=0) >= 0.5 * act_layer.shape[0] # Shape: [num_neurons]
|
| 452 |
+
|
| 453 |
+
# Skip computation if no neurons meet the condition
|
| 454 |
+
if not torch.any(sparse_neurons):
|
| 455 |
+
continue
|
| 456 |
+
|
| 457 |
+
# Get values at register locations for all neurons simultaneously
|
| 458 |
+
# This creates a tensor of shape [num_register_locations, num_neurons]
|
| 459 |
+
register_values = act_layer[register_locations]
|
| 460 |
+
|
| 461 |
+
# For neurons that pass sparsity condition, compute mean at register locations
|
| 462 |
+
# First, compute mean for all neurons (this is fast)
|
| 463 |
+
neuron_means = register_values.mean(dim=0) # Shape: [num_neurons]
|
| 464 |
+
|
| 465 |
+
# Then zero out means for neurons that don't pass sparsity condition
|
| 466 |
+
neuron_means = neuron_means * sparse_neurons.float()
|
| 467 |
+
|
| 468 |
+
# Store in score tensor
|
| 469 |
+
neuron_scores[i, layer] = neuron_means
|
| 470 |
+
|
| 471 |
+
# Rest of the code remains the same
|
| 472 |
+
mean_neuron_scores = neuron_scores[:image_count].mean(dim=0)
|
| 473 |
+
mean_alignment_scores = alignment_scores[:image_count].mean(dim=0)
|
| 474 |
+
|
| 475 |
+
# Flatten and find top values
|
| 476 |
+
flattened_scores = mean_neuron_scores.flatten()
|
| 477 |
+
sorted_values, sorted_indices = torch.sort(flattened_scores, descending=True)
|
| 478 |
+
|
| 479 |
+
flattened_alignment = mean_alignment_scores.flatten()
|
| 480 |
+
sorted_alignment_values, sorted_alignment_indices = torch.sort(flattened_alignment, descending=True)
|
| 481 |
+
|
| 482 |
+
# Convert indices to layer/neuron pairs
|
| 483 |
+
top_indices = [(idx.item() // num_neurons, idx.item() % num_neurons) for idx in sorted_indices]
|
| 484 |
+
top_alignment_indices = [(idx.item() // num_neurons, idx.item() % num_neurons) for idx in sorted_alignment_indices]
|
| 485 |
+
|
| 486 |
+
register_norms = [
|
| 487 |
+
(layer, neuron, sorted_values[i].item())
|
| 488 |
+
for i, (layer, neuron) in enumerate(top_indices)
|
| 489 |
+
if layer <= highest_layer
|
| 490 |
+
]
|
| 491 |
+
|
| 492 |
+
best_alignment_scores = [
|
| 493 |
+
(layer, neuron, sorted_alignment_values[i].item())
|
| 494 |
+
for i, (layer, neuron) in enumerate(top_alignment_indices)
|
| 495 |
+
if layer <= highest_layer
|
| 496 |
+
]
|
| 497 |
+
|
| 498 |
+
return register_norms, best_alignment_scores
|
| 499 |
+
|
| 500 |
+
def find_register_neurons(model, preprocess, prs_group, register_norm_threshold = 30, highest_layer = -1, device = "cuda:0", processed_image_cnt = 500):
|
| 501 |
+
num_layers = len(model.visual.transformer.resblocks)
|
| 502 |
+
highest_layer = num_layers - 1 if highest_layer == -1 else highest_layer
|
| 503 |
+
num_neurons = model.visual.transformer.resblocks[0].mlp.state_dict()["c_proj.weight"].shape[1]
|
| 504 |
+
|
| 505 |
+
random_images = load_images(preprocess, count = processed_image_cnt)
|
| 506 |
+
neuron_scores = torch.zeros((len(random_images), num_layers, num_neurons))
|
| 507 |
+
for i in tqdm(range(len(random_images)), desc="Processing random images"):
|
| 508 |
+
image = random_images[i].unsqueeze(0).to(device)
|
| 509 |
+
|
| 510 |
+
prs_group.reinit()
|
| 511 |
+
with torch.no_grad():
|
| 512 |
+
representation = model.encode_image(
|
| 513 |
+
image, attn_method="head", normalize=False
|
| 514 |
+
)
|
| 515 |
+
prs_group.finalize()
|
| 516 |
+
|
| 517 |
+
# Gather neuron activations and resblock outputs
|
| 518 |
+
baseline_neuron_acts = prs_group.post_gelu_outputs().cpu().numpy()
|
| 519 |
+
baseline_resblock_outputs = prs_group.resblock_outputs().cpu().numpy()
|
| 520 |
+
|
| 521 |
+
# Calculate norms of the last resblock outputs. Only consider patches of the activation maps that correspond with registers
|
| 522 |
+
norms = np.linalg.norm(baseline_resblock_outputs[-1], axis=1)
|
| 523 |
+
norms[norms < register_norm_threshold] = 0
|
| 524 |
+
register_locations = np.where(norms > register_norm_threshold)[0]
|
| 525 |
+
|
| 526 |
+
# register_neurons = []
|
| 527 |
+
for layer in range(num_layers):
|
| 528 |
+
for neuron in range(num_neurons):
|
| 529 |
+
neuron_map = baseline_neuron_acts[layer, :, neuron]
|
| 530 |
+
mask = np.zeros_like(neuron_map, dtype=bool)
|
| 531 |
+
mask[register_locations] = True
|
| 532 |
+
neuron_map[~mask] = 0
|
| 533 |
+
if np.any(neuron_map < 0):
|
| 534 |
+
continue
|
| 535 |
+
# dist = np.linalg.norm(normalize_array(norms) - normalize_array(neuron_map))
|
| 536 |
+
# register_neurons.append((layer, neuron, dist.item(), neuron_map[register_locations].mean()))
|
| 537 |
+
|
| 538 |
+
neuron_scores[i, layer, neuron] = torch.tensor(neuron_map[register_locations].mean())
|
| 539 |
+
mean_neuron_scores = neuron_scores.mean(dim=0)
|
| 540 |
+
# Flatten the 2D tensor to find global top values
|
| 541 |
+
flattened_scores = mean_neuron_scores.flatten()
|
| 542 |
+
sorted_values, sorted_indices = torch.sort(flattened_scores, descending=True)
|
| 543 |
+
|
| 544 |
+
# Convert flat indices back to 2D coordinates (layer, neuron)
|
| 545 |
+
top_indices = [(idx.item() // num_neurons, idx.item() % num_neurons) for idx in sorted_indices]
|
| 546 |
+
|
| 547 |
+
return [(layer, neuron, sorted_values[i].item()) for i, (layer, neuron) in enumerate(top_indices) if layer <= highest_layer]
|
| 548 |
+
|
| 549 |
+
|
| 550 |
+
def plot_attn_maps(attn_maps, image_idx):
|
| 551 |
+
|
| 552 |
+
num_layers, num_heads, patch_height, patch_width = attn_maps.shape
|
| 553 |
+
print(f"Shape of image_shifts: {attn_maps.shape}")
|
| 554 |
+
|
| 555 |
+
# Create a grid of plots for all layers and heads
|
| 556 |
+
fig, axes = plt.subplots(num_layers, num_heads, figsize=(2*num_heads, 2*num_layers))
|
| 557 |
+
fig.suptitle(f'Attention Shift Maps for Image #{image_idx}', fontsize=16)
|
| 558 |
+
|
| 559 |
+
# Import the correct module for make_axes_locatable
|
| 560 |
+
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
| 561 |
+
|
| 562 |
+
# Plot each layer-head combination
|
| 563 |
+
for layer in range(num_layers):
|
| 564 |
+
# Determine min and max for this layer for consistent colorbar scaling within the layer
|
| 565 |
+
layer_vmin = attn_maps[layer].min().item()
|
| 566 |
+
layer_vmax = attn_maps[layer].max().item()
|
| 567 |
+
|
| 568 |
+
for head in range(num_heads):
|
| 569 |
+
# Get the current axis (handle both 2D and 1D cases)
|
| 570 |
+
if num_layers == 1 and num_heads == 1:
|
| 571 |
+
ax = axes
|
| 572 |
+
elif num_layers == 1:
|
| 573 |
+
ax = axes[head]
|
| 574 |
+
elif num_heads == 1:
|
| 575 |
+
ax = axes[layer]
|
| 576 |
+
else:
|
| 577 |
+
ax = axes[layer, head]
|
| 578 |
+
|
| 579 |
+
# Plot the attention shift map with layer-specific normalization
|
| 580 |
+
im = ax.imshow(attn_maps[layer, head], cmap='viridis', vmin=layer_vmin, vmax=layer_vmax)
|
| 581 |
+
|
| 582 |
+
# Remove ticks for cleaner appearance
|
| 583 |
+
ax.set_xticks([])
|
| 584 |
+
ax.set_yticks([])
|
| 585 |
+
|
| 586 |
+
# Add layer and head labels only on the edges
|
| 587 |
+
if head == 0:
|
| 588 |
+
ax.set_ylabel(f'Layer {layer}')
|
| 589 |
+
if layer == num_layers-1:
|
| 590 |
+
ax.set_xlabel(f'Head {head}')
|
| 591 |
+
|
| 592 |
+
# Add a colorbar for each layer (only once per row)
|
| 593 |
+
if head == num_heads-1:
|
| 594 |
+
# Create a colorbar that's properly sized relative to the plot
|
| 595 |
+
divider = make_axes_locatable(ax)
|
| 596 |
+
cax = divider.append_axes("right", size="5%", pad=0.05)
|
| 597 |
+
plt.colorbar(im, cax=cax)
|
| 598 |
+
|
| 599 |
+
# Adjust layout to make room for the colorbars
|
| 600 |
+
plt.tight_layout()
|
| 601 |
+
return plt
|
| 602 |
+
|
| 603 |
+
def calculate_iou(output, target):
|
| 604 |
+
intersection = output * (output == target)
|
| 605 |
+
area_inter = intersection.sum().item()
|
| 606 |
+
area_pred = output.sum().item()
|
| 607 |
+
area_target = target.sum().item()
|
| 608 |
+
union = area_pred + area_target - area_inter
|
| 609 |
+
iou = area_inter / union
|
| 610 |
+
return area_inter, union, iou
|
| 611 |
+
|
| 612 |
+
def calculate_pixel_accuracy(output, target):
|
| 613 |
+
correct = output * (output == target)
|
| 614 |
+
correct = correct.sum().item()
|
| 615 |
+
total = target.sum().item()
|
| 616 |
+
return correct, total, correct / total
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"bos_token": {"content": "<|startoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "pad_token": "<|endoftext|>"}
|
timm_model.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" timm model adapter
|
| 2 |
+
|
| 3 |
+
Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
|
| 4 |
+
"""
|
| 5 |
+
import logging
|
| 6 |
+
from collections import OrderedDict
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
import timm
|
| 13 |
+
from timm.models.layers import Mlp, to_2tuple
|
| 14 |
+
try:
|
| 15 |
+
# old timm imports < 0.8.1
|
| 16 |
+
from timm.models.layers.attention_pool2d import RotAttentionPool2d
|
| 17 |
+
from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d
|
| 18 |
+
except ImportError:
|
| 19 |
+
# new timm imports >= 0.8.1
|
| 20 |
+
from timm.layers import RotAttentionPool2d
|
| 21 |
+
from timm.layers import AttentionPool2d as AbsAttentionPool2d
|
| 22 |
+
except ImportError:
|
| 23 |
+
timm = None
|
| 24 |
+
|
| 25 |
+
from misc import freeze_batch_norm_2d
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class TimmModel(nn.Module):
|
| 29 |
+
""" timm model adapter
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
model_name,
|
| 35 |
+
embed_dim,
|
| 36 |
+
image_size=224,
|
| 37 |
+
pool='avg',
|
| 38 |
+
proj='linear',
|
| 39 |
+
proj_bias=False,
|
| 40 |
+
drop=0.,
|
| 41 |
+
drop_path=None,
|
| 42 |
+
patch_drop=None,
|
| 43 |
+
pretrained=False,
|
| 44 |
+
):
|
| 45 |
+
super().__init__()
|
| 46 |
+
if timm is None:
|
| 47 |
+
raise RuntimeError("Please `pip install timm` to use timm models.")
|
| 48 |
+
self.image_size = to_2tuple(image_size)
|
| 49 |
+
|
| 50 |
+
# setup kwargs that may not be common across all models
|
| 51 |
+
timm_kwargs = {}
|
| 52 |
+
if drop_path is not None:
|
| 53 |
+
timm_kwargs['drop_path_rate'] = drop_path
|
| 54 |
+
if patch_drop is not None:
|
| 55 |
+
timm_kwargs['patch_drop_rate'] = patch_drop
|
| 56 |
+
|
| 57 |
+
custom_pool = pool in ('abs_attn', 'rot_attn')
|
| 58 |
+
if not proj and not custom_pool:
|
| 59 |
+
# use network classifier head as projection if no proj specified and no custom pooling used
|
| 60 |
+
self.trunk = timm.create_model(
|
| 61 |
+
model_name,
|
| 62 |
+
num_classes=embed_dim,
|
| 63 |
+
global_pool=pool,
|
| 64 |
+
pretrained=pretrained,
|
| 65 |
+
**timm_kwargs,
|
| 66 |
+
)
|
| 67 |
+
prev_chs = embed_dim
|
| 68 |
+
else:
|
| 69 |
+
self.trunk = timm.create_model(
|
| 70 |
+
model_name,
|
| 71 |
+
pretrained=pretrained,
|
| 72 |
+
**timm_kwargs,
|
| 73 |
+
)
|
| 74 |
+
feat_size = self.trunk.default_cfg.get('pool_size', None)
|
| 75 |
+
feature_ndim = 1 if not feat_size else 2
|
| 76 |
+
if custom_pool:
|
| 77 |
+
assert feature_ndim == 2
|
| 78 |
+
# if attn pooling used, remove both classifier and default pool
|
| 79 |
+
self.trunk.reset_classifier(0, global_pool='')
|
| 80 |
+
else:
|
| 81 |
+
# reset global pool if pool config set, otherwise leave as network default
|
| 82 |
+
reset_kwargs = dict(global_pool=pool) if pool else {}
|
| 83 |
+
self.trunk.reset_classifier(0, **reset_kwargs)
|
| 84 |
+
prev_chs = self.trunk.num_features
|
| 85 |
+
|
| 86 |
+
head_layers = OrderedDict()
|
| 87 |
+
|
| 88 |
+
# Add custom pooling to head
|
| 89 |
+
if pool == 'abs_attn':
|
| 90 |
+
head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim)
|
| 91 |
+
prev_chs = embed_dim
|
| 92 |
+
elif pool == 'rot_attn':
|
| 93 |
+
head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
|
| 94 |
+
prev_chs = embed_dim
|
| 95 |
+
|
| 96 |
+
# NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
|
| 97 |
+
if proj == 'linear':
|
| 98 |
+
head_layers['drop'] = nn.Dropout(drop)
|
| 99 |
+
head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias)
|
| 100 |
+
elif proj == 'mlp':
|
| 101 |
+
head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias))
|
| 102 |
+
else:
|
| 103 |
+
assert not proj, f'Unknown projection type {proj}.'
|
| 104 |
+
|
| 105 |
+
self.head = nn.Sequential(head_layers)
|
| 106 |
+
|
| 107 |
+
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
|
| 108 |
+
""" lock modules
|
| 109 |
+
Args:
|
| 110 |
+
unlocked_groups (int): leave last n layer groups unlocked (default: 0)
|
| 111 |
+
"""
|
| 112 |
+
if not unlocked_groups:
|
| 113 |
+
# lock full model
|
| 114 |
+
for param in self.trunk.parameters():
|
| 115 |
+
param.requires_grad = False
|
| 116 |
+
if freeze_bn_stats:
|
| 117 |
+
freeze_batch_norm_2d(self.trunk)
|
| 118 |
+
else:
|
| 119 |
+
# NOTE: partial freeze requires latest timm (master) branch and is subject to change
|
| 120 |
+
try:
|
| 121 |
+
# FIXME import here until API stable and in an official release
|
| 122 |
+
from timm.models.helpers import group_parameters, group_modules
|
| 123 |
+
except ImportError:
|
| 124 |
+
raise RuntimeError(
|
| 125 |
+
'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`')
|
| 126 |
+
matcher = self.trunk.group_matcher()
|
| 127 |
+
gparams = group_parameters(self.trunk, matcher)
|
| 128 |
+
max_layer_id = max(gparams.keys())
|
| 129 |
+
max_layer_id = max_layer_id - unlocked_groups
|
| 130 |
+
for group_idx in range(max_layer_id + 1):
|
| 131 |
+
group = gparams[group_idx]
|
| 132 |
+
for param in group:
|
| 133 |
+
self.trunk.get_parameter(param).requires_grad = False
|
| 134 |
+
if freeze_bn_stats:
|
| 135 |
+
gmodules = group_modules(self.trunk, matcher, reverse=True)
|
| 136 |
+
gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
|
| 137 |
+
freeze_batch_norm_2d(self.trunk, gmodules)
|
| 138 |
+
|
| 139 |
+
@torch.jit.ignore
|
| 140 |
+
def set_grad_checkpointing(self, enable=True):
|
| 141 |
+
try:
|
| 142 |
+
self.trunk.set_grad_checkpointing(enable)
|
| 143 |
+
except Exception as e:
|
| 144 |
+
logging.warning('grad checkpointing not supported for this timm image tower, continuing without...')
|
| 145 |
+
|
| 146 |
+
def forward(self, x):
|
| 147 |
+
x = self.trunk(x)
|
| 148 |
+
x = self.head(x)
|
| 149 |
+
return x
|
tokenizer.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" CLIP tokenizer
|
| 2 |
+
|
| 3 |
+
Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
|
| 4 |
+
"""
|
| 5 |
+
import gzip
|
| 6 |
+
import html
|
| 7 |
+
import os
|
| 8 |
+
from functools import lru_cache
|
| 9 |
+
from typing import Union, List
|
| 10 |
+
|
| 11 |
+
import ftfy
|
| 12 |
+
import regex as re
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
# https://stackoverflow.com/q/62691279
|
| 16 |
+
import os
|
| 17 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@lru_cache()
|
| 21 |
+
def default_bpe():
|
| 22 |
+
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "vocab/bpe_simple_vocab_16e6.txt.gz")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@lru_cache()
|
| 26 |
+
def bytes_to_unicode():
|
| 27 |
+
"""
|
| 28 |
+
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
| 29 |
+
The reversible bpe codes work on unicode strings.
|
| 30 |
+
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
| 31 |
+
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
| 32 |
+
This is a significant percentage of your normal, say, 32K bpe vocab.
|
| 33 |
+
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
| 34 |
+
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
| 35 |
+
"""
|
| 36 |
+
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
| 37 |
+
cs = bs[:]
|
| 38 |
+
n = 0
|
| 39 |
+
for b in range(2**8):
|
| 40 |
+
if b not in bs:
|
| 41 |
+
bs.append(b)
|
| 42 |
+
cs.append(2**8+n)
|
| 43 |
+
n += 1
|
| 44 |
+
cs = [chr(n) for n in cs]
|
| 45 |
+
return dict(zip(bs, cs))
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def get_pairs(word):
|
| 49 |
+
"""Return set of symbol pairs in a word.
|
| 50 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
| 51 |
+
"""
|
| 52 |
+
pairs = set()
|
| 53 |
+
prev_char = word[0]
|
| 54 |
+
for char in word[1:]:
|
| 55 |
+
pairs.add((prev_char, char))
|
| 56 |
+
prev_char = char
|
| 57 |
+
return pairs
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def basic_clean(text):
|
| 61 |
+
text = ftfy.fix_text(text)
|
| 62 |
+
text = html.unescape(html.unescape(text))
|
| 63 |
+
return text.strip()
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def whitespace_clean(text):
|
| 67 |
+
text = re.sub(r'\s+', ' ', text)
|
| 68 |
+
text = text.strip()
|
| 69 |
+
return text
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class SimpleTokenizer(object):
|
| 73 |
+
def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
|
| 74 |
+
self.byte_encoder = bytes_to_unicode()
|
| 75 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
| 76 |
+
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
| 77 |
+
merges = merges[1:49152-256-2+1]
|
| 78 |
+
merges = [tuple(merge.split()) for merge in merges]
|
| 79 |
+
vocab = list(bytes_to_unicode().values())
|
| 80 |
+
vocab = vocab + [v+'</w>' for v in vocab]
|
| 81 |
+
for merge in merges:
|
| 82 |
+
vocab.append(''.join(merge))
|
| 83 |
+
if not special_tokens:
|
| 84 |
+
special_tokens = ['<start_of_text>', '<end_of_text>']
|
| 85 |
+
else:
|
| 86 |
+
special_tokens = ['<start_of_text>', '<end_of_text>'] + special_tokens
|
| 87 |
+
vocab.extend(special_tokens)
|
| 88 |
+
self.encoder = dict(zip(vocab, range(len(vocab))))
|
| 89 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
| 90 |
+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
| 91 |
+
self.cache = {t:t for t in special_tokens}
|
| 92 |
+
special = "|".join(special_tokens)
|
| 93 |
+
self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
|
| 94 |
+
|
| 95 |
+
self.vocab_size = len(self.encoder)
|
| 96 |
+
self.all_special_ids = [self.encoder[t] for t in special_tokens]
|
| 97 |
+
|
| 98 |
+
def bpe(self, token):
|
| 99 |
+
if token in self.cache:
|
| 100 |
+
return self.cache[token]
|
| 101 |
+
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
| 102 |
+
pairs = get_pairs(word)
|
| 103 |
+
|
| 104 |
+
if not pairs:
|
| 105 |
+
return token+'</w>'
|
| 106 |
+
|
| 107 |
+
while True:
|
| 108 |
+
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
| 109 |
+
if bigram not in self.bpe_ranks:
|
| 110 |
+
break
|
| 111 |
+
first, second = bigram
|
| 112 |
+
new_word = []
|
| 113 |
+
i = 0
|
| 114 |
+
while i < len(word):
|
| 115 |
+
try:
|
| 116 |
+
j = word.index(first, i)
|
| 117 |
+
new_word.extend(word[i:j])
|
| 118 |
+
i = j
|
| 119 |
+
except:
|
| 120 |
+
new_word.extend(word[i:])
|
| 121 |
+
break
|
| 122 |
+
|
| 123 |
+
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
| 124 |
+
new_word.append(first+second)
|
| 125 |
+
i += 2
|
| 126 |
+
else:
|
| 127 |
+
new_word.append(word[i])
|
| 128 |
+
i += 1
|
| 129 |
+
new_word = tuple(new_word)
|
| 130 |
+
word = new_word
|
| 131 |
+
if len(word) == 1:
|
| 132 |
+
break
|
| 133 |
+
else:
|
| 134 |
+
pairs = get_pairs(word)
|
| 135 |
+
word = ' '.join(word)
|
| 136 |
+
self.cache[token] = word
|
| 137 |
+
return word
|
| 138 |
+
|
| 139 |
+
def encode(self, text):
|
| 140 |
+
bpe_tokens = []
|
| 141 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
| 142 |
+
for token in re.findall(self.pat, text):
|
| 143 |
+
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
| 144 |
+
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
| 145 |
+
return bpe_tokens
|
| 146 |
+
|
| 147 |
+
def decode(self, tokens):
|
| 148 |
+
text = ''.join([self.decoder[token] for token in tokens])
|
| 149 |
+
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
| 150 |
+
return text
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
_tokenizer = SimpleTokenizer()
|
| 154 |
+
|
| 155 |
+
def decode(output_ids: torch.Tensor):
|
| 156 |
+
output_ids = output_ids.cpu().numpy()
|
| 157 |
+
return _tokenizer.decode(output_ids)
|
| 158 |
+
|
| 159 |
+
def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
|
| 160 |
+
"""
|
| 161 |
+
Returns the tokenized representation of given input string(s)
|
| 162 |
+
|
| 163 |
+
Parameters
|
| 164 |
+
----------
|
| 165 |
+
texts : Union[str, List[str]]
|
| 166 |
+
An input string or a list of input strings to tokenize
|
| 167 |
+
context_length : int
|
| 168 |
+
The context length to use; all CLIP models use 77 as the context length
|
| 169 |
+
|
| 170 |
+
Returns
|
| 171 |
+
-------
|
| 172 |
+
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
|
| 173 |
+
"""
|
| 174 |
+
if isinstance(texts, str):
|
| 175 |
+
texts = [texts]
|
| 176 |
+
|
| 177 |
+
sot_token = _tokenizer.encoder["<start_of_text>"]
|
| 178 |
+
eot_token = _tokenizer.encoder["<end_of_text>"]
|
| 179 |
+
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
|
| 180 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
| 181 |
+
|
| 182 |
+
for i, tokens in enumerate(all_tokens):
|
| 183 |
+
if len(tokens) > context_length:
|
| 184 |
+
tokens = tokens[:context_length] # Truncate
|
| 185 |
+
tokens[-1] = eot_token
|
| 186 |
+
result[i, :len(tokens)] = torch.tensor(tokens)
|
| 187 |
+
|
| 188 |
+
return result
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class HFTokenizer:
|
| 192 |
+
"""HuggingFace tokenizer wrapper"""
|
| 193 |
+
|
| 194 |
+
def __init__(self, tokenizer_name: str):
|
| 195 |
+
from transformers import AutoTokenizer
|
| 196 |
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
| 197 |
+
|
| 198 |
+
def save_pretrained(self, dest):
|
| 199 |
+
self.tokenizer.save_pretrained(dest)
|
| 200 |
+
|
| 201 |
+
def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.Tensor:
|
| 202 |
+
# same cleaning as for default tokenizer, except lowercasing
|
| 203 |
+
# adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
|
| 204 |
+
if isinstance(texts, str):
|
| 205 |
+
texts = [texts]
|
| 206 |
+
texts = [whitespace_clean(basic_clean(text)) for text in texts]
|
| 207 |
+
input_ids = self.tokenizer(
|
| 208 |
+
texts,
|
| 209 |
+
return_tensors='pt',
|
| 210 |
+
max_length=context_length,
|
| 211 |
+
padding='max_length',
|
| 212 |
+
truncation=True,
|
| 213 |
+
).input_ids
|
| 214 |
+
return input_ids
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"unk_token": {
|
| 3 |
+
"content": "<|endoftext|>",
|
| 4 |
+
"single_word": false,
|
| 5 |
+
"lstrip": false,
|
| 6 |
+
"rstrip": false,
|
| 7 |
+
"normalized": true,
|
| 8 |
+
"__type": "AddedToken"
|
| 9 |
+
},
|
| 10 |
+
"bos_token": {
|
| 11 |
+
"content": "<|startoftext|>",
|
| 12 |
+
"single_word": false,
|
| 13 |
+
"lstrip": false,
|
| 14 |
+
"rstrip": false,
|
| 15 |
+
"normalized": true,
|
| 16 |
+
"__type": "AddedToken"
|
| 17 |
+
},
|
| 18 |
+
"eos_token": {
|
| 19 |
+
"content": "<|endoftext|>",
|
| 20 |
+
"single_word": false,
|
| 21 |
+
"lstrip": false,
|
| 22 |
+
"rstrip": false,
|
| 23 |
+
"normalized": true,
|
| 24 |
+
"__type": "AddedToken"
|
| 25 |
+
},
|
| 26 |
+
"pad_token": "<|endoftext|>",
|
| 27 |
+
"add_prefix_space": false,
|
| 28 |
+
"errors": "replace",
|
| 29 |
+
"do_lower_case": true,
|
| 30 |
+
"name_or_path": "openai/clip-vit-base-patch32",
|
| 31 |
+
"model_max_length": 77,
|
| 32 |
+
"special_tokens_map_file": "/home/suraj/.cache/huggingface/transformers/18a566598f286c9139f88160c99f84eec492a26bd22738fa9cb44d5b7e0a5c76.cce1206abbad28826f000510f22f354e53e66a97f7c23745a7dfe27609cc07f5",
|
| 33 |
+
"tokenizer_class": "CLIPTokenizer"
|
| 34 |
+
}
|
tokenizer_config_bak.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"tokenizer_class": "SimpleTokenizer",
|
| 3 |
+
"vocab_size": 49408,
|
| 4 |
+
"context_length": 77,
|
| 5 |
+
"bpe_path": "vocab/bpe_simple_vocab_16e6.txt.gz",
|
| 6 |
+
"special_tokens": [
|
| 7 |
+
"<start_of_text>",
|
| 8 |
+
"<end_of_text>"
|
| 9 |
+
]
|
| 10 |
+
}
|
transform.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
from dataclasses import dataclass, asdict
|
| 3 |
+
from typing import Any, Dict, Optional, Sequence, Tuple, Union
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torchvision.transforms.functional as F
|
| 8 |
+
|
| 9 |
+
from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
|
| 10 |
+
CenterCrop
|
| 11 |
+
|
| 12 |
+
from constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class AugmentationCfg:
|
| 17 |
+
scale: Tuple[float, float] = (0.9, 1.0)
|
| 18 |
+
ratio: Optional[Tuple[float, float]] = None
|
| 19 |
+
color_jitter: Optional[Union[float, Tuple[float, float, float]]] = None
|
| 20 |
+
interpolation: Optional[str] = None
|
| 21 |
+
re_prob: Optional[float] = None
|
| 22 |
+
re_count: Optional[int] = None
|
| 23 |
+
use_timm: bool = False
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class ResizeMaxSize(nn.Module):
|
| 27 |
+
|
| 28 |
+
def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0):
|
| 29 |
+
super().__init__()
|
| 30 |
+
if not isinstance(max_size, int):
|
| 31 |
+
raise TypeError(f"Size should be int. Got {type(max_size)}")
|
| 32 |
+
self.max_size = max_size
|
| 33 |
+
self.interpolation = interpolation
|
| 34 |
+
self.fn = min if fn == 'min' else min
|
| 35 |
+
self.fill = fill
|
| 36 |
+
|
| 37 |
+
def forward(self, img):
|
| 38 |
+
if isinstance(img, torch.Tensor):
|
| 39 |
+
height, width = img.shape[:2]
|
| 40 |
+
else:
|
| 41 |
+
width, height = img.size
|
| 42 |
+
scale = self.max_size / float(max(height, width))
|
| 43 |
+
if scale != 1.0:
|
| 44 |
+
new_size = tuple(round(dim * scale) for dim in (height, width))
|
| 45 |
+
img = F.resize(img, new_size, self.interpolation)
|
| 46 |
+
pad_h = self.max_size - new_size[0]
|
| 47 |
+
pad_w = self.max_size - new_size[1]
|
| 48 |
+
img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill)
|
| 49 |
+
return img
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _convert_to_rgb(image):
|
| 53 |
+
return image.convert('RGB')
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def image_transform(
|
| 57 |
+
image_size: int,
|
| 58 |
+
is_train: bool,
|
| 59 |
+
mean: Optional[Tuple[float, ...]] = None,
|
| 60 |
+
std: Optional[Tuple[float, ...]] = None,
|
| 61 |
+
resize_longest_max: bool = False,
|
| 62 |
+
fill_color: int = 0,
|
| 63 |
+
aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
|
| 64 |
+
):
|
| 65 |
+
mean = mean or OPENAI_DATASET_MEAN
|
| 66 |
+
if not isinstance(mean, (list, tuple)):
|
| 67 |
+
mean = (mean,) * 3
|
| 68 |
+
|
| 69 |
+
std = std or OPENAI_DATASET_STD
|
| 70 |
+
if not isinstance(std, (list, tuple)):
|
| 71 |
+
std = (std,) * 3
|
| 72 |
+
|
| 73 |
+
if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
|
| 74 |
+
# for square size, pass size as int so that Resize() uses aspect preserving shortest edge
|
| 75 |
+
image_size = image_size[0]
|
| 76 |
+
|
| 77 |
+
if isinstance(aug_cfg, dict):
|
| 78 |
+
aug_cfg = AugmentationCfg(**aug_cfg)
|
| 79 |
+
else:
|
| 80 |
+
aug_cfg = aug_cfg or AugmentationCfg()
|
| 81 |
+
normalize = Normalize(mean=mean, std=std)
|
| 82 |
+
if is_train:
|
| 83 |
+
aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None}
|
| 84 |
+
use_timm = aug_cfg_dict.pop('use_timm', False)
|
| 85 |
+
if use_timm:
|
| 86 |
+
from timm.data import create_transform # timm can still be optional
|
| 87 |
+
if isinstance(image_size, (tuple, list)):
|
| 88 |
+
assert len(image_size) >= 2
|
| 89 |
+
input_size = (3,) + image_size[-2:]
|
| 90 |
+
else:
|
| 91 |
+
input_size = (3, image_size, image_size)
|
| 92 |
+
# by default, timm aug randomly alternates bicubic & bilinear for better robustness at inference time
|
| 93 |
+
aug_cfg_dict.setdefault('interpolation', 'random')
|
| 94 |
+
aug_cfg_dict.setdefault('color_jitter', None) # disable by default
|
| 95 |
+
train_transform = create_transform(
|
| 96 |
+
input_size=input_size,
|
| 97 |
+
is_training=True,
|
| 98 |
+
hflip=0.,
|
| 99 |
+
mean=mean,
|
| 100 |
+
std=std,
|
| 101 |
+
re_mode='pixel',
|
| 102 |
+
**aug_cfg_dict,
|
| 103 |
+
)
|
| 104 |
+
else:
|
| 105 |
+
train_transform = Compose([
|
| 106 |
+
RandomResizedCrop(
|
| 107 |
+
image_size,
|
| 108 |
+
scale=aug_cfg_dict.pop('scale'),
|
| 109 |
+
interpolation=InterpolationMode.BICUBIC,
|
| 110 |
+
),
|
| 111 |
+
_convert_to_rgb,
|
| 112 |
+
ToTensor(),
|
| 113 |
+
normalize,
|
| 114 |
+
])
|
| 115 |
+
if aug_cfg_dict:
|
| 116 |
+
warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).')
|
| 117 |
+
return train_transform
|
| 118 |
+
else:
|
| 119 |
+
if resize_longest_max:
|
| 120 |
+
transforms = [
|
| 121 |
+
ResizeMaxSize(image_size, fill=fill_color)
|
| 122 |
+
]
|
| 123 |
+
else:
|
| 124 |
+
transforms = [
|
| 125 |
+
Resize(image_size, interpolation=InterpolationMode.BICUBIC),
|
| 126 |
+
CenterCrop(image_size),
|
| 127 |
+
]
|
| 128 |
+
transforms.extend([
|
| 129 |
+
_convert_to_rgb,
|
| 130 |
+
ToTensor(),
|
| 131 |
+
normalize,
|
| 132 |
+
])
|
| 133 |
+
return Compose(transforms)
|
transformer.py
ADDED
|
@@ -0,0 +1,872 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
import math
|
| 3 |
+
from typing import Callable, Optional, Sequence, Tuple, Text
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
from torch.nn import functional as F
|
| 8 |
+
from torch.utils.checkpoint import checkpoint
|
| 9 |
+
import numbers
|
| 10 |
+
import einops
|
| 11 |
+
import numpy as np
|
| 12 |
+
from misc import to_2tuple
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class LayerNorm(nn.Module):
|
| 16 |
+
"""Subclass torch's LayerNorm (with cast back to input dtype)."""
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
normalized_shape,
|
| 21 |
+
eps: float = 1e-5,
|
| 22 |
+
elementwise_affine: bool = True,
|
| 23 |
+
device=None,
|
| 24 |
+
dtype=None,
|
| 25 |
+
):
|
| 26 |
+
super().__init__()
|
| 27 |
+
if isinstance(normalized_shape, numbers.Integral):
|
| 28 |
+
normalized_shape = (normalized_shape,)
|
| 29 |
+
self.normalized_shape = tuple(normalized_shape)
|
| 30 |
+
self.eps = eps
|
| 31 |
+
self.elementwise_affine = elementwise_affine
|
| 32 |
+
if self.elementwise_affine:
|
| 33 |
+
self.weight = torch.nn.Parameter(
|
| 34 |
+
torch.empty(self.normalized_shape)
|
| 35 |
+
)
|
| 36 |
+
self.bias = torch.nn.Parameter(
|
| 37 |
+
torch.empty(self.normalized_shape)
|
| 38 |
+
)
|
| 39 |
+
else:
|
| 40 |
+
self.register_parameter("weight", None)
|
| 41 |
+
self.register_parameter("bias", None)
|
| 42 |
+
|
| 43 |
+
def forward(self, x: torch.Tensor):
|
| 44 |
+
orig_type = x.dtype
|
| 45 |
+
assert self.normalized_shape == x.shape[-len(self.normalized_shape):]
|
| 46 |
+
dims = [-(i + 1) for i in range(len(self.normalized_shape))]
|
| 47 |
+
mean = x.mean(dim=dims, keepdim=True)
|
| 48 |
+
mean_x2 = (x**2).mean(dim=dims, keepdim=True)
|
| 49 |
+
var = mean_x2 - mean**2
|
| 50 |
+
x_norm = (x - mean) / torch.sqrt(var + self.eps)
|
| 51 |
+
if self.elementwise_affine:
|
| 52 |
+
x_norm = self.weight * x_norm + self.bias
|
| 53 |
+
return x_norm.to(orig_type)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class QuickGELU(nn.Module):
|
| 57 |
+
def forward(self, x: torch.Tensor):
|
| 58 |
+
return x * torch.sigmoid(1.702 * x)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class LayerScale(nn.Module):
|
| 62 |
+
def __init__(self, dim, init_values=1e-5, inplace=False):
|
| 63 |
+
super().__init__()
|
| 64 |
+
self.inplace = inplace
|
| 65 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
| 66 |
+
|
| 67 |
+
def forward(self, x):
|
| 68 |
+
raise ValueError("Not implemented")
|
| 69 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class PatchDropout(nn.Module):
|
| 73 |
+
"""
|
| 74 |
+
https://arxiv.org/abs/2212.00794
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
def __init__(self, prob, exclude_first_token=True):
|
| 78 |
+
super().__init__()
|
| 79 |
+
assert 0 <= prob < 1.0
|
| 80 |
+
self.prob = prob
|
| 81 |
+
self.exclude_first_token = exclude_first_token
|
| 82 |
+
|
| 83 |
+
def forward(self, x):
|
| 84 |
+
if not self.training or self.prob == 0.0:
|
| 85 |
+
return x
|
| 86 |
+
|
| 87 |
+
if self.exclude_first_token:
|
| 88 |
+
cls_tokens, x = x[:, :1], x[:, 1:]
|
| 89 |
+
else:
|
| 90 |
+
cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
|
| 91 |
+
|
| 92 |
+
batch = x.size()[0]
|
| 93 |
+
num_tokens = x.size()[1]
|
| 94 |
+
|
| 95 |
+
batch_indices = torch.arange(batch)
|
| 96 |
+
batch_indices = batch_indices[..., None]
|
| 97 |
+
|
| 98 |
+
keep_prob = 1 - self.prob
|
| 99 |
+
num_patches_keep = max(1, int(num_tokens * keep_prob))
|
| 100 |
+
|
| 101 |
+
rand = torch.randn(batch, num_tokens)
|
| 102 |
+
patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
|
| 103 |
+
|
| 104 |
+
x = x[batch_indices, patch_indices_keep]
|
| 105 |
+
|
| 106 |
+
if self.exclude_first_token:
|
| 107 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 108 |
+
|
| 109 |
+
return x
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class Attention(nn.Module):
|
| 113 |
+
def __init__(
|
| 114 |
+
self,
|
| 115 |
+
dim,
|
| 116 |
+
num_heads=8,
|
| 117 |
+
qkv_bias=True,
|
| 118 |
+
scaled_cosine=False,
|
| 119 |
+
scale_heads=False,
|
| 120 |
+
logit_scale_max=math.log(1.0 / 0.01),
|
| 121 |
+
attn_drop=0.0,
|
| 122 |
+
proj_drop=0.0,
|
| 123 |
+
):
|
| 124 |
+
super().__init__()
|
| 125 |
+
self.scaled_cosine = scaled_cosine
|
| 126 |
+
self.scale_heads = scale_heads
|
| 127 |
+
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
| 128 |
+
self.num_heads = num_heads
|
| 129 |
+
self.head_dim = dim // num_heads
|
| 130 |
+
self.scale = self.head_dim**-0.5
|
| 131 |
+
self.logit_scale_max = logit_scale_max
|
| 132 |
+
|
| 133 |
+
self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
|
| 134 |
+
if qkv_bias:
|
| 135 |
+
self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
|
| 136 |
+
else:
|
| 137 |
+
self.in_proj_bias = None
|
| 138 |
+
|
| 139 |
+
if self.scaled_cosine:
|
| 140 |
+
self.logit_scale = nn.Parameter(
|
| 141 |
+
torch.log(10 * torch.ones((num_heads, 1, 1)))
|
| 142 |
+
)
|
| 143 |
+
else:
|
| 144 |
+
self.logit_scale = None
|
| 145 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 146 |
+
if self.scale_heads:
|
| 147 |
+
self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
|
| 148 |
+
else:
|
| 149 |
+
self.head_scale = None
|
| 150 |
+
self.out_proj = nn.Linear(dim, dim)
|
| 151 |
+
self.out_drop = nn.Dropout(proj_drop)
|
| 152 |
+
|
| 153 |
+
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
|
| 154 |
+
L, N, C = x.shape
|
| 155 |
+
q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
|
| 156 |
+
q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
|
| 157 |
+
k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
|
| 158 |
+
v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
|
| 159 |
+
|
| 160 |
+
if self.logit_scale is not None:
|
| 161 |
+
attn = torch.bmm(
|
| 162 |
+
F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2)
|
| 163 |
+
)
|
| 164 |
+
logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
|
| 165 |
+
attn = attn.view(N, self.num_heads, L, L) * logit_scale
|
| 166 |
+
attn = attn.view(-1, L, L)
|
| 167 |
+
else:
|
| 168 |
+
q = q * self.scale
|
| 169 |
+
attn = torch.bmm(q, k.transpose(-1, -2))
|
| 170 |
+
|
| 171 |
+
if attn_mask is not None:
|
| 172 |
+
if attn_mask.dtype == torch.bool:
|
| 173 |
+
new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
|
| 174 |
+
new_attn_mask.masked_fill_(attn_mask, float("-inf"))
|
| 175 |
+
attn_mask = new_attn_mask
|
| 176 |
+
attn += attn_mask
|
| 177 |
+
|
| 178 |
+
attn = attn.softmax(dim=-1)
|
| 179 |
+
attn = self.attn_drop(attn)
|
| 180 |
+
|
| 181 |
+
x = torch.bmm(attn, v)
|
| 182 |
+
if self.head_scale is not None:
|
| 183 |
+
x = x.view(N, self.num_heads, L, C) * self.head_scale
|
| 184 |
+
x = x.view(-1, L, C)
|
| 185 |
+
x = x.transpose(0, 1).reshape(L, N, C)
|
| 186 |
+
x = self.out_proj(x)
|
| 187 |
+
x = self.out_drop(x)
|
| 188 |
+
return x
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class AttentionalPooler(nn.Module):
|
| 192 |
+
def __init__(
|
| 193 |
+
self,
|
| 194 |
+
d_model: int,
|
| 195 |
+
context_dim: int,
|
| 196 |
+
n_head: int = 8,
|
| 197 |
+
n_queries: int = 256,
|
| 198 |
+
norm_layer: Callable = LayerNorm,
|
| 199 |
+
):
|
| 200 |
+
super().__init__()
|
| 201 |
+
self.query = nn.Parameter(torch.randn(n_queries, d_model))
|
| 202 |
+
self.attn = nn.MultiheadAttention(
|
| 203 |
+
d_model, n_head, kdim=context_dim, vdim=context_dim
|
| 204 |
+
)
|
| 205 |
+
self.ln_q = norm_layer(d_model)
|
| 206 |
+
self.ln_k = norm_layer(context_dim)
|
| 207 |
+
|
| 208 |
+
def forward(self, x: torch.Tensor):
|
| 209 |
+
x = self.ln_k(x).permute(1, 0, 2) # NLD -> LND
|
| 210 |
+
N = x.shape[1]
|
| 211 |
+
q = self.ln_q(self.query)
|
| 212 |
+
out = self.attn(self._repeat(q, N), x, x, need_weights=False)[0]
|
| 213 |
+
return out.permute(1, 0, 2) # LND -> NLD
|
| 214 |
+
|
| 215 |
+
def _repeat(self, query, N: int):
|
| 216 |
+
return query.unsqueeze(1).repeat(1, N, 1)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
class MLP(nn.Module):
|
| 220 |
+
def __init__(
|
| 221 |
+
self,
|
| 222 |
+
d_model: int,
|
| 223 |
+
mlp_width: int,
|
| 224 |
+
act_layer: Callable = nn.GELU,
|
| 225 |
+
layer_id: Optional[int] = None,
|
| 226 |
+
):
|
| 227 |
+
super().__init__()
|
| 228 |
+
self.c_fc = nn.Linear(d_model, mlp_width)
|
| 229 |
+
self.gelu = act_layer()
|
| 230 |
+
self.c_proj = nn.Linear(mlp_width, d_model)
|
| 231 |
+
self.layer_id = layer_id
|
| 232 |
+
|
| 233 |
+
def forward(self, x, neuron_dict=None, num_register_tokens=0):
|
| 234 |
+
x = self.c_fc(x)
|
| 235 |
+
|
| 236 |
+
# If we have a dictionary of modifications and this layer is in it
|
| 237 |
+
if neuron_dict is not None and self.layer_id in neuron_dict and num_register_tokens>0:
|
| 238 |
+
neurons = neuron_dict[self.layer_id]
|
| 239 |
+
|
| 240 |
+
# Apply GELU to all activations
|
| 241 |
+
x_after_gelu = self.gelu(x)
|
| 242 |
+
|
| 243 |
+
original_activations = x_after_gelu.clone()
|
| 244 |
+
# Create new activation map for specified neurons
|
| 245 |
+
new_activation_map = torch.zeros(
|
| 246 |
+
(x_after_gelu.shape[0], x_after_gelu.shape[1], len(neurons)),
|
| 247 |
+
device=x_after_gelu.device,
|
| 248 |
+
).to(x_after_gelu.dtype)
|
| 249 |
+
|
| 250 |
+
max_values = torch.max(original_activations[:, :, neurons], dim=1, keepdim=True).values
|
| 251 |
+
|
| 252 |
+
new_activation_map[:, -num_register_tokens:, :] = max_values
|
| 253 |
+
new_activation_map[:,0,:] = x_after_gelu[:,0,neurons]
|
| 254 |
+
|
| 255 |
+
x_after_gelu[:,:,neurons] = new_activation_map
|
| 256 |
+
x = x_after_gelu
|
| 257 |
+
else:
|
| 258 |
+
x = self.gelu(x)
|
| 259 |
+
|
| 260 |
+
x = self.c_proj(x)
|
| 261 |
+
return x
|
| 262 |
+
|
| 263 |
+
# TODO 여기가 custom attetion이 아니라는 점에서 문제가 발생한 것으로 보인다.
|
| 264 |
+
class MultiheadAttention(nn.Module):
|
| 265 |
+
def __init__(
|
| 266 |
+
self,
|
| 267 |
+
embed_dim,
|
| 268 |
+
num_heads,
|
| 269 |
+
dropout=0.0,
|
| 270 |
+
bias=True,
|
| 271 |
+
add_bias_kv=False,
|
| 272 |
+
add_zero_attn=False,
|
| 273 |
+
kdim=None,
|
| 274 |
+
vdim=None,
|
| 275 |
+
batch_first=False,
|
| 276 |
+
device=None,
|
| 277 |
+
dtype=None,
|
| 278 |
+
):
|
| 279 |
+
super().__init__()
|
| 280 |
+
self.embed_dim = embed_dim
|
| 281 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
| 282 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
| 283 |
+
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
| 284 |
+
self.q_out = nn.Identity()
|
| 285 |
+
self.k_out = nn.Identity()
|
| 286 |
+
self.v_out = nn.Identity()
|
| 287 |
+
self.qkv_out = nn.Identity()
|
| 288 |
+
self.attn_map = nn.Identity()
|
| 289 |
+
|
| 290 |
+
self.num_heads = num_heads
|
| 291 |
+
self.dropout = dropout
|
| 292 |
+
self.batch_first = batch_first
|
| 293 |
+
self.head_dim = embed_dim // num_heads
|
| 294 |
+
assert (
|
| 295 |
+
self.head_dim * num_heads == self.embed_dim
|
| 296 |
+
), "embed_dim must be divisible by num_heads"
|
| 297 |
+
self.in_proj_weight = nn.Parameter(torch.empty((3 * embed_dim, embed_dim)))
|
| 298 |
+
|
| 299 |
+
if bias:
|
| 300 |
+
self.in_proj_bias = nn.Parameter(torch.empty(3 * embed_dim))
|
| 301 |
+
else:
|
| 302 |
+
self.register_parameter("in_proj_bias", None)
|
| 303 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 304 |
+
|
| 305 |
+
if add_bias_kv:
|
| 306 |
+
self.bias_k = nn.Parameter(torch.empty((1, 1, embed_dim)))
|
| 307 |
+
self.bias_v = nn.Parameter(torch.empty((1, 1, embed_dim)))
|
| 308 |
+
else:
|
| 309 |
+
self.bias_k = self.bias_v = None
|
| 310 |
+
|
| 311 |
+
self.add_zero_attn = add_zero_attn
|
| 312 |
+
|
| 313 |
+
def forward_direct(self, x, attn_mask=None):
|
| 314 |
+
B, N, C = x.shape
|
| 315 |
+
qkv = x @ self.in_proj_weight.T + self.in_proj_bias
|
| 316 |
+
qkv = self.qkv_out(qkv)
|
| 317 |
+
qkv = qkv.reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
| 318 |
+
# B, S, 3, H, d -> 3, B, H, S, d batch first computation
|
| 319 |
+
# 이 지점 때문에 연산 결과에 차이가 생기는 거 같은데?
|
| 320 |
+
q, k, v = qkv.unbind(0)
|
| 321 |
+
|
| 322 |
+
q = self.q_out(q)
|
| 323 |
+
k = self.k_out(k)
|
| 324 |
+
v = self.v_out(v)
|
| 325 |
+
|
| 326 |
+
dk = q.size()[-1]
|
| 327 |
+
q = q / math.sqrt(dk)
|
| 328 |
+
attn = q @ k.transpose(-2, -1)
|
| 329 |
+
if attn_mask is not None:
|
| 330 |
+
attn += attn_mask
|
| 331 |
+
attn = attn.softmax(dim=-1)
|
| 332 |
+
attn = self.attn_map(attn)
|
| 333 |
+
x = attn @ v
|
| 334 |
+
|
| 335 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
| 336 |
+
x = x @ self.out_proj.weight.T + self.out_proj.bias
|
| 337 |
+
return x
|
| 338 |
+
|
| 339 |
+
def _split_qkv_weight(self):
|
| 340 |
+
q_weight, k_weight, v_weight = (
|
| 341 |
+
self.in_proj_weight[: self.embed_dim].reshape(
|
| 342 |
+
self.num_heads, self.head_dim, -1
|
| 343 |
+
),
|
| 344 |
+
self.in_proj_weight[self.embed_dim : self.embed_dim * 2].reshape(
|
| 345 |
+
self.num_heads, self.head_dim, -1
|
| 346 |
+
),
|
| 347 |
+
self.in_proj_weight[self.embed_dim * 2 :].reshape(
|
| 348 |
+
self.num_heads, self.head_dim, -1
|
| 349 |
+
),
|
| 350 |
+
)
|
| 351 |
+
return q_weight, k_weight, v_weight
|
| 352 |
+
|
| 353 |
+
def _split_qkv_bias(self):
|
| 354 |
+
q_bias, k_bias, v_bias = (
|
| 355 |
+
self.in_proj_bias[: self.embed_dim].reshape(
|
| 356 |
+
1, self.num_heads, 1, self.head_dim
|
| 357 |
+
),
|
| 358 |
+
self.in_proj_bias[self.embed_dim : self.embed_dim * 2].reshape(
|
| 359 |
+
1, self.num_heads, 1, self.head_dim
|
| 360 |
+
),
|
| 361 |
+
self.in_proj_bias[self.embed_dim * 2 :].reshape(
|
| 362 |
+
1, self.num_heads, 1, self.head_dim
|
| 363 |
+
),
|
| 364 |
+
)
|
| 365 |
+
return q_bias, k_bias, v_bias
|
| 366 |
+
|
| 367 |
+
def forward_qkv(self, x, attn_mask=None):
|
| 368 |
+
B, N, C = x.shape
|
| 369 |
+
q_weight, k_weight, v_weight = (
|
| 370 |
+
self.in_proj_weight[: self.embed_dim],
|
| 371 |
+
self.in_proj_weight[self.embed_dim : self.embed_dim * 2],
|
| 372 |
+
self.in_proj_weight[self.embed_dim * 2 :],
|
| 373 |
+
)
|
| 374 |
+
q_bias, k_bias, v_bias = (
|
| 375 |
+
self.in_proj_bias[: self.embed_dim],
|
| 376 |
+
self.in_proj_bias[self.embed_dim : self.embed_dim * 2],
|
| 377 |
+
self.in_proj_bias[self.embed_dim * 2 :],
|
| 378 |
+
)
|
| 379 |
+
q = (x @ q_weight.T + q_bias).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
|
| 380 |
+
k = (x @ k_weight.T + k_bias).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
|
| 381 |
+
v = (x @ v_weight.T + v_bias).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
|
| 382 |
+
|
| 383 |
+
dk = q.size()[-1]
|
| 384 |
+
q = q / math.sqrt(dk)
|
| 385 |
+
attn = q @ k.transpose(-2, -1)
|
| 386 |
+
if attn_mask is not None:
|
| 387 |
+
attn += attn_mask
|
| 388 |
+
attn = attn.softmax(dim=-1)
|
| 389 |
+
x = torch.einsum("bhnm,bhmc->bhnmc", attn, v)
|
| 390 |
+
x = x.sum(axis=3).transpose(1, 2).reshape(B, N, C)
|
| 391 |
+
x = x @ self.out_proj.weight.T + self.out_proj.bias
|
| 392 |
+
return x
|
| 393 |
+
|
| 394 |
+
def forward_per_head(self, x, attn_mask=None):
|
| 395 |
+
B, N, C = x.shape
|
| 396 |
+
q_weight, k_weight, v_weight = self._split_qkv_weight()
|
| 397 |
+
q_bias, k_bias, v_bias = self._split_qkv_bias()
|
| 398 |
+
q = torch.einsum("bnc,hdc->bhnd", x, q_weight) + q_bias
|
| 399 |
+
k = torch.einsum("bnc,hdc->bhnd", x, k_weight) + k_bias
|
| 400 |
+
v = torch.einsum("bnc,hdc->bhnd", x, v_weight) + v_bias
|
| 401 |
+
|
| 402 |
+
dk = q.size()[-1]
|
| 403 |
+
q = q / math.sqrt(dk)
|
| 404 |
+
attn = q @ k.transpose(-2, -1)
|
| 405 |
+
if attn_mask is not None:
|
| 406 |
+
attn += attn_mask
|
| 407 |
+
attn = attn.softmax(dim=-1)
|
| 408 |
+
x = torch.einsum("bhnm,bhmc->bnmhc", attn, v)
|
| 409 |
+
x = torch.einsum(
|
| 410 |
+
"bnmhc,dhc->bnmhd",
|
| 411 |
+
x,
|
| 412 |
+
self.out_proj.weight.reshape(self.embed_dim, self.num_heads, self.head_dim),
|
| 413 |
+
)
|
| 414 |
+
x = x.sum(axis=[2, 3]) + self.out_proj.bias
|
| 415 |
+
return x
|
| 416 |
+
|
| 417 |
+
def _get_ov_circuit(self):
|
| 418 |
+
reshaped_o = self.out_proj.weight.reshape(
|
| 419 |
+
self.embed_dim, self.num_heads, self.head_dim
|
| 420 |
+
)
|
| 421 |
+
_, _, v_weight = self._split_qkv_weight()
|
| 422 |
+
_, _, v_bias = self._split_qkv_bias()
|
| 423 |
+
ov_circuit = torch.einsum("onh,nhi->oni", reshaped_o, v_weight)
|
| 424 |
+
ov_bias_circuit = torch.einsum("onh,bnxh->bnxo", reshaped_o, v_bias)
|
| 425 |
+
return ov_circuit, ov_bias_circuit
|
| 426 |
+
|
| 427 |
+
def forward_ov_circuit(self, x, attn_mask=None):
|
| 428 |
+
B, N, C = x.shape
|
| 429 |
+
q_weight, k_weight, _ = self._split_qkv_weight()
|
| 430 |
+
q_bias, k_bias, _ = self._split_qkv_bias()
|
| 431 |
+
q = torch.einsum("bnc,hdc->bhnd", x, q_weight) + q_bias
|
| 432 |
+
k = torch.einsum("bnc,hdc->bhnd", x, k_weight) + k_bias
|
| 433 |
+
ov, ov_bias = self._get_ov_circuit()
|
| 434 |
+
v = torch.einsum("bnc,dhc->bhnd", x, ov) + ov_bias
|
| 435 |
+
|
| 436 |
+
dk = q.size()[-1]
|
| 437 |
+
q = q / math.sqrt(dk)
|
| 438 |
+
attn = q @ k.transpose(-2, -1)
|
| 439 |
+
if attn_mask is not None:
|
| 440 |
+
attn += attn_mask
|
| 441 |
+
attn = attn.softmax(dim=-1)
|
| 442 |
+
x = torch.einsum("bhnm,bhmc->bnmhc", attn, v)
|
| 443 |
+
x = x.sum(axis=[2, 3]) + self.out_proj.bias
|
| 444 |
+
return x
|
| 445 |
+
|
| 446 |
+
def forward(self, x, attn_mask=None, method: Text = "ov_circuit"):
|
| 447 |
+
if method == "direct":
|
| 448 |
+
return self.forward_direct(x, attn_mask=attn_mask)
|
| 449 |
+
elif method == "qkv":
|
| 450 |
+
return self.forward_qkv(x, attn_mask=attn_mask)
|
| 451 |
+
elif method == "head":
|
| 452 |
+
return self.forward_per_head(x, attn_mask=attn_mask)
|
| 453 |
+
elif method == "ov_circuit":
|
| 454 |
+
return self.forward_ov_circuit(x, attn_mask=attn_mask)
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
class ResidualAttentionBlock(nn.Module):
|
| 458 |
+
def __init__(
|
| 459 |
+
self,
|
| 460 |
+
d_model: int,
|
| 461 |
+
n_head: int,
|
| 462 |
+
mlp_ratio: float = 4.0,
|
| 463 |
+
ls_init_value: float = None,
|
| 464 |
+
act_layer: Callable = nn.GELU,
|
| 465 |
+
norm_layer: Callable = LayerNorm,
|
| 466 |
+
layer_id: Optional[int] = None,
|
| 467 |
+
):
|
| 468 |
+
super().__init__()
|
| 469 |
+
self.ln_1 = norm_layer(d_model)
|
| 470 |
+
self.attn = MultiheadAttention(d_model, n_head)
|
| 471 |
+
self.layer_id = layer_id
|
| 472 |
+
|
| 473 |
+
self.ls_1 = (
|
| 474 |
+
LayerScale(d_model, ls_init_value)
|
| 475 |
+
if ls_init_value is not None
|
| 476 |
+
else nn.Identity()
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
self.ln_2 = norm_layer(d_model)
|
| 480 |
+
self.mlp_width = int(d_model * mlp_ratio)
|
| 481 |
+
self.mlp = MLP(
|
| 482 |
+
d_model,
|
| 483 |
+
self.mlp_width,
|
| 484 |
+
act_layer=act_layer,
|
| 485 |
+
layer_id=layer_id,
|
| 486 |
+
)
|
| 487 |
+
self.ls_2 = (
|
| 488 |
+
LayerScale(d_model, ls_init_value)
|
| 489 |
+
if ls_init_value is not None
|
| 490 |
+
else nn.Identity()
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
def attention(
|
| 494 |
+
self,
|
| 495 |
+
q_x: torch.Tensor,
|
| 496 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 497 |
+
method: Text = "direct",
|
| 498 |
+
):
|
| 499 |
+
attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None
|
| 500 |
+
return self.attn(q_x, attn_mask=attn_mask, method=method)
|
| 501 |
+
|
| 502 |
+
def forward(
|
| 503 |
+
self,
|
| 504 |
+
q_x: torch.Tensor,
|
| 505 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 506 |
+
attn_method: Text = "direct",
|
| 507 |
+
neuron_dict=None,
|
| 508 |
+
num_register_tokens=0
|
| 509 |
+
):
|
| 510 |
+
after_ln1 = self.ln_1(q_x)
|
| 511 |
+
after_attn = self.attention(
|
| 512 |
+
q_x=after_ln1, attn_mask=attn_mask, method=attn_method
|
| 513 |
+
)
|
| 514 |
+
x = q_x + self.ls_1(after_attn)
|
| 515 |
+
after_ln2 = self.ln_2(x)
|
| 516 |
+
after_mlp = self.mlp(after_ln2, neuron_dict=neuron_dict, num_register_tokens=num_register_tokens)
|
| 517 |
+
x = x + self.ls_2(after_mlp)
|
| 518 |
+
return x
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
class Transformer(nn.Module):
|
| 522 |
+
def __init__(
|
| 523 |
+
self,
|
| 524 |
+
width: int,
|
| 525 |
+
layers: int,
|
| 526 |
+
heads: int,
|
| 527 |
+
mlp_ratio: float = 4.0,
|
| 528 |
+
ls_init_value: float = None,
|
| 529 |
+
act_layer: Callable = nn.GELU,
|
| 530 |
+
norm_layer: Callable = LayerNorm,
|
| 531 |
+
):
|
| 532 |
+
super().__init__()
|
| 533 |
+
self.width = width
|
| 534 |
+
self.layers = layers
|
| 535 |
+
self.grad_checkpointing = False
|
| 536 |
+
|
| 537 |
+
self.resblocks = nn.ModuleList(
|
| 538 |
+
[
|
| 539 |
+
ResidualAttentionBlock(
|
| 540 |
+
width,
|
| 541 |
+
heads,
|
| 542 |
+
mlp_ratio,
|
| 543 |
+
ls_init_value=ls_init_value,
|
| 544 |
+
act_layer=act_layer,
|
| 545 |
+
norm_layer=norm_layer,
|
| 546 |
+
layer_id=i,
|
| 547 |
+
)
|
| 548 |
+
for i in range(layers)
|
| 549 |
+
]
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
def get_cast_dtype(self) -> torch.dtype:
|
| 553 |
+
if hasattr(self.resblocks[0].mlp.c_fc, "int8_original_dtype"):
|
| 554 |
+
return self.resblocks[0].mlp.c_fc.int8_original_dtype
|
| 555 |
+
return self.resblocks[0].mlp.c_fc.weight.dtype
|
| 556 |
+
|
| 557 |
+
def forward(
|
| 558 |
+
self,
|
| 559 |
+
x: torch.Tensor,
|
| 560 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 561 |
+
attn_method: Text = "direct",
|
| 562 |
+
neuron_dict=None,
|
| 563 |
+
num_register_tokens=0
|
| 564 |
+
):
|
| 565 |
+
for r in self.resblocks:
|
| 566 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
| 567 |
+
raise ValueError("grad_checkpointing not implemented")
|
| 568 |
+
else:
|
| 569 |
+
x = r(
|
| 570 |
+
x,
|
| 571 |
+
attn_mask=attn_mask,
|
| 572 |
+
attn_method=attn_method,
|
| 573 |
+
neuron_dict=neuron_dict,
|
| 574 |
+
num_register_tokens=num_register_tokens
|
| 575 |
+
)
|
| 576 |
+
return x
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
class VisionTransformer(nn.Module):
|
| 580 |
+
output_tokens: torch.jit.Final[bool]
|
| 581 |
+
|
| 582 |
+
def __init__(
|
| 583 |
+
self,
|
| 584 |
+
image_size: int,
|
| 585 |
+
patch_size: int,
|
| 586 |
+
width: int,
|
| 587 |
+
layers: int,
|
| 588 |
+
heads: int,
|
| 589 |
+
mlp_ratio: float,
|
| 590 |
+
ls_init_value: float = None,
|
| 591 |
+
global_average_pool: bool = False,
|
| 592 |
+
attentional_pool: bool = False,
|
| 593 |
+
n_queries: int = 256,
|
| 594 |
+
attn_pooler_heads: int = 8,
|
| 595 |
+
output_dim: int = 512,
|
| 596 |
+
patch_dropout: float = 0.0,
|
| 597 |
+
input_patchnorm: bool = False,
|
| 598 |
+
act_layer: Callable = nn.GELU,
|
| 599 |
+
norm_layer: Callable = LayerNorm,
|
| 600 |
+
output_tokens: bool = False,
|
| 601 |
+
):
|
| 602 |
+
super().__init__()
|
| 603 |
+
self.output_tokens = output_tokens
|
| 604 |
+
image_height, image_width = self.image_size = to_2tuple(image_size)
|
| 605 |
+
patch_height, patch_width = self.patch_size = to_2tuple(patch_size)
|
| 606 |
+
self.grid_size = (image_height // patch_height, image_width // patch_width)
|
| 607 |
+
self.output_dim = output_dim
|
| 608 |
+
|
| 609 |
+
self.num_register_tokens = 0
|
| 610 |
+
self.neuron_dict = None
|
| 611 |
+
|
| 612 |
+
self.input_patchnorm = input_patchnorm
|
| 613 |
+
|
| 614 |
+
if input_patchnorm:
|
| 615 |
+
patch_input_dim = patch_height * patch_width * 3
|
| 616 |
+
self.patchnorm_pre_ln = LayerNorm(patch_input_dim)
|
| 617 |
+
self.conv1 = nn.Linear(patch_input_dim, width)
|
| 618 |
+
else:
|
| 619 |
+
self.patchnorm_pre_ln = nn.Identity()
|
| 620 |
+
self.conv1 = nn.Conv2d(
|
| 621 |
+
in_channels=3,
|
| 622 |
+
out_channels=width,
|
| 623 |
+
kernel_size=patch_size,
|
| 624 |
+
stride=patch_size,
|
| 625 |
+
bias=False,
|
| 626 |
+
)
|
| 627 |
+
|
| 628 |
+
scale = width**-0.5
|
| 629 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
| 630 |
+
self.positional_embedding = nn.Parameter(
|
| 631 |
+
scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width)
|
| 632 |
+
)
|
| 633 |
+
|
| 634 |
+
self.width = width
|
| 635 |
+
self.scale = scale
|
| 636 |
+
self.extra_token = self.scale * torch.randn(width)
|
| 637 |
+
|
| 638 |
+
self.patch_dropout = (
|
| 639 |
+
PatchDropout(patch_dropout) if patch_dropout > 0.0 else nn.Identity()
|
| 640 |
+
)
|
| 641 |
+
|
| 642 |
+
self.ln_pre = norm_layer(width)
|
| 643 |
+
self.transformer = Transformer(
|
| 644 |
+
width,
|
| 645 |
+
layers,
|
| 646 |
+
heads,
|
| 647 |
+
mlp_ratio,
|
| 648 |
+
ls_init_value=ls_init_value,
|
| 649 |
+
act_layer=act_layer,
|
| 650 |
+
norm_layer=norm_layer,
|
| 651 |
+
)
|
| 652 |
+
|
| 653 |
+
self.global_average_pool = global_average_pool
|
| 654 |
+
if attentional_pool:
|
| 655 |
+
self.attn_pool = AttentionalPooler(
|
| 656 |
+
output_dim, width, n_head=attn_pooler_heads, n_queries=n_queries
|
| 657 |
+
)
|
| 658 |
+
self.ln_post = norm_layer(output_dim)
|
| 659 |
+
self.proj = nn.Parameter(scale * torch.randn(output_dim, output_dim))
|
| 660 |
+
else:
|
| 661 |
+
self.attn_pool = None
|
| 662 |
+
self.ln_post = norm_layer(width)
|
| 663 |
+
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
| 664 |
+
|
| 665 |
+
@torch.jit.ignore
|
| 666 |
+
def set_grad_checkpointing(self, enable=True):
|
| 667 |
+
self.transformer.grad_checkpointing = enable
|
| 668 |
+
|
| 669 |
+
def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 670 |
+
if self.global_average_pool:
|
| 671 |
+
return x.mean(dim=1), x
|
| 672 |
+
else:
|
| 673 |
+
return x[:, 0], x[:, 1:]
|
| 674 |
+
|
| 675 |
+
def forward(self, x: torch.Tensor, attn_method: Text = "direct", num_register_tokens = None, neuron_dict=None):
|
| 676 |
+
# to patches
|
| 677 |
+
|
| 678 |
+
if num_register_tokens is None and neuron_dict is None:
|
| 679 |
+
num_register_tokens = self.num_register_tokens
|
| 680 |
+
neuron_dict = self.neuron_dict
|
| 681 |
+
|
| 682 |
+
if self.input_patchnorm:
|
| 683 |
+
x = x.reshape(
|
| 684 |
+
x.shape[0],
|
| 685 |
+
x.shape[1],
|
| 686 |
+
self.grid_size[0],
|
| 687 |
+
self.patch_size[0],
|
| 688 |
+
self.grid_size[1],
|
| 689 |
+
self.patch_size[1],
|
| 690 |
+
)
|
| 691 |
+
x = x.permute(0, 2, 4, 1, 3, 5)
|
| 692 |
+
x = x.reshape(x.shape[0], self.grid_size[0] * self.grid_size[1], -1)
|
| 693 |
+
x = self.patchnorm_pre_ln(x)
|
| 694 |
+
x = self.conv1(x)
|
| 695 |
+
else:
|
| 696 |
+
x = self.conv1(x)
|
| 697 |
+
x = x.reshape(x.shape[0], x.shape[1], -1)
|
| 698 |
+
x = x.permute(0, 2, 1)
|
| 699 |
+
|
| 700 |
+
# class embeddings and positional embeddings
|
| 701 |
+
x = torch.cat([
|
| 702 |
+
self.class_embedding.to(x.dtype)
|
| 703 |
+
+ torch.zeros(
|
| 704 |
+
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
|
| 705 |
+
),
|
| 706 |
+
x,
|
| 707 |
+
],
|
| 708 |
+
dim=1,
|
| 709 |
+
)
|
| 710 |
+
x = x + self.positional_embedding.to(x.dtype)
|
| 711 |
+
|
| 712 |
+
extra_token_embeddings = []
|
| 713 |
+
total_patches = x.shape[1] - 1
|
| 714 |
+
for i in range(num_register_tokens):
|
| 715 |
+
extra_token_embeddings.append(
|
| 716 |
+
torch.zeros(
|
| 717 |
+
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
|
| 718 |
+
),
|
| 719 |
+
)
|
| 720 |
+
|
| 721 |
+
# Add extra tokens
|
| 722 |
+
if num_register_tokens > 0:
|
| 723 |
+
x = torch.cat([x, *extra_token_embeddings], dim=1)
|
| 724 |
+
|
| 725 |
+
x = self.patch_dropout(x)
|
| 726 |
+
x = self.ln_pre(x)
|
| 727 |
+
|
| 728 |
+
x = self.transformer(x, attn_mask=None, attn_method=attn_method, neuron_dict=neuron_dict, num_register_tokens=num_register_tokens)
|
| 729 |
+
|
| 730 |
+
if self.attn_pool is not None:
|
| 731 |
+
x = self.attn_pool(x)
|
| 732 |
+
x = self.ln_post(x)
|
| 733 |
+
pooled, tokens = self._global_pool(x)
|
| 734 |
+
else:
|
| 735 |
+
pooled, tokens = self._global_pool(x)
|
| 736 |
+
pooled = self.ln_post(pooled)
|
| 737 |
+
|
| 738 |
+
if self.proj is not None:
|
| 739 |
+
pooled = pooled @ self.proj
|
| 740 |
+
|
| 741 |
+
if self.output_tokens:
|
| 742 |
+
return pooled, tokens
|
| 743 |
+
|
| 744 |
+
return pooled
|
| 745 |
+
|
| 746 |
+
|
| 747 |
+
class TextTransformer(nn.Module):
|
| 748 |
+
output_tokens: torch.jit.Final[bool]
|
| 749 |
+
|
| 750 |
+
def __init__(
|
| 751 |
+
self,
|
| 752 |
+
context_length: int = 77,
|
| 753 |
+
vocab_size: int = 49408,
|
| 754 |
+
width: int = 512,
|
| 755 |
+
heads: int = 8,
|
| 756 |
+
layers: int = 12,
|
| 757 |
+
ls_init_value: float = None,
|
| 758 |
+
output_dim: int = 512,
|
| 759 |
+
act_layer: Callable = nn.GELU,
|
| 760 |
+
norm_layer: Callable = LayerNorm,
|
| 761 |
+
embed_cls: bool = False,
|
| 762 |
+
pad_id: int = 0,
|
| 763 |
+
output_tokens: bool = False,
|
| 764 |
+
):
|
| 765 |
+
super().__init__()
|
| 766 |
+
self.output_tokens = output_tokens
|
| 767 |
+
self.num_pos = self.context_length = context_length
|
| 768 |
+
self.vocab_size = vocab_size
|
| 769 |
+
self.width = width
|
| 770 |
+
self.output_dim = output_dim
|
| 771 |
+
self.heads = heads
|
| 772 |
+
self.pad_id = pad_id
|
| 773 |
+
|
| 774 |
+
self.text_projection = nn.Parameter(torch.empty(width, output_dim))
|
| 775 |
+
|
| 776 |
+
if embed_cls:
|
| 777 |
+
self.cls_emb = nn.Parameter(torch.empty(width))
|
| 778 |
+
self.num_pos += 1
|
| 779 |
+
else:
|
| 780 |
+
self.cls_emb = None
|
| 781 |
+
|
| 782 |
+
self.token_embedding = nn.Embedding(vocab_size, width)
|
| 783 |
+
self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width))
|
| 784 |
+
self.transformer = Transformer(
|
| 785 |
+
width=width,
|
| 786 |
+
layers=layers,
|
| 787 |
+
heads=heads,
|
| 788 |
+
ls_init_value=ls_init_value,
|
| 789 |
+
act_layer=act_layer,
|
| 790 |
+
norm_layer=norm_layer,
|
| 791 |
+
)
|
| 792 |
+
self.ln_final = norm_layer(width)
|
| 793 |
+
|
| 794 |
+
self.register_buffer("attn_mask", self.build_attention_mask(), persistent=False)
|
| 795 |
+
|
| 796 |
+
self.init_parameters()
|
| 797 |
+
|
| 798 |
+
def init_parameters(self):
|
| 799 |
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
| 800 |
+
nn.init.normal_(self.positional_embedding, std=0.01)
|
| 801 |
+
if self.cls_emb is not None:
|
| 802 |
+
nn.init.normal_(self.cls_emb, std=0.01)
|
| 803 |
+
|
| 804 |
+
proj_std = (self.transformer.width**-0.5) * (
|
| 805 |
+
(2 * self.transformer.layers) ** -0.5
|
| 806 |
+
)
|
| 807 |
+
attn_std = self.transformer.width**-0.5
|
| 808 |
+
fc_std = (2 * self.transformer.width) ** -0.5
|
| 809 |
+
for block in self.transformer.resblocks:
|
| 810 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
| 811 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
| 812 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
| 813 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
| 814 |
+
|
| 815 |
+
if self.text_projection is not None:
|
| 816 |
+
nn.init.normal_(self.text_projection, std=self.transformer.width**-0.5)
|
| 817 |
+
|
| 818 |
+
@torch.jit.ignore
|
| 819 |
+
def set_grad_checkpointing(self, enable=True):
|
| 820 |
+
self.transformer.grad_checkpointing = enable
|
| 821 |
+
|
| 822 |
+
def build_attention_mask(self):
|
| 823 |
+
mask = torch.empty(self.num_pos, self.num_pos)
|
| 824 |
+
mask.fill_(float("-inf"))
|
| 825 |
+
mask.triu_(1)
|
| 826 |
+
return mask
|
| 827 |
+
|
| 828 |
+
def build_cls_mask(self, text, cast_dtype: torch.dtype):
|
| 829 |
+
cls_mask = (text != self.pad_id).unsqueeze(1)
|
| 830 |
+
cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=1.0)
|
| 831 |
+
additive_mask = torch.empty(
|
| 832 |
+
cls_mask.shape, dtype=cast_dtype, device=cls_mask.device
|
| 833 |
+
)
|
| 834 |
+
additive_mask.fill_(0)
|
| 835 |
+
additive_mask.masked_fill_(~cls_mask, float("-inf"))
|
| 836 |
+
additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0)
|
| 837 |
+
return additive_mask
|
| 838 |
+
|
| 839 |
+
def _repeat(self, t, N: int):
|
| 840 |
+
return t.reshape(1, 1, -1).repeat(N, 1, 1)
|
| 841 |
+
|
| 842 |
+
def forward(self, text, attn_method: Text = "direct"):
|
| 843 |
+
cast_dtype = self.transformer.get_cast_dtype()
|
| 844 |
+
seq_len = text.shape[1]
|
| 845 |
+
|
| 846 |
+
x = self.token_embedding(text).to(cast_dtype)
|
| 847 |
+
attn_mask = self.attn_mask
|
| 848 |
+
if self.cls_emb is not None:
|
| 849 |
+
seq_len += 1
|
| 850 |
+
x = torch.cat([x, self._repeat(self.cls_emb, x.shape[0])], dim=1)
|
| 851 |
+
cls_mask = self.build_cls_mask(text, cast_dtype)
|
| 852 |
+
attn_mask = (
|
| 853 |
+
attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len]
|
| 854 |
+
)
|
| 855 |
+
|
| 856 |
+
x = x + self.positional_embedding[:seq_len].to(cast_dtype)
|
| 857 |
+
x = self.transformer(x, attn_mask=attn_mask, attn_method=attn_method)
|
| 858 |
+
|
| 859 |
+
if self.cls_emb is not None:
|
| 860 |
+
pooled, tokens = x[:, -1], x[:, :-1]
|
| 861 |
+
pooled = self.ln_final(pooled)
|
| 862 |
+
else:
|
| 863 |
+
x = self.ln_final(x)
|
| 864 |
+
pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x
|
| 865 |
+
|
| 866 |
+
if self.text_projection is not None:
|
| 867 |
+
pooled = pooled @ self.text_projection
|
| 868 |
+
|
| 869 |
+
if self.output_tokens:
|
| 870 |
+
return pooled, tokens
|
| 871 |
+
|
| 872 |
+
return pooled
|
utils.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import torch
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
class SaveOcassionally:
|
| 7 |
+
def __init__(self, out, every_sec = None, every_count = None):
|
| 8 |
+
assert every_sec != None or every_count != None
|
| 9 |
+
|
| 10 |
+
self.out = out
|
| 11 |
+
self.curr_time = time.time()
|
| 12 |
+
self.every_sec = every_sec
|
| 13 |
+
self.cnt = 0
|
| 14 |
+
self.every_count = every_count
|
| 15 |
+
self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 16 |
+
|
| 17 |
+
if "TIMESTAMP" in self.out:
|
| 18 |
+
self.out = self.out.replace("TIMESTAMP", self.timestamp)
|
| 19 |
+
print(f"Replacing TIMESTAMP with {self.timestamp}")
|
| 20 |
+
|
| 21 |
+
# Ensure the directory exists
|
| 22 |
+
out_dir = os.path.abspath(os.path.dirname(self.out))
|
| 23 |
+
if not os.path.exists(out_dir):
|
| 24 |
+
os.makedirs(out_dir)
|
| 25 |
+
|
| 26 |
+
def save(self, obj):
|
| 27 |
+
self.cnt += 1
|
| 28 |
+
if self.every_sec != None and time.time() - self.curr_time > self.every_sec:
|
| 29 |
+
torch.save(obj, self.out)
|
| 30 |
+
elif self.every_count != None and self.cnt % self.every_count == 0:
|
| 31 |
+
torch.save(obj, self.out)
|
| 32 |
+
|
| 33 |
+
def force_save(self, obj):
|
| 34 |
+
torch.save(obj, self.out)
|
utils/utils.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import torch
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
class SaveOcassionally:
|
| 7 |
+
def __init__(self, out, every_sec = None, every_count = None):
|
| 8 |
+
assert every_sec != None or every_count != None
|
| 9 |
+
|
| 10 |
+
self.out = out
|
| 11 |
+
self.curr_time = time.time()
|
| 12 |
+
self.every_sec = every_sec
|
| 13 |
+
self.cnt = 0
|
| 14 |
+
self.every_count = every_count
|
| 15 |
+
self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 16 |
+
|
| 17 |
+
if "TIMESTAMP" in self.out:
|
| 18 |
+
self.out = self.out.replace("TIMESTAMP", self.timestamp)
|
| 19 |
+
print(f"Replacing TIMESTAMP with {self.timestamp}")
|
| 20 |
+
|
| 21 |
+
# Ensure the directory exists
|
| 22 |
+
out_dir = os.path.abspath(os.path.dirname(self.out))
|
| 23 |
+
if not os.path.exists(out_dir):
|
| 24 |
+
os.makedirs(out_dir)
|
| 25 |
+
|
| 26 |
+
def save(self, obj):
|
| 27 |
+
self.cnt += 1
|
| 28 |
+
if self.every_sec != None and time.time() - self.curr_time > self.every_sec:
|
| 29 |
+
torch.save(obj, self.out)
|
| 30 |
+
elif self.every_count != None and self.cnt % self.every_count == 0:
|
| 31 |
+
torch.save(obj, self.out)
|
| 32 |
+
|
| 33 |
+
def force_save(self, obj):
|
| 34 |
+
torch.save(obj, self.out)
|
vitl14_attention.png
ADDED
|
Git LFS Details
|
vitl14_patchnorms.png
ADDED
|
Git LFS Details
|
vocab.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
vocab/bpe_simple_vocab_16e6.txt.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
| 3 |
+
size 1356917
|
zeroshot_classifier.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7dff47ac37ed4b67771bf6cf651a55dcf95d22eddc91acce2f54638ec82c6783
|
| 3 |
+
size 1537240
|