Model Card: Pokemon Generation 1 through 9 Image Classifier

Model Description

The Fine-Tuned Vision Transformer (ViT) is a variant of the transformer encoder architecture, similar to BERT, that has been adapted for image classification tasks. This specific model, named "google/vit-base-patch16-224-in21k," is pre-trained on a substantial collection of images in a supervised manner, leveraging the ImageNet-21k dataset. The images in the pre-training dataset are resized to a resolution of 224x224 pixels, making it suitable for a wide range of image recognition tasks.

The model was trained using an augmented dataset of JJMack/pokemon-classification-gen1-9, with 5 additional augmentend version of each image.

This model was for me to learn how to fine tune a model and I am writing a LinkedIn Article series around the process. You can find the first link Building a Real Pokédex - An AI Journey

Intended Uses

  • Pokemon Classification: The primary intended use of this model is for the classification of Pokemon images.

How to use

Here is how to use this model to classifiy an image based on 1 of 1025 pokemone:

# Use a pipeline as a high-level helper
from PIL import Image
from transformers import pipeline
img = Image.open("<path_to_image_file>")
classifier = pipeline("image-classification", model="JJMack/pokemon_gen1_9_classifier")
classifier(img)

# Load model directly
import torch
from PIL import Image
from transformers import AutoModelForImageClassification, ViTImageProcessor
img = Image.open("<path_to_image_file>")
model = AutoModelForImageClassification.from_pretrained("JJMack/pokemon_gen1_9_classifier")
processor = ViTImageProcessor.from_pretrained('JJMack/pokemon_gen1_9_classifier')
with torch.no_grad():
    inputs = processor(images=img, return_tensors="pt")
    outputs = model(**inputs)
    logits = outputs.logits
predicted_label = logits.argmax(-1).item()
model.config.id2label[predicted_label]

Limitations

  • Specialized Task Fine-Tuning: While the model is adept at NSFW image classification, its performance may vary when applied to other tasks.
  • Users interested in employing this model for different tasks should explore fine-tuned versions available in the model hub for optimal results.

Training Data

The model's training data came from Bulapedia. Each image of the training dataset was augmented 5 times with the following augments

 - RandomHorizontalFlip(p=0.5),
 - RandomVerticalFlip(p=0.5),
 - RandomRotation(degrees=30),
 - ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2),
 - GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)),
 - RandomAffine(degrees=0, translate=(0.1, 0.1)),
 - RandomPerspective(distortion_scale=0.5, p=0.5),
 - RandomGrayscale(p=0.2),

Training Stats

- 'eval_loss': 0.7451944351196289,
- 'eval_accuracy': 0.9221343873517787, 
- 'eval_runtime': 39.6834, 
- 'eval_samples_per_second': 63.755, 
- 'eval_steps_per_second': 7.988

Downloads last month
1
Safetensors
Model size
86.6M params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for JJMack/pokemon_gen1_9_classifier

Finetuned
(2152)
this model

Dataset used to train JJMack/pokemon_gen1_9_classifier