Spaces:
Sleeping
Sleeping
Commit
·
0053356
1
Parent(s):
2c455c2
Add model management endpoints and database fetch functionality
Browse files- api/prediction.py +3 -2
- api/router.py +112 -1
- config/settings.py +9 -0
- db/models.py +47 -0
- models/loader.py +87 -40
api/prediction.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import torch
|
2 |
-
from fastapi import APIRouter, HTTPException
|
3 |
from fastapi.responses import JSONResponse
|
4 |
from PIL import Image
|
5 |
from io import BytesIO
|
@@ -12,11 +12,12 @@ from schemas.requests import BatchPredictRequest
|
|
12 |
from models.loader import get_model
|
13 |
from steps.preprocess import process_image
|
14 |
from config.settings import IMAGE_SIZE, NUM_THREADS
|
|
|
15 |
|
16 |
logger = logging.getLogger(__name__)
|
17 |
router = APIRouter()
|
18 |
|
19 |
-
@router.post("/batch_predict")
|
20 |
async def batch_predict(request: BatchPredictRequest):
|
21 |
"""Endpoint pour prédire à partir de plusieurs images."""
|
22 |
model_name = request.modelName
|
|
|
1 |
import torch
|
2 |
+
from fastapi import APIRouter, HTTPException, Depends
|
3 |
from fastapi.responses import JSONResponse
|
4 |
from PIL import Image
|
5 |
from io import BytesIO
|
|
|
12 |
from models.loader import get_model
|
13 |
from steps.preprocess import process_image
|
14 |
from config.settings import IMAGE_SIZE, NUM_THREADS
|
15 |
+
from api.router import verify_api_key
|
16 |
|
17 |
logger = logging.getLogger(__name__)
|
18 |
router = APIRouter()
|
19 |
|
20 |
+
@router.post("/batch_predict", dependencies=[Depends(verify_api_key)])
|
21 |
async def batch_predict(request: BatchPredictRequest):
|
22 |
"""Endpoint pour prédire à partir de plusieurs images."""
|
23 |
model_name = request.modelName
|
api/router.py
CHANGED
@@ -3,7 +3,10 @@ import logging
|
|
3 |
import os
|
4 |
|
5 |
from api import prediction
|
6 |
-
from config.settings import API_KEY
|
|
|
|
|
|
|
7 |
|
8 |
logger = logging.getLogger(__name__)
|
9 |
|
@@ -34,5 +37,113 @@ async def verify_api_key(request: Request, call_next):
|
|
34 |
response = await call_next(request)
|
35 |
return response
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
# Inclure les routes des autres modules
|
38 |
router.include_router(prediction.router, tags=["Prediction"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
import os
|
4 |
|
5 |
from api import prediction
|
6 |
+
from config.settings import API_KEY, MANAGEMENT_API_KEY
|
7 |
+
from db.models import fetch_model_by_id
|
8 |
+
from models.loader import model_pipelines, _load_single_model_pipeline, get_model
|
9 |
+
from models.schemas import PredictionRequest, PredictionResponse
|
10 |
|
11 |
logger = logging.getLogger(__name__)
|
12 |
|
|
|
37 |
response = await call_next(request)
|
38 |
return response
|
39 |
|
40 |
+
# Dépendance pour la Sécurité de l'API de Gestion
|
41 |
+
async def verify_management_api_key(x_api_key: str = Header(...)):
|
42 |
+
"""Vérifie si la clé API fournie correspond à celle configurée."""
|
43 |
+
if not MANAGEMENT_API_KEY:
|
44 |
+
logger.warning("MANAGEMENT_API_KEY is not set. Management endpoints are unsecured!")
|
45 |
+
# Décider si on bloque ou on autorise sans clé définie
|
46 |
+
# Pour la sécurité, il vaut mieux bloquer par défaut
|
47 |
+
raise HTTPException(status_code=500, detail="Management API key not configured on server.")
|
48 |
+
if x_api_key != MANAGEMENT_API_KEY:
|
49 |
+
logger.warning(f"Invalid or missing API key attempt for management endpoint.")
|
50 |
+
raise HTTPException(status_code=401, detail="Invalid or missing API Key")
|
51 |
+
return True # Clé valide
|
52 |
+
|
53 |
# Inclure les routes des autres modules
|
54 |
router.include_router(prediction.router, tags=["Prediction"])
|
55 |
+
|
56 |
+
# Nouvel Endpoint de Gestion
|
57 |
+
@router.post(
|
58 |
+
"/manage/load_model/{model_db_id}",
|
59 |
+
summary="Load a specific model into memory",
|
60 |
+
dependencies=[Depends(verify_management_api_key)] # Sécurise l'endpoint
|
61 |
+
)
|
62 |
+
async def load_single_model(model_db_id: Any): # L'ID peut être int ou str
|
63 |
+
"""Charge un modèle spécifique en mémoire en utilisant son ID de la base de données."""
|
64 |
+
logger.info(f"Received request to load model with DB ID: {model_db_id}")
|
65 |
+
|
66 |
+
# 1. Vérifier si le modèle est déjà chargé
|
67 |
+
if model_db_id in model_pipelines:
|
68 |
+
logger.info(f"Model ID {model_db_id} is already loaded.")
|
69 |
+
return {"status": "success", "message": f"Model ID {model_db_id} is already loaded."}
|
70 |
+
|
71 |
+
# 2. Récupérer les informations du modèle depuis la DB
|
72 |
+
try:
|
73 |
+
model_data = await fetch_model_by_id(model_db_id)
|
74 |
+
if not model_data:
|
75 |
+
logger.error(f"Model ID {model_db_id} not found in database.")
|
76 |
+
raise HTTPException(status_code=404, detail=f"Model ID {model_db_id} not found in database.")
|
77 |
+
except Exception as e:
|
78 |
+
logger.exception(f"Database error fetching model ID {model_db_id}: {e}")
|
79 |
+
raise HTTPException(status_code=500, detail=f"Database error checking model ID {model_db_id}.")
|
80 |
+
|
81 |
+
# 3. Charger le modèle
|
82 |
+
try:
|
83 |
+
logger.info(f"Attempting to load model ID {model_db_id} ('{model_data.get('name', 'N/A')}') into memory...")
|
84 |
+
pipeline = await _load_single_model_pipeline(model_data)
|
85 |
+
|
86 |
+
# 4. Ajouter au dictionnaire des modèles chargés
|
87 |
+
model_pipelines[model_db_id] = pipeline
|
88 |
+
logger.info(f"Successfully loaded and added model ID {model_db_id} to running pipelines.")
|
89 |
+
return {"status": "success", "message": f"Model ID {model_db_id} loaded successfully."}
|
90 |
+
|
91 |
+
except Exception as e:
|
92 |
+
logger.exception(f"Failed to load model ID {model_db_id}: {e}")
|
93 |
+
# Ne pas laisser un pipeline potentiellement corrompu dans le dictionnaire
|
94 |
+
if model_db_id in model_pipelines:
|
95 |
+
del model_pipelines[model_db_id]
|
96 |
+
raise HTTPException(status_code=500, detail=f"Failed to load model ID {model_db_id}. Check server logs for details.")
|
97 |
+
|
98 |
+
@router.post(
|
99 |
+
"/manage/update_model/{model_db_id}",
|
100 |
+
summary="Reload/Update a specific model already in memory",
|
101 |
+
dependencies=[Depends(verify_management_api_key)] # Sécurise l'endpoint
|
102 |
+
)
|
103 |
+
async def update_single_model(model_db_id: Any):
|
104 |
+
"""Retélécharge et met à jour un modèle spécifique qui est déjà chargé en mémoire."""
|
105 |
+
logger.info(f"Received request to update model with DB ID: {model_db_id}")
|
106 |
+
|
107 |
+
# 1. Vérifier si le modèle est actuellement chargé
|
108 |
+
if model_db_id not in model_pipelines:
|
109 |
+
logger.error(f"Attempted to update model ID {model_db_id}, but it is not loaded.")
|
110 |
+
raise HTTPException(
|
111 |
+
status_code=404,
|
112 |
+
detail=f"Model ID {model_db_id} is not currently loaded. Use load_model first."
|
113 |
+
)
|
114 |
+
|
115 |
+
# 2. Récupérer les informations du modèle depuis la DB (pour s'assurer qu'elles sont à jour si besoin)
|
116 |
+
try:
|
117 |
+
model_data = await fetch_model_by_id(model_db_id)
|
118 |
+
if not model_data:
|
119 |
+
# Ceci indiquerait une incohérence si le modèle est dans model_pipelines mais pas dans la DB
|
120 |
+
logger.error(f"Inconsistency: Model ID {model_db_id} loaded but not found in database during update.")
|
121 |
+
raise HTTPException(status_code=500, detail=f"Inconsistency: Model ID {model_db_id} not found in database.")
|
122 |
+
except Exception as e:
|
123 |
+
logger.exception(f"Database error fetching model ID {model_db_id} during update: {e}")
|
124 |
+
raise HTTPException(status_code=500, detail=f"Database error checking model ID {model_db_id} for update.")
|
125 |
+
|
126 |
+
# 3. Recharger le modèle
|
127 |
+
try:
|
128 |
+
logger.info(f"Attempting to reload model ID {model_db_id} ('{model_data.get('name', 'N/A')}') from source...")
|
129 |
+
# Supprimer l'ancien modèle de la mémoire avant de charger le nouveau pour libérer des ressources GPU/CPU si possible
|
130 |
+
# Attention : ceci pourrait causer une brève indisponibilité du modèle pendant le rechargement.
|
131 |
+
# Une stratégie alternative serait de charger le nouveau d'abord, puis de remplacer.
|
132 |
+
if model_db_id in model_pipelines:
|
133 |
+
del model_pipelines[model_db_id]
|
134 |
+
# Potentiellement forcer le nettoyage de la mémoire GPU ici si nécessaire (torch.cuda.empty_cache() - à utiliser avec prudence)
|
135 |
+
logger.debug(f"Removed old instance of model ID {model_db_id} from memory before update.")
|
136 |
+
|
137 |
+
pipeline = await _load_single_model_pipeline(model_data)
|
138 |
+
|
139 |
+
# 4. Mettre à jour le dictionnaire avec le nouveau pipeline
|
140 |
+
model_pipelines[model_db_id] = pipeline
|
141 |
+
logger.info(f"Successfully updated model ID {model_db_id} in running pipelines.")
|
142 |
+
return {"status": "success", "message": f"Model ID {model_db_id} updated successfully."}
|
143 |
+
|
144 |
+
except Exception as e:
|
145 |
+
logger.exception(f"Failed to reload model ID {model_db_id}: {e}")
|
146 |
+
# Si le rechargement échoue, l'ancien modèle a déjà été supprimé.
|
147 |
+
# Il faut soit tenter de recharger l'ancien, soit le laisser déchargé.
|
148 |
+
# Pour l'instant, on le laisse déchargé et on signale l'erreur.
|
149 |
+
raise HTTPException(status_code=500, detail=f"Failed to reload model ID {model_db_id}. Model is now unloaded. Check server logs.")
|
config/settings.py
CHANGED
@@ -1,6 +1,10 @@
|
|
1 |
import os
|
2 |
import logging
|
3 |
import torch
|
|
|
|
|
|
|
|
|
4 |
|
5 |
# Configuration de base des logs
|
6 |
logging.basicConfig(level=logging.INFO)
|
@@ -22,6 +26,11 @@ HF_TOKEN = get_env_or_fail("api_read")
|
|
22 |
RESOURCE_GROUP = get_env_or_fail("RESOURCE_GROUP")
|
23 |
DATABASE_URL = get_env_or_fail("DATABASE_URL")
|
24 |
|
|
|
|
|
|
|
|
|
|
|
25 |
# Log des paramètres importants (sans détails sensibles)
|
26 |
logger.info(f"RESOURCE_GROUP set to: {RESOURCE_GROUP}")
|
27 |
|
|
|
1 |
import os
|
2 |
import logging
|
3 |
import torch
|
4 |
+
from dotenv import load_dotenv
|
5 |
+
|
6 |
+
# Charger les variables d'environnement depuis le fichier .env
|
7 |
+
load_dotenv()
|
8 |
|
9 |
# Configuration de base des logs
|
10 |
logging.basicConfig(level=logging.INFO)
|
|
|
26 |
RESOURCE_GROUP = get_env_or_fail("RESOURCE_GROUP")
|
27 |
DATABASE_URL = get_env_or_fail("DATABASE_URL")
|
28 |
|
29 |
+
# Configuration Gestion
|
30 |
+
MANAGEMENT_API_KEY = os.getenv("MANAGEMENT_API_KEY")
|
31 |
+
if not MANAGEMENT_API_KEY:
|
32 |
+
print("Warning: MANAGEMENT_API_KEY environment variable is not set. Management endpoints will be inaccessible.")
|
33 |
+
|
34 |
# Log des paramètres importants (sans détails sensibles)
|
35 |
logger.info(f"RESOURCE_GROUP set to: {RESOURCE_GROUP}")
|
36 |
|
db/models.py
CHANGED
@@ -46,3 +46,50 @@ async def fetch_models_for_group(resource_group: str) -> List[Dict[str, Any]]:
|
|
46 |
if conn and not conn.is_closed():
|
47 |
await conn.close()
|
48 |
logger.debug("Database connection closed")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
if conn and not conn.is_closed():
|
47 |
await conn.close()
|
48 |
logger.debug("Database connection closed")
|
49 |
+
|
50 |
+
|
51 |
+
async def fetch_model_by_id(model_id: str) -> Dict[str, Any] | None:
|
52 |
+
"""Récupérer les détails d'un modèle spécifique par son ID de base de données.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
model_id: L'ID du modèle dans la base de données (peut être int ou str selon le schéma).
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
Un dictionnaire contenant les informations du modèle si trouvé, sinon None.
|
59 |
+
|
60 |
+
Raises:
|
61 |
+
Exception: Si une erreur se produit lors de la connexion ou de la requête.
|
62 |
+
"""
|
63 |
+
conn = None
|
64 |
+
try:
|
65 |
+
conn = await asyncpg.connect(DATABASE_URL)
|
66 |
+
logger.debug(f"Successfully connected to database to fetch model ID: {model_id}")
|
67 |
+
|
68 |
+
# Récupérer le modèle spécifique par son ID
|
69 |
+
query = """
|
70 |
+
SELECT
|
71 |
+
model_id,
|
72 |
+
name,
|
73 |
+
display_name,
|
74 |
+
hf_repo_id,
|
75 |
+
hf_subfolder,
|
76 |
+
hf_filename
|
77 |
+
FROM models
|
78 |
+
WHERE model_id = $1
|
79 |
+
"""
|
80 |
+
row = await conn.fetchrow(query, model_id)
|
81 |
+
|
82 |
+
if row:
|
83 |
+
logger.info(f"Found model with ID '{model_id}': {row['name']}")
|
84 |
+
return dict(row)
|
85 |
+
else:
|
86 |
+
logger.warning(f"No model found with ID '{model_id}'")
|
87 |
+
return None
|
88 |
+
|
89 |
+
except Exception as e:
|
90 |
+
logger.error(f"Database error fetching model ID {model_id}: {e}", exc_info=True)
|
91 |
+
raise
|
92 |
+
finally:
|
93 |
+
if conn and not conn.is_closed():
|
94 |
+
await conn.close()
|
95 |
+
logger.debug(f"Database connection closed after fetching model ID: {model_id}")
|
models/loader.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
-
import torch
|
2 |
import logging
|
3 |
-
|
4 |
from huggingface_hub import hf_hub_download
|
|
|
5 |
|
6 |
from config.settings import DEVICE, HF_TOKEN, NUM_THREADS
|
7 |
from architecture.resnet import ResNet
|
@@ -12,62 +12,109 @@ logger = logging.getLogger(__name__)
|
|
12 |
torch.set_num_threads(NUM_THREADS)
|
13 |
|
14 |
# Instance de base pour le modèle ResNet
|
15 |
-
|
|
|
16 |
|
17 |
-
# Dictionnaire global pour stocker les modèles chargés
|
18 |
-
|
|
|
19 |
|
20 |
-
async def
|
21 |
-
"""
|
22 |
|
23 |
Args:
|
24 |
-
|
25 |
|
26 |
Returns:
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
Raises:
|
30 |
-
RuntimeError: Si aucun modèle n'est
|
31 |
"""
|
32 |
-
logger.info(f"Attempting to load {len(models_data)} models...")
|
33 |
|
34 |
if not models_data:
|
35 |
-
error_msg = "No models
|
36 |
logger.error(error_msg)
|
37 |
-
|
|
|
38 |
|
39 |
loaded_count = 0
|
|
|
40 |
for model_data in models_data:
|
|
|
41 |
try:
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
repo_id=model_data['hf_repo_id'],
|
47 |
-
subfolder=model_data['hf_subfolder'],
|
48 |
-
filename=model_name,
|
49 |
-
token=HF_TOKEN,
|
50 |
-
)
|
51 |
-
|
52 |
-
# Créer une nouvelle instance pour chaque modèle pour tenir ses poids spécifiques
|
53 |
-
model = base_model.__class__("resnet152", num_output_neurons=2).to(DEVICE)
|
54 |
-
model.load_state_dict(
|
55 |
-
torch.load(model_weight, weights_only=True, map_location=DEVICE)
|
56 |
-
)
|
57 |
-
model.eval()
|
58 |
-
model_pipelines[model_name] = model
|
59 |
loaded_count += 1
|
60 |
except Exception as e:
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
logger.
|
68 |
-
|
69 |
-
|
70 |
-
return model_pipelines
|
|
|
71 |
|
72 |
def get_model(model_name: str):
|
73 |
"""Récupérer un modèle chargé par son nom.
|
|
|
|
|
1 |
import logging
|
2 |
+
import torch
|
3 |
from huggingface_hub import hf_hub_download
|
4 |
+
from typing import List, Dict, Any
|
5 |
|
6 |
from config.settings import DEVICE, HF_TOKEN, NUM_THREADS
|
7 |
from architecture.resnet import ResNet
|
|
|
12 |
torch.set_num_threads(NUM_THREADS)
|
13 |
|
14 |
# Instance de base pour le modèle ResNet
|
15 |
+
# Note: Peut-être pas nécessaire de l'instancier ici si chaque chargement en crée une nouvelle
|
16 |
+
# base_model = ResNet("resnet152", num_output_neurons=2).to(DEVICE)
|
17 |
|
18 |
+
# Dictionnaire global pour stocker les modèles chargés (pipelines)
|
19 |
+
# Clé: ID du modèle (provenant de la DB), Valeur: Pipeline/Modèle chargé
|
20 |
+
model_pipelines: Dict[Any, Any] = {}
|
21 |
|
22 |
+
async def _load_single_model_pipeline(model_data: Dict[str, Any]) -> Any:
|
23 |
+
"""Charge un seul pipeline de modèle à partir de ses données.
|
24 |
|
25 |
Args:
|
26 |
+
model_data: Dictionnaire contenant les informations du modèle (hf_repo_id, etc.).
|
27 |
|
28 |
Returns:
|
29 |
+
Le pipeline/modèle chargé.
|
30 |
+
|
31 |
+
Raises:
|
32 |
+
Exception: Si le chargement échoue.
|
33 |
+
"""
|
34 |
+
model_id = model_data['model_id'] # Utiliser l'ID de la DB comme clé
|
35 |
+
model_name = model_data['hf_filename']
|
36 |
+
repo_id = model_data['hf_repo_id']
|
37 |
+
subfolder = model_data['hf_subfolder']
|
38 |
+
|
39 |
+
logger.info(f"Loading model ID {model_id}: {model_name} (repo: {repo_id}, subfolder: {subfolder})")
|
40 |
+
|
41 |
+
try:
|
42 |
+
model_weight_path = hf_hub_download(
|
43 |
+
repo_id=repo_id,
|
44 |
+
subfolder=subfolder,
|
45 |
+
filename=model_name,
|
46 |
+
token=HF_TOKEN, # Assurez-vous que HF_TOKEN est géré correctement
|
47 |
+
)
|
48 |
+
|
49 |
+
logger.debug(f"Model weights downloaded to: {model_weight_path}")
|
50 |
+
|
51 |
+
# Créer une nouvelle instance de modèle ResNet pour ce chargement spécifique
|
52 |
+
# Assurez-vous que ResNet et ses arguments sont corrects
|
53 |
+
model = ResNet("resnet152", num_output_neurons=2).to(DEVICE)
|
54 |
+
|
55 |
+
# Charger les poids
|
56 |
+
# Attention: la méthode de chargement dépend du format des poids (state_dict, etc.)
|
57 |
+
state_dict = torch.load(model_weight_path, map_location=DEVICE)
|
58 |
+
|
59 |
+
# Gérer les cas où les poids sont dans une sous-clé (ex: 'state_dict', 'model')
|
60 |
+
if isinstance(state_dict, dict) and 'state_dict' in state_dict:
|
61 |
+
state_dict = state_dict['state_dict']
|
62 |
+
elif isinstance(state_dict, dict) and 'model' in state_dict: # Autre cas commun
|
63 |
+
state_dict = state_dict['model']
|
64 |
+
|
65 |
+
# Adapter les clés si nécessaire (ex: supprimer le préfixe 'module.' de DataParallel/DDP)
|
66 |
+
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
|
67 |
+
|
68 |
+
model.load_state_dict(state_dict)
|
69 |
+
model.eval() # Mettre le modèle en mode évaluation
|
70 |
+
|
71 |
+
logger.info(f"Successfully loaded model ID {model_id}: {model_name}")
|
72 |
+
return model # Retourner le modèle chargé (ou un pipeline si vous en créez un)
|
73 |
+
|
74 |
+
except Exception as e:
|
75 |
+
logger.error(f"Failed to load model ID {model_id} ({model_name}): {e}", exc_info=True)
|
76 |
+
raise # Propage l'exception pour que l'appelant puisse la gérer
|
77 |
+
|
78 |
+
|
79 |
+
async def load_models(models_data: List[Dict[str, Any]]) -> None:
|
80 |
+
"""Charger les modèles depuis Hugging Face et les stocker dans model_pipelines.
|
81 |
+
|
82 |
+
Args:
|
83 |
+
models_data: Liste de dictionnaires contenant les informations des modèles.
|
84 |
|
85 |
Raises:
|
86 |
+
RuntimeError: Si aucun modèle n'est trouvé.
|
87 |
"""
|
88 |
+
logger.info(f"Attempting to load {len(models_data)} models into memory...")
|
89 |
|
90 |
if not models_data:
|
91 |
+
error_msg = "No models data provided. Cannot load models."
|
92 |
logger.error(error_msg)
|
93 |
+
# On ne lève plus d'erreur ici, on logge juste. L'API démarrera sans modèles.
|
94 |
+
return
|
95 |
|
96 |
loaded_count = 0
|
97 |
+
failed_models = []
|
98 |
for model_data in models_data:
|
99 |
+
model_id = model_data.get('model_id', 'N/A') # Assurez-vous que model_id est présent
|
100 |
try:
|
101 |
+
# Utilise la nouvelle fonction pour charger un seul modèle
|
102 |
+
pipeline = await _load_single_model_pipeline(model_data)
|
103 |
+
# Stocke le pipeline chargé dans le dictionnaire global
|
104 |
+
model_pipelines[model_id] = pipeline
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
loaded_count += 1
|
106 |
except Exception as e:
|
107 |
+
# Log l'échec mais continue avec les autres modèles
|
108 |
+
logger.error(f"Failed to load model ID {model_id}: {e}")
|
109 |
+
failed_models.append(model_data.get('name', f'ID {model_id}'))
|
110 |
+
|
111 |
+
logger.info(f"Finished loading models. Successfully loaded: {loaded_count}/{len(models_data)}")
|
112 |
+
if failed_models:
|
113 |
+
logger.warning(f"Failed to load the following models: {', '.join(failed_models)}")
|
114 |
+
|
115 |
+
# Pas besoin de retourner les pipelines, ils sont dans le dictionnaire global
|
116 |
+
# return model_pipelines # Ancienne logique
|
117 |
+
|
118 |
|
119 |
def get_model(model_name: str):
|
120 |
"""Récupérer un modèle chargé par son nom.
|