fix: fix the bug when return_numpy is false
Browse files
modeling_jina_embeddings_v4.py
CHANGED
@@ -334,6 +334,8 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
334 |
shuffle=False,
|
335 |
collate_fn=processor_fn,
|
336 |
)
|
|
|
|
|
337 |
results = []
|
338 |
self.eval()
|
339 |
for batch in tqdm(dataloader, desc=desc):
|
@@ -344,23 +346,18 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
344 |
):
|
345 |
embeddings = self(**batch, task_label=task_label)
|
346 |
attention_mask = embeddings.attention_mask
|
347 |
-
if not
|
348 |
embeddings = embeddings.single_vec_emb
|
349 |
if truncate_dim is not None:
|
350 |
embeddings = embeddings[:, :truncate_dim]
|
351 |
else:
|
352 |
embeddings = embeddings.multi_vec_emb
|
353 |
-
if return_multivector
|
354 |
-
# Get valid token mask from attention_mask
|
355 |
valid_tokens = attention_mask.bool()
|
356 |
-
# Remove padding by selecting only valid tokens for each sequence
|
357 |
embeddings = [emb[mask] for emb, mask in zip(embeddings, valid_tokens)]
|
358 |
-
# Stack back into tensor with variable sequence lengths
|
359 |
results.append(embeddings)
|
360 |
else:
|
361 |
results.append(
|
362 |
-
# If return_numpy is True, move embeddings to CPU for numpy conversion
|
363 |
-
# Otherwise, unbind the tensor into a list of individual tensors along dim=0
|
364 |
embeddings.cpu()
|
365 |
if return_numpy
|
366 |
else list(torch.unbind(embeddings))
|
@@ -450,6 +447,12 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
450 |
)
|
451 |
|
452 |
return_list = isinstance(texts, list)
|
|
|
|
|
|
|
|
|
|
|
|
|
453 |
|
454 |
if isinstance(texts, str):
|
455 |
texts = [texts]
|
@@ -498,7 +501,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
498 |
images: image(s) to encode, can be PIL Image(s), URL(s), or local file path(s)
|
499 |
batch_size: Number of images to process at once
|
500 |
return_multivector: Whether to return multi-vector embeddings instead of single-vector embeddings
|
501 |
-
return_numpy: Whether to return numpy arrays instead of torch tensors
|
502 |
truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
|
503 |
max_pixels: Maximum number of pixels to process per image
|
504 |
|
@@ -515,6 +518,12 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
515 |
|
516 |
return_list = isinstance(images, list)
|
517 |
|
|
|
|
|
|
|
|
|
|
|
|
|
518 |
# Convert single image to list
|
519 |
if isinstance(images, (str, Image.Image)):
|
520 |
images = [images]
|
|
|
334 |
shuffle=False,
|
335 |
collate_fn=processor_fn,
|
336 |
)
|
337 |
+
if return_multivector and len(data) > 1:
|
338 |
+
assert not return_numpy, "`return_numpy` is not supported when `return_multivector=True` and more than one data is encoded"
|
339 |
results = []
|
340 |
self.eval()
|
341 |
for batch in tqdm(dataloader, desc=desc):
|
|
|
346 |
):
|
347 |
embeddings = self(**batch, task_label=task_label)
|
348 |
attention_mask = embeddings.attention_mask
|
349 |
+
if return_multivector and not return_numpy:
|
350 |
embeddings = embeddings.single_vec_emb
|
351 |
if truncate_dim is not None:
|
352 |
embeddings = embeddings[:, :truncate_dim]
|
353 |
else:
|
354 |
embeddings = embeddings.multi_vec_emb
|
355 |
+
if return_multivector:
|
|
|
356 |
valid_tokens = attention_mask.bool()
|
|
|
357 |
embeddings = [emb[mask] for emb, mask in zip(embeddings, valid_tokens)]
|
|
|
358 |
results.append(embeddings)
|
359 |
else:
|
360 |
results.append(
|
|
|
|
|
361 |
embeddings.cpu()
|
362 |
if return_numpy
|
363 |
else list(torch.unbind(embeddings))
|
|
|
447 |
)
|
448 |
|
449 |
return_list = isinstance(texts, list)
|
450 |
+
|
451 |
+
# If return_multivector is True and encoding multiple texts, ignore return_numpy
|
452 |
+
if return_multivector and return_list and len(texts) > 1:
|
453 |
+
if return_numpy:
|
454 |
+
print("Warning: `return_numpy` is ignored when `return_multivector=True` and `len(texts) > 1`")
|
455 |
+
return_numpy = False
|
456 |
|
457 |
if isinstance(texts, str):
|
458 |
texts = [texts]
|
|
|
501 |
images: image(s) to encode, can be PIL Image(s), URL(s), or local file path(s)
|
502 |
batch_size: Number of images to process at once
|
503 |
return_multivector: Whether to return multi-vector embeddings instead of single-vector embeddings
|
504 |
+
return_numpy: Whether to return numpy arrays instead of torch tensors. If `return_multivector` is `True` and more than one image is encoded, this parameter is ignored.
|
505 |
truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
|
506 |
max_pixels: Maximum number of pixels to process per image
|
507 |
|
|
|
518 |
|
519 |
return_list = isinstance(images, list)
|
520 |
|
521 |
+
# If return_multivector is True and encoding multiple images, ignore return_numpy
|
522 |
+
if return_multivector and return_list and len(images) > 1:
|
523 |
+
if return_numpy:
|
524 |
+
print("Warning: `return_numpy` is ignored when `return_multivector=True` and `len(images) > 1`")
|
525 |
+
return_numpy = False
|
526 |
+
|
527 |
# Convert single image to list
|
528 |
if isinstance(images, (str, Image.Image)):
|
529 |
images = [images]
|