Spaces:
Sleeping
Sleeping
import logging | |
import os | |
from contextlib import asynccontextmanager | |
from fastapi import FastAPI, Request | |
import gradio as gr | |
from gradio.routes import mount_gradio_app | |
from api.router import router | |
from db.models import fetch_models_for_group | |
from models.loader import load_models, model_pipelines | |
from config.settings import RESOURCE_GROUP, DATABASE_URL | |
from utils.system_monitor import get_memory_status, format_memory_status | |
# Configuration de base des logs | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
) | |
logger = logging.getLogger(__name__) | |
# --- Événements Startup/Shutdown (Lifespan Manager) --- | |
async def lifespan(app: FastAPI): | |
# Code exécuté au démarrage | |
logger.info("Starting up API...") | |
logger.info(f"Attempting to connect to database: {DATABASE_URL[:DATABASE_URL.find('@')] + '@...'}") # Masquer les crédentiels | |
try: | |
models_to_load = await fetch_models_for_group(RESOURCE_GROUP) | |
if models_to_load: | |
await load_models(models_to_load) | |
logger.info("Initial models loaded successfully.") | |
else: | |
logger.warning(f"No models found for resource group '{RESOURCE_GROUP}'. API starting without preloaded models.") | |
except Exception as e: | |
logger.exception(f"Failed to load initial models during startup: {e}") | |
# Décider s'il faut empêcher le démarrage de l'API ou continuer sans modèles | |
# raise RuntimeError("Could not load initial models, API startup aborted.") from e | |
yield | |
# Code exécuté à l'arrêt | |
logger.info("Shutting down API...") | |
# Ajouter ici le code de nettoyage si nécessaire (ex: fermer connexions persistantes) | |
# Créer l'application FastAPI | |
app = FastAPI( | |
title="Tamis AI Inference API", | |
description="API pour l'inférence des modèles de classification d'objets", | |
version="0.1.0", | |
lifespan=lifespan # Correction: Utilisation de la fonction lifespan définie ci-dessus | |
) | |
# Ajout du Middleware ici | |
async def api_key_middleware(request: Request, call_next): | |
"""Middleware pour vérifier la clé API et exempter certaines routes.""" | |
# Skip if we're in debug mode or during startup | |
if os.environ.get("DEBUG") == "1": | |
logger.debug("DEBUG mode active, skipping API key check.") | |
return await call_next(request) | |
# Liste des chemins publics ou internes à exempter de la vérification de clé | |
public_paths = [ | |
'/docs', '/openapi.json', # Documentation Swagger/OpenAPI | |
'/health', # Health check endpoint | |
'/', # Racine (Interface Gradio) | |
'/assets/', # Assets Gradio | |
'/file=', # Fichiers Gradio | |
'/queue/', # Queue Gradio | |
'/startup-logs', # Logs HF Space | |
'/config', # Config Gradio/HF | |
'/info', # Info Gradio/HF | |
'/gradio', # Potentiel préfixe Gradio | |
'/favicon.ico' # Favicon | |
] | |
# Vérifie si le chemin commence par un des préfixes publics | |
is_public = any(request.url.path == p or (p.endswith('/') and request.url.path.startswith(p)) for p in public_paths) | |
if is_public: | |
logger.debug(f"Public path accessed: {request.url.path}, skipping API key check.") | |
response = await call_next(request) | |
return response | |
else: | |
# Pour toutes les autres routes, la vérification se fait via Depends() sur l'endpoint lui-même. | |
# Ce middleware ne fait donc plus de vérification active ici, | |
# il sert juste à logger et potentiellement à exempter certaines routes si besoin. | |
logger.debug(f"Protected path accessed: {request.url.path}. API key verification delegated to endpoint.") | |
response = await call_next(request) | |
return response | |
# Inclure les routes | |
app.include_router(router) | |
async def init_models(): | |
"""Charger les modèles au démarrage pour Gradio et FastAPI.""" | |
logger.info("Initializing models for Gradio and FastAPI...") | |
try: | |
models_data = await fetch_models_for_group(RESOURCE_GROUP) | |
await load_models(models_data) | |
logger.info("Models loaded successfully.") | |
except Exception as e: | |
logger.error(f"Failed to initialize models: {e}", exc_info=True) | |
# Decide if the app should fail to start or continue without models | |
# raise RuntimeError("Model initialization failed.") | |
# For now, let's log and continue, Gradio will show an empty list | |
pass | |
# Définir les fonctions pour Gradio qui récupèrent les modèles chargés | |
def get_loaded_models_list(): | |
"""Retourne la liste des métadonnées des modèles actuellement chargés.""" | |
# Extraire les métadonnées de chaque entrée dans model_pipelines | |
return [item['metadata'] for item in model_pipelines.values()] | |
def format_model_info(metadata_list): | |
"""Formate les informations des modèles pour un affichage plus convivial.""" | |
if not metadata_list: | |
return "Aucun modèle chargé actuellement." | |
# Créer une version formatée des informations des modèles | |
formatted_info = "" | |
for model in metadata_list: | |
# Utiliser le nom du fichier comme titre principal | |
formatted_info += f"### Modèle : {model.get('hf_filename', 'N/A')}\n" | |
# Mettre l'ID et la dernière mise à jour sur une même ligne | |
formatted_info += f"**ID:** {model.get('model_id', 'N/A')} | **Dernière mise à jour:** {model.get('updated_at', 'N/A')}\n\n" | |
return formatted_info | |
# Créer l'interface Gradio | |
gradio_app = gr.Blocks(title="Tamis AI - Modèles Chargés", theme=gr.themes.Soft()) | |
with gradio_app: | |
gr.Markdown("# 🤖 Tamis AI - Interface d'administration") | |
gr.Markdown("## Modèles actuellement chargés dans l'API") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
# Visualisation des modèles avec des cartes | |
with gr.Tab("Vue détaillée"): | |
markdown_output = gr.Markdown(value="Chargement des modèles...", elem_id="model_details") | |
# Affichage en tableau | |
with gr.Tab("Vue tableau"): | |
model_table = gr.Dataframe( | |
headers=["ID", "Fichier", "Dernière mise à jour"], | |
datatype=["str", "str", "str"], | |
elem_id="model_table" | |
) | |
# Vue JSON (pour référence) | |
with gr.Tab("Vue JSON (debug)"): | |
json_output = gr.JSON(label="Données brutes") | |
with gr.Column(scale=1): | |
refresh_btn = gr.Button("Rafraîchir", variant="primary") | |
status = gr.Textbox(label="Statut", value="Chargement des informations...", interactive=False, lines=8) | |
# Fonction pour mettre à jour tous les composants d'affichage | |
def update_all_displays(): | |
models = get_loaded_models_list() | |
formatted_text = format_model_info(models) | |
# Préparer les données pour le tableau | |
table_data = [] | |
for model in models: | |
table_data.append([ | |
model.get('model_id', 'N/A'), | |
model.get('hf_filename', 'N/A'), | |
model.get('updated_at', 'N/A') | |
]) | |
# Récupérer et formater les informations de mémoire | |
memory_status = get_memory_status(model_pipelines) | |
status_text = format_memory_status(memory_status) | |
return formatted_text, table_data, models, status_text | |
# Connecter les événements | |
refresh_btn.click( | |
fn=update_all_displays, | |
outputs=[markdown_output, model_table, json_output, status] | |
) | |
# Initialiser l'affichage de la mémoire dès le démarrage | |
memory_status = get_memory_status(model_pipelines) | |
status.value = format_memory_status(memory_status) | |
# Initialiser les affichages au chargement | |
gradio_app.load(fn=update_all_displays, outputs=[markdown_output, model_table, json_output, status]) | |
# Monter l'application Gradio à la racine dans FastAPI | |
app = mount_gradio_app( | |
app, gradio_app, path="/" | |
) | |
# Event startup to load models (ensure it runs *after* Gradio is mounted if needed) | |
# We call init_models inside startup | |
async def startup(): | |
"""Initialiser l'API : charger les modèles depuis la base de données.""" | |
await init_models() # Call the consolidated init function | |
async def health_check(): | |
"""Point d'entrée pour vérifier l'état de l'API.""" | |
return {"status": "healthy"} | |