inference-api-g1 / main.py
alexfremont's picture
Remove periodic memory status updates and related helper function
1c12662
import logging
import os
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request
import gradio as gr
from gradio.routes import mount_gradio_app
from api.router import router
from db.models import fetch_models_for_group
from models.loader import load_models, model_pipelines
from config.settings import RESOURCE_GROUP, DATABASE_URL
from utils.system_monitor import get_memory_status, format_memory_status
# Configuration de base des logs
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
# --- Événements Startup/Shutdown (Lifespan Manager) ---
@asynccontextmanager
async def lifespan(app: FastAPI):
# Code exécuté au démarrage
logger.info("Starting up API...")
logger.info(f"Attempting to connect to database: {DATABASE_URL[:DATABASE_URL.find('@')] + '@...'}") # Masquer les crédentiels
try:
models_to_load = await fetch_models_for_group(RESOURCE_GROUP)
if models_to_load:
await load_models(models_to_load)
logger.info("Initial models loaded successfully.")
else:
logger.warning(f"No models found for resource group '{RESOURCE_GROUP}'. API starting without preloaded models.")
except Exception as e:
logger.exception(f"Failed to load initial models during startup: {e}")
# Décider s'il faut empêcher le démarrage de l'API ou continuer sans modèles
# raise RuntimeError("Could not load initial models, API startup aborted.") from e
yield
# Code exécuté à l'arrêt
logger.info("Shutting down API...")
# Ajouter ici le code de nettoyage si nécessaire (ex: fermer connexions persistantes)
# Créer l'application FastAPI
app = FastAPI(
title="Tamis AI Inference API",
description="API pour l'inférence des modèles de classification d'objets",
version="0.1.0",
lifespan=lifespan # Correction: Utilisation de la fonction lifespan définie ci-dessus
)
# Ajout du Middleware ici
@app.middleware("http")
async def api_key_middleware(request: Request, call_next):
"""Middleware pour vérifier la clé API et exempter certaines routes."""
# Skip if we're in debug mode or during startup
if os.environ.get("DEBUG") == "1":
logger.debug("DEBUG mode active, skipping API key check.")
return await call_next(request)
# Liste des chemins publics ou internes à exempter de la vérification de clé
public_paths = [
'/docs', '/openapi.json', # Documentation Swagger/OpenAPI
'/health', # Health check endpoint
'/', # Racine (Interface Gradio)
'/assets/', # Assets Gradio
'/file=', # Fichiers Gradio
'/queue/', # Queue Gradio
'/startup-logs', # Logs HF Space
'/config', # Config Gradio/HF
'/info', # Info Gradio/HF
'/gradio', # Potentiel préfixe Gradio
'/favicon.ico' # Favicon
]
# Vérifie si le chemin commence par un des préfixes publics
is_public = any(request.url.path == p or (p.endswith('/') and request.url.path.startswith(p)) for p in public_paths)
if is_public:
logger.debug(f"Public path accessed: {request.url.path}, skipping API key check.")
response = await call_next(request)
return response
else:
# Pour toutes les autres routes, la vérification se fait via Depends() sur l'endpoint lui-même.
# Ce middleware ne fait donc plus de vérification active ici,
# il sert juste à logger et potentiellement à exempter certaines routes si besoin.
logger.debug(f"Protected path accessed: {request.url.path}. API key verification delegated to endpoint.")
response = await call_next(request)
return response
# Inclure les routes
app.include_router(router)
async def init_models():
"""Charger les modèles au démarrage pour Gradio et FastAPI."""
logger.info("Initializing models for Gradio and FastAPI...")
try:
models_data = await fetch_models_for_group(RESOURCE_GROUP)
await load_models(models_data)
logger.info("Models loaded successfully.")
except Exception as e:
logger.error(f"Failed to initialize models: {e}", exc_info=True)
# Decide if the app should fail to start or continue without models
# raise RuntimeError("Model initialization failed.")
# For now, let's log and continue, Gradio will show an empty list
pass
# Définir les fonctions pour Gradio qui récupèrent les modèles chargés
def get_loaded_models_list():
"""Retourne la liste des métadonnées des modèles actuellement chargés."""
# Extraire les métadonnées de chaque entrée dans model_pipelines
return [item['metadata'] for item in model_pipelines.values()]
def format_model_info(metadata_list):
"""Formate les informations des modèles pour un affichage plus convivial."""
if not metadata_list:
return "Aucun modèle chargé actuellement."
# Créer une version formatée des informations des modèles
formatted_info = ""
for model in metadata_list:
# Utiliser le nom du fichier comme titre principal
formatted_info += f"### Modèle : {model.get('hf_filename', 'N/A')}\n"
# Mettre l'ID et la dernière mise à jour sur une même ligne
formatted_info += f"**ID:** {model.get('model_id', 'N/A')} | **Dernière mise à jour:** {model.get('updated_at', 'N/A')}\n\n"
return formatted_info
# Créer l'interface Gradio
gradio_app = gr.Blocks(title="Tamis AI - Modèles Chargés", theme=gr.themes.Soft())
with gradio_app:
gr.Markdown("# 🤖 Tamis AI - Interface d'administration")
gr.Markdown("## Modèles actuellement chargés dans l'API")
with gr.Row():
with gr.Column(scale=2):
# Visualisation des modèles avec des cartes
with gr.Tab("Vue détaillée"):
markdown_output = gr.Markdown(value="Chargement des modèles...", elem_id="model_details")
# Affichage en tableau
with gr.Tab("Vue tableau"):
model_table = gr.Dataframe(
headers=["ID", "Fichier", "Dernière mise à jour"],
datatype=["str", "str", "str"],
elem_id="model_table"
)
# Vue JSON (pour référence)
with gr.Tab("Vue JSON (debug)"):
json_output = gr.JSON(label="Données brutes")
with gr.Column(scale=1):
refresh_btn = gr.Button("Rafraîchir", variant="primary")
status = gr.Textbox(label="Statut", value="Chargement des informations...", interactive=False, lines=8)
# Fonction pour mettre à jour tous les composants d'affichage
def update_all_displays():
models = get_loaded_models_list()
formatted_text = format_model_info(models)
# Préparer les données pour le tableau
table_data = []
for model in models:
table_data.append([
model.get('model_id', 'N/A'),
model.get('hf_filename', 'N/A'),
model.get('updated_at', 'N/A')
])
# Récupérer et formater les informations de mémoire
memory_status = get_memory_status(model_pipelines)
status_text = format_memory_status(memory_status)
return formatted_text, table_data, models, status_text
# Connecter les événements
refresh_btn.click(
fn=update_all_displays,
outputs=[markdown_output, model_table, json_output, status]
)
# Initialiser l'affichage de la mémoire dès le démarrage
memory_status = get_memory_status(model_pipelines)
status.value = format_memory_status(memory_status)
# Initialiser les affichages au chargement
gradio_app.load(fn=update_all_displays, outputs=[markdown_output, model_table, json_output, status])
# Monter l'application Gradio à la racine dans FastAPI
app = mount_gradio_app(
app, gradio_app, path="/"
)
# Event startup to load models (ensure it runs *after* Gradio is mounted if needed)
# We call init_models inside startup
@app.on_event("startup")
async def startup():
"""Initialiser l'API : charger les modèles depuis la base de données."""
await init_models() # Call the consolidated init function
@app.get("/health")
async def health_check():
"""Point d'entrée pour vérifier l'état de l'API."""
return {"status": "healthy"}