Update README.md
Browse filesUpdating the model card
README.md
CHANGED
|
@@ -7,4 +7,78 @@ base_model:
|
|
| 7 |
tags:
|
| 8 |
- videogames
|
| 9 |
- pokemon
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
tags:
|
| 8 |
- videogames
|
| 9 |
- pokemon
|
| 10 |
+
pipeline_tag: image-classification
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
# Model Card: Pokemon Generation 1 through 9 Image Classifier
|
| 14 |
+
|
| 15 |
+
## Model Description
|
| 16 |
+
The Fine-Tuned Vision Transformer (ViT) is a variant of the transformer encoder architecture, similar to BERT, that has been adapted for image classification tasks. This specific model, named "google/vit-base-patch16-224-in21k," is pre-trained on a substantial collection of images in a supervised manner, leveraging the ImageNet-21k dataset. The images in the pre-training dataset are resized to a resolution of 224x224 pixels, making it suitable for a wide range of image recognition tasks.
|
| 17 |
+
|
| 18 |
+
The model was trained using an augmented dataset of JJMack/pokemon-classification-gen1-9, with 5 additional augmentend version of each image.
|
| 19 |
+
|
| 20 |
+
This model was for me to learn how to fine tune a model and I am writing a LinkedIn Article series around the process. You can find the first link [Building a Real Pokédex - An AI Journey](https://www.linkedin.com/pulse/building-real-pok%C3%A9dex-ai-journey-jeremy-mack-jc3fc/?trackingId=zWK6TeRJ%2FXLAmv7BKZsQxA%3D%3D)
|
| 21 |
+
|
| 22 |
+
### Intended Uses
|
| 23 |
+
- **Pokemon Classification**: The primary intended use of this model is for the classification of Pokemon images.
|
| 24 |
+
|
| 25 |
+
### How to use
|
| 26 |
+
Here is how to use this model to classifiy an image based on 1 of 1025 pokemone:
|
| 27 |
+
|
| 28 |
+
```python
|
| 29 |
+
# Use a pipeline as a high-level helper
|
| 30 |
+
from PIL import Image
|
| 31 |
+
from transformers import pipeline
|
| 32 |
+
img = Image.open("<path_to_image_file>")
|
| 33 |
+
classifier = pipeline("image-classification", model="JJMack/pokemon_gen1_9_classifier")
|
| 34 |
+
classifier(img)
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
<hr>
|
| 38 |
+
|
| 39 |
+
``` markdown
|
| 40 |
+
# Load model directly
|
| 41 |
+
import torch
|
| 42 |
+
from PIL import Image
|
| 43 |
+
from transformers import AutoModelForImageClassification, ViTImageProcessor
|
| 44 |
+
img = Image.open("<path_to_image_file>")
|
| 45 |
+
model = AutoModelForImageClassification.from_pretrained("JJMack/pokemon_gen1_9_classifier")
|
| 46 |
+
processor = ViTImageProcessor.from_pretrained('JJMack/pokemon_gen1_9_classifier')
|
| 47 |
+
with torch.no_grad():
|
| 48 |
+
inputs = processor(images=img, return_tensors="pt")
|
| 49 |
+
outputs = model(**inputs)
|
| 50 |
+
logits = outputs.logits
|
| 51 |
+
predicted_label = logits.argmax(-1).item()
|
| 52 |
+
model.config.id2label[predicted_label]
|
| 53 |
+
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
### Limitations
|
| 57 |
+
- **Specialized Task Fine-Tuning**: While the model is adept at NSFW image classification, its performance may vary when applied to other tasks.
|
| 58 |
+
- Users interested in employing this model for different tasks should explore fine-tuned versions available in the model hub for optimal results.
|
| 59 |
+
|
| 60 |
+
## Training Data
|
| 61 |
+
|
| 62 |
+
The model's training data came from [Bulapedia](https://bulbapedia.bulbagarden.net/wiki/Main_Page). Each image of the training dataset was augmented 5 times with the following augments
|
| 63 |
+
```
|
| 64 |
+
- RandomHorizontalFlip(p=0.5),
|
| 65 |
+
- RandomVerticalFlip(p=0.5),
|
| 66 |
+
- RandomRotation(degrees=30),
|
| 67 |
+
- ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2),
|
| 68 |
+
- GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)),
|
| 69 |
+
- RandomAffine(degrees=0, translate=(0.1, 0.1)),
|
| 70 |
+
- RandomPerspective(distortion_scale=0.5, p=0.5),
|
| 71 |
+
- RandomGrayscale(p=0.2),
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
### Training Stats
|
| 75 |
+
```
|
| 76 |
+
- 'eval_loss': 0.7451944351196289,
|
| 77 |
+
- 'eval_accuracy': 0.9221343873517787,
|
| 78 |
+
- 'eval_runtime': 39.6834,
|
| 79 |
+
- 'eval_samples_per_second': 63.755,
|
| 80 |
+
- 'eval_steps_per_second': 7.988
|
| 81 |
+
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
<hr>
|