lamhieu commited on
Commit
716eebd
·
1 Parent(s): 5264f8d

perf: improvement requests handling with async mode

Browse files
lightweight_embeddings/router.py CHANGED
@@ -2,8 +2,8 @@ from __future__ import annotations
2
 
3
  import logging
4
  import os
5
- from typing import Dict, List, Union
6
  from datetime import datetime
 
7
 
8
  from fastapi import APIRouter, BackgroundTasks, HTTPException
9
  from pydantic import BaseModel, Field
@@ -27,45 +27,45 @@ router = APIRouter(
27
 
28
  class EmbeddingRequest(BaseModel):
29
  """
30
- Input to /v1/embeddings
31
  """
32
 
33
  model: str = Field(
34
  default=TextModelType.MULTILINGUAL_E5_SMALL.value,
35
  description=(
36
  "Which model ID to use? "
37
- "Text: ['multilingual-e5-small', 'multilingual-e5-base', 'multilingual-e5-large', 'snowflake-arctic-embed-l-v2.0', 'paraphrase-multilingual-MiniLM-L12-v2', 'paraphrase-multilingual-mpnet-base-v2', 'bge-m3']. "
38
- "Image: ['siglip-base-patch16-256-multilingual']."
 
 
39
  ),
40
  )
41
  input: Union[str, List[str]] = Field(
42
- ..., description="Text(s) or Image URL(s)/path(s)."
43
  )
44
 
45
 
46
  class RankRequest(BaseModel):
47
  """
48
- Input to /v1/rank
49
  """
50
 
51
  model: str = Field(
52
  default=TextModelType.MULTILINGUAL_E5_SMALL.value,
53
  description=(
54
  "Model ID for the queries. "
55
- "Text or Image model, e.g. 'siglip-base-patch16-256-multilingual' for images."
56
  ),
57
  )
58
  queries: Union[str, List[str]] = Field(
59
- ..., description="Query text or image(s) depending on the model type."
60
- )
61
- candidates: List[str] = Field(
62
- ..., description="Candidate texts to rank. Must be text."
63
  )
 
64
 
65
 
66
  class EmbeddingResponse(BaseModel):
67
  """
68
- Response of /v1/embeddings
69
  """
70
 
71
  object: str
@@ -76,7 +76,7 @@ class EmbeddingResponse(BaseModel):
76
 
77
  class RankResponse(BaseModel):
78
  """
79
- Response of /v1/rank
80
  """
81
 
82
  probabilities: List[List[float]]
@@ -84,7 +84,9 @@ class RankResponse(BaseModel):
84
 
85
 
86
  class StatsBucket(BaseModel):
87
- """Helper model for daily/weekly/monthly/yearly stats"""
 
 
88
 
89
  total: Dict[str, int]
90
  daily: Dict[str, int]
@@ -94,12 +96,15 @@ class StatsBucket(BaseModel):
94
 
95
 
96
  class StatsResponse(BaseModel):
97
- """Analytics stats response model, including both access and token counts"""
 
 
98
 
99
  access: StatsBucket
100
  tokens: StatsBucket
101
 
102
 
 
103
  service_config = ModelConfig()
104
  embeddings_service = EmbeddingsService(config=service_config)
105
 
@@ -115,16 +120,16 @@ async def create_embeddings(
115
  request: EmbeddingRequest, background_tasks: BackgroundTasks
116
  ):
117
  """
118
- Generates embeddings for the given input (text or image).
119
  """
120
  try:
121
  modality = detect_model_kind(request.model)
122
  embeddings = await embeddings_service.generate_embeddings(
123
- inputs=request.input,
124
  model=request.model,
 
125
  )
126
 
127
- # Estimate tokens for text only
128
  total_tokens = 0
129
  if modality == ModelKind.TEXT:
130
  total_tokens = embeddings_service.estimate_tokens(request.input)
@@ -148,6 +153,7 @@ async def create_embeddings(
148
  }
149
  )
150
 
 
151
  background_tasks.add_task(
152
  analytics.access, request.model, resp["usage"]["total_tokens"]
153
  )
