feat: replace the output_format with a boolean flag
Browse files- modeling_jina_embeddings_v4.py +11 -26
modeling_jina_embeddings_v4.py
CHANGED
@@ -30,11 +30,6 @@ class PromptType(str, Enum):
|
|
30 |
passage = "passage"
|
31 |
|
32 |
|
33 |
-
class VectorOutputFormat(str, Enum):
|
34 |
-
SINGLE = "single"
|
35 |
-
MULTIPLE = "multiple"
|
36 |
-
|
37 |
-
|
38 |
PREFIX_DICT = {"query": "Query", "passage": "Passage"}
|
39 |
|
40 |
|
@@ -325,7 +320,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
325 |
task_label: Union[str, List[str]],
|
326 |
processor_fn: Callable,
|
327 |
desc: str,
|
328 |
-
|
329 |
return_numpy: bool = False,
|
330 |
batch_size: int = 32,
|
331 |
truncate_dim: Optional[int] = None,
|
@@ -345,8 +340,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
345 |
device_type=torch.device(self.device).type, dtype=torch.bfloat16
|
346 |
):
|
347 |
embeddings = self(**batch, task_label=task_label)
|
348 |
-
|
349 |
-
if output_format_str == VectorOutputFormat.SINGLE.value:
|
350 |
embeddings = embeddings.single_vec_emb
|
351 |
if truncate_dim is not None:
|
352 |
embeddings = embeddings[:, :truncate_dim]
|
@@ -363,7 +357,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
363 |
|
364 |
def _validate_encoding_params(
|
365 |
self,
|
366 |
-
|
367 |
truncate_dim: Optional[int] = None,
|
368 |
prompt_name: Optional[str] = None,
|
369 |
) -> Dict[str, Any]:
|
@@ -380,17 +374,8 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
380 |
else PREFIX_DICT["query"]
|
381 |
)
|
382 |
|
383 |
-
|
384 |
-
|
385 |
-
encode_kwargs["output_format"] = output_format.value
|
386 |
-
else:
|
387 |
-
try:
|
388 |
-
output_format_enum = VectorOutputFormat(output_format)
|
389 |
-
encode_kwargs["output_format"] = output_format_enum.value
|
390 |
-
except ValueError:
|
391 |
-
raise ValueError(
|
392 |
-
f"Invalid output_format: {output_format}. Must be one of {[v.value for v in VectorOutputFormat]}."
|
393 |
-
)
|
394 |
|
395 |
truncate_dim = truncate_dim or self.config.truncate_dim
|
396 |
if truncate_dim is not None and truncate_dim not in self.config.matryoshka_dims:
|
@@ -423,7 +408,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
423 |
task: Optional[str] = None,
|
424 |
max_length: int = 8192,
|
425 |
batch_size: int = 8,
|
426 |
-
|
427 |
return_numpy: bool = False,
|
428 |
truncate_dim: Optional[int] = None,
|
429 |
prompt_name: Optional[str] = None,
|
@@ -435,7 +420,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
435 |
texts: text or list of text strings to encode
|
436 |
max_length: Maximum token length for text processing
|
437 |
batch_size: Number of texts to process at once
|
438 |
-
|
439 |
return_numpy: Whether to return numpy arrays instead of torch tensors
|
440 |
truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
|
441 |
prompt_name: Type of text being encoded ('query' or 'passage')
|
@@ -445,7 +430,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
445 |
"""
|
446 |
prompt_name = prompt_name or "query"
|
447 |
encode_kwargs = self._validate_encoding_params(
|
448 |
-
|
449 |
)
|
450 |
|
451 |
task = self._validate_task(task)
|
@@ -490,7 +475,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
490 |
images: Union[str, Image.Image, List[Union[str, Image.Image]]],
|
491 |
task: Optional[str] = None,
|
492 |
batch_size: int = 8,
|
493 |
-
|
494 |
return_numpy: bool = False,
|
495 |
truncate_dim: Optional[int] = None,
|
496 |
max_pixels: Optional[int] = None,
|
@@ -501,7 +486,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
501 |
Args:
|
502 |
images: image(s) to encode, can be PIL Image(s), URL(s), or local file path(s)
|
503 |
batch_size: Number of images to process at once
|
504 |
-
|
505 |
return_numpy: Whether to return numpy arrays instead of torch tensors
|
506 |
truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
|
507 |
max_pixels: Maximum number of pixels to process per image
|
@@ -514,7 +499,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
514 |
self.processor.image_processor.max_pixels = (
|
515 |
max_pixels # change during encoding
|
516 |
)
|
517 |
-
encode_kwargs = self._validate_encoding_params(
|
518 |
task = self._validate_task(task)
|
519 |
|
520 |
# Convert single image to list
|
|
|
30 |
passage = "passage"
|
31 |
|
32 |
|
|
|
|
|
|
|
|
|
|
|
33 |
PREFIX_DICT = {"query": "Query", "passage": "Passage"}
|
34 |
|
35 |
|
|
|
320 |
task_label: Union[str, List[str]],
|
321 |
processor_fn: Callable,
|
322 |
desc: str,
|
323 |
+
return_multivector: bool = False,
|
324 |
return_numpy: bool = False,
|
325 |
batch_size: int = 32,
|
326 |
truncate_dim: Optional[int] = None,
|
|
|
340 |
device_type=torch.device(self.device).type, dtype=torch.bfloat16
|
341 |
):
|
342 |
embeddings = self(**batch, task_label=task_label)
|
343 |
+
if not return_multivector:
|
|
|
344 |
embeddings = embeddings.single_vec_emb
|
345 |
if truncate_dim is not None:
|
346 |
embeddings = embeddings[:, :truncate_dim]
|
|
|
357 |
|
358 |
def _validate_encoding_params(
|
359 |
self,
|
360 |
+
return_multivector: Optional[bool] = None,
|
361 |
truncate_dim: Optional[int] = None,
|
362 |
prompt_name: Optional[str] = None,
|
363 |
) -> Dict[str, Any]:
|
|
|
374 |
else PREFIX_DICT["query"]
|
375 |
)
|
376 |
|
377 |
+
return_multivector = return_multivector or False
|
378 |
+
encode_kwargs["return_multivector"] = return_multivector
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
379 |
|
380 |
truncate_dim = truncate_dim or self.config.truncate_dim
|
381 |
if truncate_dim is not None and truncate_dim not in self.config.matryoshka_dims:
|
|
|
408 |
task: Optional[str] = None,
|
409 |
max_length: int = 8192,
|
410 |
batch_size: int = 8,
|
411 |
+
return_multivector: bool = False,
|
412 |
return_numpy: bool = False,
|
413 |
truncate_dim: Optional[int] = None,
|
414 |
prompt_name: Optional[str] = None,
|
|
|
420 |
texts: text or list of text strings to encode
|
421 |
max_length: Maximum token length for text processing
|
422 |
batch_size: Number of texts to process at once
|
423 |
+
return_multivector: Whether to return multi-vector embeddings instead of single-vector embeddings
|
424 |
return_numpy: Whether to return numpy arrays instead of torch tensors
|
425 |
truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
|
426 |
prompt_name: Type of text being encoded ('query' or 'passage')
|
|
|
430 |
"""
|
431 |
prompt_name = prompt_name or "query"
|
432 |
encode_kwargs = self._validate_encoding_params(
|
433 |
+
return_multivector=return_multivector, truncate_dim=truncate_dim, prompt_name=prompt_name
|
434 |
)
|
435 |
|
436 |
task = self._validate_task(task)
|
|
|
475 |
images: Union[str, Image.Image, List[Union[str, Image.Image]]],
|
476 |
task: Optional[str] = None,
|
477 |
batch_size: int = 8,
|
478 |
+
return_multivector: bool = False,
|
479 |
return_numpy: bool = False,
|
480 |
truncate_dim: Optional[int] = None,
|
481 |
max_pixels: Optional[int] = None,
|
|
|
486 |
Args:
|
487 |
images: image(s) to encode, can be PIL Image(s), URL(s), or local file path(s)
|
488 |
batch_size: Number of images to process at once
|
489 |
+
return_multivector: Whether to return multi-vector embeddings instead of single-vector embeddings
|
490 |
return_numpy: Whether to return numpy arrays instead of torch tensors
|
491 |
truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
|
492 |
max_pixels: Maximum number of pixels to process per image
|
|
|
499 |
self.processor.image_processor.max_pixels = (
|
500 |
max_pixels # change during encoding
|
501 |
)
|
502 |
+
encode_kwargs = self._validate_encoding_params(return_multivector=return_multivector, truncate_dim=truncate_dim)
|
503 |
task = self._validate_task(task)
|
504 |
|
505 |
# Convert single image to list
|