PredireFruit360 / app.py
mahfouz28's picture
Update app.py
c6a8c61 verified
import streamlit as st
import torch
from torchvision import transforms, models
from PIL import Image
from datasets import load_dataset
import os
# Vérification du GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Charger le modèle MobileNetV2
model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.DEFAULT)
model.classifier = torch.nn.Sequential(
torch.nn.Dropout(0.2),
torch.nn.Linear(1280, 113) # 113 classes pour Fruits-360
)
# Charger le modèle avec un chemin correct
model_path = "model_fruits360.pth"
model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
model.to(device)
model.eval()
# Charger le dataset Fruits-360 pour les labels
dataset = load_dataset("PedroSampaio/fruits-360")
labels = dataset['train'].features['label'].names
# Transformations pour MobileNetV2
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Fonction pour prétraiter l'image
def preprocess_image(image):
image = transform(image).unsqueeze(0).to(device)
return image
# Interface Streamlit
st.title("🍎🍌🍊 Prédiction des Fruits avec IA")
# Étape 1 : Upload d'image
uploaded_file = st.file_uploader("📤 **Choisissez une image**", type=["jpg", "png", "jpeg"])
# Vérification si une image est bien importée
if uploaded_file is not None:
image = Image.open(uploaded_file).convert("RGB")
# Étape 2 : Affichage de l’image sélectionnée
st.image(image, caption="🖼️ **Image sélectionnée**", use_container_width=True)
# Étape 3 : Bouton de validation
if st.button("✅ Valider l'image"):
# Étape 4 : Traitement et Prédiction
image_tensor = preprocess_image(image)
with torch.no_grad():
outputs = model(image_tensor)
_, predicted_class = torch.max(outputs, 1)
fruit_name = labels[int(predicted_class.item())]
# Étape 5 : Affichage du résultat
st.success(f"🎯 **Classe prédite : {fruit_name}**")