alexfremont commited on
Commit
0053356
·
1 Parent(s): 2c455c2

Add model management endpoints and database fetch functionality

Browse files
Files changed (5) hide show
  1. api/prediction.py +3 -2
  2. api/router.py +112 -1
  3. config/settings.py +9 -0
  4. db/models.py +47 -0
  5. models/loader.py +87 -40
api/prediction.py CHANGED
@@ -1,5 +1,5 @@
1
  import torch
2
- from fastapi import APIRouter, HTTPException
3
  from fastapi.responses import JSONResponse
4
  from PIL import Image
5
  from io import BytesIO
@@ -12,11 +12,12 @@ from schemas.requests import BatchPredictRequest
12
  from models.loader import get_model
13
  from steps.preprocess import process_image
14
  from config.settings import IMAGE_SIZE, NUM_THREADS
 
15
 
16
  logger = logging.getLogger(__name__)
17
  router = APIRouter()
18
 
19
- @router.post("/batch_predict")
20
  async def batch_predict(request: BatchPredictRequest):
21
  """Endpoint pour prédire à partir de plusieurs images."""
22
  model_name = request.modelName
 
1
  import torch
2
+ from fastapi import APIRouter, HTTPException, Depends
3
  from fastapi.responses import JSONResponse
4
  from PIL import Image
5
  from io import BytesIO
 
12
  from models.loader import get_model
13
  from steps.preprocess import process_image
14
  from config.settings import IMAGE_SIZE, NUM_THREADS
15
+ from api.router import verify_api_key
16
 
17
  logger = logging.getLogger(__name__)
18
  router = APIRouter()
19
 
20
+ @router.post("/batch_predict", dependencies=[Depends(verify_api_key)])
21
  async def batch_predict(request: BatchPredictRequest):
22
  """Endpoint pour prédire à partir de plusieurs images."""
23
  model_name = request.modelName
api/router.py CHANGED
@@ -3,7 +3,10 @@ import logging
3
  import os
4
 
5
  from api import prediction
6
- from config.settings import API_KEY
 
 
 
7
 
8
  logger = logging.getLogger(__name__)
9
 
@@ -34,5 +37,113 @@ async def verify_api_key(request: Request, call_next):
34
  response = await call_next(request)
