ResNet18 Garbage Classifier
This is a ResNet18 model pruned & fine-tuned for classifying different types of garbage.
Model Details
- Architecture: ResNet18
- Task: Image Classification
How to Use for Inference
Here's a Python code snippet demonstrating how to load the model and perform inference on a single image:
import torch
from torchvision import models, transforms
from PIL import Image
import cv2
# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the model architecture
model = models.resnet18(pretrained=False)
num_ftrs = model.fc.in_features
model.fc = torch.nn.Linear(num_ftrs, 8)
# Load the trained weights
model.load_state_dict(torch.load('resnet_18_pruned.pth', map_location=device))
model.eval()
model.to(device)
# Define the class names
class_names = ["Garbage", "Cardboard", "Garbage", "Glass", "Metal", "Paper", "Plastic", "Trash"]
# Define the transformations for inference
def get_transform(train=False):
if train:
raise ValueError("This transform is for training, use train=False for inference.")
else:
return transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def predict_image(model, image_path, transform, class_names):
model.eval()
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(image)
_, predicted = torch.max(outputs, 1)
print(f"Predicted Class ID: {predicted.item()}")
print(f"Predicted Class: {class_names[predicted.item()]}")
# Example usage: Replace 'path/to/your/image.jpg' with the actual path
image_path = 'path/to/your/image.jpg'
transform = get_transform(train=False)
predict_image(model, image_path, transform, class_names)
Intended Use
This model is intended for the classification of common garbage types.
Limitations
The accuracy of this model may vary depending on the quality and diversity of the training data. It may not perform well on unseen or unusual types of waste.
Trained on dmedhi/garbage-image-classification-detection dataset for 50 epochs with
a validation loss of 1.49.
Accuracy and loss can be optimized with further preprocessing of the dataset.
Pruning
Fine-grained pruning reduced the model size from 42.65 MB
to just 6.45 MB
(15.13% of the original model), and fine-tuning on just 5 epochs helped the model to
regain its lost accuracy upto what it has been achieved during training.
In the files section, if you check the model, the size is 44 MB
because the weights are still there. They are only reduced to zeroes. To actually check the size of
a fine-grained prune model, use count_nonzero()
.
for param in model.parameters():
num_counted_elements += param.count_nonzero()
- Downloads last month
- 11
Model tree for dmedhi/restnet-18-pruned-garbage-classification
Base model
microsoft/resnet-18