DEMO APP

Fruit Classifier - EfficientNet-B0 (Fruits-360 Merged)

This repository contains a fruit image classification model based on a fine-tuned EfficientNet-B0 architecture using PyTorch and torchvision. The model was trained on the Fruits-360 dataset, with a modification where specific fruit variants were merged into broader categories (e.g., "Apple Red 1", "Apple 6" merged into "Apple"), resulting in [76] distinct classes. <-- Make sure this matches your actual class count

Training progress and metrics were tracked using Neptune.ai.

Model Description

  • Architecture: EfficientNet-B0 (pre-trained on ImageNet)
  • Fine-tuning Strategy: Transfer learning. The pre-trained base model's weights were frozen, and only the final classifier layer was replaced and trained on the target dataset.
  • Framework: PyTorch / torchvision
  • Task: Image Classification
  • Dataset: Fruits-360 (Merged Classes)
  • Number of Classes: [76] <-- Make sure this matches your actual class count

Intended Uses & Limitations

  • Intended Use: Classifying images of fruits belonging to one of the [76] merged categories derived from the Fruits-360 dataset. Suitable for educational purposes, demonstrations, or as a baseline for further development.
  • Limitations:
    • Trained only on the Fruits-360 dataset. Performance on images significantly different from this dataset (e.g., different lighting, backgrounds, occlusions, fruit varieties not present) is not guaranteed.
    • Only recognizes the specific [76] merged classes it was trained on.
    • Performance may vary depending on input image quality.
    • Not intended for safety-critical applications without rigorous testing and validation.

How to Use

You can load the model and its configuration directly from the Hugging Face Hub using torch, torchvision, and huggingface_hub.

import torch
import torchvision.models as models
from torchvision.models import EfficientNet_B0_Weights # Or the specific version used
from PIL import Image
from torchvision import transforms
import json
import requests
from huggingface_hub import hf_hub_download
import os

# --- 1. Define Model Loading Function ---
def load_model_from_hf(repo_id, model_filename="pytorch_model.bin", config_filename="config.json"):
    """Loads model state_dict and config from Hugging Face Hub."""

    # Download config file
    config_path = hf_hub_download(repo_id=repo_id, filename=config_filename)
    with open(config_path, 'r') as f:
        config = json.load(f)

    num_labels = config['num_labels']
    id2label = config['id2label'] # Load label mapping

    # Instantiate the correct architecture (EfficientNet-B0)
    # Load architecture without pre-trained weights, as we'll load our fine-tuned ones
    model = models.efficientnet_b0(weights=None)

    # Modify the classifier head to match the number of classes used during training
    num_ftrs = model.classifier[1].in_features
    model.classifier[1] = torch.nn.Linear(num_ftrs, num_labels)

    # Download model weights
    model_path = hf_hub_download(repo_id=repo_id, filename=model_filename)

    # Load the state dict
    # Ensure map_location handles CPU/GPU as needed
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    state_dict = torch.load(model_path, map_location=device)
    model.load_state_dict(state_dict)

    model.eval() # Set to evaluation mode
    print(f"Model loaded successfully from {repo_id} and set to evaluation mode.")
    return model, config, id2label

# --- 2. Define Preprocessing ---
# Use the same transformations as validation during training
IMG_SIZE = (224, 224) # Standard EfficientNet input size
# ImageNet stats often used with EfficientNet pre-training
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]

preprocess = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std),
])

# --- 3. Load Model ---
repo_id_to_load = "Bhumong/fruit-classifier-efficientnet-b0" # Your repo ID
model, config, id2label = load_model_from_hf(repo_id_to_load)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)


# --- 4. Prepare Input Image ---
# Example: Load an image file (replace with your image path)
image_path = "path/to/your/fruit_image.jpg" # <-- REPLACE WITH YOUR IMAGE PATH

if not os.path.exists(image_path):
    print(f"Warning: Image path not found: {image_path}")
    print("Skipping prediction. Please provide a valid image path.")
    input_batch = None
else:
    try:
        img = Image.open(image_path).convert("RGB")
        input_tensor = preprocess(img)
        # Add batch dimension (model expects batches)
        input_batch = input_tensor.unsqueeze(0)
        input_batch = input_batch.to(device)
    except Exception as e:
        print(f"Error processing image {image_path}: {e}")
        input_batch = None

# --- 5. Make Prediction ---
if input_batch is not None:
    with torch.no_grad(): # Disable gradient calculations for inference
        output = model(input_batch)
        probabilities = torch.nn.functional.softmax(output[0], dim=0)
        top_prob, top_catid = torch.max(probabilities, dim=0)

    predicted_label_index = top_catid.item()
    # Use the id2label mapping loaded from config
    predicted_label = id2label.get(str(predicted_label_index), "Unknown Label")
    confidence = top_prob.item()

    print(f"\nPrediction for: {os.path.basename(image_path)}")
    print(f"Predicted Label Index: {predicted_label_index}")
    print(f"Predicted Label: {predicted_label}")
    print(f"Confidence: {confidence:.4f}")

Downloads last month
13
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for bhumong/fruit-classifier-efficientnet-b0

Finetuned
(16)
this model

Dataset used to train bhumong/fruit-classifier-efficientnet-b0

Space using bhumong/fruit-classifier-efficientnet-b0 1