Fruit Classifier - EfficientNet-B0 (Fruits-360 Merged)
This repository contains a fruit image classification model based on a fine-tuned EfficientNet-B0 architecture using PyTorch and torchvision. The model was trained on the Fruits-360 dataset, with a modification where specific fruit variants were merged into broader categories (e.g., "Apple Red 1", "Apple 6" merged into "Apple"), resulting in [76] distinct classes. <-- Make sure this matches your actual class count
Training progress and metrics were tracked using Neptune.ai.
Model Description
- Architecture: EfficientNet-B0 (pre-trained on ImageNet)
- Fine-tuning Strategy: Transfer learning. The pre-trained base model's weights were frozen, and only the final classifier layer was replaced and trained on the target dataset.
- Framework: PyTorch / torchvision
- Task: Image Classification
- Dataset: Fruits-360 (Merged Classes)
- Number of Classes: [76] <-- Make sure this matches your actual class count
Intended Uses & Limitations
- Intended Use: Classifying images of fruits belonging to one of the [76] merged categories derived from the Fruits-360 dataset. Suitable for educational purposes, demonstrations, or as a baseline for further development.
- Limitations:
- Trained only on the Fruits-360 dataset. Performance on images significantly different from this dataset (e.g., different lighting, backgrounds, occlusions, fruit varieties not present) is not guaranteed.
- Only recognizes the specific [76] merged classes it was trained on.
- Performance may vary depending on input image quality.
- Not intended for safety-critical applications without rigorous testing and validation.
How to Use
You can load the model and its configuration directly from the Hugging Face Hub using torch
, torchvision
, and huggingface_hub
.
import torch
import torchvision.models as models
from torchvision.models import EfficientNet_B0_Weights # Or the specific version used
from PIL import Image
from torchvision import transforms
import json
import requests
from huggingface_hub import hf_hub_download
import os
# --- 1. Define Model Loading Function ---
def load_model_from_hf(repo_id, model_filename="pytorch_model.bin", config_filename="config.json"):
"""Loads model state_dict and config from Hugging Face Hub."""
# Download config file
config_path = hf_hub_download(repo_id=repo_id, filename=config_filename)
with open(config_path, 'r') as f:
config = json.load(f)
num_labels = config['num_labels']
id2label = config['id2label'] # Load label mapping
# Instantiate the correct architecture (EfficientNet-B0)
# Load architecture without pre-trained weights, as we'll load our fine-tuned ones
model = models.efficientnet_b0(weights=None)
# Modify the classifier head to match the number of classes used during training
num_ftrs = model.classifier[1].in_features
model.classifier[1] = torch.nn.Linear(num_ftrs, num_labels)
# Download model weights
model_path = hf_hub_download(repo_id=repo_id, filename=model_filename)
# Load the state dict
# Ensure map_location handles CPU/GPU as needed
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict)
model.eval() # Set to evaluation mode
print(f"Model loaded successfully from {repo_id} and set to evaluation mode.")
return model, config, id2label
# --- 2. Define Preprocessing ---
# Use the same transformations as validation during training
IMG_SIZE = (224, 224) # Standard EfficientNet input size
# ImageNet stats often used with EfficientNet pre-training
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]
preprocess = transforms.Compose([
transforms.Resize(IMG_SIZE),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
# --- 3. Load Model ---
repo_id_to_load = "Bhumong/fruit-classifier-efficientnet-b0" # Your repo ID
model, config, id2label = load_model_from_hf(repo_id_to_load)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
# --- 4. Prepare Input Image ---
# Example: Load an image file (replace with your image path)
image_path = "path/to/your/fruit_image.jpg" # <-- REPLACE WITH YOUR IMAGE PATH
if not os.path.exists(image_path):
print(f"Warning: Image path not found: {image_path}")
print("Skipping prediction. Please provide a valid image path.")
input_batch = None
else:
try:
img = Image.open(image_path).convert("RGB")
input_tensor = preprocess(img)
# Add batch dimension (model expects batches)
input_batch = input_tensor.unsqueeze(0)
input_batch = input_batch.to(device)
except Exception as e:
print(f"Error processing image {image_path}: {e}")
input_batch = None
# --- 5. Make Prediction ---
if input_batch is not None:
with torch.no_grad(): # Disable gradient calculations for inference
output = model(input_batch)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
top_prob, top_catid = torch.max(probabilities, dim=0)
predicted_label_index = top_catid.item()
# Use the id2label mapping loaded from config
predicted_label = id2label.get(str(predicted_label_index), "Unknown Label")
confidence = top_prob.item()
print(f"\nPrediction for: {os.path.basename(image_path)}")
print(f"Predicted Label Index: {predicted_label_index}")
print(f"Predicted Label: {predicted_label}")
print(f"Confidence: {confidence:.4f}")
- Downloads last month
- 13
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
🙋
Ask for provider support
HF Inference deployability: The model has no library tag.
Model tree for bhumong/fruit-classifier-efficientnet-b0
Base model
google/efficientnet-b0