nan commited on
Commit
085e2ed
·
2 Parent(s): e723064 1036c04

Merge branch 'main' into pr/21

Browse files
Files changed (1) hide show
  1. modeling_jina_embeddings_v4.py +21 -13
modeling_jina_embeddings_v4.py CHANGED
@@ -416,9 +416,9 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
416
  )
417
  return task
418
 
419
- def encode_texts(
420
  self,
421
- texts: List[str],
422
  task: Optional[str] = None,
423
  max_length: int = 8192,
424
  batch_size: int = 8,
@@ -426,12 +426,12 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
426
  return_numpy: bool = False,
427
  truncate_dim: Optional[int] = None,
428
  prompt_name: Optional[str] = None,
429
- ) -> List[torch.Tensor]:
430
  """
431
  Encodes a list of texts into embeddings.
432
 
433
  Args:
434
- texts: List of text strings to encode
435
  max_length: Maximum token length for text processing
436
  batch_size: Number of texts to process at once
437
  vector_type: Type of embedding vector to generate (VectorType.single or VectorType.multi)
@@ -440,7 +440,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
440
  prompt_name: Type of text being encoded ('query' or 'passage')
441
 
442
  Returns:
443
- List of text embeddings as tensors or numpy arrays
444
  """
445
  prompt_name = prompt_name or "query"
446
  encode_kwargs = self._validate_encoding_params(
@@ -455,6 +455,9 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
455
  prefix=encode_kwargs.pop("prefix"),
456
  )
457
 
 
 
 
458
  embeddings = self._process_batches(
459
  data=texts,
460
  processor_fn=processor_fn,
@@ -465,7 +468,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
465
  **encode_kwargs,
466
  )
467
 
468
- return embeddings
469
 
470
  def _load_images_if_needed(
471
  self, images: List[Union[str, Image.Image]]
@@ -481,21 +484,21 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
481
  loaded_images.append(image)
482
  return loaded_images
483
 
484
- def encode_images(
485
  self,
486
- images: List[Union[str, Image.Image]],
487
  task: Optional[str] = None,
488
  batch_size: int = 8,
489
  vector_type: Optional[Union[str, VectorType]] = None,
490
  return_numpy: bool = False,
491
  truncate_dim: Optional[int] = None,
492
  max_pixels: Optional[int] = None,
493
- ) -> List[torch.Tensor]:
494
  """
495
- Encodes a list of images into embeddings.
496
 
497
  Args:
498
- images: List of PIL images, URLs, or local file paths to encode
499
  batch_size: Number of images to process at once
500
  vector_type: Type of embedding vector to generate (VectorType.single or VectorType.multi)
501
  return_numpy: Whether to return numpy arrays instead of torch tensors
@@ -503,7 +506,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
503
  max_pixels: Maximum number of pixels to process per image
504
 
505
  Returns:
506
- List of image embeddings as tensors or numpy arrays
507
  """
508
  if max_pixels:
509
  default_max_pixels = self.processor.image_processor.max_pixels
@@ -512,6 +515,11 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
512
  )
513
  encode_kwargs = self._validate_encoding_params(vector_type, truncate_dim)
514
  task = self._validate_task(task)
 
 
 
 
 
515
  images = self._load_images_if_needed(images)
516
  embeddings = self._process_batches(
517
  data=images,
@@ -526,7 +534,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
526
  if max_pixels:
527
  self.processor.image_processor.max_pixels = default_max_pixels
528
 
529
- return embeddings
530
 
531
  @classmethod
532
  def from_pretrained(
 
416
  )
417
  return task
418
 
419
+ def encode_text(
420
  self,
421
+ texts: Union[str, List[str]],
422
  task: Optional[str] = None,
423
  max_length: int = 8192,
424
  batch_size: int = 8,
 
426
  return_numpy: bool = False,
427
  truncate_dim: Optional[int] = None,
428
  prompt_name: Optional[str] = None,
429
+ ) -> Union[List[torch.Tensor], torch.Tensor]:
430
  """
431
  Encodes a list of texts into embeddings.
432
 
433
  Args:
434
+ texts: text or list of text strings to encode
435
  max_length: Maximum token length for text processing
436
  batch_size: Number of texts to process at once
437
  vector_type: Type of embedding vector to generate (VectorType.single or VectorType.multi)
 
440
  prompt_name: Type of text being encoded ('query' or 'passage')
441
 
442
  Returns:
443
+ List of text embeddings as tensors or numpy arrays when encoding multiple texts, or single text embedding as tensor when encoding a single text
444
  """
445
  prompt_name = prompt_name or "query"
446
  encode_kwargs = self._validate_encoding_params(
 
455
  prefix=encode_kwargs.pop("prefix"),
456
  )
457
 
458
+ if isinstance(texts, str):
459
+ texts = [texts]
460
+
461
  embeddings = self._process_batches(
462
  data=texts,
463
  processor_fn=processor_fn,
 
468
  **encode_kwargs,
469
  )
470
 
471
+ return embeddings if len(texts) > 1 else embeddings[0]
472
 
473
  def _load_images_if_needed(
474
  self, images: List[Union[str, Image.Image]]
 
484
  loaded_images.append(image)
485
  return loaded_images
486
 
487
+ def encode_image(
488
  self,
489
+ images: Union[str, Image.Image, List[Union[str, Image.Image]]],
490
  task: Optional[str] = None,
491
  batch_size: int = 8,
492
  vector_type: Optional[Union[str, VectorType]] = None,
493
  return_numpy: bool = False,
494
  truncate_dim: Optional[int] = None,
495
  max_pixels: Optional[int] = None,
496
+ ) -> Union[List[torch.Tensor], torch.Tensor]:
497
  """
498
+ Encodes a list of images or a single image into embedding(s).
499
 
500
  Args:
501
+ images: image(s) to encode, can be PIL Image(s), URL(s), or local file path(s)
502
  batch_size: Number of images to process at once
503
  vector_type: Type of embedding vector to generate (VectorType.single or VectorType.multi)
504
  return_numpy: Whether to return numpy arrays instead of torch tensors
 
506
  max_pixels: Maximum number of pixels to process per image
507
 
508
  Returns:
509
+ List of image embeddings as tensors or numpy arrays when encoding multiple images, or single image embedding as tensor when encoding a single image
510
  """
511
  if max_pixels:
512
  default_max_pixels = self.processor.image_processor.max_pixels
 
515
  )
516
  encode_kwargs = self._validate_encoding_params(vector_type, truncate_dim)
517
  task = self._validate_task(task)
518
+
519
+ # Convert single image to list
520
+ if isinstance(images, (str, Image.Image)):
521
+ images = [images]
522
+
523
  images = self._load_images_if_needed(images)
524
  embeddings = self._process_batches(
525
  data=images,
 
534
  if max_pixels:
535
  self.processor.image_processor.max_pixels = default_max_pixels
536
 
537
+ return embeddings if len(images) > 1 else embeddings[0]
538
 
539
  @classmethod
540
  def from_pretrained(