alexfremont commited on
Commit
e109700
·
1 Parent(s): 6f7e5fa

Refactor API architecture with modular design and database integration

Browse files
api/__init__.py ADDED
File without changes
api/prediction.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from fastapi import APIRouter, HTTPException
3
+ from fastapi.responses import JSONResponse
4
+ from PIL import Image
5
+ from io import BytesIO
6
+ import logging
7
+ import httpx
8
+ import asyncio
9
+ from torchvision import transforms
10
+
11
+ from schemas.requests import BatchPredictRequest
12
+ from models.loader import get_model
13
+ from steps.preprocess import process_image
14
+ from config.settings import IMAGE_SIZE, NUM_THREADS
15
+
16
+ logger = logging.getLogger(__name__)
17
+ router = APIRouter()
18
+
19
+ @router.post("/batch_predict")
20
+ async def batch_predict(request: BatchPredictRequest):
21
+ """Endpoint pour prédire à partir de plusieurs images."""
22
+ model_name = request.modelName
23
+
24
+ try:
25
+ # Récupérer le modèle
26
+ model = get_model(model_name)
27
+
28
+ semaphore = asyncio.Semaphore(NUM_THREADS) # Limiter à 8 tâches simultanées
29
+
30
+ async def process_single_image(image_url):
31
+ async with semaphore:
32
+ try:
33
+ async with httpx.AsyncClient() as client:
34
+ response = await client.get(image_url)
35
+ image = Image.open(BytesIO(response.content))
36
+ except Exception:
37
+ logger.error(f"Error downloading image from {image_url}")
38
+ return {"imageUrl": image_url, "error": "Invalid image URL"}
39
+
40
+ # Prétraiter l'image
41
+ processed_image = process_image(image, size=IMAGE_SIZE)
42
+
43
+ # Convertir en tenseur
44
+ image_tensor = transforms.ToTensor()(processed_image).unsqueeze(0)
45
+
46
+ # Inférence
47
+ with torch.no_grad():
48
+ outputs = model(image_tensor)
49
+ probabilities = torch.nn.functional.softmax(outputs, dim=1)
50
+ predicted_probabilities = probabilities.numpy().tolist()
51
+ confidence = round(predicted_probabilities[0][1], 2)
52
+
53
+ return {"imageUrl": image_url, "confidence": confidence}
54
+
55
+ # Lancer les tâches en parallèle
56
+ tasks = [process_single_image(url) for url in request.imageUrls]
57
+ results = await asyncio.gather(*tasks)
58
+
59
+ return JSONResponse(content={"results": results})
60
+
61
+ except KeyError:
62
+ raise HTTPException(status_code=404, detail="Model not found")
63
+ except Exception as e:
64
+ logger.error(f"Batch prediction error: {e}", exc_info=True)
65
+ raise HTTPException(status_code=500, detail="Batch prediction failed")
api/router.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, Request, HTTPException
2
+ import logging
3
+
4
+ from api import prediction
5
+ from config.settings import API_KEY
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+ # Router principal
10
+ router = APIRouter()
11
+
12
+ # Middleware d'authentification
13
+ async def verify_api_key(request: Request, call_next):
14
+ """Middleware pour vérifier la clé API dans les en-têtes."""
15
+ api_key = request.headers.get("x-api-key")
16
+ if api_key is None or api_key not in API_KEY.split(','):
17
+ logger.warning(f"Unauthorized API access attempt from {request.client.host}")
18
+ raise HTTPException(status_code=403, detail="Unauthorized")
19
+ response = await call_next(request)
20
+ return response
21
+
22
+ # Inclure les routes des autres modules
23
+ router.include_router(prediction.router, tags=["Prediction"])
architecture/__init__.py ADDED
File without changes
config/__init__.py ADDED
File without changes
config/settings.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import torch
4
+
5
+ # Configuration de base des logs
6
+ logging.basicConfig(level=logging.INFO)
7
+ logger = logging.getLogger(__name__)
8
+
9
+ # Variables d'environnement obligatoires
10
+ def get_env_or_fail(var_name: str) -> str:
11
+ """Récupérer une variable d'environnement ou échouer si elle n'est pas définie."""
12
+ value = os.environ.get(var_name)
13
+ if not value:
14
+ error_msg = f"{var_name} environment variable is not set or empty. API cannot start."
15
+ logger.error(error_msg)
16
+ raise RuntimeError(error_msg)
17
+ return value
18
+
19
+ # Configuration API
20
+ API_KEY = get_env_or_fail("api_key")
21
+ HF_TOKEN = get_env_or_fail("api_read")
22
+ RESOURCE_GROUP = get_env_or_fail("RESOURCE_GROUP")
23
+ DATABASE_URL = get_env_or_fail("DATABASE_URL")
24
+
25
+ # Log des paramètres importants (sans détails sensibles)
26
+ logger.info(f"RESOURCE_GROUP set to: {RESOURCE_GROUP}")
27
+
28
+ # Autres constantes
29
+ IMAGE_SIZE = 256
30
+ DEVICE = torch.device("cpu") # Changer pour "cuda" si GPU disponible
31
+ NUM_THREADS = 8 # Nombre de threads pour PyTorch
db/__init__.py ADDED
File without changes
db/models.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncpg
2
+ import logging
3
+ from typing import List, Dict, Any
4
+
5
+ from config.settings import DATABASE_URL
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+ async def fetch_models_for_group(resource_group: str) -> List[Dict[str, Any]]:
10
+ """Récupérer les détails des modèles depuis la base de données pour un groupe de ressources spécifique.
11
+
12
+ Args:
13
+ resource_group: Identifiant du groupe de ressources
14
+
15
+ Returns:
16
+ Liste de dictionnaires contenant les informations des modèles
17
+
18
+ Raises:
19
+ Exception: Si une erreur se produit lors de la connexion ou de la requête
20
+ """
21
+ conn = None
22
+ try:
23
+ conn = await asyncpg.connect(DATABASE_URL)
24
+ logger.info(f"Successfully connected to database")
25
+
26
+ # Récupérer les modèles spécifiques à ce groupe
27
+ query = """
28
+ SELECT
29
+ model_id,
30
+ name,
31
+ display_name,
32
+ hf_repo_id,
33
+ hf_subfolder,
34
+ hf_filename
35
+ FROM models
36
+ WHERE hf_resource_group = $1
37
+ """
38
+ rows = await conn.fetch(query, resource_group)
39
+ logger.info(f"Found {len(rows)} models for group '{resource_group}'")
40
+
41
+ return [dict(row) for row in rows]
42
+ except Exception as e:
43
+ logger.error(f"Database error: {e}", exc_info=True)
44
+ raise
45
+ finally:
46
+ if conn and not conn.is_closed():
47
+ await conn.close()
48
+ logger.debug("Database connection closed")
main.py CHANGED
@@ -1,223 +1,43 @@
1
- import os
2
- from fastapi import FastAPI, HTTPException
3
- from fastapi.responses import JSONResponse
4
- from pydantic import BaseModel
5
- from transformers import pipeline
6
- from torchvision import transforms
7
- from PIL import Image
8
- import requests
9
- from io import BytesIO
10
- from steps.preprocess import process_image
11
- from huggingface_hub import hf_hub_download
12
- from architecture.resnet import ResNet
13
- import torch
14
  import logging
