nan commited on
Commit
1ffab4f
·
1 Parent(s): bb15721

feat: replace the output_format with a boolean flag

Browse files
Files changed (1) hide show
  1. modeling_jina_embeddings_v4.py +11 -26
modeling_jina_embeddings_v4.py CHANGED
@@ -30,11 +30,6 @@ class PromptType(str, Enum):
30
  passage = "passage"
31
 
32
 
33
- class VectorOutputFormat(str, Enum):
34
- SINGLE = "single"
35
- MULTIPLE = "multiple"
36
-
37
-
38
  PREFIX_DICT = {"query": "Query", "passage": "Passage"}
39
 
40
 
@@ -325,7 +320,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
325
  task_label: Union[str, List[str]],
326
  processor_fn: Callable,
327
  desc: str,
328
- output_format: Union[str, VectorOutputFormat] = VectorOutputFormat.SINGLE,
329
  return_numpy: bool = False,
330
  batch_size: int = 32,
331
  truncate_dim: Optional[int] = None,
@@ -345,8 +340,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
345
  device_type=torch.device(self.device).type, dtype=torch.bfloat16
346
  ):
347
  embeddings = self(**batch, task_label=task_label)
348
- output_format_str = output_format.value if isinstance(output_format, VectorOutputFormat) else output_format
349
- if output_format_str == VectorOutputFormat.SINGLE.value:
350
  embeddings = embeddings.single_vec_emb
351
  if truncate_dim is not None:
352
  embeddings = embeddings[:, :truncate_dim]
@@ -363,7 +357,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
363
 
364
  def _validate_encoding_params(
365
  self,
366
- output_format: Optional[Union[str, VectorOutputFormat]] = None,
367
  truncate_dim: Optional[int] = None,
368
  prompt_name: Optional[str] = None,
369
  ) -> Dict[str, Any]:
@@ -380,17 +374,8 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
380
  else PREFIX_DICT["query"]
381
  )
382
 
383
- output_format = output_format or VectorOutputFormat.SINGLE
384
- if isinstance(output_format, VectorOutputFormat):
385
- encode_kwargs["output_format"] = output_format.value
386
- else:
387
- try:
388
- output_format_enum = VectorOutputFormat(output_format)
389
- encode_kwargs["output_format"] = output_format_enum.value
390
- except ValueError:
391
- raise ValueError(
392
- f"Invalid output_format: {output_format}. Must be one of {[v.value for v in VectorOutputFormat]}."
393
- )
394
 
395
  truncate_dim = truncate_dim or self.config.truncate_dim
396
  if truncate_dim is not None and truncate_dim not in self.config.matryoshka_dims:
@@ -423,7 +408,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
423
  task: Optional[str] = None,
424
  max_length: int = 8192,
425
  batch_size: int = 8,
426
- output_format: Optional[Union[str, VectorOutputFormat]] = VectorOutputFormat.SINGLE,
427
  return_numpy: bool = False,
428
  truncate_dim: Optional[int] = None,
429
  prompt_name: Optional[str] = None,
@@ -435,7 +420,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
435
  texts: text or list of text strings to encode
436
  max_length: Maximum token length for text processing
437
  batch_size: Number of texts to process at once
438
- output_format: Type of embedding vector to generate (VectorOutputFormat.SINGLE or VectorOutputFormat.MULTIPLE)
439
  return_numpy: Whether to return numpy arrays instead of torch tensors
440
  truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
441
  prompt_name: Type of text being encoded ('query' or 'passage')
