nan commited on
Commit
4c2a7cb
·
verified ·
1 Parent(s): 1036c04

feat-rename-vector-type-0622 (#21)

Browse files

- feat: avoid the redundant words in the variables (7f10796af034d90842575c6a877c3ef8b8d0b212)
- feat: use enum for the vector type (e7230645cd96df2626c429031ef6d9761c595ab5)
- Merge branch 'main' into pr/21 (085e2ed8f55f14e4ea5a67596d41bf50026ee9f3)
- refactor: rename vector_type to output_format (96925c43b3978bb6de3d3ab0ebfb27701d625f1a)
- feat: rename the VectorType (669c42abab2468a13298a192ff96826e6d8394f1)
- feat: fix the default values (bb1572174c755b90eb888cb78c496db2c3a8ecf4)
- feat: replace the output_format with a boolean flag (1ffab4f0c4c3d022d3c4e3555fd7bcc362262c1f)
- feat: avoid validating return_multivector (fe4c51b73e21a2ac2f1ff293337a3cac82517e88)
- feat: return a list when the input is a list (f7df96abf5c4741c0e88f6b30b347bb7191f7596)

Files changed (1) hide show
  1. modeling_jina_embeddings_v4.py +19 -24
modeling_jina_embeddings_v4.py CHANGED
@@ -31,7 +31,6 @@ class PromptType(str, Enum):
31
 
32
 
33
  PREFIX_DICT = {"query": "Query", "passage": "Passage"}
34
- VECTOR_TYPES = ["single_vector", "multi_vector"]
35
 
36
 
37
  class JinaEmbeddingsV4Processor(Qwen2_5_VLProcessor):
@@ -284,8 +283,9 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
284
  attention_mask (torch.Tensor): The attention mask tensor.
285
  Returns:
286
  JinaEmbeddingsV4ModelOutput:
287
- single_vector (torch.Tensor): Single-vector embeddings of shape (batch_size, dim).
288
- multi_vector (torch.Tensor): Multi-vector embeddings of shape (batch_size, num_tokens, dim).
 
289
  """
290
  # Forward pass through the VLM
291
  hidden_states = self.get_last_hidden_states(
@@ -320,7 +320,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
320
  task_label: Union[str, List[str]],
321
  processor_fn: Callable,
322
  desc: str,
323
- vector_type: str = "single_vector",
324
  return_numpy: bool = False,
325
  batch_size: int = 32,
326
  truncate_dim: Optional[int] = None,
@@ -340,7 +340,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
340
  device_type=torch.device(self.device).type, dtype=torch.bfloat16
341
  ):
342
  embeddings = self(**batch, task_label=task_label)
343
- if vector_type == "single_vector":
344
  embeddings = embeddings.single_vec_emb
345
  if truncate_dim is not None:
346
  embeddings = embeddings[:, :truncate_dim]
@@ -357,7 +357,6 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
357
 
358
  def _validate_encoding_params(
359
  self,
360
- vector_type: Optional[str] = None,
361
  truncate_dim: Optional[int] = None,
362
  prompt_name: Optional[str] = None,
363
  ) -> Dict[str, Any]:
@@ -374,14 +373,6 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
374
  else PREFIX_DICT["query"]
375
  )
376
 
377
- vector_type = vector_type or "single_vector"
378
- if vector_type not in VECTOR_TYPES:
379
- raise ValueError(
380
- f"Invalid vector_type: {vector_type}. Must be one of {VECTOR_TYPES}."
381
- )
382
- else:
383
- encode_kwargs["vector_type"] = vector_type
384
-
385
  truncate_dim = truncate_dim or self.config.truncate_dim
386
  if truncate_dim is not None and truncate_dim not in self.config.matryoshka_dims:
387
  raise ValueError(
@@ -413,7 +404,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
413
  task: Optional[str] = None,
414
  max_length: int = 8192,
415
  batch_size: int = 8,
416
- vector_type: Optional[str] = None,
417
  return_numpy: bool = False,
418
  truncate_dim: Optional[int] = None,
419
  prompt_name: Optional[str] = None,
@@ -425,7 +416,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
425
  texts: text or list of text strings to encode
426
  max_length: Maximum token length for text processing
427
  batch_size: Number of texts to process at once
428
- vector_type: Type of embedding vector to generate ('single_vector' or 'multi_vector')
429
  return_numpy: Whether to return numpy arrays instead of torch tensors
430
  truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
431
  prompt_name: Type of text being encoded ('query' or 'passage')
@@ -434,9 +425,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
434
  List of text embeddings as tensors or numpy arrays when encoding multiple texts, or single text embedding as tensor when encoding a single text
435
  """
436
  prompt_name = prompt_name or "query"
437
- encode_kwargs = self._validate_encoding_params(
438
- vector_type, truncate_dim, prompt_name
439
- )
440
 
441
  task = self._validate_task(task)
442
 
@@ -446,6 +435,8 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
446
  prefix=encode_kwargs.pop("prefix"),
447
  )