@@ -166,7 +172,7 @@ async def create_embeddings(
166
  @router.post("/rank", response_model=RankResponse, tags=["rank"])
167
  async def rank_candidates(request: RankRequest, background_tasks: BackgroundTasks):
168
  """
169
- Ranks candidate texts against the given queries (which can be text or image).
170
  """
171
  try:
172
  results = await embeddings_service.rank(
@@ -175,6 +181,7 @@ async def rank_candidates(request: RankRequest, background_tasks: BackgroundTask
175
  candidates=request.candidates,
176
  )
177
 
 
178
  background_tasks.add_task(
179
  analytics.access, request.model, results["usage"]["total_tokens"]
180
  )
@@ -192,14 +199,18 @@ async def rank_candidates(request: RankRequest, background_tasks: BackgroundTask
192
 
193
  @router.get("/stats", response_model=StatsResponse, tags=["stats"])
194
  async def get_stats():
195
- """Get usage statistics for all models, including access and tokens."""
 
 
196
  try:
197
  day_key = datetime.utcnow().strftime("%Y-%m-%d")
198
  week_key = f"{datetime.utcnow().year}-W{datetime.utcnow().strftime('%U')}"
199
  month_key = datetime.utcnow().strftime("%Y-%m")
200
  year_key = datetime.utcnow().strftime("%Y")
201
 
202
- stats_data = await analytics.stats() # { "access": {...}, "tokens": {...} }
 
 
203
 
204
  return {
205
  "access": {
 
2
 
3
  import logging
4
  import os
 
5
  from datetime import datetime
6
+ from typing import Dict, List, Union
7
 
8
  from fastapi import APIRouter, BackgroundTasks, HTTPException
9
  from pydantic import BaseModel, Field
 
27
 
28
  class EmbeddingRequest(BaseModel):
29
  """
30
+ Request model for generating embeddings.
31
  """
32
 
33
  model: str = Field(
34
  default=TextModelType.MULTILINGUAL_E5_SMALL.value,
35
  description=(
36
  "Which model ID to use? "
37
+ "Text options: ['multilingual-e5-small', 'multilingual-e5-base', 'multilingual-e5-large', "
38
+ "'snowflake-arctic-embed-l-v2.0', 'paraphrase-multilingual-MiniLM-L12-v2', "
39
+ "'paraphrase-multilingual-mpnet-base-v2', 'bge-m3']. "
40
+ "Image option: ['siglip-base-patch16-256-multilingual']."
41
  ),
42
  )
43
  input: Union[str, List[str]] = Field(
44
+ ..., description="Text(s) or image URL(s)/path(s)."
45
  )
46
 
47
 
48
  class RankRequest(BaseModel):
49
  """
50
+ Request model for ranking candidates.
51
  """
52
 
53
  model: str = Field(
54
  default=TextModelType.MULTILINGUAL_E5_SMALL.value,
55
  description=(
56
  "Model ID for the queries. "
57
+ "Can be a text or image model (e.g. 'siglip-base-patch16-256-multilingual' for images)."
58
  ),
59
  )
60
  queries: Union[str, List[str]] = Field(
61
+ ..., description="Query text(s) or image(s) depending on the model type."
 
 
 
62
  )
63
+ candidates: List[str] = Field(..., description="Candidate texts to rank.")
64
 
65
 
66
  class EmbeddingResponse(BaseModel):
67
  """
68
+ Response model for embeddings.
69
  """
70
 
71
  object: str
 
76
 
77
  class RankResponse(BaseModel):
78
  """
79
+ Response model for ranking results.
80
  """
81
 
82
  probabilities: List[List[float]]
 
84
 
85
 
86
  class StatsBucket(BaseModel):
87
+ """
88
+ Model for daily/weekly/monthly/yearly stats.
89
+ """
90
 
91
  total: Dict[str, int]
92
  daily: Dict[str, int]
 
96
 
97
 
98
  class StatsResponse(BaseModel):
99
+ """
100
+ Analytics stats response model, including both access and token counts.
101
+ """
102
 
103
  access: StatsBucket
104
  tokens: StatsBucket
105
 
106
 
107
+ # Initialize the embeddings service and analytics.
108
  service_config = ModelConfig()
109
  embeddings_service = EmbeddingsService(config=service_config)
110
 
 
120
  request: EmbeddingRequest, background_tasks: BackgroundTasks
121
  ):
122
  """
123
+ Generate embeddings for the given text or image inputs.
124
  """
125
  try:
126
  modality = detect_model_kind(request.model)
127
  embeddings = await embeddings_service.generate_embeddings(
 
128
  model=request.model,
129
+ inputs=request.input,
130
  )
131
 
132
+ # Estimate tokens if using a text model.
133
  total_tokens = 0
134
  if modality == ModelKind.TEXT:
135
  total_tokens = embeddings_service.estimate_tokens(request.input)
 
153
  }
154
  )
155
 
156
+ # Record analytics in the background.
157
  background_tasks.add_task(
158
  analytics.access, request.model, resp["usage"]["total_tokens"]
159
  )
 
172
  @router.post("/rank", response_model=RankResponse, tags=["rank"])
173
  async def rank_candidates(request: RankRequest, background_tasks: BackgroundTasks):
174
  """
175
+ Rank candidate texts against the given queries.
176
  """
177
  try:
178
  results = await embeddings_service.rank(
 
181
  candidates=request.candidates,
182
  )
183
 
184
+ # Record analytics in the background.
185
  background_tasks.add_task(
186
  analytics.access, request.model, results["usage"]["total_tokens"]
187
  )
 
199
 
200
  @router.get("/stats", response_model=StatsResponse, tags=["stats"])
201
  async def get_stats():
202
+ """
203
+ Retrieve usage statistics for all models, including access counts and token usage.
204
+ """
205
  try:
206
  day_key = datetime.utcnow().strftime("%Y-%m-%d")
207
  week_key = f"{datetime.utcnow().year}-W{datetime.utcnow().strftime('%U')}"
208
  month_key = datetime.utcnow().strftime("%Y-%m")
209
  year_key = datetime.utcnow().strftime("%Y")
210
 
211
+ stats_data = (
212
+ await analytics.stats()
213
+ ) # Expected to return a dict with 'access' and 'tokens' keys
214
 
215
  return {
216
  "access": {
lightweight_embeddings/service.py CHANGED
@@ -1,5 +1,6 @@
1
  from __future__ import annotations
2
 
 
3
  import logging
4
  from enum import Enum
5
  from typing import List, Union, Dict, Optional, NamedTuple, Any
@@ -9,7 +10,7 @@ from io import BytesIO
9
  from hashlib import md5
10
  from cachetools import LRUCache
11
 
12
- import requests
13
  import numpy as np
14
  import torch
15
  from PIL import Image
@@ -45,9 +46,7 @@ class ImageModelType(str, Enum):
45
 
46
  class ModelInfo(NamedTuple):
47
  """
48
- This container maps an enum to:
49
- - model_id: Hugging Face model ID (or local path)
50
- - onnx_file: Path to ONNX file (if available)
51
  """
52
 
53
  model_id: str
@@ -69,7 +68,7 @@ class ModelConfig:
69
  @property
70
  def text_model_info(self) -> ModelInfo:
71
  """
72
- Returns ModelInfo for the configured text_model_type.
73
  """
74
  text_configs = {
75
  TextModelType.MULTILINGUAL_E5_SMALL: ModelInfo(
@@ -110,7 +109,7 @@ class ModelConfig:
110
  @property
111
  def image_model_info(self) -> ModelInfo:
112
  """
113
- Returns ModelInfo for the configured image_model_type.
114
  """
115
  image_configs = {
116
  ImageModelType.SIGLIP_BASE_PATCH16_256_MULTILINGUAL: ModelInfo(
@@ -121,14 +120,20 @@ class ModelConfig:
121
 
122
 
123
  class ModelKind(str, Enum):
 
 
 
 
124
  TEXT = "text"
125
  IMAGE = "image"
126
 
127
 
128
  def detect_model_kind(model_id: str) -> ModelKind:
129
  """
130
- Detect whether model_id belongs to a text or an image model.
131
- Raises ValueError if the model is not recognized.
 
 
132
  """
133
  if model_id in [m.value for m in TextModelType]:
134
  return ModelKind.TEXT
@@ -145,32 +150,38 @@ def detect_model_kind(model_id: str) -> ModelKind:
145
  class EmbeddingsService:
146
  """
147
  Service for generating text/image embeddings and performing similarity ranking.
148
- Batch size has been removed. Single or multiple inputs are handled uniformly.
149
  """
150
 
151
  def __init__(self, config: Optional[ModelConfig] = None):
 
 
 
 
152
  self.lru_cache = LRUCache(maxsize=10_000)
153
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
154
  self.config = config or ModelConfig()
155
 
156
- # Dictionaries to hold preloaded models
157
  self.text_models: Dict[TextModelType, SentenceTransformer] = {}
158
  self.image_models: Dict[ImageModelType, AutoModel] = {}
159
  self.image_processors: Dict[ImageModelType, AutoProcessor] = {}
160
 
161
- # Load all relevant models on init
 
 
 
162
  self._load_all_models()
163
 
164
  def _load_all_models(self) -> None:
165
  """
166
- Pre-load all known text and image models for quick switching.
167
  """
168
  try:
169
- # Preload text models
170
  for t_model_type in TextModelType:
171
  info = ModelConfig(text_model_type=t_model_type).text_model_info
172
  logger.info("Loading text model: %s", info.model_id)
173
-
174
  if info.onnx_file:
175
  logger.info("Using ONNX file: %s", info.onnx_file)
176
  self.text_models[t_model_type] = SentenceTransformer(
@@ -190,16 +201,15 @@ class EmbeddingsService:
190
  trust_remote_code=True,
191
  )
192
 
193
- # Preload image models
194
  for i_model_type in ImageModelType:
195
  model_id = ModelConfig(
196
  image_model_type=i_model_type
197
  ).image_model_info.model_id
198
  logger.info("Loading image model: %s", model_id)
199
-
200
  model = AutoModel.from_pretrained(model_id).to(self.device)
 
201
  processor = AutoProcessor.from_pretrained(model_id)
202
-
203
  self.image_models[i_model_type] = model
204
  self.image_processors[i_model_type] = processor
205
 
@@ -212,8 +222,10 @@ class EmbeddingsService:
212
  @staticmethod
213
  def _validate_text_list(input_text: Union[str, List[str]]) -> List[str]:
214
  """
215
- Convert text input into a non-empty list of strings.
216
- Raises ValueError if the input is invalid.
 
 
217
  """
218
  if isinstance(input_text, str):
219
  if not input_text.strip():
@@ -233,8 +245,10 @@ class EmbeddingsService:
233
  @staticmethod
234
  def _validate_image_list(input_images: Union[str, List[str]]) -> List[str]:
235
  """
236
- Convert image input into a non-empty list of image paths/URLs.
237
- Raises ValueError if the input is invalid.
 
 
238
  """
239
  if isinstance(input_images, str):
240
  if not input_images.strip():
@@ -251,24 +265,51 @@ class EmbeddingsService:
251
 
252
  return input_images
253
 
254
- def _process_image(self, path_or_url: str) -> Dict[str, torch.Tensor]:
255
  """
256
- Loads and processes a single image from local path or URL.
257
- Returns a dictionary of tensors ready for the model.
 
 
 
 
 
 
 
 
258
  """
259
  try:
260
  if path_or_url.startswith("http"):
261
- resp = requests.get(path_or_url, timeout=10)
262
- resp.raise_for_status()
263
- img = Image.open(BytesIO(resp.content)).convert("RGB")
 
 
264
  else:
265
- img = Image.open(Path(path_or_url)).convert("RGB")
266
-
267
- processor = self.image_processors[self.config.image_model_type]
268
- processed_data = processor(images=img, return_tensors="pt").to(self.device)
269
- return processed_data
270
  except Exception as e:
271
- raise ValueError(f"Error processing image '{path_or_url}': {str(e)}") from e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
 
273
  def _generate_text_embeddings(
274
  self,
@@ -276,8 +317,14 @@ class EmbeddingsService:
276
  texts: List[str],
277
  ) -> np.ndarray:
278
  """
279
- Generates text embeddings using the SentenceTransformer-based model.
280
- Utilizes an LRU cache for single-input scenarios.
 
 
 
 
 
 
281
  """
282
  try:
283
  if len(texts) == 1:
@@ -285,48 +332,54 @@ class EmbeddingsService:
285
  key = md5(f"{model_id}:{single_text}".encode("utf-8")).hexdigest()[:8]
286
  if key in self.lru_cache:
287
  return self.lru_cache[key]
288
-
289
  model = self.text_models[model_id]
290
  emb = model.encode([single_text])
291
  self.lru_cache[key] = emb
292
  return emb
293
 
294
- # For multiple texts, no LRU cache is used
295
  model = self.text_models[model_id]
296
  return model.encode(texts)
297
-
298
  except Exception as e:
299
  raise RuntimeError(
300
  f"Error generating text embeddings with model '{model_id}': {e}"
301
  ) from e
302
 
303
- def _generate_image_embeddings(
304
  self,
305
  model_id: ImageModelType,
306
  images: List[str],
307
  ) -> np.ndarray:
308
  """
309
- Generates image embeddings using the CLIP-like transformer model.
310
- Handles single or multiple images uniformly (no batch size parameter).
 
 
 
 
 
 
 
 
311
  """
312
  try:
313
- model = self.image_models[model_id]
314
- # Collect processed inputs in a single batch
315
- processed_tensors = []
316
- for img_path in images:
317
- processed_tensors.append(self._process_image(img_path))
318
-
319
- # Keys should be the same for all processed outputs
320
  keys = processed_tensors[0].keys()
321
- # Concatenate along the batch dimension
322
  combined = {
323
  k: torch.cat([pt[k] for pt in processed_tensors], dim=0) for k in keys
324
  }
325
 
326
- with torch.no_grad():
327
- embeddings = model.get_image_features(**combined)
328
- return embeddings.cpu().numpy()
 
 
 
329
 
 
330
  except Exception as e:
331
  raise RuntimeError(
332
  f"Error generating image embeddings with model '{model_id}': {e}"
@@ -338,19 +391,28 @@ class EmbeddingsService:
338
  inputs: Union[str, List[str]],
339
  ) -> np.ndarray:
340
  """
341
- Asynchronously generates embeddings for either text or image based on the model type.
 
 
 
 
 
 
 
342
  """
343
  modality = detect_model_kind(model)
344
-
345
  if modality == ModelKind.TEXT:
346
  text_model_id = TextModelType(model)
347
  text_list = self._validate_text_list(inputs)
348
- return self._generate_text_embeddings(text_model_id, text_list)
349
-
 
350
  elif modality == ModelKind.IMAGE:
351
  image_model_id = ImageModelType(model)
352
  image_list = self._validate_image_list(inputs)
353
- return self._generate_image_embeddings(image_model_id, image_list)
 
 
354
 
355
  async def rank(
356
  self,
@@ -359,35 +421,32 @@ class EmbeddingsService:
359
  candidates: Union[str, List[str]],
360
  ) -> Dict[str, Any]:
361
  """
362
- Ranks text `candidates` given `queries`, which can be text or images.
363
- Always returns a dictionary of { probabilities, cosine_similarities, usage }.
364
 
365
- Note: This implementation uses the same model for both queries and candidates.
366
- For true cross-modal ranking, you might need separate models or a shared model.
367
  """
368
  modality = detect_model_kind(model)
369
-
370
- # Convert the string model to the appropriate enum
371
  if modality == ModelKind.TEXT:
372
  model_enum = TextModelType(model)
373
  else:
374
  model_enum = ImageModelType(model)
375
 
376
- # 1) Generate embeddings for queries
377
- query_embeds = await self.generate_embeddings(model_enum.value, queries)
378
-
379
- # 2) Generate embeddings for candidates (assumed text if queries are text;
380
- # or if queries are images, also use the image model for candidates).
381
- candidate_embeds = await self.generate_embeddings(model_enum.value, candidates)
 
 
382
 
383
- # 3) Compute cosine similarity
384
  sim_matrix = self.cosine_similarity(query_embeds, candidate_embeds)
385
-
386
- # 4) Apply logit scale + softmax to obtain probabilities
387
  scaled = np.exp(self.config.logit_scale) * sim_matrix
388
  probs = self.softmax(scaled)
389
 
390
- # 5) Estimate token usage if we're dealing with text
391
  if modality == ModelKind.TEXT:
392
  query_tokens = self.estimate_tokens(queries)
393
  candidate_tokens = self.estimate_tokens(candidates)
@@ -408,32 +467,41 @@ class EmbeddingsService:
408
 
409
  def estimate_tokens(self, input_data: Union[str, List[str]]) -> int:
410
  """
411
- Estimates token count using the SentenceTransformer tokenizer.
412
- Only applicable if the current configured model is a text model.
 
 
413
  """
414
  texts = self._validate_text_list(input_data)
415
  model = self.text_models[self.config.text_model_type]
416
  tokenized = model.tokenize(texts)
417
- # Summing over the lengths of input_ids for each example
418
  return sum(len(ids) for ids in tokenized["input_ids"])
419
 
420
  @staticmethod
421
  def softmax(scores: np.ndarray) -> np.ndarray:
422
  """
423
- Applies the standard softmax function along the last dimension.
 
 
 
424
  """
425
- # Stabilize scores by subtracting max
426
  exps = np.exp(scores - np.max(scores, axis=-1, keepdims=True))
427
  return exps / np.sum(exps, axis=-1, keepdims=True)
428
 
429
  @staticmethod
430
  def cosine_similarity(a: np.ndarray, b: np.ndarray) -> np.ndarray:
431
  """
432
- Computes the pairwise cosine similarity between all rows of a and b.
433
- a: (N, D)
434
- b: (M, D)
435
- Return: (N, M) matrix of cosine similarities
436
  """
437
  a_norm = a / (np.linalg.norm(a, axis=1, keepdims=True) + 1e-9)
438
  b_norm = b / (np.linalg.norm(b, axis=1, keepdims=True) + 1e-9)
439
  return np.dot(a_norm, b_norm.T)
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
3
+ import asyncio
4
  import logging
5
  from enum import Enum
6
  from typing import List, Union, Dict, Optional, NamedTuple, Any
 
10
  from hashlib import md5
11
  from cachetools import LRUCache
12
 
13
+ import httpx
14
  import numpy as np
15
  import torch
16
  from PIL import Image
 
46
 
47
  class ModelInfo(NamedTuple):
48
  """
49
+ Container mapping a model type to its model identifier and optional ONNX file.
 
 
50
  """
51
 
52
  model_id: str
 
68
  @property
69
  def text_model_info(self) -> ModelInfo:
70
  """
71
+ Return model information for the configured text model.
72
  """
73
  text_configs = {
74
  TextModelType.MULTILINGUAL_E5_SMALL: ModelInfo(
 
109
  @property
110
  def image_model_info(self) -> ModelInfo:
111
  """
112
+ Return model information for the configured image model.
113
  """
114
  image_configs = {
115
  ImageModelType.SIGLIP_BASE_PATCH16_256_MULTILINGUAL: ModelInfo(
 
120
 
121
 
122
  class ModelKind(str, Enum):
123
+ """
124
+ Indicates the type of model: text or image.
125
+ """
126
+
127
  TEXT = "text"
128
  IMAGE = "image"
129
 
130
 
131
  def detect_model_kind(model_id: str) -> ModelKind:
132
  """
133
+ Detect whether the model identifier corresponds to a text or image model.
134
+
135
+ Raises:
136
+ ValueError: If the model identifier is unrecognized.
137
  """
138
  if model_id in [m.value for m in TextModelType]:
139
  return ModelKind.TEXT
 
150
  class EmbeddingsService:
151
  """
152
  Service for generating text/image embeddings and performing similarity ranking.
153
+ Asynchronous methods are used to maximize throughput and avoid blocking the event loop.
154
  """
155
 
156
  def __init__(self, config: Optional[ModelConfig] = None):
157
+ """
158
+ Initialize the service by setting up model caches, device configuration,
159
+ and asynchronous HTTP client.
160
+ """
161
  self.lru_cache = LRUCache(maxsize=10_000)
162
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
163
  self.config = config or ModelConfig()
164
 
165
+ # Dictionaries to hold preloaded models.
166
  self.text_models: Dict[TextModelType, SentenceTransformer] = {}
167
  self.image_models: Dict[ImageModelType, AutoModel] = {}
168
  self.image_processors: Dict[ImageModelType, AutoProcessor] = {}
169
 
170
+ # Create a persistent asynchronous HTTP client.
171
+ self.async_http_client = httpx.AsyncClient(timeout=10)
172
+
173
+ # Preload all models.
174
  self._load_all_models()
175
 
176
  def _load_all_models(self) -> None:
177
  """
178
+ Pre-load all text and image models to minimize latency at request time.
179
  """
180
  try:
181
+ # Preload text models.
182
  for t_model_type in TextModelType:
183
  info = ModelConfig(text_model_type=t_model_type).text_model_info
184
  logger.info("Loading text model: %s", info.model_id)
 
185
  if info.onnx_file:
186
  logger.info("Using ONNX file: %s", info.onnx_file)
187
  self.text_models[t_model_type] = SentenceTransformer(
 
201
  trust_remote_code=True,
202
  )
203
 
204
+ # Preload image models.
205
  for i_model_type in ImageModelType:
206
  model_id = ModelConfig(
207
  image_model_type=i_model_type
208
  ).image_model_info.model_id
209
  logger.info("Loading image model: %s", model_id)
 
210
  model = AutoModel.from_pretrained(model_id).to(self.device)
211
+ model.eval() # Set the model to evaluation mode.
212
  processor = AutoProcessor.from_pretrained(model_id)
 
213
  self.image_models[i_model_type] = model
214
  self.image_processors[i_model_type] = processor
215
 
 
222
  @staticmethod
223
  def _validate_text_list(input_text: Union[str, List[str]]) -> List[str]:
224
  """
225
+ Validate and convert text input into a non-empty list of strings.
226
+
227
+ Raises:
228
+ ValueError: If the input is invalid.
229
  """
230
  if isinstance(input_text, str):
231
  if not input_text.strip():
 
245
  @staticmethod
246
  def _validate_image_list(input_images: Union[str, List[str]]) -> List[str]:
247
  """
248
+ Validate and convert image input into a non-empty list of image paths/URLs.
249
+
250
+ Raises:
251
+ ValueError: If the input is invalid.
252
  """
253
  if isinstance(input_images, str):
254
  if not input_images.strip():
 
265
 
266
  return input_images
267
 
268
+ async def _fetch_image(self, path_or_url: str) -> Image.Image:
269
  """
270
+ Asynchronously fetch an image from a URL or load from a local path.
271
+
272
+ Args:
273
+ path_or_url: The URL or file path of the image.
274
+
275
+ Returns:
276
+ A PIL Image in RGB mode.
277
+
278
+ Raises:
279
+ ValueError: If image fetching or processing fails.
280
  """
281
  try:
282
  if path_or_url.startswith("http"):
283
+ # Asynchronously fetch the image bytes.
284
+ response = await self.async_http_client.get(path_or_url)
285
+ response.raise_for_status()
286
+ # Offload the blocking I/O (PIL image opening) to a thread.
287
+ img = await asyncio.to_thread(Image.open, BytesIO(response.content))
288
  else:
289
+ # Offload file I/O to a thread.
290
+ img = await asyncio.to_thread(Image.open, Path(path_or_url))
291
+ return img.convert("RGB")
 
 
292
  except Exception as e:
293
+ raise ValueError(f"Error fetching image '{path_or_url}': {str(e)}") from e
294
+
295
+ async def _process_image(self, path_or_url: str) -> Dict[str, torch.Tensor]:
296
+ """
297
+ Asynchronously load and process a single image.
298
+
299
+ Args:
300
+ path_or_url: The image URL or local path.
301
+
302
+ Returns:
303
+ A dictionary of processed tensors ready for model input.
304
+
305
+ Raises:
306
+ ValueError: If image processing fails.
307
+ """
308
+ img = await self._fetch_image(path_or_url)
309
+ processor = self.image_processors[self.config.image_model_type]
310
+ # Note: Processor may perform CPU-intensive work; if needed, offload to thread.
311
+ processed_data = processor(images=img, return_tensors="pt").to(self.device)
312
+ return processed_data
313
 
314
  def _generate_text_embeddings(
315
  self,
 
317
  texts: List[str],
318
  ) -> np.ndarray:
319
  """
320
+ Generate text embeddings using the SentenceTransformer model.
321
+ Single-text requests are cached using an LRU cache.
322
+
323
+ Returns:
324
+ A NumPy array of text embeddings.
325
+
326
+ Raises:
327
+ RuntimeError: If text embedding generation fails.
328
  """
329
  try:
330
  if len(texts) == 1:
 
332
  key = md5(f"{model_id}:{single_text}".encode("utf-8")).hexdigest()[:8]
333
  if key in self.lru_cache:
334
  return self.lru_cache[key]
 
335
  model = self.text_models[model_id]
336
  emb = model.encode([single_text])
337
  self.lru_cache[key] = emb
338
  return emb
339
 
 
340
  model = self.text_models[model_id]
341
  return model.encode(texts)
 
342
  except Exception as e:
343
  raise RuntimeError(
344
  f"Error generating text embeddings with model '{model_id}': {e}"
345
  ) from e
346
 
347
+ async def _async_generate_image_embeddings(
348
  self,
349
  model_id: ImageModelType,
350
  images: List[str],
351
  ) -> np.ndarray:
352
  """
353
+ Asynchronously generate image embeddings.
354
+
355
+ This method concurrently processes multiple images and offloads
356
+ the blocking model inference to a separate thread.
357
+
358
+ Returns:
359
+ A NumPy array of image embeddings.
360
+
361
+ Raises:
362
+ RuntimeError: If image embedding generation fails.
363
  """
364
  try:
365
+ # Concurrently process all images.
366
+ processed_tensors = await asyncio.gather(
367
+ *[self._process_image(img_path) for img_path in images]
368
+ )
369
+ # Assume all processed outputs have the same keys.
 
 
370
  keys = processed_tensors[0].keys()
 
371
  combined = {
372
  k: torch.cat([pt[k] for pt in processed_tensors], dim=0) for k in keys
373
  }
374
 
375
+ def infer():
376
+ with torch.no_grad():
377
+ embeddings = self.image_models[model_id].get_image_features(
378
+ **combined
379
+ )
380
+ return embeddings.cpu().numpy()
381
 
382
+ return await asyncio.to_thread(infer)
383
  except Exception as e:
384
  raise RuntimeError(
385
  f"Error generating image embeddings with model '{model_id}': {e}"
 
391
  inputs: Union[str, List[str]],
392
  ) -> np.ndarray:
393
  """
394
+ Asynchronously generate embeddings for text or image inputs based on model type.
395
+
396
+ Args:
397
+ model: The model identifier.
398
+ inputs: The text or image input(s).
399
+
400
+ Returns:
401
+ A NumPy array of embeddings.
402
  """
403
  modality = detect_model_kind(model)
 
404
  if modality == ModelKind.TEXT:
405
  text_model_id = TextModelType(model)
406
  text_list = self._validate_text_list(inputs)
407
+ return await asyncio.to_thread(
408
+ self._generate_text_embeddings, text_model_id, text_list
409
+ )
410
  elif modality == ModelKind.IMAGE:
411
  image_model_id = ImageModelType(model)
412
  image_list = self._validate_image_list(inputs)
413
+ return await self._async_generate_image_embeddings(
414
+ image_model_id, image_list
415
+ )
416
 
417
  async def rank(
418
  self,
 
421
  candidates: Union[str, List[str]],
422
  ) -> Dict[str, Any]:
423
  """
424
+ Asynchronously rank candidate texts/images against the provided queries.
425
+ Embeddings for queries and candidates are generated concurrently.
426
 
427
+ Returns:
428
+ A dictionary containing probabilities, cosine similarities, and usage statistics.
429
  """
430
  modality = detect_model_kind(model)
 
 
431
  if modality == ModelKind.TEXT:
432
  model_enum = TextModelType(model)
433
  else:
434
  model_enum = ImageModelType(model)
435
 
436
+ # Concurrently generate embeddings.
437
+ query_task = asyncio.create_task(self.generate_embeddings(model, queries))
438
+ candidate_task = asyncio.create_task(
439
+ self.generate_embeddings(model, candidates)
440
+ )
441
+ query_embeds, candidate_embeds = await asyncio.gather(
442
+ query_task, candidate_task
443
+ )
444
 
445
+ # Compute cosine similarity.
446
  sim_matrix = self.cosine_similarity(query_embeds, candidate_embeds)
 
 
447
  scaled = np.exp(self.config.logit_scale) * sim_matrix
448
  probs = self.softmax(scaled)
449
 
 
450
  if modality == ModelKind.TEXT:
451
  query_tokens = self.estimate_tokens(queries)
452
  candidate_tokens = self.estimate_tokens(candidates)
 
467
 
468
  def estimate_tokens(self, input_data: Union[str, List[str]]) -> int:
469
  """
470
+ Estimate the token count for the given text input using the SentenceTransformer tokenizer.
471
+
472
+ Returns:
473
+ The total number of tokens.
474
  """
475
  texts = self._validate_text_list(input_data)
476
  model = self.text_models[self.config.text_model_type]
477
  tokenized = model.tokenize(texts)
 
478
  return sum(len(ids) for ids in tokenized["input_ids"])
479
 
480
  @staticmethod
481
  def softmax(scores: np.ndarray) -> np.ndarray:
482
  """
483
+ Compute the softmax over the last dimension of the input array.
484
+
485
+ Returns:
486
+ The softmax probabilities.
487
  """
 
488
  exps = np.exp(scores - np.max(scores, axis=-1, keepdims=True))
489
  return exps / np.sum(exps, axis=-1, keepdims=True)
490
 
491
  @staticmethod
492
  def cosine_similarity(a: np.ndarray, b: np.ndarray) -> np.ndarray:
493
  """
494
+ Compute the pairwise cosine similarity between all rows of arrays a and b.
495
+
496
+ Returns:
497
+ A (N x M) matrix of cosine similarities.
498
  """
499
  a_norm = a / (np.linalg.norm(a, axis=1, keepdims=True) + 1e-9)
500
  b_norm = b / (np.linalg.norm(b, axis=1, keepdims=True) + 1e-9)
501
  return np.dot(a_norm, b_norm.T)
502
+
503
+ async def close(self) -> None:
504
+ """
505
+ Close the asynchronous HTTP client.
506
+ """
507
+ await self.async_http_client.aclose()
requirements.txt CHANGED
@@ -5,6 +5,7 @@ requests
5
  pydantic
6
  cachetools
7
  pandas
 
8
  sentence-transformers[onnx]==3.3.1
9
  sentencepiece==0.2.0
10
  torch==2.4.0
 
5
  pydantic
6
  cachetools
7
  pandas
8
+ httpx
9
  sentence-transformers[onnx]==3.3.1
10
  sentencepiece==0.2.0
11
  torch==2.4.0