alexfremont commited on
Commit
8a8fe7c
·
1 Parent(s): 8580232

Improve model unloading with explicit GPU memory cleanup and CUDA cache clearing

Browse files
Files changed (1) hide show
  1. api/management.py +28 -2
api/management.py CHANGED
@@ -100,12 +100,38 @@ async def delete_single_model(model_db_id: Any):
100
  # Récupérer le nom du fichier pour le logging avant de supprimer
101
  filename = model_pipelines[model_db_id]['metadata'].get('hf_filename', 'unknown')
102
 
103
- # Supprimer le modèle de la mémoire
 
 
 
 
 
 
 
 
 
 
 
 
104
  del model_pipelines[model_db_id]
105
 
106
- # Force le garbage collector pour libérer la mémoire
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  import gc
108
  gc.collect()
 
109
 
110
  logger.info(f"Successfully unloaded model ID {model_db_id} ({filename}) from memory")
111
  return {"message": f"Model {model_db_id} successfully unloaded from memory"}
 
100
  # Récupérer le nom du fichier pour le logging avant de supprimer
101
  filename = model_pipelines[model_db_id]['metadata'].get('hf_filename', 'unknown')
102
 
103
+ # Récupérer une référence au modèle avant de le supprimer
104
+ model_data = model_pipelines[model_db_id]
105
+ pipeline = model_data.get('pipeline')
106
+
107
+ # Détacher explicitement le modèle du GPU si applicable
108
+ if hasattr(pipeline, 'to') and hasattr(pipeline, 'cpu'):
109
+ try:
110
+ pipeline.to('cpu')
111
+ logger.info(f"Modèle {model_db_id} détaché du GPU")
112
+ except Exception as e:
113
+ logger.warning(f"Impossible de détacher le modèle du GPU: {e}")
114
+
115
+ # Supprimer le modèle du dictionnaire
116
  del model_pipelines[model_db_id]
117
 
118
+ # Supprimer explicitement les références
119
+ del model_data
120
+ del pipeline
121
+
122
+ # Vider le cache PyTorch si disponible
123
+ try:
124
+ import torch
125
+ if torch.cuda.is_available():
126
+ torch.cuda.empty_cache()
127
+ logger.info("Cache CUDA vidé")
128
+ except (ImportError, AttributeError) as e:
129
+ logger.debug(f"Impossible de vider le cache CUDA: {e}")
130
+
131
+ # Force le garbage collector plusieurs fois pour libérer la mémoire
132
  import gc
133
  gc.collect()
134
+ gc.collect() # Parfois un second appel aide à libérer plus de mémoire
135
 
136
  logger.info(f"Successfully unloaded model ID {model_db_id} ({filename}) from memory")
137
  return {"message": f"Model {model_db_id} successfully unloaded from memory"}