nan commited on
Commit
1036c04
·
verified ·
1 Parent(s): 455d3b0

feat-unify-encode-function-0622 (#19)

Browse files

- feat: make the encode_texts and encode_images support single inputs (8e178a167d8f5ea41b0e6ab2a74775363abe32ae)
- feat: update the function names (a9d6eecadd0c92a8b60b12724a5e4387d16ee327)
- feat: return a single tensor when a single image is given (2dc412733ac7e28f74080fb48d0c29062156a429)

Files changed (1) hide show
  1. modeling_jina_embeddings_v4.py +21 -13
modeling_jina_embeddings_v4.py CHANGED
@@ -407,9 +407,9 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
407
  )
408
  return task
409
 
410
- def encode_texts(
411
  self,
412
- texts: List[str],
413
  task: Optional[str] = None,
414
  max_length: int = 8192,
415
  batch_size: int = 8,
@@ -417,12 +417,12 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
417
  return_numpy: bool = False,
418
  truncate_dim: Optional[int] = None,
419
  prompt_name: Optional[str] = None,
420
- ) -> List[torch.Tensor]:
421
  """
422
  Encodes a list of texts into embeddings.
423
 
424
  Args:
425
- texts: List of text strings to encode
426
  max_length: Maximum token length for text processing
427
  batch_size: Number of texts to process at once
428
  vector_type: Type of embedding vector to generate ('single_vector' or 'multi_vector')
@@ -431,7 +431,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
431
  prompt_name: Type of text being encoded ('query' or 'passage')
432
 
433
  Returns:
434
- List of text embeddings as tensors or numpy arrays
435
  """
436
  prompt_name = prompt_name or "query"
437
  encode_kwargs = self._validate_encoding_params(
@@ -446,6 +446,9 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
446
  prefix=encode_kwargs.pop("prefix"),
447
  )
448
 
 
 
 
449
  embeddings = self._process_batches(
450
  data=texts,
451
  processor_fn=processor_fn,
@@ -456,7 +459,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
456
  **encode_kwargs,
457
  )
458
 
459
- return embeddings
460
 
461
  def _load_images_if_needed(
462
  self, images: List[Union[str, Image.Image]]
@@ -472,21 +475,21 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
472
  loaded_images.append(image)
473
  return loaded_images
474
 
475
- def encode_images(
476
  self,
477
- images: List[Union[str, Image.Image]],
478
  task: Optional[str] = None,
479
  batch_size: int = 8,
480
  vector_type: Optional[str] = None,
481
  return_numpy: bool = False,
482
  truncate_dim: Optional[int] = None,
483
  max_pixels: Optional[int] = None,
484
- ) -> List[torch.Tensor]:
485
  """
486
- Encodes a list of images into embeddings.
487
 
488
  Args:
489
- images: List of PIL images, URLs, or local file paths to encode
490
  batch_size: Number of images to process at once
491
  vector_type: Type of embedding vector to generate ('single_vector' or 'multi_vector')
492
  return_numpy: Whether to return numpy arrays instead of torch tensors
@@ -494,7 +497,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
494
  max_pixels: Maximum number of pixels to process per image
495
 
496
  Returns:
497
- List of image embeddings as tensors or numpy arrays
498
  """
499
  if max_pixels:
500
  default_max_pixels = self.processor.image_processor.max_pixels
@@ -503,6 +506,11 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
503
  )
504
  encode_kwargs = self._validate_encoding_params(vector_type, truncate_dim)
505
  task = self._validate_task(task)
 
 
 
 
 
506
  images = self._load_images_if_needed(images)
507
  embeddings = self._process_batches(
508
  data=images,
@@ -517,7 +525,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
517
  if max_pixels:
518
  self.processor.image_processor.max_pixels = default_max_pixels
519
 
520
- return embeddings
521
 
522
  @classmethod
523
  def from_pretrained(
 
407
  )
408
  return task
409
 
410
+ def encode_text(
411
  self,
412
+ texts: Union[str, List[str]],
413
  task: Optional[str] = None,
414
  max_length: int = 8192,
415
  batch_size: int = 8,
 
417
  return_numpy: bool = False,
418
  truncate_dim: Optional[int] = None,
419
  prompt_name: Optional[str] = None,
420
+ ) -> Union[List[torch.Tensor], torch.Tensor]:
421
  """
422
  Encodes a list of texts into embeddings.
423
 
424
  Args:
425
+ texts: text or list of text strings to encode
426
  max_length: Maximum token length for text processing
427
  batch_size: Number of texts to process at once
428
  vector_type: Type of embedding vector to generate ('single_vector' or 'multi_vector')
 
431
  prompt_name: Type of text being encoded ('query' or 'passage')
432
 
433
  Returns:
434
+ List of text embeddings as tensors or numpy arrays when encoding multiple texts, or single text embedding as tensor when encoding a single text
435
  """
436
  prompt_name = prompt_name or "query"
437
  encode_kwargs = self._validate_encoding_params(
 
446
  prefix=encode_kwargs.pop("prefix"),
447
  )
448
 
449
+ if isinstance(texts, str):
450
+ texts = [texts]
451
+
452
  embeddings = self._process_batches(
453
  data=texts,
454
  processor_fn=processor_fn,
 
459
  **encode_kwargs,
460
  )
461
 
462
+ return embeddings if len(texts) > 1 else embeddings[0]
463
 
464
  def _load_images_if_needed(
465
  self, images: List[Union[str, Image.Image]]
 
475
  loaded_images.append(image)
476
  return loaded_images
477
 
478
+ def encode_image(
479
  self,
480
+ images: Union[str, Image.Image, List[Union[str, Image.Image]]],
481
  task: Optional[str] = None,
482
  batch_size: int = 8,
483
  vector_type: Optional[str] = None,
484
  return_numpy: bool = False,
485
  truncate_dim: Optional[int] = None,
486
  max_pixels: Optional[int] = None,
487
+ ) -> Union[List[torch.Tensor], torch.Tensor]:
488
  """
489
+ Encodes a list of images or a single image into embedding(s).
490
 
491
  Args:
492
+ images: image(s) to encode, can be PIL Image(s), URL(s), or local file path(s)
493
  batch_size: Number of images to process at once
494
  vector_type: Type of embedding vector to generate ('single_vector' or 'multi_vector')
495
  return_numpy: Whether to return numpy arrays instead of torch tensors
 
497
  max_pixels: Maximum number of pixels to process per image
498
 
499
  Returns:
500
+ List of image embeddings as tensors or numpy arrays when encoding multiple images, or single image embedding as tensor when encoding a single image
501
  """
502
  if max_pixels:
503
  default_max_pixels = self.processor.image_processor.max_pixels
 
506
  )
507
  encode_kwargs = self._validate_encoding_params(vector_type, truncate_dim)
508
  task = self._validate_task(task)
509
+
510
+ # Convert single image to list
511
+ if isinstance(images, (str, Image.Image)):
512
+ images = [images]
513
+
514
  images = self._load_images_if_needed(images)
515
  embeddings = self._process_batches(
516
  data=images,
 
525
  if max_pixels:
526
  self.processor.image_processor.max_pixels = default_max_pixels
527
 
528
+ return embeddings if len(images) > 1 else embeddings[0]
529
 
530
  @classmethod
531
  def from_pretrained(