15
- from typing import List
16
- import httpx
17
- import asyncio
18
 
19
- app = FastAPI()
 
 
 
20
 
21
- image_size = 256
22
- hf_token = os.environ.get("api_read")
23
- VALID_API_KEYS = os.environ.get("api_key")
24
- INSTANCE_GROUP = os.environ.get("INSTANCE_GROUP")
25
- if INSTANCE_GROUP:
26
- logging.info(f"INSTANCE_GROUP={INSTANCE_GROUP}")
27
- else:
28
- logging.warning("INSTANCE_GROUP not set; all models will be loaded")
29
 
 
 
 
 
 
 
30
 
31
- @app.middleware("http")
32
- async def verify_api_key(request, call_next):
33
- api_key = request.headers.get("x-api-key")
34
- if api_key is None or api_key not in VALID_API_KEYS:
35
- raise HTTPException(status_code=403, detail="Unauthorized")
36
- response = await call_next(request)
37
- return response
38
-
39
-
40
- models_locations = [
41
- # {
42
- # "repo_id": "TamisAI/category-lamp",
43
- # "subfolder": "maison-jansen/palmtree-152-0005-32-256",
44
- # "filename": "palmtree-jansen.pth",
45
- # },
46
- {
47
- "repo_id": "TamisAI/category-lamp",
48
- "subfolder": "maison-charles/corail-152-0001-32-256-L1",
49
- "filename": "maison-charles-corail-L1.pth",
50
- },
51
- {
52
- "repo_id": "TamisAI/category-lamp",
53
- "subfolder": "michel-armand/flamme-152-0001A-32-256-L1",
54
- "filename": "flamme-L1.pth",
55
- },
56
- ]
57
-
58
- device = torch.device("cpu")
59
-
60
-
61
- # Modèle de données pour les requêtes
62
- class PredictRequest(BaseModel):
63
- imageUrl: str
64
- modelName: str
65
-
66
-
67
- torch.set_num_threads(8)
68
-
69
- # Dictionnaire pour stocker les pipelines de modèles
70
- model_pipelines = {}
71
-
72
- # Create a single instance of the ResNet model
73
- base_model = ResNet("resnet152", num_output_neurons=2).to(device)
74
 
 
 