35
  return response
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  # Inclure les routes des autres modules
38
  router.include_router(prediction.router, tags=["Prediction"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import os
4
 
5
  from api import prediction
6
+ from config.settings import API_KEY, MANAGEMENT_API_KEY
7
+ from db.models import fetch_model_by_id
8
+ from models.loader import model_pipelines, _load_single_model_pipeline, get_model
9
+ from models.schemas import PredictionRequest, PredictionResponse
10
 
11
  logger = logging.getLogger(__name__)
12
 
 
37
  response = await call_next(request)
38
  return response
39
 
40
+ # Dépendance pour la Sécurité de l'API de Gestion
41
+ async def verify_management_api_key(x_api_key: str = Header(...)):
42
+ """Vérifie si la clé API fournie correspond à celle configurée."""
43
+ if not MANAGEMENT_API_KEY:
44
+ logger.warning("MANAGEMENT_API_KEY is not set. Management endpoints are unsecured!")
45
+ # Décider si on bloque ou on autorise sans clé définie
46
+ # Pour la sécurité, il vaut mieux bloquer par défaut
47
+ raise HTTPException(status_code=500, detail="Management API key not configured on server.")
48
+ if x_api_key != MANAGEMENT_API_KEY:
49
+ logger.warning(f"Invalid or missing API key attempt for management endpoint.")
50
+ raise HTTPException(status_code=401, detail="Invalid or missing API Key")
51
+ return True # Clé valide
52
+
53
  # Inclure les routes des autres modules
54
  router.include_router(prediction.router, tags=["Prediction"])
55
+
56
+ # Nouvel Endpoint de Gestion
57
+ @router.post(
58
+ "/manage/load_model/{model_db_id}",
59
+ summary="Load a specific model into memory",
60
+ dependencies=[Depends(verify_management_api_key)] # Sécurise l'endpoint
61
+ )
62
+ async def load_single_model(model_db_id: Any): # L'ID peut être int ou str
63
+ """Charge un modèle spécifique en mémoire en utilisant son ID de la base de données."""
64
+ logger.info(f"Received request to load model with DB ID: {model_db_id}")
65
+
66
+ # 1. Vérifier si le modèle est déjà chargé
67
+ if model_db_id in model_pipelines:
68
+ logger.info(f"Model ID {model_db_id} is already loaded.")
69
+ return {"status": "success", "message": f"Model ID {model_db_id} is already loaded."}
70
+
71
+ # 2. Récupérer les informations du modèle depuis la DB
72
+ try:
73
+ model_data = await fetch_model_by_id(model_db_id)
74
+ if not model_data:
75
+ logger.error(f"Model ID {model_db_id} not found in database.")
76
+ raise HTTPException(status_code=404, detail=f"Model ID {model_db_id} not found in database.")
77
+ except Exception as e:
78
+ logger.exception(f"Database error fetching model ID {model_db_id}: {e}")
79
+ raise HTTPException(status_code=500, detail=f"Database error checking model ID {model_db_id}.")
80
+
81
+ # 3. Charger le modèle
82
+ try:
83
+ logger.info(f"Attempting to load model ID {model_db_id} ('{model_data.get('name', 'N/A')}') into memory...")
84
+ pipeline = await _load_single_model_pipeline(model_data)
85
+
86
+ # 4. Ajouter au dictionnaire des modèles chargés
87
+ model_pipelines[model_db_id] = pipeline
88
+ logger.info(f"Successfully loaded and added model ID {model_db_id} to running pipelines.")
89
+ return {"status": "success", "message": f"Model ID {model_db_id} loaded successfully."}
90
+
91
+ except Exception as e:
92
+ logger.exception(f"Failed to load model ID {model_db_id}: {e}")
93
+ # Ne pas laisser un pipeline potentiellement corrompu dans le dictionnaire
94
+ if model_db_id in model_pipelines:
95
+ del model_pipelines[model_db_id]
96
+ raise HTTPException(status_code=500, detail=f"Failed to load model ID {model_db_id}. Check server logs for details.")
97
+
98
+ @router.post(
99
+ "/manage/update_model/{model_db_id}",
100
+ summary="Reload/Update a specific model already in memory",
101
+ dependencies=[Depends(verify_management_api_key)] # Sécurise l'endpoint
102
+ )
103
+ async def update_single_model(model_db_id: Any):
104
+ """Retélécharge et met à jour un modèle spécifique qui est déjà chargé en mémoire."""
105
+ logger.info(f"Received request to update model with DB ID: {model_db_id}")
106
+
107
+ # 1. Vérifier si le modèle est actuellement chargé
108
+ if model_db_id not in model_pipelines:
109
+ logger.error(f"Attempted to update model ID {model_db_id}, but it is not loaded.")
110
+ raise HTTPException(
111
+ status_code=404,
112
+ detail=f"Model ID {model_db_id} is not currently loaded. Use load_model first."
113
+ )
114
+
115
+ # 2. Récupérer les informations du modèle depuis la DB (pour s'assurer qu'elles sont à jour si besoin)
116
+ try:
117
+ model_data = await fetch_model_by_id(model_db_id)
118
+ if not model_data:
119
+ # Ceci indiquerait une incohérence si le modèle est dans model_pipelines mais pas dans la DB
120
+ logger.error(f"Inconsistency: Model ID {model_db_id} loaded but not found in database during update.")
121
+ raise HTTPException(status_code=500, detail=f"Inconsistency: Model ID {model_db_id} not found in database.")
122
+ except Exception as e:
123
+ logger.exception(f"Database error fetching model ID {model_db_id} during update: {e}")
124
+ raise HTTPException(status_code=500, detail=f"Database error checking model ID {model_db_id} for update.")
125
+
126
+ # 3. Recharger le modèle
127
+ try:
128
+ logger.info(f"Attempting to reload model ID {model_db_id} ('{model_data.get('name', 'N/A')}') from source...")
129
+ # Supprimer l'ancien modèle de la mémoire avant de charger le nouveau pour libérer des ressources GPU/CPU si possible
130
+ # Attention : ceci pourrait causer une brève indisponibilité du modèle pendant le rechargement.
131
+ # Une stratégie alternative serait de charger le nouveau d'abord, puis de remplacer.
132
+ if model_db_id in model_pipelines:
133
+ del model_pipelines[model_db_id]
134
+ # Potentiellement forcer le nettoyage de la mémoire GPU ici si nécessaire (torch.cuda.empty_cache() - à utiliser avec prudence)
135
+ logger.debug(f"Removed old instance of model ID {model_db_id} from memory before update.")
136
+
137
+ pipeline = await _load_single_model_pipeline(model_data)
138
+
139
+ # 4. Mettre à jour le dictionnaire avec le nouveau pipeline
140
+ model_pipelines[model_db_id] = pipeline
141
+ logger.info(f"Successfully updated model ID {model_db_id} in running pipelines.")
142
+ return {"status": "success", "message": f"Model ID {model_db_id} updated successfully."}
143
+
144
+ except Exception as e:
145
+ logger.exception(f"Failed to reload model ID {model_db_id}: {e}")
146
+ # Si le rechargement échoue, l'ancien modèle a déjà été supprimé.
147
+ # Il faut soit tenter de recharger l'ancien, soit le laisser déchargé.
148
+ # Pour l'instant, on le laisse déchargé et on signale l'erreur.
149
+ raise HTTPException(status_code=500, detail=f"Failed to reload model ID {model_db_id}. Model is now unloaded. Check server logs.")
config/settings.py CHANGED
@@ -1,6 +1,10 @@
1
  import os
2
  import logging
3
  import torch
 
 
 
 
4
 
5
  # Configuration de base des logs
6
  logging.basicConfig(level=logging.INFO)
@@ -22,6 +26,11 @@ HF_TOKEN = get_env_or_fail("api_read")
22
  RESOURCE_GROUP = get_env_or_fail("RESOURCE_GROUP")
23
  DATABASE_URL = get_env_or_fail("DATABASE_URL")
24
 
 
 
 
 
 
25
  # Log des paramètres importants (sans détails sensibles)
26
  logger.info(f"RESOURCE_GROUP set to: {RESOURCE_GROUP}")
27
 
 
1
  import os
2
  import logging
3
  import torch
4
+ from dotenv import load_dotenv
5
+
6
+ # Charger les variables d'environnement depuis le fichier .env
7
+ load_dotenv()
8
 
9
  # Configuration de base des logs
10
  logging.basicConfig(level=logging.INFO)
 
26
  RESOURCE_GROUP = get_env_or_fail("RESOURCE_GROUP")
27
  DATABASE_URL = get_env_or_fail("DATABASE_URL")
28
 
29
+ # Configuration Gestion
30
+ MANAGEMENT_API_KEY = os.getenv("MANAGEMENT_API_KEY")
31
+ if not MANAGEMENT_API_KEY:
32
+ print("Warning: MANAGEMENT_API_KEY environment variable is not set. Management endpoints will be inaccessible.")
33
+
34
  # Log des paramètres importants (sans détails sensibles)
35
  logger.info(f"RESOURCE_GROUP set to: {RESOURCE_GROUP}")
36
 
db/models.py CHANGED
@@ -46,3 +46,50 @@ async def fetch_models_for_group(resource_group: str) -> List[Dict[str, Any]]:
46
  if conn and not conn.is_closed():
47
  await conn.close()
48
  logger.debug("Database connection closed")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  if conn and not conn.is_closed():
47
  await conn.close()
48
  logger.debug("Database connection closed")
49
+
50
+
51
+ async def fetch_model_by_id(model_id: str) -> Dict[str, Any] | None:
52
+ """Récupérer les détails d'un modèle spécifique par son ID de base de données.
53
+
54
+ Args:
55
+ model_id: L'ID du modèle dans la base de données (peut être int ou str selon le schéma).
56
+
57
+ Returns:
58
+ Un dictionnaire contenant les informations du modèle si trouvé, sinon None.
59
+
60
+ Raises:
61
+ Exception: Si une erreur se produit lors de la connexion ou de la requête.
62
+ """
63
+ conn = None
64
+ try:
65
+ conn = await asyncpg.connect(DATABASE_URL)
66
+ logger.debug(f"Successfully connected to database to fetch model ID: {model_id}")
67
+
68
+ # Récupérer le modèle spécifique par son ID
69
+ query = """
70
+ SELECT
71
+ model_id,
72
+ name,
73
+ display_name,
74
+ hf_repo_id,
75
+ hf_subfolder,
76
+ hf_filename
77
+ FROM models
78
+ WHERE model_id = $1
79
+ """
80
+ row = await conn.fetchrow(query, model_id)
81
+
82
+ if row:
83
+ logger.info(f"Found model with ID '{model_id}': {row['name']}")
84
+ return dict(row)
85
+ else:
86
+ logger.warning(f"No model found with ID '{model_id}'")
87
+ return None
88
+
89
+ except Exception as e:
90
+ logger.error(f"Database error fetching model ID {model_id}: {e}", exc_info=True)
91
+ raise
92
+ finally:
93
+ if conn and not conn.is_closed():
94
+ await conn.close()
95
+ logger.debug(f"Database connection closed after fetching model ID: {model_id}")
models/loader.py CHANGED
@@ -1,7 +1,7 @@
1
- import torch
2
  import logging
3
- from typing import Dict, List, Any
4
  from huggingface_hub import hf_hub_download
 
5
 
6
  from config.settings import DEVICE, HF_TOKEN, NUM_THREADS
7
  from architecture.resnet import ResNet
@@ -12,62 +12,109 @@ logger = logging.getLogger(__name__)
12
  torch.set_num_threads(NUM_THREADS)
13
 
14
  # Instance de base pour le modèle ResNet
15
- base_model = ResNet("resnet152", num_output_neurons=2).to(DEVICE)
 
16
 
17
- # Dictionnaire global pour stocker les modèles chargés
18
- model_pipelines = {}
 
19
 
20
- async def load_models(models_data: List[Dict[str, Any]]) -> Dict[str, Any]:
21
- """Charger les modèles depuis Hugging Face à partir des données de la base de données.
22
 
23
  Args:
24
- models_data: Liste de dictionnaires contenant les informations des modèles
25
 
26
  Returns:
27
- Dictionnaire des modèles chargés
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  Raises:
30
- RuntimeError: Si aucun modèle n'est trouvé ou ne peut être chargé
31
  """
32
- logger.info(f"Attempting to load {len(models_data)} models...")
33
 
34
  if not models_data:
35
- error_msg = "No models found. API cannot start without models."
36
  logger.error(error_msg)
37
- raise RuntimeError(error_msg)
 
38
 
39
  loaded_count = 0
 
40
  for model_data in models_data:
 
41
  try:
42
- model_name = model_data['hf_filename']
43
- logger.info(f"Loading model: {model_name} (repo: {model_data['hf_repo_id']}, subfolder: {model_data['hf_subfolder']})")
44
-
45
- model_weight = hf_hub_download(
46
- repo_id=model_data['hf_repo_id'],
47
- subfolder=model_data['hf_subfolder'],
48
- filename=model_name,
49
- token=HF_TOKEN,
50
- )
51
-
52
- # Créer une nouvelle instance pour chaque modèle pour tenir ses poids spécifiques
53
- model = base_model.__class__("resnet152", num_output_neurons=2).to(DEVICE)
54
- model.load_state_dict(
55
- torch.load(model_weight, weights_only=True, map_location=DEVICE)
56
- )
57
- model.eval()
58
- model_pipelines[model_name] = model
59
  loaded_count += 1
60
  except Exception as e:
61
- logger.error(f"Error loading model {model_data.get('hf_filename', 'N/A')}: {e}", exc_info=True)
62
-
63
- logger.info(f"Model loading finished. Successfully loaded {loaded_count}/{len(models_data)} models.")
64
-
65
- if loaded_count == 0:
66
- error_msg = "Failed to load any models. API cannot start without models."
67
- logger.error(error_msg)
68
- raise RuntimeError(error_msg)
69
-
70
- return model_pipelines
 
71
 
72
  def get_model(model_name: str):
73
  """Récupérer un modèle chargé par son nom.
 
 
1
  import logging
2
+ import torch
3
  from huggingface_hub import hf_hub_download
4
+ from typing import List, Dict, Any
5
 
6
  from config.settings import DEVICE, HF_TOKEN, NUM_THREADS
7
  from architecture.resnet import ResNet
 
12
  torch.set_num_threads(NUM_THREADS)
13
 
14
  # Instance de base pour le modèle ResNet
15
+ # Note: Peut-être pas nécessaire de l'instancier ici si chaque chargement en crée une nouvelle
16
+ # base_model = ResNet("resnet152", num_output_neurons=2).to(DEVICE)
17
 
18
+ # Dictionnaire global pour stocker les modèles chargés (pipelines)
19
+ # Clé: ID du modèle (provenant de la DB), Valeur: Pipeline/Modèle chargé
20
+ model_pipelines: Dict[Any, Any] = {}
21
 
22
+ async def _load_single_model_pipeline(model_data: Dict[str, Any]) -> Any:
23
+ """Charge un seul pipeline de modèle à partir de ses données.
24
 
25
  Args:
26
+ model_data: Dictionnaire contenant les informations du modèle (hf_repo_id, etc.).
27
 
28
  Returns:
29
+ Le pipeline/modèle chargé.
30
+
31
+ Raises:
32
+ Exception: Si le chargement échoue.
33
+ """
34
+ model_id = model_data['model_id'] # Utiliser l'ID de la DB comme clé
35
+ model_name = model_data['hf_filename']
36
+ repo_id = model_data['hf_repo_id']
37
+ subfolder = model_data['hf_subfolder']
38
+
39
+ logger.info(f"Loading model ID {model_id}: {model_name} (repo: {repo_id}, subfolder: {subfolder})")
40
+
41
+ try:
42
+ model_weight_path = hf_hub_download(
43
+ repo_id=repo_id,
44
+ subfolder=subfolder,
45
+ filename=model_name,
46
+ token=HF_TOKEN, # Assurez-vous que HF_TOKEN est géré correctement
47
+ )
48
+
49
+ logger.debug(f"Model weights downloaded to: {model_weight_path}")
50
+
51
+ # Créer une nouvelle instance de modèle ResNet pour ce chargement spécifique
52
+ # Assurez-vous que ResNet et ses arguments sont corrects
53
+ model = ResNet("resnet152", num_output_neurons=2).to(DEVICE)
54
+
55
+ # Charger les poids
56
+ # Attention: la méthode de chargement dépend du format des poids (state_dict, etc.)
57
+ state_dict = torch.load(model_weight_path, map_location=DEVICE)
58
+
59
+ # Gérer les cas où les poids sont dans une sous-clé (ex: 'state_dict', 'model')
60
+ if isinstance(state_dict, dict) and 'state_dict' in state_dict:
61
+ state_dict = state_dict['state_dict']
62
+ elif isinstance(state_dict, dict) and 'model' in state_dict: # Autre cas commun
63
+ state_dict = state_dict['model']
64
+
65
+ # Adapter les clés si nécessaire (ex: supprimer le préfixe 'module.' de DataParallel/DDP)
66
+ state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
67
+
68
+ model.load_state_dict(state_dict)
69
+ model.eval() # Mettre le modèle en mode évaluation
70
+
71
+ logger.info(f"Successfully loaded model ID {model_id}: {model_name}")
72
+ return model # Retourner le modèle chargé (ou un pipeline si vous en créez un)
73
+
74
+ except Exception as e:
75
+ logger.error(f"Failed to load model ID {model_id} ({model_name}): {e}", exc_info=True)
76
+ raise # Propage l'exception pour que l'appelant puisse la gérer
77
+
78
+
79
+ async def load_models(models_data: List[Dict[str, Any]]) -> None:
80
+ """Charger les modèles depuis Hugging Face et les stocker dans model_pipelines.
81
+
82
+ Args:
83
+ models_data: Liste de dictionnaires contenant les informations des modèles.
84
 
85
  Raises:
86
+ RuntimeError: Si aucun modèle n'est trouvé.
87
  """
88
+ logger.info(f"Attempting to load {len(models_data)} models into memory...")
89
 
90
  if not models_data:
91
+ error_msg = "No models data provided. Cannot load models."
92
  logger.error(error_msg)
93
+ # On ne lève plus d'erreur ici, on logge juste. L'API démarrera sans modèles.
94
+ return
95
 
96
  loaded_count = 0
97
+ failed_models = []
98
  for model_data in models_data:
99
+ model_id = model_data.get('model_id', 'N/A') # Assurez-vous que model_id est présent
100
  try:
101
+ # Utilise la nouvelle fonction pour charger un seul modèle
102
+ pipeline = await _load_single_model_pipeline(model_data)
103
+ # Stocke le pipeline chargé dans le dictionnaire global
104
+ model_pipelines[model_id] = pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  loaded_count += 1
106
  except Exception as e:
107
+ # Log l'échec mais continue avec les autres modèles
108
+ logger.error(f"Failed to load model ID {model_id}: {e}")
109
+ failed_models.append(model_data.get('name', f'ID {model_id}'))
110
+
111
+ logger.info(f"Finished loading models. Successfully loaded: {loaded_count}/{len(models_data)}")
112
+ if failed_models:
113
+ logger.warning(f"Failed to load the following models: {', '.join(failed_models)}")
114
+
115
+ # Pas besoin de retourner les pipelines, ils sont dans le dictionnaire global
116
+ # return model_pipelines # Ancienne logique
117
+
118
 
119
  def get_model(model_name: str):
120
  """Récupérer un modèle chargé par son nom.