448
 
 
 
449
  if isinstance(texts, str):
450
  texts = [texts]
451
 
@@ -454,12 +445,13 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
454
  processor_fn=processor_fn,
455
  desc="Encoding texts...",
456
  task_label=task,
 
457
  return_numpy=return_numpy,
458
  batch_size=batch_size,
459
  **encode_kwargs,
460
  )
461
 
462
- return embeddings if len(texts) > 1 else embeddings[0]
463
 
464
  def _load_images_if_needed(
465
  self, images: List[Union[str, Image.Image]]
@@ -480,7 +472,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
480
  images: Union[str, Image.Image, List[Union[str, Image.Image]]],
481
  task: Optional[str] = None,
482
  batch_size: int = 8,
483
- vector_type: Optional[str] = None,
484
  return_numpy: bool = False,
485
  truncate_dim: Optional[int] = None,
486
  max_pixels: Optional[int] = None,
@@ -491,7 +483,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
491
  Args:
492
  images: image(s) to encode, can be PIL Image(s), URL(s), or local file path(s)
493
  batch_size: Number of images to process at once
494
- vector_type: Type of embedding vector to generate ('single_vector' or 'multi_vector')
495
  return_numpy: Whether to return numpy arrays instead of torch tensors
496
  truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
497
  max_pixels: Maximum number of pixels to process per image
@@ -504,9 +496,11 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
504
  self.processor.image_processor.max_pixels = (
505
  max_pixels # change during encoding
506
  )
507
- encode_kwargs = self._validate_encoding_params(vector_type, truncate_dim)
508
  task = self._validate_task(task)
509
 
 
 
510
  # Convert single image to list
511
  if isinstance(images, (str, Image.Image)):
512
  images = [images]
@@ -518,6 +512,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
518
  desc="Encoding images...",
519
  task_label=task,
520
  batch_size=batch_size,
 
521
  return_numpy=return_numpy,
522
  **encode_kwargs,
523
  )
@@ -525,7 +520,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
525
  if max_pixels:
526
  self.processor.image_processor.max_pixels = default_max_pixels
527
 
528
- return embeddings if len(images) > 1 else embeddings[0]
529
 
530
  @classmethod
