import streamlit as st import torch from torchvision import transforms, models from PIL import Image from datasets import load_dataset import os # Vérification du GPU device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Charger le modèle MobileNetV2 model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.DEFAULT) model.classifier = torch.nn.Sequential( torch.nn.Dropout(0.2), torch.nn.Linear(1280, 113) # 113 classes pour Fruits-360 ) # Charger le modèle avec un chemin correct model_path = "model_fruits360.pth" model.load_state_dict(torch.load(model_path, map_location=device), strict=False) model.to(device) model.eval() # Charger le dataset Fruits-360 pour les labels dataset = load_dataset("PedroSampaio/fruits-360") labels = dataset['train'].features['label'].names # Transformations pour MobileNetV2 transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Fonction pour prétraiter l'image def preprocess_image(image): image = transform(image).unsqueeze(0).to(device) return image # Interface Streamlit st.title("🍎🍌🍊 Prédiction des Fruits avec IA") # Étape 1 : Upload d'image uploaded_file = st.file_uploader("📤 **Choisissez une image**", type=["jpg", "png", "jpeg"]) # Vérification si une image est bien importée if uploaded_file is not None: image = Image.open(uploaded_file).convert("RGB") # Étape 2 : Affichage de l’image sélectionnée st.image(image, caption="🖼️ **Image sélectionnée**", use_container_width=True) # Étape 3 : Bouton de validation if st.button("✅ Valider l'image"): # Étape 4 : Traitement et Prédiction image_tensor = preprocess_image(image) with torch.no_grad(): outputs = model(image_tensor) _, predicted_class = torch.max(outputs, 1) fruit_name = labels[int(predicted_class.item())] # Étape 5 : Affichage du résultat st.success(f"🎯 **Classe prédite : {fruit_name}**")