mahfouz28 commited on
Commit
9c13380
·
verified ·
1 Parent(s): 0a6aa2b

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +60 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from torchvision import transforms, models
4
+ from PIL import Image
5
+ from datasets import load_dataset
6
+ import os
7
+
8
+ # Vérification du GPU
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+
11
+ # Charger le modèle MobileNetV2
12
+ model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.DEFAULT)
13
+ model.classifier = torch.nn.Sequential(
14
+ torch.nn.Dropout(0.2),
15
+ torch.nn.Linear(1280, 113) # 113 classes pour Fruits-360
16
+ )
17
+
18
+ # Charger le modèle avec un chemin correct
19
+ model_path = "model_fruits360.pth"
20
+ model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
21
+ model.to(device)
22
+ model.eval()
23
+
24
+ # Charger le dataset Fruits-360 pour les labels
25
+ dataset = load_dataset("PedroSampaio/fruits-360")
26
+ labels = dataset['train'].features['label'].names
27
+
28
+ # Transformations pour MobileNetV2
29
+ transform = transforms.Compose([
30
+ transforms.Resize((224, 224)),
31
+ transforms.ToTensor(),
32
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
33
+ ])
34
+
35
+ # Fonction pour prétraiter l'image
36
+ def preprocess_image(image):
37
+ image = transform(image).unsqueeze(0).to(device)
38
+ return image
39
+
40
+ # Interface Streamlit
41
+ st.title("Prédiction des Fruits 🍎🍌🍊")
42
+
43
+ # Upload d'image
44
+ uploaded_file = st.file_uploader("Choisir une image...", type=["jpg", "png", "jpeg"])
45
+
46
+ if uploaded_file is not None:
47
+ # Affichage de l'image
48
+ image = Image.open(uploaded_file).convert("RGB")
49
+ st.image(image, caption="Image sélectionnée", use_column_width=True)
50
+
51
+ # Prédiction
52
+ image_tensor = preprocess_image(image)
53
+ with torch.no_grad():
54
+ outputs = model(image_tensor)
55
+ _, predicted_class = torch.max(outputs, 1)
56
+
57
+ fruit_name = labels[int(predicted_class.item())]
58
+
59
+ # Affichage du résultat
60
+ st.success(f"✅ Classe prédite : **{fruit_name}**")
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ streamlit
4
+ datasets
5
+ Pillow