fix: remove the padding tokens when a list of multivectors are returned
Browse files
modeling_jina_embeddings_v4.py
CHANGED
@@ -127,11 +127,13 @@ class JinaEmbeddingsV4ModelOutput:
|
|
127 |
vlm_last_hidden_states (torch.Tensor, optional): Last hidden states of the VLM.
|
128 |
single_vec_emb (torch.Tensor, optional): Single-vector embeddings.
|
129 |
multi_vec_emb (torch.Tensor, optional): Multi-vector embeddings.
|
|
|
130 |
"""
|
131 |
|
132 |
vlm_last_hidden_states: Optional[torch.Tensor] = None
|
133 |
single_vec_emb: Optional[torch.Tensor] = None
|
134 |
multi_vec_emb: Optional[torch.Tensor] = None
|
|
|
135 |
|
136 |
|
137 |
class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
@@ -312,6 +314,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
312 |
),
|
313 |
single_vec_emb=single_vec_emb,
|
314 |
multi_vec_emb=multi_vec_emb,
|
|
|
315 |
)
|
316 |
|
317 |
def _process_batches(
|
@@ -340,12 +343,20 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
340 |
device_type=torch.device(self.device).type, dtype=torch.bfloat16
|
341 |
):
|
342 |
embeddings = self(**batch, task_label=task_label)
|
|
|
343 |
if not return_multivector:
|
344 |
embeddings = embeddings.single_vec_emb
|
345 |
if truncate_dim is not None:
|
346 |
embeddings = embeddings[:, :truncate_dim]
|
347 |
else:
|
348 |
embeddings = embeddings.multi_vec_emb
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
349 |
results.append(
|
350 |
embeddings.cpu()
|
351 |
if return_numpy
|
|
|
127 |
vlm_last_hidden_states (torch.Tensor, optional): Last hidden states of the VLM.
|
128 |
single_vec_emb (torch.Tensor, optional): Single-vector embeddings.
|
129 |
multi_vec_emb (torch.Tensor, optional): Multi-vector embeddings.
|
130 |
+
attention_mask (torch.Tensor, optional): Attention mask.
|
131 |
"""
|
132 |
|
133 |
vlm_last_hidden_states: Optional[torch.Tensor] = None
|
134 |
single_vec_emb: Optional[torch.Tensor] = None
|
135 |
multi_vec_emb: Optional[torch.Tensor] = None
|
136 |
+
attention_mask: Optional[torch.Tensor] = None
|
137 |
|
138 |
|
139 |
class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
|
314 |
),
|
315 |
single_vec_emb=single_vec_emb,
|
316 |
multi_vec_emb=multi_vec_emb,
|
317 |
+
attention_mask=attention_mask,
|
318 |
)
|
319 |
|
320 |
def _process_batches(
|
|
|
343 |
device_type=torch.device(self.device).type, dtype=torch.bfloat16
|
344 |
):
|
345 |
embeddings = self(**batch, task_label=task_label)
|
346 |
+
attention_mask = embeddings.attention_mask
|
347 |
if not return_multivector:
|
348 |
embeddings = embeddings.single_vec_emb
|
349 |
if truncate_dim is not None:
|
350 |
embeddings = embeddings[:, :truncate_dim]
|
351 |
else:
|
352 |
embeddings = embeddings.multi_vec_emb
|
353 |
+
if return_multivector:
|
354 |
+
# Get valid token mask from attention_mask
|
355 |
+
valid_tokens = attention_mask.bool()
|
356 |
+
# Remove padding by selecting only valid tokens for each sequence
|
357 |
+
embeddings = [emb[mask] for emb, mask in zip(embeddings, valid_tokens)]
|
358 |
+
# Stack back into tensor with variable sequence lengths
|
359 |
+
embeddings = torch.stack(embeddings)
|
360 |
results.append(
|
361 |
embeddings.cpu()
|
362 |
if return_numpy
|