531
  def from_pretrained(
 
31
 
32
 
33
  PREFIX_DICT = {"query": "Query", "passage": "Passage"}
 
34
 
35
 
36
  class JinaEmbeddingsV4Processor(Qwen2_5_VLProcessor):
 
283
  attention_mask (torch.Tensor): The attention mask tensor.
284
  Returns:
285
  JinaEmbeddingsV4ModelOutput:
286
+ vlm_last_hidden_states (torch.Tensor, optional): Last hidden states of the VLM.
287
+ single_vec_emb (torch.Tensor, optional): Single-vector embeddings.
288
+ multi_vec_emb (torch.Tensor, optional): Multi-vector embeddings.
289
  """
290
  # Forward pass through the VLM
291
  hidden_states = self.get_last_hidden_states(
 
320
  task_label: Union[str, List[str]],
321
  processor_fn: Callable,
322
  desc: str,
323
+ return_multivector: bool = False,
324
  return_numpy: bool = False,
325
  batch_size: int = 32,
326
  truncate_dim: Optional[int] = None,
 
340
  device_type=torch.device(self.device).type, dtype=torch.bfloat16
341
  ):
342
  embeddings = self(**batch, task_label=task_label)
343
+ if not return_multivector:
344
  embeddings = embeddings.single_vec_emb
345
  if truncate_dim is not None:
346
  embeddings = embeddings[:, :truncate_dim]
 
357
 
358
  def _validate_encoding_params(
359
  self,
 
360
  truncate_dim: Optional[int] = None,
361
  prompt_name: Optional[str] = None,
362
  ) -> Dict[str, Any]:
 
373
  else PREFIX_DICT["query"]
374
  )
375
 
 
 
 
 
 
 
 
 
376
  truncate_dim = truncate_dim or self.config.truncate_dim
377
  if truncate_dim is not None and truncate_dim not in self.config.matryoshka_dims:
378
  raise ValueError(
 
404
  task: Optional[str] = None,
405
  max_length: int = 8192,
406
  batch_size: int = 8,
407
+ return_multivector: bool = False,
408
  return_numpy: bool = False,
409
  truncate_dim: Optional[int] = None,
410
  prompt_name: Optional[str] = None,
 
416
  texts: text or list of text strings to encode
417
  max_length: Maximum token length for text processing
418
  batch_size: Number of texts to process at once
419
+ return_multivector: Whether to return multi-vector embeddings instead of single-vector embeddings
420
  return_numpy: Whether to return numpy arrays instead of torch tensors
421
  truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
422
  prompt_name: Type of text being encoded ('query' or 'passage')
 
425
  List of text embeddings as tensors or numpy arrays when encoding multiple texts, or single text embedding as tensor when encoding a single text
426
  """
427
  prompt_name = prompt_name or "query"
428
+ encode_kwargs = self._validate_encoding_params(truncate_dim=truncate_dim, prompt_name=prompt_name)
 
 
429
 
430
  task = self._validate_task(task)
431
 
 
435
  prefix=encode_kwargs.pop("prefix"),
436
  )
437
 
438
+ return_list = isinstance(texts, list)
439
+
440
  if isinstance(texts, str):
441
  texts = [texts]
442
 
 
445
  processor_fn=processor_fn,
446
  desc="Encoding texts...",
447
  task_label=task,
448
+ return_multivector=return_multivector,
449
  return_numpy=return_numpy,
450
  batch_size=batch_size,
451
  **encode_kwargs,
452
  )
453
 
454
+ return embeddings if return_list else embeddings[0]
455
 
456
  def _load_images_if_needed(
457
  self, images: List[Union[str, Image.Image]]
 
472
  images: Union[str, Image.Image, List[Union[str, Image.Image]]],
473
  task: Optional[str] = None,
474
  batch_size: int = 8,
475
+ return_multivector: bool = False,
476
  return_numpy: bool = False,
477
  truncate_dim: Optional[int] = None,
478
  max_pixels: Optional[int] = None,
 
483
  Args:
484
  images: image(s) to encode, can be PIL Image(s), URL(s), or local file path(s)
485
  batch_size: Number of images to process at once
486
+ return_multivector: Whether to return multi-vector embeddings instead of single-vector embeddings
487
  return_numpy: Whether to return numpy arrays instead of torch tensors
488
  truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
489
  max_pixels: Maximum number of pixels to process per image
 
496
  self.processor.image_processor.max_pixels = (
497
  max_pixels # change during encoding
498
  )
499
+ encode_kwargs = self._validate_encoding_params(truncate_dim=truncate_dim)
500
  task = self._validate_task(task)
501
 
502
+ return_list = isinstance(images, list)
503
+
504
  # Convert single image to list
505
  if isinstance(images, (str, Image.Image)):
506
  images = [images]
 
512
  desc="Encoding images...",
513
  task_label=task,
514
  batch_size=batch_size,
515
+ return_multivector=return_multivector,
516
  return_numpy=return_numpy,
517
  **encode_kwargs,
518
  )
 
520
  if max_pixels:
521
  self.processor.image_processor.max_pixels = default_max_pixels
522
 
523
+ return embeddings if return_list else embeddings[0]
524
 
525
  @classmethod
526
  def from_pretrained(