@@ -445,7 +430,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
445
  """
446
  prompt_name = prompt_name or "query"
447
  encode_kwargs = self._validate_encoding_params(
448
- output_format, truncate_dim, prompt_name
449
  )
450
 
451
  task = self._validate_task(task)
@@ -490,7 +475,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
490
  images: Union[str, Image.Image, List[Union[str, Image.Image]]],
491
  task: Optional[str] = None,
492
  batch_size: int = 8,
493
- output_format: Optional[Union[str, VectorOutputFormat]] = VectorOutputFormat.SINGLE,
494
  return_numpy: bool = False,
495
  truncate_dim: Optional[int] = None,
496
  max_pixels: Optional[int] = None,
@@ -501,7 +486,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
501
  Args:
502
  images: image(s) to encode, can be PIL Image(s), URL(s), or local file path(s)
503
  batch_size: Number of images to process at once
504
- output_format: Type of embedding vector to generate (VectorOutputFormat.SINGLE or VectorOutputFormat.MULTIPLE)
505
  return_numpy: Whether to return numpy arrays instead of torch tensors
506
  truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
507
  max_pixels: Maximum number of pixels to process per image
@@ -514,7 +499,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
514
  self.processor.image_processor.max_pixels = (
515
  max_pixels # change during encoding
516
  )
517
- encode_kwargs = self._validate_encoding_params(output_format, truncate_dim)
518
  task = self._validate_task(task)
519
 
520
  # Convert single image to list
 
30
  passage = "passage"
31
 
32
 
 
 
 
 
 
33
  PREFIX_DICT = {"query": "Query", "passage": "Passage"}
34
 
35
 
 
320
  task_label: Union[str, List[str]],
321
  processor_fn: Callable,
322
  desc: str,
323
+ return_multivector: bool = False,
324
  return_numpy: bool = False,
325
  batch_size: int = 32,
326
  truncate_dim: Optional[int] = None,
 
340
  device_type=torch.device(self.device).type, dtype=torch.bfloat16
341
  ):
342
  embeddings = self(**batch, task_label=task_label)
343
+ if not return_multivector:
 
344
  embeddings = embeddings.single_vec_emb
345
  if truncate_dim is not None:
346
  embeddings = embeddings[:, :truncate_dim]
 
357
 
358
  def _validate_encoding_params(
359
  self,
360
+ return_multivector: Optional[bool] = None,
361
  truncate_dim: Optional[int] = None,
362
  prompt_name: Optional[str] = None,
363
  ) -> Dict[str, Any]:
 
374
  else PREFIX_DICT["query"]
375
  )
376
 
377
+ return_multivector = return_multivector or False
378
+ encode_kwargs["return_multivector"] = return_multivector
 
 
 
 
 
 
 
 
 
379
 
380
  truncate_dim = truncate_dim or self.config.truncate_dim
381
  if truncate_dim is not None and truncate_dim not in self.config.matryoshka_dims:
 
408
  task: Optional[str] = None,
409
  max_length: int = 8192,
410
  batch_size: int = 8,
411
+ return_multivector: bool = False,
412
  return_numpy: bool = False,
413
  truncate_dim: Optional[int] = None,
414
  prompt_name: Optional[str] = None,
 
420
  texts: text or list of text strings to encode
421
  max_length: Maximum token length for text processing
422
  batch_size: Number of texts to process at once
423
+ return_multivector: Whether to return multi-vector embeddings instead of single-vector embeddings
424
  return_numpy: Whether to return numpy arrays instead of torch tensors
425
  truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
426
  prompt_name: Type of text being encoded ('query' or 'passage')
 
430
  """
431
  prompt_name = prompt_name or "query"
432
  encode_kwargs = self._validate_encoding_params(
433
+ return_multivector=return_multivector, truncate_dim=truncate_dim, prompt_name=prompt_name
434
  )
435
 
436
  task = self._validate_task(task)
 
475
  images: Union[str, Image.Image, List[Union[str, Image.Image]]],
476
  task: Optional[str] = None,
477
  batch_size: int = 8,
478
+ return_multivector: bool = False,
479
  return_numpy: bool = False,
480
  truncate_dim: Optional[int] = None,
481
  max_pixels: Optional[int] = None,
 
486
  Args:
487
  images: image(s) to encode, can be PIL Image(s), URL(s), or local file path(s)
488
  batch_size: Number of images to process at once
489
+ return_multivector: Whether to return multi-vector embeddings instead of single-vector embeddings
490
  return_numpy: Whether to return numpy arrays instead of torch tensors
491
  truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
492
  max_pixels: Maximum number of pixels to process per image
 
499
  self.processor.image_processor.max_pixels = (
500
  max_pixels # change during encoding
501
  )
502
+ encode_kwargs = self._validate_encoding_params(return_multivector=return_multivector, truncate_dim=truncate_dim)
503
  task = self._validate_task(task)
504
 
505
  # Convert single image to list