Clemylia commited on
Commit
a58d701
·
verified ·
1 Parent(s): 56c90d6

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +142 -1
README.md CHANGED
@@ -86,4 +86,145 @@ Le modèle a été entraîné sur un jeu de données **synthétique** créé man
86
 
87
  -----
88
 
89
- *Ce projet a été développé avec passion par Clemylia pour l'apprentissage du Machine Learning **from scratch** en PyTorch. Contribuez à la ruche \! 💛*
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  -----
88
 
89
+ *Ce projet a été développé avec passion par Clemylia pour l'apprentissage du Machine Learning **from scratch** en PyTorch. Contribuez à la ruche \! 💛*
90
+
91
+ **exemples de code d'utilisation**:
92
+
93
+ ```
94
+ import torch
95
+ import torch.nn as nn
96
+ import numpy as np
97
+ from huggingface_hub import hf_hub_download
98
+
99
+ # ==============================================================================
100
+ # 1. Configuration et Architecture (DOIT correspondre au modèle entraîné)
101
+ # ==============================================================================
102
+
103
+ # --- Constantes du Modèle ---
104
+ REPO_ID = "Clemylia/BeeAI-Bena"
105
+ FILENAME = "pytorch_model.bin"
106
+
107
+ FLOWER_NAMES = [
108
+ 'lavande', 'coquelicot', 'muguet', 'lilas', 'jasmin',
109
+ 'marguerite', 'rose', 'tournesol', 'acacia', 'tulipe',
110
+ 'pissenlit', 'trèfle', 'bruyère', 'romarin', 'thym',
111
+ 'sauge', 'bourrache', 'bleuet', 'primevère', 'camomille'
112
+ ]
113
+ NUM_FLOWERS = len(FLOWER_NAMES) # 20
114
+ NUM_DANCE_REPORTS = 10 # 0 à 9
115
+ INPUT_SIZE = NUM_FLOWERS + 2 # 22 (20 fleurs + 2 positions)
116
+
117
+ # --- Rapports de Danse (Pour décoder la sortie 0-9) ---
118
+ RAPPORT_DESCRIPTIONS = [
119
+ "0: La fleur n'est pas butinable, c'est hautement probable qu'elle soit vide de nectar.",
120
+ "1: La fleur n'est pas butinable, c'est très probable qu'elle soit vide de nectar, ou qu'elle en est pas assez.",
121
+ "2: La fleur n'est pas trop butinable, emplacement peu favorable (risque de toiles d'araignée).",
122
+ "3: La fleur est butinable, mais dans la moindre mesure (pour une petite ruche).",
123
+ "4: La fleur est butinable, mais ses ressources sont limitées.",
124
+ "5: La fleur se situe trop loin de la ruche.",
125
+ "6: La fleur se situe à un endroit susceptible d'attirer des frelons, c'est trop dangereux.",
126
+ "7: La fleur est butinable, et possède beaucoup de nectar, mais pas assez pour toute la ruche et les larves.",
127
+ "8: La fleur est butinable, et se situe dans une prairie remplis de fleurs hautement butinable, mais il y a un danger (grenouilles).",
128
+ "9: La fleur est parfaitement butinable."
129
+ ]
130
+
131
+ # --- Définition de la classe du modèle ---
132
+ class BeeAI(nn.Module):
133
+ def __init__(self, input_size, num_classes):
134
+ super(BeeAI, self).__init__()
135
+ # L'architecture doit être identique à celle utilisée pour la sauvegarde
136
+ self.fc1 = nn.Linear(input_size, 128)
137
+ self.relu = nn.ReLU()
138
+ self.fc2 = nn.Linear(128, 64)
139
+ self.fc_out = nn.Linear(64, num_classes)
140
+
141
+ def forward(self, x):
142
+ out = self.fc1(x)
143
+ out = self.relu(out)
144
+ out = self.fc2(out)
145
+ out = self.relu(out)
146
+ out = self.fc_out(out)
147
+ return out
148
+
149
+ # ==============================================================================
150
+ # 2. FONCTIONS UTILITAIRES POUR LE CHARGEMENT ET L'INFÉRENCE
151
+ # ==============================================================================
152
+
153
+ def load_bee_ai_model(repo_id, filename, input_size, num_classes):
154
+ """Télécharge les poids et charge le modèle PyTorch."""
155
+ print(f"🔄 Tentative de téléchargement des poids depuis {repo_id}...")
156
+
157
+ # 1. Télécharge le fichier de poids
158
+ try:
159
+ weights_path = hf_hub_download(repo_id=repo_id, filename=filename)
160
+ except Exception as e:
161
+ print(f"❌ Erreur de téléchargement : {e}")
162
+ return None
163
+
164
+ # 2. Crée et charge le modèle
165
+ model = BeeAI(input_size, num_classes)
166
+ model.load_state_dict(torch.load(weights_path))
167
+ model.eval() # Passe en mode évaluation pour l'inférence
168
+
169
+ print("✅ Modèle BeeAI chargé avec succès.")
170
+ return model
171
+
172
+ def flower_to_input(flower_name, x_pos, y_pos):
173
+ """Convertit les entrées utilisateur en tenseur d'entrée (22 dimensions)."""
174
+ try:
175
+ idx = FLOWER_NAMES.index(flower_name.lower())
176
+ except ValueError:
177
+ print(f"⚠️ Avertissement: Fleur '{flower_name}' inconnue. Utilisation d'un vecteur nul.")
178
+ idx = -1
179
+
180
+ one_hot = np.zeros(NUM_FLOWERS, dtype=np.float32)
181
+ if idx != -1:
182
+ one_hot[idx] = 1.0
183
+
184
+ position = np.array([x_pos, y_pos], dtype=np.float32)
185
+ input_vector = np.concatenate((one_hot, position))
186
+ # Ajout d'une dimension pour le batch
187
+ return torch.tensor(input_vector, dtype=torch.float32).unsqueeze(0)
188
+
189
+ def generate_dance_report(model, flower_name, x_pos, y_pos):
190
+ """Effectue l'inférence et retourne le code de danse décodé."""
191
+
192
+ # Prépare l'entrée
193
+ X_test = flower_to_input(flower_name, x_pos, y_pos)
194
+
195
+ # Inférence
196
+ with torch.no_grad():
197
+ output = model(X_test)
198
+
199
+ # Décodage (obtient l'indice de la probabilité maximale)
200
+ _, predicted_index = torch.max(output.data, 1)
201
+ report_code = predicted_index.item()
202
+
203
+ return report_code, RAPPORT_DESCRIPTIONS[report_code]
204
+
205
+ # ==============================================================================
206
+ # 3. EXÉCUTION (L'expérience utilisateur)
207
+ # ==============================================================================
208
+
209
+ # 1. Charger le modèle
210
+ bee_ai_model = load_bee_ai_model(REPO_ID, FILENAME, INPUT_SIZE, NUM_DANCE_REPORTS)
211
+
212
+ if bee_ai_model:
213
+
214
+ # 2. Définir la situation à tester
215
+ test_flower = 'pissenlit' # Une fleur connue et généralement bonne
216
+ test_x = 4.5
217
+ test_y = 5.5
218
+
219
+ # 3. Générer le rapport !
220
+ code, description = generate_dance_report(bee_ai_model, test_flower, test_x, test_y)
221
+
222
+ # 4. Afficher le résultat
223
+ print(f"\n==============================================")
224
+ print(f" Rapport de Danse de Bena (Bee AI) ")
225
+ print(f"==============================================")
226
+ print(f"Fleur évaluée : {test_flower.upper()} à la position ({test_x}, {test_y})")
227
+ print(f"CODE GÉNÉRÉ : {code}")
228
+ print(f"DESCRIPTION : {description.split(':')[1].strip()}")
229
+ print(f"==============================================")
230
+ ```