|
--- |
|
license: mit |
|
language: |
|
- en |
|
metrics: |
|
- accuracy |
|
pipeline_tag: image-classification |
|
tags: |
|
- code |
|
library_name: transformers |
|
--- |
|
|
|
|
|
# ResNet Cat-Dog Classifier |
|
|
|
This repository contains a ResNet-based convolutional neural network trained to classify images as either cats or dogs. The model achieves an accuracy of 90.27% on a test dataset and is fine-tuned using transfer learning on the ImageNet dataset. It uses PyTorch for training and inference. |
|
|
|
## Model Details |
|
|
|
### Architecture: |
|
- Backbone: ResNet-18 |
|
- Input Size: 128x128 RGB images |
|
- Output: Binary classification (Cat or Dog) |
|
|
|
### Training Details: |
|
- Dataset: Kaggle Cats and Dogs dataset |
|
- Loss Function: Cross-entropy loss |
|
- Optimizer: Adam optimizer |
|
- Learning Rate: 0.001 |
|
- Epochs: 15 |
|
- Batch Size: 32 |
|
|
|
### Performance: |
|
- Accuracy: 90.27% on test images |
|
- Training Time: Approximately 1 hour on NVIDIA GTX 1080 Ti |
|
|
|
## Usage |
|
|
|
### Installation: |
|
- Dependencies: PyTorch, TorchVision, matplotlib |
|
|
|
### Inference: |
|
```python |
|
import torch |
|
from torchvision import transforms |
|
from PIL import Image |
|
from transformers import pipeline |
|
|
|
# Define the image transformation |
|
transform = transforms.Compose([ |
|
transforms.Resize((128, 128)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
]) |
|
|
|
# Load the model from Hugging Face |
|
pipe = pipeline("image-classification", model="DineshKumar1329/DogCat_Classifier") |
|
|
|
# Load and preprocess an image |
|
image_path = 'path/to/your/image.jpg' |
|
image = Image.open(image_path) |
|
image = transform(image) |
|
image = image.unsqueeze(0) # Add batch dimension |
|
|
|
# Make a prediction |
|
result = classifier(image_path) |
|
|
|
# Extract the predicted label |
|
predicted_label = result[0]['label'] |
|
|
|
# Output the prediction |
|
print(f'The predicted class for the image is: {predicted_label}') |