nan commited on
Commit
e723064
·
1 Parent(s): 7f10796

feat: use enum for the vector type

Browse files
Files changed (1) hide show
  1. modeling_jina_embeddings_v4.py +23 -14
modeling_jina_embeddings_v4.py CHANGED
@@ -30,8 +30,12 @@ class PromptType(str, Enum):
30
  passage = "passage"
31
 
32
 
 
 
 
 
 
33
  PREFIX_DICT = {"query": "Query", "passage": "Passage"}
34
- VECTOR_TYPES = ["single", "multi"]
35
 
36
 
37
  class JinaEmbeddingsV4Processor(Qwen2_5_VLProcessor):
@@ -320,7 +324,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
320
  task_label: Union[str, List[str]],
321
  processor_fn: Callable,
322
  desc: str,
323
- vector_type: str = "single",
324
  return_numpy: bool = False,
325
  batch_size: int = 32,
326
  truncate_dim: Optional[int] = None,
@@ -340,7 +344,8 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
340
  device_type=torch.device(self.device).type, dtype=torch.bfloat16
341
  ):
342
  embeddings = self(**batch, task_label=task_label)
343
- if vector_type == "single":
 
344
  embeddings = embeddings.single_vec_emb
345
  if truncate_dim is not None:
346
  embeddings = embeddings[:, :truncate_dim]
@@ -357,7 +362,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
357
 
358
  def _validate_encoding_params(
359
  self,
360
- vector_type: Optional[str] = None,
361
  truncate_dim: Optional[int] = None,
362
  prompt_name: Optional[str] = None,
363
  ) -> Dict[str, Any]:
@@ -374,13 +379,17 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
374
  else PREFIX_DICT["query"]
375
  )
376
 
377
- vector_type = vector_type or "single"
378
- if vector_type not in VECTOR_TYPES:
379
- raise ValueError(
380
- f"Invalid vector_type: {vector_type}. Must be one of {VECTOR_TYPES}."
381
- )
382
  else:
383
- encode_kwargs["vector_type"] = vector_type
 
 
 
 
 
 
384
 
385
  truncate_dim = truncate_dim or self.config.truncate_dim
386
  if truncate_dim is not None and truncate_dim not in self.config.matryoshka_dims:
@@ -413,7 +422,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
413
  task: Optional[str] = None,
414
  max_length: int = 8192,
415
  batch_size: int = 8,
416
- vector_type: Optional[str] = None,
417
  return_numpy: bool = False,
418
  truncate_dim: Optional[int] = None,
419
  prompt_name: Optional[str] = None,
@@ -425,7 +434,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
425
  texts: List of text strings to encode
426
  max_length: Maximum token length for text processing
427
  batch_size: Number of texts to process at once
428
- vector_type: Type of embedding vector to generate ('single' or 'multi')
429
  return_numpy: Whether to return numpy arrays instead of torch tensors
430
  truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
431
  prompt_name: Type of text being encoded ('query' or 'passage')
@@ -477,7 +486,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
477
  images: List[Union[str, Image.Image]],
478
  task: Optional[str] = None,
479
  batch_size: int = 8,
480
- vector_type: Optional[str] = None,
481
  return_numpy: bool = False,
482
  truncate_dim: Optional[int] = None,
483
  max_pixels: Optional[int] = None,
@@ -488,7 +497,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
488
  Args:
489
  images: List of PIL images, URLs, or local file paths to encode
490
  batch_size: Number of images to process at once
491
- vector_type: Type of embedding vector to generate ('single' or 'multi')
492
  return_numpy: Whether to return numpy arrays instead of torch tensors
493
  truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
494
  max_pixels: Maximum number of pixels to process per image
 
30
  passage = "passage"
31
 
32
 
33
+ class VectorType(str, Enum):
34
+ single = "single"
35
+ multi = "multi"
36
+
37
+
38
  PREFIX_DICT = {"query": "Query", "passage": "Passage"}
 
39
 
40
 
41
  class JinaEmbeddingsV4Processor(Qwen2_5_VLProcessor):
 
324
  task_label: Union[str, List[str]],
325
  processor_fn: Callable,
326
  desc: str,
327
+ vector_type: Union[str, VectorType] = VectorType.single,
328
  return_numpy: bool = False,
329
  batch_size: int = 32,
330
  truncate_dim: Optional[int] = None,
 
344
  device_type=torch.device(self.device).type, dtype=torch.bfloat16
345
  ):
346
  embeddings = self(**batch, task_label=task_label)
347
+ vector_type_str = vector_type.value if isinstance(vector_type, VectorType) else vector_type
348
+ if vector_type_str == VectorType.single.value:
349
  embeddings = embeddings.single_vec_emb
350
  if truncate_dim is not None:
351
  embeddings = embeddings[:, :truncate_dim]
 
362
 
363
  def _validate_encoding_params(
364
  self,
365
+ vector_type: Optional[Union[str, VectorType]] = None,
366
  truncate_dim: Optional[int] = None,
367
  prompt_name: Optional[str] = None,
368
  ) -> Dict[str, Any]:
 
379
  else PREFIX_DICT["query"]
380
  )
381
 
382
+ vector_type = vector_type or VectorType.single
383
+ if isinstance(vector_type, VectorType):
384
+ encode_kwargs["vector_type"] = vector_type.value
 
 
385
  else:
386
+ try:
387
+ vector_type_enum = VectorType(vector_type)
388
+ encode_kwargs["vector_type"] = vector_type_enum.value
389
+ except ValueError:
390
+ raise ValueError(
391
+ f"Invalid vector_type: {vector_type}. Must be one of {[v.value for v in VectorType]}."
392
+ )
393
 
394
  truncate_dim = truncate_dim or self.config.truncate_dim
395
  if truncate_dim is not None and truncate_dim not in self.config.matryoshka_dims:
 
422
  task: Optional[str] = None,
423
  max_length: int = 8192,
424
  batch_size: int = 8,
425
+ vector_type: Optional[Union[str, VectorType]] = None,
426
  return_numpy: bool = False,
427
  truncate_dim: Optional[int] = None,
428
  prompt_name: Optional[str] = None,
 
434
  texts: List of text strings to encode
435
  max_length: Maximum token length for text processing
436
  batch_size: Number of texts to process at once
437
+ vector_type: Type of embedding vector to generate (VectorType.single or VectorType.multi)
438
  return_numpy: Whether to return numpy arrays instead of torch tensors
439
  truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
440
  prompt_name: Type of text being encoded ('query' or 'passage')
 
486
  images: List[Union[str, Image.Image]],
487
  task: Optional[str] = None,
488
  batch_size: int = 8,
489
+ vector_type: Optional[Union[str, VectorType]] = None,
490
  return_numpy: bool = False,
491
  truncate_dim: Optional[int] = None,
492
  max_pixels: Optional[int] = None,
 
497
  Args:
498
  images: List of PIL images, URLs, or local file paths to encode
499
  batch_size: Number of images to process at once
500
+ vector_type: Type of embedding vector to generate (VectorType.single or VectorType.multi)
501
  return_numpy: Whether to return numpy arrays instead of torch tensors
502
  truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
503
  max_pixels: Maximum number of pixels to process per image