Spaces:
Sleeping
Sleeping
Upload 6 files
Browse files- Figure_1.png +0 -0
- README.md +69 -0
- deploy.py +76 -0
- index.html +72 -0
- model_fruits360.pth +3 -0
- requierements.txt +182 -0
Figure_1.png
ADDED
![]() |
README.md
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Predict_Fruit_360
|
2 |
+
## Description
|
3 |
+
Ce projet consiste à entraîner un modèle de deep learning pré-entraîné (ResNet50, MobileNetV2, VGG16) sur un dataset personnalisé pour classer des images de fruits. Une fois le modèle fine-tuné, il est déployé sous forme d'API Flask permettant de faire des prédictions sur de nouvelles images.
|
4 |
+
|
5 |
+
## Objectifs
|
6 |
+
- Charger et préparer le dataset `PedroSampaio/fruits-360`.
|
7 |
+
- Utiliser un modèle pré-entraîné (ResNet50, MobileNetV2, VGG16) pour le fine-tuning.
|
8 |
+
- Entraîner le modèle et évaluer sa performance.
|
9 |
+
- Déployer le modèle via une API Flask.
|
10 |
+
|
11 |
+
---
|
12 |
+
|
13 |
+
## Structure du Projet
|
14 |
+
```
|
15 |
+
📂 fruits_classifier
|
16 |
+
│── 📂 templates
|
17 |
+
│ └── index.html
|
18 |
+
│── 📂 models
|
19 |
+
│ └── train.py # Contient le code d'entraînement et de sauvegarde du modèle
|
20 |
+
│── 📂 app
|
21 |
+
│ └── deploy.py # Contient l'API Flask pour le déploiement
|
22 |
+
│── 📂 static # (Si besoin d'ajouter des fichiers CSS/JS)
|
23 |
+
│── requirements.txt
|
24 |
+
│── README.md
|
25 |
+
```
|
26 |
+
|
27 |
+
---
|
28 |
+
|
29 |
+
## Installation et Utilisation
|
30 |
+
|
31 |
+
### 1. Installation des dépendances
|
32 |
+
Avant de commencer, assurez-vous d'avoir **Python 3.8+** installé. Ensuite, installez les bibliothèques nécessaires :
|
33 |
+
```bash
|
34 |
+
pip install torch torchvision flask datasets pillow
|
35 |
+
```
|
36 |
+
|
37 |
+
### 2. Entraînement du Modèle
|
38 |
+
Le script `models/train.py` charge le dataset, applique des transformations et entraîne un modèle MobileNetV2.
|
39 |
+
|
40 |
+
Lancer l'entraînement :
|
41 |
+
```bash
|
42 |
+
cd models
|
43 |
+
python train.py
|
44 |
+
```
|
45 |
+
Une fois l'entraînement terminé, le modèle est sauvegardé dans `models/model_fruits360.pth`.
|
46 |
+
|
47 |
+
### 3. Déploiement de l'API Flask
|
48 |
+
Le script `app/deploy.py` charge le modèle et crée une API Flask permettant de faire des prédictions.
|
49 |
+
|
50 |
+
Lancer l'API :
|
51 |
+
```bash
|
52 |
+
cd app
|
53 |
+
python deploy.py
|
54 |
+
```
|
55 |
+
L'API sera accessible sur **http://127.0.0.1:5000/**.
|
56 |
+
|
57 |
+
### 4. Utilisation de l'Interface Web
|
58 |
+
Accéder à l'URL **http://127.0.0.1:5000/** et uploader une image pour obtenir la classe prédite.
|
59 |
+
|
60 |
+
---
|
61 |
+
|
62 |
+
## Fonctionnalités
|
63 |
+
- **Prétraitement des images** : Redimensionnement, normalisation.
|
64 |
+
- **Modèles utilisé** : ResNet50, MobileNetV2, VGG16.
|
65 |
+
- **Visualisation des performances** : Matrice de confusion et rapport de classification.
|
66 |
+
- **Déploiement facile** : API Flask simple et rapide.
|
67 |
+
|
68 |
+
## Auteur
|
69 |
+
Projet réalisé dans le cadre de la Licence 3 - Computer Vision au **Dakar Institut of Technology**.
|
deploy.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from flask import Flask, request, jsonify, render_template
|
2 |
+
import torch
|
3 |
+
from torchvision import transforms, models
|
4 |
+
from PIL import Image
|
5 |
+
from datasets import load_dataset
|
6 |
+
import os
|
7 |
+
|
8 |
+
# Vérification du GPU
|
9 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
10 |
+
|
11 |
+
# Charger le modèle MobileNetV2
|
12 |
+
model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.DEFAULT)
|
13 |
+
model.classifier = torch.nn.Sequential(
|
14 |
+
torch.nn.Dropout(0.2),
|
15 |
+
torch.nn.Linear(1280, 113) # 113 classes pour Fruits-360
|
16 |
+
)
|
17 |
+
|
18 |
+
# Charger le modèle avec un chemin correct
|
19 |
+
model_path = os.path.join('..', 'models', 'model_fruits360.pth')
|
20 |
+
model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
|
21 |
+
model.to(device)
|
22 |
+
model.eval()
|
23 |
+
|
24 |
+
# Charger le dataset Fruits-360
|
25 |
+
dataset = load_dataset("PedroSampaio/fruits-360")
|
26 |
+
labels = dataset['train'].features['label'].names
|
27 |
+
|
28 |
+
# Chemin pour le dossier templates
|
29 |
+
app = Flask(__name__, template_folder=os.path.join('..', ''))
|
30 |
+
|
31 |
+
# Transformations pour MobileNetV2
|
32 |
+
transform = transforms.Compose([
|
33 |
+
transforms.Resize((224, 224)),
|
34 |
+
transforms.ToTensor(),
|
35 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
36 |
+
])
|
37 |
+
|
38 |
+
# Fonction pour prétraiter l'image
|
39 |
+
def preprocess_image(image):
|
40 |
+
image = transform(image).unsqueeze(0).to(device)
|
41 |
+
return image
|
42 |
+
|
43 |
+
# Route principale
|
44 |
+
@app.route('/')
|
45 |
+
def index():
|
46 |
+
return render_template('index.html')
|
47 |
+
|
48 |
+
# Route pour la prédiction
|
49 |
+
@app.route('/predict', methods=['POST'])
|
50 |
+
def predict():
|
51 |
+
try:
|
52 |
+
if 'file' not in request.files:
|
53 |
+
return jsonify({'error': 'Aucun fichier envoyé'})
|
54 |
+
|
55 |
+
file = request.files['file']
|
56 |
+
if file.filename == '':
|
57 |
+
return jsonify({'error': 'Fichier vide'})
|
58 |
+
|
59 |
+
# Charger l'image
|
60 |
+
image = Image.open(file.stream).convert("RGB")
|
61 |
+
image = preprocess_image(image)
|
62 |
+
|
63 |
+
# Prédiction
|
64 |
+
with torch.no_grad():
|
65 |
+
outputs = model(image)
|
66 |
+
_, predicted_class = torch.max(outputs, 1)
|
67 |
+
|
68 |
+
fruit_name = labels[int(predicted_class.item())]
|
69 |
+
|
70 |
+
return jsonify({'class': fruit_name})
|
71 |
+
|
72 |
+
except Exception as e:
|
73 |
+
return jsonify({'error': str(e)})
|
74 |
+
|
75 |
+
if __name__ == "__main__":
|
76 |
+
app.run(debug=True)
|
index.html
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<meta charset="UTF-8">
|
5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
6 |
+
<title>Prédiction Fruits 360</title>
|
7 |
+
<script src="https://cdn.tailwindcss.com"></script>
|
8 |
+
</head>
|
9 |
+
<body class="bg-gray-100 p-10">
|
10 |
+
<div class="max-w-md mx-auto bg-white p-6 rounded-lg shadow-md">
|
11 |
+
<h1 class="text-2xl font-bold mb-4 text-center">Prédiction des Fruits 🍎🍌🍊</h1>
|
12 |
+
|
13 |
+
<form id="upload-form" action="/predict" method="post" enctype="multipart/form-data" class="space-y-4">
|
14 |
+
<input type="file" name="file" accept="image/*" required class="block w-full text-sm text-gray-700 file:mr-4 file:py-2 file:px-4 file:rounded-lg file:border-0 file:text-sm file:font-semibold file:bg-blue-50 file:text-blue-700 hover:file:bg-blue-100">
|
15 |
+
<button type="submit" class="w-full bg-blue-500 text-white py-2 rounded-md hover:bg-blue-600">Envoyer</button>
|
16 |
+
</form>
|
17 |
+
|
18 |
+
<div class="mt-6 text-center">
|
19 |
+
<h2 class="text-lg font-semibold mb-2">Image sélectionnée :</h2>
|
20 |
+
<img id="preview" class="mx-auto w-48 h-48 object-cover rounded-lg hidden">
|
21 |
+
</div>
|
22 |
+
|
23 |
+
<h2 class="text-lg font-semibold mt-6">Résultat de la prédiction :</h2>
|
24 |
+
<p id="result" class="text-gray-700 text-center mt-2"></p>
|
25 |
+
|
26 |
+
<div id="loading" class="text-center mt-4 hidden">⏳ Prédiction en cours...</div>
|
27 |
+
</div>
|
28 |
+
|
29 |
+
<script>
|
30 |
+
const form = document.getElementById('upload-form');
|
31 |
+
const preview = document.getElementById('preview');
|
32 |
+
const result = document.getElementById('result');
|
33 |
+
const loading = document.getElementById('loading');
|
34 |
+
|
35 |
+
form.addEventListener('submit', async (event) => {
|
36 |
+
event.preventDefault();
|
37 |
+
const formData = new FormData(form);
|
38 |
+
|
39 |
+
// Afficher l'animation de chargement
|
40 |
+
loading.classList.remove('hidden');
|
41 |
+
result.textContent = '';
|
42 |
+
|
43 |
+
const response = await fetch('/predict', {
|
44 |
+
method: 'POST',
|
45 |
+
body: formData
|
46 |
+
});
|
47 |
+
|
48 |
+
const data = await response.json();
|
49 |
+
loading.classList.add('hidden');
|
50 |
+
|
51 |
+
if (response.ok) {
|
52 |
+
result.textContent = `✅ Classe prédite : ${data.class}`;
|
53 |
+
} else {
|
54 |
+
result.textContent = `❌ Erreur : ${data.error || 'Impossible de prédire'}`;
|
55 |
+
}
|
56 |
+
});
|
57 |
+
|
58 |
+
// Afficher l'aperçu de l'image sélectionnée
|
59 |
+
form.file.addEventListener('change', (event) => {
|
60 |
+
const file = event.target.files[0];
|
61 |
+
if (file) {
|
62 |
+
const reader = new FileReader();
|
63 |
+
reader.onload = (e) => {
|
64 |
+
preview.src = e.target.result;
|
65 |
+
preview.classList.remove('hidden');
|
66 |
+
};
|
67 |
+
reader.readAsDataURL(file);
|
68 |
+
}
|
69 |
+
});
|
70 |
+
</script>
|
71 |
+
</body>
|
72 |
+
</html>
|
model_fruits360.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:db6b74b473e87a9f0deda45c26646f22ca114d434db1a3d9cdcfc9d379a0a4ad
|
3 |
+
size 9721326
|
requierements.txt
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiofiles==23.2.1
|
2 |
+
aiohappyeyeballs==2.4.8
|
3 |
+
aiohttp==3.11.13
|
4 |
+
aiosignal==1.3.2
|
5 |
+
altair==5.5.0
|
6 |
+
annotated-types==0.7.0
|
7 |
+
anyio==4.8.0
|
8 |
+
attrs==25.1.0
|
9 |
+
audioop-lts==0.2.1
|
10 |
+
beautifulsoup4==4.13.3
|
11 |
+
blinker==1.9.0
|
12 |
+
cachetools==5.5.2
|
13 |
+
certifi==2025.1.31
|
14 |
+
cffi==1.17.1
|
15 |
+
charset-normalizer==3.4.1
|
16 |
+
click==8.1.8
|
17 |
+
colorama==0.4.6
|
18 |
+
contourpy==1.3.1
|
19 |
+
cycler==0.12.1
|
20 |
+
dataclasses-json==0.6.7
|
21 |
+
datasets==3.4.0
|
22 |
+
defusedxml==0.7.1
|
23 |
+
dill==0.3.8
|
24 |
+
distro==1.9.0
|
25 |
+
facenet-pytorch==2.5.2
|
26 |
+
faiss-cpu==1.10.0
|
27 |
+
fastapi==0.115.8
|
28 |
+
ffmpy==0.5.0
|
29 |
+
filelock==3.17.0
|
30 |
+
filetype==1.2.0
|
31 |
+
Flask==3.1.0
|
32 |
+
fonttools==4.56.0
|
33 |
+
frozenlist==1.5.0
|
34 |
+
fsspec==2024.12.0
|
35 |
+
gitdb==4.0.12
|
36 |
+
GitPython==3.1.44
|
37 |
+
google-ai-generativelanguage==0.6.15
|
38 |
+
google-api-core==2.24.1
|
39 |
+
google-api-python-client==2.163.0
|
40 |
+
google-auth==2.38.0
|
41 |
+
google-auth-httplib2==0.2.0
|
42 |
+
google-generativeai==0.8.4
|
43 |
+
googleapis-common-protos==1.69.1
|
44 |
+
gradio==5.21.0
|
45 |
+
gradio_client==1.7.2
|
46 |
+
greenlet==3.1.1
|
47 |
+
groovy==0.1.2
|
48 |
+
grpcio==1.71.0rc2
|
49 |
+
grpcio-status==1.71.0rc2
|
50 |
+
h11==0.14.0
|
51 |
+
httpcore==1.0.7
|
52 |
+
httplib2==0.22.0
|
53 |
+
httpx==0.28.1
|
54 |
+
httpx-sse==0.4.0
|
55 |
+
huggingface-hub==0.29.2
|
56 |
+
idna==3.10
|
57 |
+
itsdangerous==2.2.0
|
58 |
+
Jinja2==3.1.5
|
59 |
+
jiter==0.8.2
|
60 |
+
joblib==1.4.2
|
61 |
+
jsonpatch==1.33
|
62 |
+
jsonpointer==3.0.0
|
63 |
+
jsonschema==4.23.0
|
64 |
+
jsonschema-specifications==2024.10.1
|
65 |
+
kiwisolver==1.4.8
|
66 |
+
langchain==0.3.20
|
67 |
+
langchain-community==0.3.19
|
68 |
+
langchain-core==0.3.41
|
69 |
+
langchain-google-genai==2.0.10
|
70 |
+
langchain-openai==0.3.7
|
71 |
+
langchain-text-splitters==0.3.6
|
72 |
+
langsmith==0.3.11
|
73 |
+
markdown-it-py==3.0.0
|
74 |
+
MarkupSafe==2.1.5
|
75 |
+
marshmallow==3.26.1
|
76 |
+
matplotlib==3.10.0
|
77 |
+
mdurl==0.1.2
|
78 |
+
mock==5.2.0
|
79 |
+
mpmath==1.3.0
|
80 |
+
multidict==6.1.0
|
81 |
+
multiprocess==0.70.16
|
82 |
+
mypy-extensions==1.0.0
|
83 |
+
narwhals==1.28.0
|
84 |
+
networkx==3.4.2
|
85 |
+
numpy==2.1.1
|
86 |
+
openai==1.65.3
|
87 |
+
opencv-python==4.11.0.86
|
88 |
+
orjson==3.10.15
|
89 |
+
outcome==1.3.0.post0
|
90 |
+
packaging==24.2
|
91 |
+
pandas==2.2.3
|
92 |
+
pillow==11.1.0
|
93 |
+
plotly==6.0.0
|
94 |
+
propcache==0.3.0
|
95 |
+
proto-plus==1.26.0
|
96 |
+
protobuf==5.29.3
|
97 |
+
psutil==7.0.0
|
98 |
+
py-cpuinfo==9.0.0
|
99 |
+
pyarrow==19.0.1
|
100 |
+
pyasn1==0.6.1
|
101 |
+
pyasn1_modules==0.4.1
|
102 |
+
pycparser==2.22
|
103 |
+
pydantic==2.10.6
|
104 |
+
pydantic-settings==2.8.1
|
105 |
+
pydantic_core==2.27.2
|
106 |
+
pydeck==0.9.1
|
107 |
+
pydub==0.25.1
|
108 |
+
Pygments==2.19.1
|
109 |
+
PyMuPDF==1.25.3
|
110 |
+
pyparsing==3.2.1
|
111 |
+
pypdf==5.3.1
|
112 |
+
PyPDF2==3.0.1
|
113 |
+
PySocks==1.7.1
|
114 |
+
python-dateutil==2.9.0.post0
|
115 |
+
python-dotenv==1.0.1
|
116 |
+
python-multipart==0.0.20
|
117 |
+
pytz==2025.1
|
118 |
+
PyYAML==6.0.2
|
119 |
+
referencing==0.36.2
|
120 |
+
regex==2024.11.6
|
121 |
+
relaxml==0.1.3
|
122 |
+
requests==2.32.3
|
123 |
+
requests-toolbelt==1.0.0
|
124 |
+
rich==13.9.4
|
125 |
+
rpds-py==0.23.1
|
126 |
+
rsa==4.9
|
127 |
+
ruff==0.11.0
|
128 |
+
safehttpx==0.1.6
|
129 |
+
safetensors==0.5.3
|
130 |
+
scikit-learn==1.6.1
|
131 |
+
scipy==1.15.2
|
132 |
+
seaborn==0.13.2
|
133 |
+
selenium==4.29.0
|
134 |
+
semantic-version==2.10.0
|
135 |
+
sentence-transformers==3.4.1
|
136 |
+
setuptools==75.8.1
|
137 |
+
shellingham==1.5.4
|
138 |
+
simplejson==3.20.1
|
139 |
+
six==1.17.0
|
140 |
+
smmap==5.0.2
|
141 |
+
sniffio==1.3.1
|
142 |
+
sortedcontainers==2.4.0
|
143 |
+
soupsieve==2.6
|
144 |
+
SQLAlchemy==2.0.38
|
145 |
+
starlette==0.45.3
|
146 |
+
streamlit==1.42.2
|
147 |
+
supervision==0.25.1
|
148 |
+
sympy==1.13.1
|
149 |
+
tenacity==9.0.0
|
150 |
+
threadpoolctl==3.5.0
|
151 |
+
three==0.8.0
|
152 |
+
tiktoken==0.9.0
|
153 |
+
tokenizers==0.21.0
|
154 |
+
toml==0.10.2
|
155 |
+
tomlkit==0.13.2
|
156 |
+
torch==2.6.0
|
157 |
+
torchaudio==2.6.0
|
158 |
+
torchvision==0.21.0
|
159 |
+
tornado==6.4.2
|
160 |
+
tqdm==4.67.1
|
161 |
+
transformers==4.49.0
|
162 |
+
trio==0.29.0
|
163 |
+
trio-websocket==0.12.2
|
164 |
+
typer==0.15.2
|
165 |
+
typing-inspect==0.9.0
|
166 |
+
typing_extensions==4.12.2
|
167 |
+
tzdata==2025.1
|
168 |
+
ultralytics==8.3.80
|
169 |
+
ultralytics-thop==2.0.14
|
170 |
+
uritemplate==4.1.1
|
171 |
+
urllib3==2.3.0
|
172 |
+
uvicorn==0.34.0
|
173 |
+
watchdog==6.0.0
|
174 |
+
webdriver-manager==4.0.2
|
175 |
+
websocket-client==1.8.0
|
176 |
+
websockets==15.0.1
|
177 |
+
Werkzeug==3.1.3
|
178 |
+
wheel==0.45.1
|
179 |
+
wsproto==1.2.0
|
180 |
+
xxhash==3.5.0
|
181 |
+
yarl==1.18.3
|
182 |
+
zstandard==0.23.0
|