75
 
76
  @app.on_event("startup")
77
- async def load_models():
78
- # Charger les modèles au démarrage
79
- print(f"Loading models...{len(models_locations)}")
80
-
81
- for model_location in models_locations:
82
- try:
83
- print(f"Loading model: {model_location['filename']}")
84
- model_weight = hf_hub_download(
85
- repo_id=model_location["repo_id"],
86
- subfolder=model_location["subfolder"],
87
- filename=model_location["filename"],
88
- token=hf_token,
89
- )
90
- model = base_model.__class__("resnet152", num_output_neurons=2).to(device)
91
- model.load_state_dict(
92
- torch.load(model_weight, weights_only=True, map_location=device)
93
- )
94
- model.eval()
95
- model_pipelines[model_location["filename"]] = model
96
- except Exception as e:
97
- print(f"Error loading model {model_location['filename']}: {e}")
98
- print(f"Models loaded. {len(model_pipelines)}")
99
-
100
-
101
- @app.post("/predict")
102
- async def predict(request: PredictRequest):
103
- image_url = request.imageUrl
104
- model_name = request.modelName
105
-
106
- # Télécharger l'image depuis l'URL
107
- try:
108
- response = requests.get(image_url)
109
- image = Image.open(BytesIO(response.content))
110
- except Exception as e:
111
- raise HTTPException(status_code=400, detail="Invalid image URL")
112
-
113
- # Vérifier si le modèle est chargé
114
- if model_name not in model_pipelines:
115
- raise HTTPException(status_code=404, detail="Model not found")
116
-
117
- # Preprocess the image
118
- processed_image = process_image(image, size=image_size)
119
-
120
- # Convert to tensor
121
- image_tensor = transforms.ToTensor()(processed_image).unsqueeze(0)
122
-
123
- model = model_pipelines[model_name]
124
-
125
- # Perform inference
126
- with torch.no_grad():
127
- outputs = model(image_tensor)
128
- probabilities = torch.nn.functional.softmax(outputs, dim=1)
129
- predicted_probabilities = probabilities.numpy().tolist()
130
- confidence = round(predicted_probabilities[0][1], 2)
131
- logging.info("confidence: %s", confidence)
132
- # Return the probabilities as JSON
133
- return JSONResponse(content={"confidence": confidence})
134
-
135
-
136
- class BatchPredictRequest(BaseModel):
137
- imageUrls: List[str]
138
- modelName: str
139
-
140
-
141
- # @app.post("/batch_predict")
142
- # async def batch_predict(request: BatchPredictRequest):
143
- # model_name = request.modelName
144
- # results = []
145
-
146
- # # Verify if the model is loaded
147
- # if model_name not in model_pipelines:
148
- # raise HTTPException(status_code=404, detail="Model not found")
149
-
150
- # model = model_pipelines[model_name]
151
-
152
- # # Asynchronously process each image
153
- # async with httpx.AsyncClient() as client:
154
- # for image_url in request.imageUrls:
155
- # try:
156
- # response = await client.get(image_url)
157
- # image = Image.open(BytesIO(response.content))
158
- # except Exception as e:
159
- # results.append({"imageUrl": image_url, "error": "Invalid image URL"})
160
- # continue
161
-
162
- # # Preprocess the image
163
- # processed_image = process_image(image, size=image_size)
164
-
165
- # # Convert to tensor
166
- # image_tensor = transforms.ToTensor()(processed_image).unsqueeze(0)
167
-
168
- # # Perform inference
169
- # with torch.no_grad():
170
- # outputs = model(image_tensor)
171
- # probabilities = torch.nn.functional.softmax(outputs, dim=1)
172
- # predicted_probabilities = probabilities.numpy().tolist()
173
- # confidence = round(predicted_probabilities[0][1], 2)
174
-
175
- # results.append({"imageUrl": image_url, "confidence": confidence})
176
-
177
- # # Return the results as JSON
178
- # return JSONResponse(content={"results": results})
179
-
180
-
181
- @app.post("/batch_predict")
182
- async def batch_predict(request: BatchPredictRequest):
183
- model_name = request.modelName
184
-
185
- # Verify if the model is loaded
186
- if model_name not in model_pipelines:
187
- raise HTTPException(status_code=404, detail="Model not found")
188
-
189
- model = model_pipelines[model_name]
190
- semaphore = asyncio.Semaphore(
191
- 8
192
- ) # Limiter à 8 tâches simultanées pour éviter de surcharger la machine
193
-
194
- async def process_single_image(image_url):
195
- async with semaphore:
196
- try:
197
- async with httpx.AsyncClient() as client:
198
- response = await client.get(image_url)
199
- image = Image.open(BytesIO(response.content))
200
- except Exception:
201
- return {"imageUrl": image_url, "error": "Invalid image URL"}
202
-
203
- # Preprocess the image
204
- processed_image = process_image(image, size=image_size)
205
-
206
- # Convert to tensor
207
- image_tensor = transforms.ToTensor()(processed_image).unsqueeze(0)
208
-
209
- # Perform inference
210
- with torch.no_grad():
211
- outputs = model(image_tensor)
212
- probabilities = torch.nn.functional.softmax(outputs, dim=1)
213
- predicted_probabilities = probabilities.numpy().tolist()
214
- confidence = round(predicted_probabilities[0][1], 2)
215
-
216
- return {"imageUrl": image_url, "confidence": confidence}
217
-
218
- # Launch tasks in parallel
219
- tasks = [process_single_image(url) for url in request.imageUrls]
220
- results = await asyncio.gather(*tasks)
221
-
222
- # Return the results as JSON
223
- return JSONResponse(content={"results": results})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import logging
2
+ from fastapi import FastAPI
 
 
3
 
