feat: use enum for the vector type
Browse files- modeling_jina_embeddings_v4.py +23 -14
modeling_jina_embeddings_v4.py
CHANGED
@@ -30,8 +30,12 @@ class PromptType(str, Enum):
|
|
30 |
passage = "passage"
|
31 |
|
32 |
|
|
|
|
|
|
|
|
|
|
|
33 |
PREFIX_DICT = {"query": "Query", "passage": "Passage"}
|
34 |
-
VECTOR_TYPES = ["single", "multi"]
|
35 |
|
36 |
|
37 |
class JinaEmbeddingsV4Processor(Qwen2_5_VLProcessor):
|
@@ -320,7 +324,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
320 |
task_label: Union[str, List[str]],
|
321 |
processor_fn: Callable,
|
322 |
desc: str,
|
323 |
-
vector_type: str =
|
324 |
return_numpy: bool = False,
|
325 |
batch_size: int = 32,
|
326 |
truncate_dim: Optional[int] = None,
|
@@ -340,7 +344,8 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
340 |
device_type=torch.device(self.device).type, dtype=torch.bfloat16
|
341 |
):
|
342 |
embeddings = self(**batch, task_label=task_label)
|
343 |
-
if vector_type
|
|
|
344 |
embeddings = embeddings.single_vec_emb
|
345 |
if truncate_dim is not None:
|
346 |
embeddings = embeddings[:, :truncate_dim]
|
@@ -357,7 +362,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
357 |
|
358 |
def _validate_encoding_params(
|
359 |
self,
|
360 |
-
vector_type: Optional[str] = None,
|
361 |
truncate_dim: Optional[int] = None,
|
362 |
prompt_name: Optional[str] = None,
|
363 |
) -> Dict[str, Any]:
|
@@ -374,13 +379,17 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
374 |
else PREFIX_DICT["query"]
|
375 |
)
|
376 |
|
377 |
-
vector_type = vector_type or
|
378 |
-
if vector_type
|
379 |
-
|
380 |
-
f"Invalid vector_type: {vector_type}. Must be one of {VECTOR_TYPES}."
|
381 |
-
)
|
382 |
else:
|
383 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
384 |
|
385 |
truncate_dim = truncate_dim or self.config.truncate_dim
|
386 |
if truncate_dim is not None and truncate_dim not in self.config.matryoshka_dims:
|
@@ -413,7 +422,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
413 |
task: Optional[str] = None,
|
414 |
max_length: int = 8192,
|
415 |
batch_size: int = 8,
|
416 |
-
vector_type: Optional[str] = None,
|
417 |
return_numpy: bool = False,
|
418 |
truncate_dim: Optional[int] = None,
|
419 |
prompt_name: Optional[str] = None,
|
@@ -425,7 +434,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
425 |
texts: List of text strings to encode
|
426 |
max_length: Maximum token length for text processing
|
427 |
batch_size: Number of texts to process at once
|
428 |
-
vector_type: Type of embedding vector to generate (
|
429 |
return_numpy: Whether to return numpy arrays instead of torch tensors
|
430 |
truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
|
431 |
prompt_name: Type of text being encoded ('query' or 'passage')
|
@@ -477,7 +486,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
477 |
images: List[Union[str, Image.Image]],
|
478 |
task: Optional[str] = None,
|
479 |
batch_size: int = 8,
|
480 |
-
vector_type: Optional[str] = None,
|
481 |
return_numpy: bool = False,
|
482 |
truncate_dim: Optional[int] = None,
|
483 |
max_pixels: Optional[int] = None,
|
@@ -488,7 +497,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
488 |
Args:
|
489 |
images: List of PIL images, URLs, or local file paths to encode
|
490 |
batch_size: Number of images to process at once
|
491 |
-
vector_type: Type of embedding vector to generate (
|
492 |
return_numpy: Whether to return numpy arrays instead of torch tensors
|
493 |
truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
|
494 |
max_pixels: Maximum number of pixels to process per image
|
|
|
30 |
passage = "passage"
|
31 |
|
32 |
|
33 |
+
class VectorType(str, Enum):
|
34 |
+
single = "single"
|
35 |
+
multi = "multi"
|
36 |
+
|
37 |
+
|
38 |
PREFIX_DICT = {"query": "Query", "passage": "Passage"}
|
|
|
39 |
|
40 |
|
41 |
class JinaEmbeddingsV4Processor(Qwen2_5_VLProcessor):
|
|
|
324 |
task_label: Union[str, List[str]],
|
325 |
processor_fn: Callable,
|
326 |
desc: str,
|
327 |
+
vector_type: Union[str, VectorType] = VectorType.single,
|
328 |
return_numpy: bool = False,
|
329 |
batch_size: int = 32,
|
330 |
truncate_dim: Optional[int] = None,
|
|
|
344 |
device_type=torch.device(self.device).type, dtype=torch.bfloat16
|
345 |
):
|
346 |
embeddings = self(**batch, task_label=task_label)
|
347 |
+
vector_type_str = vector_type.value if isinstance(vector_type, VectorType) else vector_type
|
348 |
+
if vector_type_str == VectorType.single.value:
|
349 |
embeddings = embeddings.single_vec_emb
|
350 |
if truncate_dim is not None:
|
351 |
embeddings = embeddings[:, :truncate_dim]
|
|
|
362 |
|
363 |
def _validate_encoding_params(
|
364 |
self,
|
365 |
+
vector_type: Optional[Union[str, VectorType]] = None,
|
366 |
truncate_dim: Optional[int] = None,
|
367 |
prompt_name: Optional[str] = None,
|
368 |
) -> Dict[str, Any]:
|
|
|
379 |
else PREFIX_DICT["query"]
|
380 |
)
|
381 |
|
382 |
+
vector_type = vector_type or VectorType.single
|
383 |
+
if isinstance(vector_type, VectorType):
|
384 |
+
encode_kwargs["vector_type"] = vector_type.value
|
|
|
|
|
385 |
else:
|
386 |
+
try:
|
387 |
+
vector_type_enum = VectorType(vector_type)
|
388 |
+
encode_kwargs["vector_type"] = vector_type_enum.value
|
389 |
+
except ValueError:
|
390 |
+
raise ValueError(
|
391 |
+
f"Invalid vector_type: {vector_type}. Must be one of {[v.value for v in VectorType]}."
|
392 |
+
)
|
393 |
|
394 |
truncate_dim = truncate_dim or self.config.truncate_dim
|
395 |
if truncate_dim is not None and truncate_dim not in self.config.matryoshka_dims:
|
|
|
422 |
task: Optional[str] = None,
|
423 |
max_length: int = 8192,
|
424 |
batch_size: int = 8,
|
425 |
+
vector_type: Optional[Union[str, VectorType]] = None,
|
426 |
return_numpy: bool = False,
|
427 |
truncate_dim: Optional[int] = None,
|
428 |
prompt_name: Optional[str] = None,
|
|
|
434 |
texts: List of text strings to encode
|
435 |
max_length: Maximum token length for text processing
|
436 |
batch_size: Number of texts to process at once
|
437 |
+
vector_type: Type of embedding vector to generate (VectorType.single or VectorType.multi)
|
438 |
return_numpy: Whether to return numpy arrays instead of torch tensors
|
439 |
truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
|
440 |
prompt_name: Type of text being encoded ('query' or 'passage')
|
|
|
486 |
images: List[Union[str, Image.Image]],
|
487 |
task: Optional[str] = None,
|
488 |
batch_size: int = 8,
|
489 |
+
vector_type: Optional[Union[str, VectorType]] = None,
|
490 |
return_numpy: bool = False,
|
491 |
truncate_dim: Optional[int] = None,
|
492 |
max_pixels: Optional[int] = None,
|
|
|
497 |
Args:
|
498 |
images: List of PIL images, URLs, or local file paths to encode
|
499 |
batch_size: Number of images to process at once
|
500 |
+
vector_type: Type of embedding vector to generate (VectorType.single or VectorType.multi)
|
501 |
return_numpy: Whether to return numpy arrays instead of torch tensors
|
502 |
truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
|
503 |
max_pixels: Maximum number of pixels to process per image
|