nan commited on
Commit
6bb8cf2
·
1 Parent(s): 205b18f

fix: fix the bug when return_numpy is false

Browse files
Files changed (1) hide show
  1. modeling_jina_embeddings_v4.py +17 -8
modeling_jina_embeddings_v4.py CHANGED
@@ -334,6 +334,8 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
334
  shuffle=False,
335
  collate_fn=processor_fn,
336
  )
 
 
337
  results = []
338
  self.eval()
339
  for batch in tqdm(dataloader, desc=desc):
@@ -344,23 +346,18 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
344
  ):
345
  embeddings = self(**batch, task_label=task_label)
346
  attention_mask = embeddings.attention_mask
347
- if not return_multivector:
348
  embeddings = embeddings.single_vec_emb
349
  if truncate_dim is not None:
350
  embeddings = embeddings[:, :truncate_dim]
351
  else:
352
  embeddings = embeddings.multi_vec_emb
353
- if return_multivector and not return_numpy:
354
- # Get valid token mask from attention_mask
355
  valid_tokens = attention_mask.bool()
356
- # Remove padding by selecting only valid tokens for each sequence
357
  embeddings = [emb[mask] for emb, mask in zip(embeddings, valid_tokens)]
358
- # Stack back into tensor with variable sequence lengths
359
  results.append(embeddings)
360
  else:
361
  results.append(
362
- # If return_numpy is True, move embeddings to CPU for numpy conversion
363
- # Otherwise, unbind the tensor into a list of individual tensors along dim=0
364
  embeddings.cpu()
365
  if return_numpy
366
  else list(torch.unbind(embeddings))
@@ -450,6 +447,12 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
450
  )
451
 
452
  return_list = isinstance(texts, list)
 
 
 
 
 
 
453
 
454
  if isinstance(texts, str):
455
  texts = [texts]
@@ -498,7 +501,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
498
  images: image(s) to encode, can be PIL Image(s), URL(s), or local file path(s)
499
  batch_size: Number of images to process at once
500
  return_multivector: Whether to return multi-vector embeddings instead of single-vector embeddings
501
- return_numpy: Whether to return numpy arrays instead of torch tensors
502
  truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
503
  max_pixels: Maximum number of pixels to process per image
504
 
@@ -515,6 +518,12 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
515
 
516
  return_list = isinstance(images, list)
517
 
 
 
 
 
 
 
518
  # Convert single image to list
519
  if isinstance(images, (str, Image.Image)):
520
  images = [images]
 
334
  shuffle=False,
335
  collate_fn=processor_fn,
336
  )
337
+ if return_multivector and len(data) > 1:
338
+ assert not return_numpy, "`return_numpy` is not supported when `return_multivector=True` and more than one data is encoded"
339
  results = []
340
  self.eval()
341
  for batch in tqdm(dataloader, desc=desc):
 
346
  ):
347
  embeddings = self(**batch, task_label=task_label)
348
  attention_mask = embeddings.attention_mask
349
+ if return_multivector and not return_numpy:
350
  embeddings = embeddings.single_vec_emb
351
  if truncate_dim is not None:
352
  embeddings = embeddings[:, :truncate_dim]
353
  else:
354
  embeddings = embeddings.multi_vec_emb
355
+ if return_multivector:
 
356
  valid_tokens = attention_mask.bool()
 
357
  embeddings = [emb[mask] for emb, mask in zip(embeddings, valid_tokens)]
 
358
  results.append(embeddings)
359
  else:
360
  results.append(
 
 
361
  embeddings.cpu()
362
  if return_numpy
363
  else list(torch.unbind(embeddings))
 
447
  )
448
 
449
  return_list = isinstance(texts, list)
450
+
451
+ # If return_multivector is True and encoding multiple texts, ignore return_numpy
452
+ if return_multivector and return_list and len(texts) > 1:
453
+ if return_numpy:
454
+ print("Warning: `return_numpy` is ignored when `return_multivector=True` and `len(texts) > 1`")
455
+ return_numpy = False
456
 
457
  if isinstance(texts, str):
458
  texts = [texts]
 
501
  images: image(s) to encode, can be PIL Image(s), URL(s), or local file path(s)
502
  batch_size: Number of images to process at once
503
  return_multivector: Whether to return multi-vector embeddings instead of single-vector embeddings
504
+ return_numpy: Whether to return numpy arrays instead of torch tensors. If `return_multivector` is `True` and more than one image is encoded, this parameter is ignored.
505
  truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
506
  max_pixels: Maximum number of pixels to process per image
507
 
 
518
 
519
  return_list = isinstance(images, list)
520
 
521
+ # If return_multivector is True and encoding multiple images, ignore return_numpy
522
+ if return_multivector and return_list and len(images) > 1:
523
+ if return_numpy:
524
+ print("Warning: `return_numpy` is ignored when `return_multivector=True` and `len(images) > 1`")
525
+ return_numpy = False
526
+
527
  # Convert single image to list
528
  if isinstance(images, (str, Image.Image)):
529
  images = [images]