nan commited on
Commit
8e178a1
·
1 Parent(s): 455d3b0

feat: make the encode_texts and encode_images support single inputs

Browse files
Files changed (1) hide show
  1. modeling_jina_embeddings_v4.py +12 -4
modeling_jina_embeddings_v4.py CHANGED
@@ -409,7 +409,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
409
 
410
  def encode_texts(
411
  self,
412
- texts: List[str],
413
  task: Optional[str] = None,
414
  max_length: int = 8192,
415
  batch_size: int = 8,
@@ -422,7 +422,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
422
  Encodes a list of texts into embeddings.
423
 
424
  Args:
425
- texts: List of text strings to encode
426
  max_length: Maximum token length for text processing
427
  batch_size: Number of texts to process at once
428
  vector_type: Type of embedding vector to generate ('single_vector' or 'multi_vector')
@@ -446,6 +446,9 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
446
  prefix=encode_kwargs.pop("prefix"),
447
  )
448
 
 
 
 
449
  embeddings = self._process_batches(
450
  data=texts,
451
  processor_fn=processor_fn,
@@ -474,7 +477,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
474
 
475
  def encode_images(
476
  self,
477
- images: List[Union[str, Image.Image]],
478
  task: Optional[str] = None,
479
  batch_size: int = 8,
480
  vector_type: Optional[str] = None,
@@ -486,7 +489,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
486
  Encodes a list of images into embeddings.
487
 
488
  Args:
489
- images: List of PIL images, URLs, or local file paths to encode
490
  batch_size: Number of images to process at once
491
  vector_type: Type of embedding vector to generate ('single_vector' or 'multi_vector')
492
  return_numpy: Whether to return numpy arrays instead of torch tensors
@@ -503,6 +506,11 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
503
  )
504
  encode_kwargs = self._validate_encoding_params(vector_type, truncate_dim)
505
  task = self._validate_task(task)
 
 
 
 
 
506
  images = self._load_images_if_needed(images)
507
  embeddings = self._process_batches(
508
  data=images,
 
409
 
410
  def encode_texts(
411
  self,
412
+ texts: Union[str, List[str]],
413
  task: Optional[str] = None,
414
  max_length: int = 8192,
415
  batch_size: int = 8,
 
422
  Encodes a list of texts into embeddings.
423
 
424
  Args:
425
+ texts: text or list of text strings to encode
426
  max_length: Maximum token length for text processing
427
  batch_size: Number of texts to process at once
428
  vector_type: Type of embedding vector to generate ('single_vector' or 'multi_vector')
 
446
  prefix=encode_kwargs.pop("prefix"),
447
  )
448
 
449
+ if isinstance(texts, str):
450
+ texts = [texts]
451
+
452
  embeddings = self._process_batches(
453
  data=texts,
454
  processor_fn=processor_fn,
 
477
 
478
  def encode_images(
479
  self,
480
+ images: Union[str, Image.Image, List[Union[str, Image.Image]]],
481
  task: Optional[str] = None,
482
  batch_size: int = 8,
483
  vector_type: Optional[str] = None,
 
489
  Encodes a list of images into embeddings.
490
 
491
  Args:
492
+ images: image(s) to encode, can be PIL Image(s), URL(s), or local file path(s)
493
  batch_size: Number of images to process at once
494
  vector_type: Type of embedding vector to generate ('single_vector' or 'multi_vector')
495
  return_numpy: Whether to return numpy arrays instead of torch tensors
 
506
  )
507
  encode_kwargs = self._validate_encoding_params(vector_type, truncate_dim)
508
  task = self._validate_task(task)
509
+
510
+ # Convert single image to list
511
+ if isinstance(images, (str, Image.Image)):
512
+ images = [images]
513
+
514
  images = self._load_images_if_needed(images)
515
  embeddings = self._process_batches(
516
  data=images,