nan commited on
Commit
2dc4127
·
1 Parent(s): a9d6eec

feat: return a single tensor when a single image is given

Browse files
Files changed (1) hide show
  1. modeling_jina_embeddings_v4.py +7 -7
modeling_jina_embeddings_v4.py CHANGED
@@ -417,7 +417,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
417
  return_numpy: bool = False,
418
  truncate_dim: Optional[int] = None,
419
  prompt_name: Optional[str] = None,
420
- ) -> List[torch.Tensor]:
421
  """
422
  Encodes a list of texts into embeddings.
423
 
@@ -431,7 +431,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
431
  prompt_name: Type of text being encoded ('query' or 'passage')
432
 
433
  Returns:
434
- List of text embeddings as tensors or numpy arrays
435
  """
436
  prompt_name = prompt_name or "query"
437
  encode_kwargs = self._validate_encoding_params(
@@ -459,7 +459,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
459
  **encode_kwargs,
460
  )
461
 
462
- return embeddings
463
 
464
  def _load_images_if_needed(
465
  self, images: List[Union[str, Image.Image]]
@@ -484,9 +484,9 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
484
  return_numpy: bool = False,
485
  truncate_dim: Optional[int] = None,
486
  max_pixels: Optional[int] = None,
487
- ) -> List[torch.Tensor]:
488
  """
489
- Encodes a list of images into embeddings.
490
 
491
  Args:
492
  images: image(s) to encode, can be PIL Image(s), URL(s), or local file path(s)
@@ -497,7 +497,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
497
  max_pixels: Maximum number of pixels to process per image
498
 
499
  Returns:
500
- List of image embeddings as tensors or numpy arrays
501
  """
502
  if max_pixels:
503
  default_max_pixels = self.processor.image_processor.max_pixels
@@ -525,7 +525,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
525
  if max_pixels:
526
  self.processor.image_processor.max_pixels = default_max_pixels
527
 
528
- return embeddings
529
 
530
  @classmethod
531
  def from_pretrained(
 
417
  return_numpy: bool = False,
418
  truncate_dim: Optional[int] = None,
419
  prompt_name: Optional[str] = None,
420
+ ) -> Union[List[torch.Tensor], torch.Tensor]:
421
  """
422
  Encodes a list of texts into embeddings.
423
 
 
431
  prompt_name: Type of text being encoded ('query' or 'passage')
432
 
433
  Returns:
434
+ List of text embeddings as tensors or numpy arrays when encoding multiple texts, or single text embedding as tensor when encoding a single text
435
  """
436
  prompt_name = prompt_name or "query"
437
  encode_kwargs = self._validate_encoding_params(
 
459
  **encode_kwargs,
460
  )
461
 
462
+ return embeddings if len(texts) > 1 else embeddings[0]
463
 
464
  def _load_images_if_needed(
465
  self, images: List[Union[str, Image.Image]]
 
484
  return_numpy: bool = False,
485
  truncate_dim: Optional[int] = None,
486
  max_pixels: Optional[int] = None,
487
+ ) -> Union[List[torch.Tensor], torch.Tensor]:
488
  """
489
+ Encodes a list of images or a single image into embedding(s).
490
 
491
  Args:
492
  images: image(s) to encode, can be PIL Image(s), URL(s), or local file path(s)
 
497
  max_pixels: Maximum number of pixels to process per image
498
 
499
  Returns:
500
+ List of image embeddings as tensors or numpy arrays when encoding multiple images, or single image embedding as tensor when encoding a single image
501
  """
502
  if max_pixels:
503
  default_max_pixels = self.processor.image_processor.max_pixels
 
525
  if max_pixels:
526
  self.processor.image_processor.max_pixels = default_max_pixels
527
 
528
+ return embeddings if len(images) > 1 else embeddings[0]
529
 
530
  @classmethod
531
  def from_pretrained(