nan commited on
Commit
ef1876f
·
1 Parent(s): 4c2a7cb

fix: remove the padding tokens when a list of multivectors are returned

Browse files
Files changed (1) hide show
  1. modeling_jina_embeddings_v4.py +11 -0
modeling_jina_embeddings_v4.py CHANGED
@@ -127,11 +127,13 @@ class JinaEmbeddingsV4ModelOutput:
127
  vlm_last_hidden_states (torch.Tensor, optional): Last hidden states of the VLM.
128
  single_vec_emb (torch.Tensor, optional): Single-vector embeddings.
129
  multi_vec_emb (torch.Tensor, optional): Multi-vector embeddings.
 
130
  """
131
 
132
  vlm_last_hidden_states: Optional[torch.Tensor] = None
133
  single_vec_emb: Optional[torch.Tensor] = None
134
  multi_vec_emb: Optional[torch.Tensor] = None
 
135
 
136
 
137
  class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
@@ -312,6 +314,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
312
  ),
313
  single_vec_emb=single_vec_emb,
314
  multi_vec_emb=multi_vec_emb,
 
315
  )
316
 
317
  def _process_batches(
@@ -340,12 +343,20 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
340
  device_type=torch.device(self.device).type, dtype=torch.bfloat16
341
  ):
342
  embeddings = self(**batch, task_label=task_label)
 
343
  if not return_multivector:
344
  embeddings = embeddings.single_vec_emb
345
  if truncate_dim is not None:
346
  embeddings = embeddings[:, :truncate_dim]
347
  else:
348
  embeddings = embeddings.multi_vec_emb
 
 
 
 
 
 
 
349
  results.append(
350
  embeddings.cpu()
351
  if return_numpy
 
127
  vlm_last_hidden_states (torch.Tensor, optional): Last hidden states of the VLM.
128
  single_vec_emb (torch.Tensor, optional): Single-vector embeddings.
129
  multi_vec_emb (torch.Tensor, optional): Multi-vector embeddings.
130
+ attention_mask (torch.Tensor, optional): Attention mask.
131
  """
132
 
133
  vlm_last_hidden_states: Optional[torch.Tensor] = None
134
  single_vec_emb: Optional[torch.Tensor] = None
135
  multi_vec_emb: Optional[torch.Tensor] = None
136
+ attention_mask: Optional[torch.Tensor] = None
137
 
138
 
139
  class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
 
314
  ),
315
  single_vec_emb=single_vec_emb,
316
  multi_vec_emb=multi_vec_emb,
317
+ attention_mask=attention_mask,
318
  )
319
 
320
  def _process_batches(
 
343
  device_type=torch.device(self.device).type, dtype=torch.bfloat16
344
  ):
345
  embeddings = self(**batch, task_label=task_label)
346
+ attention_mask = embeddings.attention_mask
347
  if not return_multivector:
348
  embeddings = embeddings.single_vec_emb
349
  if truncate_dim is not None:
350
  embeddings = embeddings[:, :truncate_dim]
351
  else:
352
  embeddings = embeddings.multi_vec_emb
353
+ if return_multivector:
354
+ # Get valid token mask from attention_mask
355
+ valid_tokens = attention_mask.bool()
356
+ # Remove padding by selecting only valid tokens for each sequence
357
+ embeddings = [emb[mask] for emb, mask in zip(embeddings, valid_tokens)]
358
+ # Stack back into tensor with variable sequence lengths
359
+ embeddings = torch.stack(embeddings)
360
  results.append(
361
  embeddings.cpu()
362
  if return_numpy