feat: add the process of using the pre-trained model
Browse filesrefering to https://github.com/huggingface/huggingface_hub/issues/595
README.md
CHANGED
|
@@ -20,3 +20,53 @@ The main ideas are:
|
|
| 20 |
|
| 21 |
- Shifted Patch Tokenization
|
| 22 |
- Locality Self Attention
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
- Shifted Patch Tokenization
|
| 22 |
- Locality Self Attention
|
| 23 |
+
|
| 24 |
+
# Use the pre-trained model
|
| 25 |
+
|
| 26 |
+
The model is pre-trained on the CIFAR100 dataset with the following hyperparameters:
|
| 27 |
+
```python
|
| 28 |
+
# DATA
|
| 29 |
+
NUM_CLASSES = 100
|
| 30 |
+
INPUT_SHAPE = (32, 32, 3)
|
| 31 |
+
BUFFER_SIZE = 512
|
| 32 |
+
BATCH_SIZE = 256
|
| 33 |
+
|
| 34 |
+
# AUGMENTATION
|
| 35 |
+
IMAGE_SIZE = 72
|
| 36 |
+
PATCH_SIZE = 6
|
| 37 |
+
NUM_PATCHES = (IMAGE_SIZE // PATCH_SIZE) ** 2
|
| 38 |
+
|
| 39 |
+
# OPTIMIZER
|
| 40 |
+
LEARNING_RATE = 0.001
|
| 41 |
+
WEIGHT_DECAY = 0.0001
|
| 42 |
+
|
| 43 |
+
# TRAINING
|
| 44 |
+
EPOCHS = 50
|
| 45 |
+
|
| 46 |
+
# ARCHITECTURE
|
| 47 |
+
LAYER_NORM_EPS = 1e-6
|
| 48 |
+
TRANSFORMER_LAYERS = 8
|
| 49 |
+
PROJECTION_DIM = 64
|
| 50 |
+
NUM_HEADS = 4
|
| 51 |
+
TRANSFORMER_UNITS = [
|
| 52 |
+
PROJECTION_DIM * 2,
|
| 53 |
+
PROJECTION_DIM,
|
| 54 |
+
]
|
| 55 |
+
MLP_HEAD_UNITS = [
|
| 56 |
+
2048,
|
| 57 |
+
1024
|
| 58 |
+
]
|
| 59 |
+
```
|
| 60 |
+
I have used the `AdamW` optimizer with cosine decay learning schedule. You can find the entire implementation in the keras blog post.
|
| 61 |
+
|
| 62 |
+
To use the pretrained model:
|
| 63 |
+
```python
|
| 64 |
+
loaded_model = from_pretrained_keras("keras-io/vit-small-ds")
|
| 65 |
+
_, accuracy, top_5_accuracy = loaded_model.evaluate(test_ds)
|
| 66 |
+
print(f"Test accuracy: {round(accuracy * 100, 2)}%")
|
| 67 |
+
print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
For an indepth understanding of the model uploading and downloading process one can refer to this [colab notebook](https://colab.research.google.com/drive/1nCMhefqySzG2p8wyXhmeAX5urddQXt49?usp=sharing).
|
| 71 |
+
|
| 72 |
+
Important: The data augmentation pipeline is excluded from the model. TensorFlow `2.7` has a weird issue of serializaiton with augmentation pipeline. You can follow [this GitHub issue](https://github.com/huggingface/huggingface_hub/issues/593) for more updates. To send images through the model, one needs to make use of the `tf.data` and `map` API to map the augmentation.
|