mahfouz28 commited on
Commit
cfc4a6e
·
verified ·
1 Parent(s): 0af321a

Upload 6 files

Browse files
Files changed (6) hide show
  1. Figure_1.png +0 -0
  2. README.md +69 -0
  3. deploy.py +76 -0
  4. index.html +72 -0
  5. model_fruits360.pth +3 -0
  6. requierements.txt +182 -0
Figure_1.png ADDED
README.md ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Predict_Fruit_360
2
+ ## Description
3
+ Ce projet consiste à entraîner un modèle de deep learning pré-entraîné (ResNet50, MobileNetV2, VGG16) sur un dataset personnalisé pour classer des images de fruits. Une fois le modèle fine-tuné, il est déployé sous forme d'API Flask permettant de faire des prédictions sur de nouvelles images.
4
+
5
+ ## Objectifs
6
+ - Charger et préparer le dataset `PedroSampaio/fruits-360`.
7
+ - Utiliser un modèle pré-entraîné (ResNet50, MobileNetV2, VGG16) pour le fine-tuning.
8
+ - Entraîner le modèle et évaluer sa performance.
9
+ - Déployer le modèle via une API Flask.
10
+
11
+ ---
12
+
13
+ ## Structure du Projet
14
+ ```
15
+ 📂 fruits_classifier
16
+ │── 📂 templates
17
+ │ └── index.html
18
+ │── 📂 models
19
+ │ └── train.py # Contient le code d'entraînement et de sauvegarde du modèle
20
+ │── 📂 app
21
+ │ └── deploy.py # Contient l'API Flask pour le déploiement
22
+ │── 📂 static # (Si besoin d'ajouter des fichiers CSS/JS)
23
+ │── requirements.txt
24
+ │── README.md
25
+ ```
26
+
27
+ ---
28
+
29
+ ## Installation et Utilisation
30
+
31
+ ### 1. Installation des dépendances
32
+ Avant de commencer, assurez-vous d'avoir **Python 3.8+** installé. Ensuite, installez les bibliothèques nécessaires :
33
+ ```bash
34
+ pip install torch torchvision flask datasets pillow
35
+ ```
36
+
37
+ ### 2. Entraînement du Modèle
38
+ Le script `models/train.py` charge le dataset, applique des transformations et entraîne un modèle MobileNetV2.
39
+
40
+ Lancer l'entraînement :
41
+ ```bash
42
+ cd models
43
+ python train.py
44
+ ```
45
+ Une fois l'entraînement terminé, le modèle est sauvegardé dans `models/model_fruits360.pth`.
46
+
47
+ ### 3. Déploiement de l'API Flask
48
+ Le script `app/deploy.py` charge le modèle et crée une API Flask permettant de faire des prédictions.
49
+
50
+ Lancer l'API :
51
+ ```bash
52
+ cd app
53
+ python deploy.py
54
+ ```
55
+ L'API sera accessible sur **http://127.0.0.1:5000/**.
56
+
57
+ ### 4. Utilisation de l'Interface Web
58
+ Accéder à l'URL **http://127.0.0.1:5000/** et uploader une image pour obtenir la classe prédite.
59
+
60
+ ---
61
+
62
+ ## Fonctionnalités
63
+ - **Prétraitement des images** : Redimensionnement, normalisation.
64
+ - **Modèles utilisé** : ResNet50, MobileNetV2, VGG16.
65
+ - **Visualisation des performances** : Matrice de confusion et rapport de classification.
66
+ - **Déploiement facile** : API Flask simple et rapide.
67
+
68
+ ## Auteur
69
+ Projet réalisé dans le cadre de la Licence 3 - Computer Vision au **Dakar Institut of Technology**.
deploy.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify, render_template
2
+ import torch
3
+ from torchvision import transforms, models
4
+ from PIL import Image
5
+ from datasets import load_dataset
6
+ import os
7
+
8
+ # Vérification du GPU
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+
11
+ # Charger le modèle MobileNetV2
12
+ model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.DEFAULT)
13
+ model.classifier = torch.nn.Sequential(
14
+ torch.nn.Dropout(0.2),
15
+ torch.nn.Linear(1280, 113) # 113 classes pour Fruits-360
16
+ )
17
+
18
+ # Charger le modèle avec un chemin correct
19
+ model_path = os.path.join('..', 'models', 'model_fruits360.pth')
20
+ model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
21
+ model.to(device)
22
+ model.eval()
23
+
24
+ # Charger le dataset Fruits-360
25
+ dataset = load_dataset("PedroSampaio/fruits-360")
26
+ labels = dataset['train'].features['label'].names
27
+
28
+ # Chemin pour le dossier templates
29
+ app = Flask(__name__, template_folder=os.path.join('..', ''))
30
+
31
+ # Transformations pour MobileNetV2
32
+ transform = transforms.Compose([
33
+ transforms.Resize((224, 224)),
34
+ transforms.ToTensor(),
35
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
36
+ ])
37
+
38
+ # Fonction pour prétraiter l'image
39
+ def preprocess_image(image):
40
+ image = transform(image).unsqueeze(0).to(device)
41
+ return image
42
+
43
+ # Route principale
44
+ @app.route('/')
45
+ def index():
46
+ return render_template('index.html')
47
+
48
+ # Route pour la prédiction
49
+ @app.route('/predict', methods=['POST'])
50
+ def predict():
51
+ try:
52
+ if 'file' not in request.files:
53
+ return jsonify({'error': 'Aucun fichier envoyé'})
54
+
55
+ file = request.files['file']
56
+ if file.filename == '':
57
+ return jsonify({'error': 'Fichier vide'})
58
+
59
+ # Charger l'image
60
+ image = Image.open(file.stream).convert("RGB")
61
+ image = preprocess_image(image)
62
+
63
+ # Prédiction
64
+ with torch.no_grad():
65
+ outputs = model(image)
66
+ _, predicted_class = torch.max(outputs, 1)
67
+
68
+ fruit_name = labels[int(predicted_class.item())]
69
+
70
+ return jsonify({'class': fruit_name})
71
+
72
+ except Exception as e:
73
+ return jsonify({'error': str(e)})
74
+
75
+ if __name__ == "__main__":
76
+ app.run(debug=True)
index.html ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Prédiction Fruits 360</title>
7
+ <script src="https://cdn.tailwindcss.com"></script>
8
+ </head>
9
+ <body class="bg-gray-100 p-10">
10
+ <div class="max-w-md mx-auto bg-white p-6 rounded-lg shadow-md">
11
+ <h1 class="text-2xl font-bold mb-4 text-center">Prédiction des Fruits 🍎🍌🍊</h1>
12
+
13
+ <form id="upload-form" action="/predict" method="post" enctype="multipart/form-data" class="space-y-4">
14
+ <input type="file" name="file" accept="image/*" required class="block w-full text-sm text-gray-700 file:mr-4 file:py-2 file:px-4 file:rounded-lg file:border-0 file:text-sm file:font-semibold file:bg-blue-50 file:text-blue-700 hover:file:bg-blue-100">
15
+ <button type="submit" class="w-full bg-blue-500 text-white py-2 rounded-md hover:bg-blue-600">Envoyer</button>
16
+ </form>
17
+
18
+ <div class="mt-6 text-center">
19
+ <h2 class="text-lg font-semibold mb-2">Image sélectionnée :</h2>
20
+ <img id="preview" class="mx-auto w-48 h-48 object-cover rounded-lg hidden">
21
+ </div>
22
+
23
+ <h2 class="text-lg font-semibold mt-6">Résultat de la prédiction :</h2>
24
+ <p id="result" class="text-gray-700 text-center mt-2"></p>
25
+
26
+ <div id="loading" class="text-center mt-4 hidden">⏳ Prédiction en cours...</div>
27
+ </div>
28
+
29
+ <script>
30
+ const form = document.getElementById('upload-form');
31
+ const preview = document.getElementById('preview');
32
+ const result = document.getElementById('result');
33
+ const loading = document.getElementById('loading');
34
+
35
+ form.addEventListener('submit', async (event) => {
36
+ event.preventDefault();
37
+ const formData = new FormData(form);
38
+
39
+ // Afficher l'animation de chargement
40
+ loading.classList.remove('hidden');
41
+ result.textContent = '';
42
+
43
+ const response = await fetch('/predict', {
44
+ method: 'POST',
45
+ body: formData
46
+ });
47
+
48
+ const data = await response.json();
49
+ loading.classList.add('hidden');
50
+
51
+ if (response.ok) {
52
+ result.textContent = `✅ Classe prédite : ${data.class}`;
53
+ } else {
54
+ result.textContent = `❌ Erreur : ${data.error || 'Impossible de prédire'}`;
55
+ }
56
+ });
57
+
58
+ // Afficher l'aperçu de l'image sélectionnée
59
+ form.file.addEventListener('change', (event) => {
60
+ const file = event.target.files[0];
61
+ if (file) {
62
+ const reader = new FileReader();
63
+ reader.onload = (e) => {
64
+ preview.src = e.target.result;
65
+ preview.classList.remove('hidden');
66
+ };
67
+ reader.readAsDataURL(file);
68
+ }
69
+ });
70
+ </script>
71
+ </body>
72
+ </html>
model_fruits360.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:db6b74b473e87a9f0deda45c26646f22ca114d434db1a3d9cdcfc9d379a0a4ad
3
+ size 9721326
requierements.txt ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ aiohappyeyeballs==2.4.8
3
+ aiohttp==3.11.13
4
+ aiosignal==1.3.2
5
+ altair==5.5.0
6
+ annotated-types==0.7.0
7
+ anyio==4.8.0
8
+ attrs==25.1.0
9
+ audioop-lts==0.2.1
10
+ beautifulsoup4==4.13.3
11
+ blinker==1.9.0
12
+ cachetools==5.5.2
13
+ certifi==2025.1.31
14
+ cffi==1.17.1
15
+ charset-normalizer==3.4.1
16
+ click==8.1.8
17
+ colorama==0.4.6
18
+ contourpy==1.3.1
19
+ cycler==0.12.1
20
+ dataclasses-json==0.6.7
21
+ datasets==3.4.0
22
+ defusedxml==0.7.1
23
+ dill==0.3.8
24
+ distro==1.9.0
25
+ facenet-pytorch==2.5.2
26
+ faiss-cpu==1.10.0
27
+ fastapi==0.115.8
28
+ ffmpy==0.5.0
29
+ filelock==3.17.0
30
+ filetype==1.2.0
31
+ Flask==3.1.0
32
+ fonttools==4.56.0
33
+ frozenlist==1.5.0
34
+ fsspec==2024.12.0
35
+ gitdb==4.0.12
36
+ GitPython==3.1.44
37
+ google-ai-generativelanguage==0.6.15
38
+ google-api-core==2.24.1
39
+ google-api-python-client==2.163.0
40
+ google-auth==2.38.0
41
+ google-auth-httplib2==0.2.0
42
+ google-generativeai==0.8.4
43
+ googleapis-common-protos==1.69.1
44
+ gradio==5.21.0
45
+ gradio_client==1.7.2
46
+ greenlet==3.1.1
47
+ groovy==0.1.2
48
+ grpcio==1.71.0rc2
49
+ grpcio-status==1.71.0rc2
50
+ h11==0.14.0
51
+ httpcore==1.0.7
52
+ httplib2==0.22.0
53
+ httpx==0.28.1
54
+ httpx-sse==0.4.0
55
+ huggingface-hub==0.29.2
56
+ idna==3.10
57
+ itsdangerous==2.2.0
58
+ Jinja2==3.1.5
59
+ jiter==0.8.2
60
+ joblib==1.4.2
61
+ jsonpatch==1.33
62
+ jsonpointer==3.0.0
63
+ jsonschema==4.23.0
64
+ jsonschema-specifications==2024.10.1
65
+ kiwisolver==1.4.8
66
+ langchain==0.3.20
67
+ langchain-community==0.3.19
68
+ langchain-core==0.3.41
69
+ langchain-google-genai==2.0.10
70
+ langchain-openai==0.3.7
71
+ langchain-text-splitters==0.3.6
72
+ langsmith==0.3.11
73
+ markdown-it-py==3.0.0
74
+ MarkupSafe==2.1.5
75
+ marshmallow==3.26.1
76
+ matplotlib==3.10.0
77
+ mdurl==0.1.2
78
+ mock==5.2.0
79
+ mpmath==1.3.0
80
+ multidict==6.1.0
81
+ multiprocess==0.70.16
82
+ mypy-extensions==1.0.0
83
+ narwhals==1.28.0
84
+ networkx==3.4.2
85
+ numpy==2.1.1
86
+ openai==1.65.3
87
+ opencv-python==4.11.0.86
88
+ orjson==3.10.15
89
+ outcome==1.3.0.post0
90
+ packaging==24.2
91
+ pandas==2.2.3
92
+ pillow==11.1.0
93
+ plotly==6.0.0
94
+ propcache==0.3.0
95
+ proto-plus==1.26.0
96
+ protobuf==5.29.3
97
+ psutil==7.0.0
98
+ py-cpuinfo==9.0.0
99
+ pyarrow==19.0.1
100
+ pyasn1==0.6.1
101
+ pyasn1_modules==0.4.1
102
+ pycparser==2.22
103
+ pydantic==2.10.6
104
+ pydantic-settings==2.8.1
105
+ pydantic_core==2.27.2
106
+ pydeck==0.9.1
107
+ pydub==0.25.1
108
+ Pygments==2.19.1
109
+ PyMuPDF==1.25.3
110
+ pyparsing==3.2.1
111
+ pypdf==5.3.1
112
+ PyPDF2==3.0.1
113
+ PySocks==1.7.1
114
+ python-dateutil==2.9.0.post0
115
+ python-dotenv==1.0.1
116
+ python-multipart==0.0.20
117
+ pytz==2025.1
118
+ PyYAML==6.0.2
119
+ referencing==0.36.2
120
+ regex==2024.11.6
121
+ relaxml==0.1.3
122
+ requests==2.32.3
123
+ requests-toolbelt==1.0.0
124
+ rich==13.9.4
125
+ rpds-py==0.23.1
126
+ rsa==4.9
127
+ ruff==0.11.0
128
+ safehttpx==0.1.6
129
+ safetensors==0.5.3
130
+ scikit-learn==1.6.1
131
+ scipy==1.15.2
132
+ seaborn==0.13.2
133
+ selenium==4.29.0
134
+ semantic-version==2.10.0
135
+ sentence-transformers==3.4.1
136
+ setuptools==75.8.1
137
+ shellingham==1.5.4
138
+ simplejson==3.20.1
139
+ six==1.17.0
140
+ smmap==5.0.2
141
+ sniffio==1.3.1
142
+ sortedcontainers==2.4.0
143
+ soupsieve==2.6
144
+ SQLAlchemy==2.0.38
145
+ starlette==0.45.3
146
+ streamlit==1.42.2
147
+ supervision==0.25.1
148
+ sympy==1.13.1
149
+ tenacity==9.0.0
150
+ threadpoolctl==3.5.0
151
+ three==0.8.0
152
+ tiktoken==0.9.0
153
+ tokenizers==0.21.0
154
+ toml==0.10.2
155
+ tomlkit==0.13.2
156
+ torch==2.6.0
157
+ torchaudio==2.6.0
158
+ torchvision==0.21.0
159
+ tornado==6.4.2
160
+ tqdm==4.67.1
161
+ transformers==4.49.0
162
+ trio==0.29.0
163
+ trio-websocket==0.12.2
164
+ typer==0.15.2
165
+ typing-inspect==0.9.0
166
+ typing_extensions==4.12.2
167
+ tzdata==2025.1
168
+ ultralytics==8.3.80
169
+ ultralytics-thop==2.0.14
170
+ uritemplate==4.1.1
171
+ urllib3==2.3.0
172
+ uvicorn==0.34.0
173
+ watchdog==6.0.0
174
+ webdriver-manager==4.0.2
175
+ websocket-client==1.8.0
176
+ websockets==15.0.1
177
+ Werkzeug==3.1.3
178
+ wheel==0.45.1
179
+ wsproto==1.2.0
180
+ xxhash==3.5.0
181
+ yarl==1.18.3
182
+ zstandard==0.23.0