Update README.md
Browse files
README.md
CHANGED
|
@@ -86,4 +86,145 @@ Le modèle a été entraîné sur un jeu de données **synthétique** créé man
|
|
| 86 |
|
| 87 |
-----
|
| 88 |
|
| 89 |
-
*Ce projet a été développé avec passion par Clemylia pour l'apprentissage du Machine Learning **from scratch** en PyTorch. Contribuez à la ruche \! 💛*
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
-----
|
| 88 |
|
| 89 |
+
*Ce projet a été développé avec passion par Clemylia pour l'apprentissage du Machine Learning **from scratch** en PyTorch. Contribuez à la ruche \! 💛*
|
| 90 |
+
|
| 91 |
+
**exemples de code d'utilisation**:
|
| 92 |
+
|
| 93 |
+
```
|
| 94 |
+
import torch
|
| 95 |
+
import torch.nn as nn
|
| 96 |
+
import numpy as np
|
| 97 |
+
from huggingface_hub import hf_hub_download
|
| 98 |
+
|
| 99 |
+
# ==============================================================================
|
| 100 |
+
# 1. Configuration et Architecture (DOIT correspondre au modèle entraîné)
|
| 101 |
+
# ==============================================================================
|
| 102 |
+
|
| 103 |
+
# --- Constantes du Modèle ---
|
| 104 |
+
REPO_ID = "Clemylia/BeeAI-Bena"
|
| 105 |
+
FILENAME = "pytorch_model.bin"
|
| 106 |
+
|
| 107 |
+
FLOWER_NAMES = [
|
| 108 |
+
'lavande', 'coquelicot', 'muguet', 'lilas', 'jasmin',
|
| 109 |
+
'marguerite', 'rose', 'tournesol', 'acacia', 'tulipe',
|
| 110 |
+
'pissenlit', 'trèfle', 'bruyère', 'romarin', 'thym',
|
| 111 |
+
'sauge', 'bourrache', 'bleuet', 'primevère', 'camomille'
|
| 112 |
+
]
|
| 113 |
+
NUM_FLOWERS = len(FLOWER_NAMES) # 20
|
| 114 |
+
NUM_DANCE_REPORTS = 10 # 0 à 9
|
| 115 |
+
INPUT_SIZE = NUM_FLOWERS + 2 # 22 (20 fleurs + 2 positions)
|
| 116 |
+
|
| 117 |
+
# --- Rapports de Danse (Pour décoder la sortie 0-9) ---
|
| 118 |
+
RAPPORT_DESCRIPTIONS = [
|
| 119 |
+
"0: La fleur n'est pas butinable, c'est hautement probable qu'elle soit vide de nectar.",
|
| 120 |
+
"1: La fleur n'est pas butinable, c'est très probable qu'elle soit vide de nectar, ou qu'elle en est pas assez.",
|
| 121 |
+
"2: La fleur n'est pas trop butinable, emplacement peu favorable (risque de toiles d'araignée).",
|
| 122 |
+
"3: La fleur est butinable, mais dans la moindre mesure (pour une petite ruche).",
|
| 123 |
+
"4: La fleur est butinable, mais ses ressources sont limitées.",
|
| 124 |
+
"5: La fleur se situe trop loin de la ruche.",
|
| 125 |
+
"6: La fleur se situe à un endroit susceptible d'attirer des frelons, c'est trop dangereux.",
|
| 126 |
+
"7: La fleur est butinable, et possède beaucoup de nectar, mais pas assez pour toute la ruche et les larves.",
|
| 127 |
+
"8: La fleur est butinable, et se situe dans une prairie remplis de fleurs hautement butinable, mais il y a un danger (grenouilles).",
|
| 128 |
+
"9: La fleur est parfaitement butinable."
|
| 129 |
+
]
|
| 130 |
+
|
| 131 |
+
# --- Définition de la classe du modèle ---
|
| 132 |
+
class BeeAI(nn.Module):
|
| 133 |
+
def __init__(self, input_size, num_classes):
|
| 134 |
+
super(BeeAI, self).__init__()
|
| 135 |
+
# L'architecture doit être identique à celle utilisée pour la sauvegarde
|
| 136 |
+
self.fc1 = nn.Linear(input_size, 128)
|
| 137 |
+
self.relu = nn.ReLU()
|
| 138 |
+
self.fc2 = nn.Linear(128, 64)
|
| 139 |
+
self.fc_out = nn.Linear(64, num_classes)
|
| 140 |
+
|
| 141 |
+
def forward(self, x):
|
| 142 |
+
out = self.fc1(x)
|
| 143 |
+
out = self.relu(out)
|
| 144 |
+
out = self.fc2(out)
|
| 145 |
+
out = self.relu(out)
|
| 146 |
+
out = self.fc_out(out)
|
| 147 |
+
return out
|
| 148 |
+
|
| 149 |
+
# ==============================================================================
|
| 150 |
+
# 2. FONCTIONS UTILITAIRES POUR LE CHARGEMENT ET L'INFÉRENCE
|
| 151 |
+
# ==============================================================================
|
| 152 |
+
|
| 153 |
+
def load_bee_ai_model(repo_id, filename, input_size, num_classes):
|
| 154 |
+
"""Télécharge les poids et charge le modèle PyTorch."""
|
| 155 |
+
print(f"🔄 Tentative de téléchargement des poids depuis {repo_id}...")
|
| 156 |
+
|
| 157 |
+
# 1. Télécharge le fichier de poids
|
| 158 |
+
try:
|
| 159 |
+
weights_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
| 160 |
+
except Exception as e:
|
| 161 |
+
print(f"❌ Erreur de téléchargement : {e}")
|
| 162 |
+
return None
|
| 163 |
+
|
| 164 |
+
# 2. Crée et charge le modèle
|
| 165 |
+
model = BeeAI(input_size, num_classes)
|
| 166 |
+
model.load_state_dict(torch.load(weights_path))
|
| 167 |
+
model.eval() # Passe en mode évaluation pour l'inférence
|
| 168 |
+
|
| 169 |
+
print("✅ Modèle BeeAI chargé avec succès.")
|
| 170 |
+
return model
|
| 171 |
+
|
| 172 |
+
def flower_to_input(flower_name, x_pos, y_pos):
|
| 173 |
+
"""Convertit les entrées utilisateur en tenseur d'entrée (22 dimensions)."""
|
| 174 |
+
try:
|
| 175 |
+
idx = FLOWER_NAMES.index(flower_name.lower())
|
| 176 |
+
except ValueError:
|
| 177 |
+
print(f"⚠️ Avertissement: Fleur '{flower_name}' inconnue. Utilisation d'un vecteur nul.")
|
| 178 |
+
idx = -1
|
| 179 |
+
|
| 180 |
+
one_hot = np.zeros(NUM_FLOWERS, dtype=np.float32)
|
| 181 |
+
if idx != -1:
|
| 182 |
+
one_hot[idx] = 1.0
|
| 183 |
+
|
| 184 |
+
position = np.array([x_pos, y_pos], dtype=np.float32)
|
| 185 |
+
input_vector = np.concatenate((one_hot, position))
|
| 186 |
+
# Ajout d'une dimension pour le batch
|
| 187 |
+
return torch.tensor(input_vector, dtype=torch.float32).unsqueeze(0)
|
| 188 |
+
|
| 189 |
+
def generate_dance_report(model, flower_name, x_pos, y_pos):
|
| 190 |
+
"""Effectue l'inférence et retourne le code de danse décodé."""
|
| 191 |
+
|
| 192 |
+
# Prépare l'entrée
|
| 193 |
+
X_test = flower_to_input(flower_name, x_pos, y_pos)
|
| 194 |
+
|
| 195 |
+
# Inférence
|
| 196 |
+
with torch.no_grad():
|
| 197 |
+
output = model(X_test)
|
| 198 |
+
|
| 199 |
+
# Décodage (obtient l'indice de la probabilité maximale)
|
| 200 |
+
_, predicted_index = torch.max(output.data, 1)
|
| 201 |
+
report_code = predicted_index.item()
|
| 202 |
+
|
| 203 |
+
return report_code, RAPPORT_DESCRIPTIONS[report_code]
|
| 204 |
+
|
| 205 |
+
# ==============================================================================
|
| 206 |
+
# 3. EXÉCUTION (L'expérience utilisateur)
|
| 207 |
+
# ==============================================================================
|
| 208 |
+
|
| 209 |
+
# 1. Charger le modèle
|
| 210 |
+
bee_ai_model = load_bee_ai_model(REPO_ID, FILENAME, INPUT_SIZE, NUM_DANCE_REPORTS)
|
| 211 |
+
|
| 212 |
+
if bee_ai_model:
|
| 213 |
+
|
| 214 |
+
# 2. Définir la situation à tester
|
| 215 |
+
test_flower = 'pissenlit' # Une fleur connue et généralement bonne
|
| 216 |
+
test_x = 4.5
|
| 217 |
+
test_y = 5.5
|
| 218 |
+
|
| 219 |
+
# 3. Générer le rapport !
|
| 220 |
+
code, description = generate_dance_report(bee_ai_model, test_flower, test_x, test_y)
|
| 221 |
+
|
| 222 |
+
# 4. Afficher le résultat
|
| 223 |
+
print(f"\n==============================================")
|
| 224 |
+
print(f" Rapport de Danse de Bena (Bee AI) ")
|
| 225 |
+
print(f"==============================================")
|
| 226 |
+
print(f"Fleur évaluée : {test_flower.upper()} à la position ({test_x}, {test_y})")
|
| 227 |
+
print(f"CODE GÉNÉRÉ : {code}")
|
| 228 |
+
print(f"DESCRIPTION : {description.split(':')[1].strip()}")
|
| 229 |
+
print(f"==============================================")
|
| 230 |
+
```
|