ResNet18 Garbage Classifier

This is a ResNet18 model pruned & fine-tuned for classifying different types of garbage.

image/png

Model Details

  • Architecture: ResNet18
  • Task: Image Classification

How to Use for Inference

Here's a Python code snippet demonstrating how to load the model and perform inference on a single image:

import torch
from torchvision import models, transforms
from PIL import Image
import cv2

# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the model architecture
model = models.resnet18(pretrained=False)
num_ftrs = model.fc.in_features
model.fc = torch.nn.Linear(num_ftrs, 8)

# Load the trained weights
model.load_state_dict(torch.load('resnet_18_pruned.pth', map_location=device))
model.eval()
model.to(device)

# Define the class names
class_names = ["Garbage", "Cardboard", "Garbage", "Glass", "Metal", "Paper", "Plastic", "Trash"]

# Define the transformations for inference
def get_transform(train=False):
    if train:
        raise ValueError("This transform is for training, use train=False for inference.")
    else:
        return transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

def predict_image(model, image_path, transform, class_names):
    model.eval()

    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    image = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        outputs = model(image)
        _, predicted = torch.max(outputs, 1)
        print(f"Predicted Class ID: {predicted.item()}")
        print(f"Predicted Class: {class_names[predicted.item()]}")

# Example usage: Replace 'path/to/your/image.jpg' with the actual path
image_path = 'path/to/your/image.jpg'
transform = get_transform(train=False)
predict_image(model, image_path, transform, class_names)

Intended Use

This model is intended for the classification of common garbage types.

Limitations

The accuracy of this model may vary depending on the quality and diversity of the training data. It may not perform well on unseen or unusual types of waste.
Trained on dmedhi/garbage-image-classification-detection dataset for 50 epochs with a validation loss of 1.49.

Accuracy and loss can be optimized with further preprocessing of the dataset.

Pruning

Fine-grained pruning reduced the model size from 42.65 MB to just 6.45 MB (15.13% of the original model), and fine-tuning on just 5 epochs helped the model to regain its lost accuracy upto what it has been achieved during training.

In the files section, if you check the model, the size is 44 MB because the weights are still there. They are only reduced to zeroes. To actually check the size of a fine-grained prune model, use count_nonzero().

for param in model.parameters():
  num_counted_elements += param.count_nonzero()
Downloads last month
11
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for dmedhi/restnet-18-pruned-garbage-classification

Finetuned
(31)
this model

Dataset used to train dmedhi/restnet-18-pruned-garbage-classification