--- license: mit language: - en metrics: - accuracy tags: - code --- # ResNet Cat-Dog Classifier This repository contains a ResNet-based convolutional neural network trained to classify images as either cats or dogs. The model achieves an accuracy of 90.27% on a test dataset and is fine-tuned using transfer learning on the ImageNet dataset. It uses PyTorch for training and inference. ## Model Details ### Architecture: - Backbone: ResNet-18 - Input Size: 128x128 RGB images - Output: Binary classification (Cat or Dog) ### Training Details: - Dataset: Kaggle Cats and Dogs dataset - Loss Function: Cross-entropy loss - Optimizer: Adam optimizer - Learning Rate: 0.001 - Epochs: 15 - Batch Size: 32 ### Performance: - Accuracy: 90.27% on test images - Training Time: Approximately 1 hour on NVIDIA RTX 3050 Ti ## Results: ![image/png](https://cdn-uploads.huggingface.co/production/uploads/65dd9dc387458446d0a9da32/IwtLTneWqyRz1uYuYY1Kp.png) ## Usage ### Installation: - Dependencies: PyTorch, TorchVision, matplotlib ### Inference: ```python import torch from torchvision.models import resnet18 from PIL import Image import torchvision.transforms as transforms import matplotlib.pyplot as plt model = resnet18(pretrained=False) num_ftrs = model.fc.in_features model.fc = torch.nn.Linear(num_ftrs, 2) # Load the trained model state_dict model_path = 'cat_dog_classifier.pth' model.load_state_dict(torch.load(model_path)) model.eval() # Define the transformation (ensure it matches the training preprocessing) transform = transforms.Compose([ transforms.Resize((128, 128)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def load_image(image_path): image = Image.open(image_path) image = transform(image) image = image.unsqueeze(0) # Add batch dimension return image ​ def predict_image(model, image_path): image = load_image(image_path) model.eval() with torch.no_grad(): outputs = model(image) _, predicted = torch.max(outputs, 1) return "Cat" if predicted.item() == 0 else "Dog" ​ def plot_image(image_path, prediction): image = Image.open(image_path) plt.imshow(image) plt.title(f'Predicted: {prediction}') plt.axis('off') plt.show() ​ # Example usage image_path = "path.jpeg" prediction = predict_image(model, image_path) print(f'The predicted class for the image is: {prediction}') plot_image(image_path, prediction) The predicted class for the image is: Cat