Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
from torchvision import transforms, models | |
from PIL import Image | |
from datasets import load_dataset | |
import os | |
# Vérification du GPU | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Charger le modèle MobileNetV2 | |
model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.DEFAULT) | |
model.classifier = torch.nn.Sequential( | |
torch.nn.Dropout(0.2), | |
torch.nn.Linear(1280, 113) # 113 classes pour Fruits-360 | |
) | |
# Charger le modèle avec un chemin correct | |
model_path = "model_fruits360.pth" | |
model.load_state_dict(torch.load(model_path, map_location=device), strict=False) | |
model.to(device) | |
model.eval() | |
# Charger le dataset Fruits-360 pour les labels | |
dataset = load_dataset("PedroSampaio/fruits-360") | |
labels = dataset['train'].features['label'].names | |
# Transformations pour MobileNetV2 | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
# Fonction pour prétraiter l'image | |
def preprocess_image(image): | |
image = transform(image).unsqueeze(0).to(device) | |
return image | |
# Interface Streamlit | |
st.title("🍎🍌🍊 Prédiction des Fruits avec IA") | |
# Étape 1 : Upload d'image | |
uploaded_file = st.file_uploader("📤 **Choisissez une image**", type=["jpg", "png", "jpeg"]) | |
# Vérification si une image est bien importée | |
if uploaded_file is not None: | |
image = Image.open(uploaded_file).convert("RGB") | |
# Étape 2 : Affichage de l’image sélectionnée | |
st.image(image, caption="🖼️ **Image sélectionnée**", use_container_width=True) | |
# Étape 3 : Bouton de validation | |
if st.button("✅ Valider l'image"): | |
# Étape 4 : Traitement et Prédiction | |
image_tensor = preprocess_image(image) | |
with torch.no_grad(): | |
outputs = model(image_tensor) | |
_, predicted_class = torch.max(outputs, 1) | |
fruit_name = labels[int(predicted_class.item())] | |
# Étape 5 : Affichage du résultat | |
st.success(f"🎯 **Classe prédite : {fruit_name}**") |