4
+ from api.router import router, verify_api_key
5
+ from db.models import fetch_models_for_group
6
+ from models.loader import load_models
7
+ from config.settings import RESOURCE_GROUP
8
 
9
+ # Configuration de base des logs
10
+ logging.basicConfig(
11
+ level=logging.INFO,
12
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
13
+ )
14
+ logger = logging.getLogger(__name__)
 
 
15
 
16
+ # Créer l'application FastAPI
17
+ app = FastAPI(
18
+ title="Tamis AI Inference API",
19
+ description="API pour l'inférence des modèles de classification d'objets",
20
+ version="0.1.0",
21
+ )
22
 
23
+ # Ajouter middleware d'authentification
24
+ app.middleware("http")(verify_api_key)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ # Inclure les routes
27
+ app.include_router(router)
28
 
29
  @app.on_event("startup")
30
+ async def startup():
31
+ """Initialiser l'API : charger les modèles depuis la base de données."""
32
+ logger.info("Starting API initialization...")
33
+
34
+ # Charger les modèles depuis la base de données
35
+ models_data = await fetch_models_for_group(RESOURCE_GROUP)
36
+ await load_models(models_data)
37
+
38
+ logger.info("API initialization complete.")
39
+
40
+ @app.get("/health")
41
+ async def health_check():
42
+ """Point d'entrée pour vérifier l'état de l'API."""
43
+ return {"status": "healthy"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/__init__.py ADDED
File without changes
models/loader.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import logging
3
+ from typing import Dict, List, Any
4
+ from huggingface_hub import hf_hub_download
5
+
6
+ from config.settings import DEVICE, HF_TOKEN, NUM_THREADS
7
+ from architecture.resnet import ResNet
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ # Configuration de PyTorch
12
+ torch.set_num_threads(NUM_THREADS)
13
+
14
+ # Instance de base pour le modèle ResNet
15
+ base_model = ResNet("resnet152", num_output_neurons=2).to(DEVICE)
16
+
17
+ # Dictionnaire global pour stocker les modèles chargés
18
+ model_pipelines = {}
19
+
20
+ async def load_models(models_data: List[Dict[str, Any]]) -> Dict[str, Any]:
21
+ """Charger les modèles depuis Hugging Face à partir des données de la base de données.
22
+
23
+ Args:
24
+ models_data: Liste de dictionnaires contenant les informations des modèles
25
+
26
+ Returns:
27
+ Dictionnaire des modèles chargés
28
+
29
+ Raises:
30
+ RuntimeError: Si aucun modèle n'est trouvé ou ne peut être chargé
31
+ """
32
+ logger.info(f"Attempting to load {len(models_data)} models...")
33
+
34
+ if not models_data:
35
+ error_msg = "No models found. API cannot start without models."
36
+ logger.error(error_msg)
37
+ raise RuntimeError(error_msg)
38
+
39
+ loaded_count = 0
40
+ for model_data in models_data:
41
+ try:
42
+ model_name = model_data['hf_filename']
43
+ logger.info(f"Loading model: {model_name} (repo: {model_data['hf_repo_id']}, subfolder: {model_data['hf_subfolder']})")
44
+
45
+ model_weight = hf_hub_download(
46
+ repo_id=model_data['hf_repo_id'],
47
+ subfolder=model_data['hf_subfolder'],
48
+ filename=model_name,
49
+ token=HF_TOKEN,
50
+ )
51
+
52
+ # Créer une nouvelle instance pour chaque modèle pour tenir ses poids spécifiques
53
+ model = base_model.__class__("resnet152", num_output_neurons=2).to(DEVICE)
54
+ model.load_state_dict(
55
+ torch.load(model_weight, weights_only=True, map_location=DEVICE)
56
+ )
57
+ model.eval()
58
+ model_pipelines[model_name] = model
59
+ loaded_count += 1
60
+ except Exception as e:
61
+ logger.error(f"Error loading model {model_data.get('hf_filename', 'N/A')}: {e}", exc_info=True)
62
+
63
+ logger.info(f"Model loading finished. Successfully loaded {loaded_count}/{len(models_data)} models.")
64
+
65
+ if loaded_count == 0:
66
+ error_msg = "Failed to load any models. API cannot start without models."
67
+ logger.error(error_msg)
68
+ raise RuntimeError(error_msg)
69
+
70
+ return model_pipelines
71
+
72
+ def get_model(model_name: str):
73
+ """Récupérer un modèle chargé par son nom.
74
+
75
+ Args:
76
+ model_name: Nom du modèle à récupérer
77
+
78
+ Returns:
79
+ Le modèle chargé
80
+
81
+ Raises:
82
+ KeyError: Si le modèle n'est pas trouvé
83
+ """
84
+ if model_name not in model_pipelines:
85
+ logger.error(f"Model {model_name} not found in loaded models")
86
+ raise KeyError(f"Model {model_name} not found")
87
+
88
+ return model_pipelines[model_name]
requirements.txt CHANGED
@@ -7,4 +7,5 @@ torchvision
7
  huggingface_hub
8
  torch
9
  numpy
10
- httpx
 
 
7
  huggingface_hub
8
  torch
9
  numpy
10
+ httpx
11
+ asyncpg
schemas/__init__.py ADDED
File without changes
schemas/requests.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+ from typing import List
3
+
4
+ class BatchPredictRequest(BaseModel):
5
+ """Modèle de requête pour des prédictions sur plusieurs images."""
6
+ imageUrls: List[str]
7
+ modelName: str
steps/__init__.py ADDED
File without changes
steps/preprocess.py CHANGED
@@ -1,4 +1,3 @@
1
- import os
2
  from PIL import Image
3
  import numpy as np
4
 
 
 
1
  from PIL import Image
2
  import numpy as np
3
 
utils/__init__.py ADDED
File without changes