fix-image-pooling (#9)
Browse files- fix: image pooling (725b8ba6ba8cff17579843ca46e5eb21f7d5ea37)
- chore: remove prints (660fe4c4d743be00c7bdcb17c740414c21c53374)
- modeling_jina_embeddings_v4.py +10 -12
modeling_jina_embeddings_v4.py
CHANGED
@@ -216,22 +216,21 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
216 |
Project the hidden states to single-vector embeddings.
|
217 |
"""
|
218 |
if self._input_has_image(input_ids[0]): # got document image
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
.unsqueeze(0)
|
229 |
-
)
|
230 |
|
231 |
else: # got query text
|
232 |
pooled_output = torch.sum(
|
233 |
hidden_states * attention_mask.unsqueeze(-1), dim=1
|
234 |
) / torch.sum(attention_mask, dim=1, keepdim=True)
|
|
|
235 |
single_vec_emb = self.single_vector_projector(pooled_output)
|
236 |
return torch.nn.functional.normalize(single_vec_emb, dim=-1)
|
237 |
|
@@ -317,7 +316,6 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
317 |
embeddings = embeddings[:, :truncate_dim]
|
318 |
else:
|
319 |
embeddings = embeddings.multi_vec_emb
|
320 |
-
|
321 |
results.append(
|
322 |
embeddings.cpu()
|
323 |
if return_numpy
|
|
|
216 |
Project the hidden states to single-vector embeddings.
|
217 |
"""
|
218 |
if self._input_has_image(input_ids[0]): # got document image
|
219 |
+
img_start_positions = torch.where(input_ids == self.config.vision_start_token_id)[1]
|
220 |
+
img_end_positions = torch.where(input_ids == self.config.vision_end_token_id)[1]
|
221 |
+
|
222 |
+
batch_size, seq_len = input_ids.shape
|
223 |
+
position_indices = torch.arange(seq_len, device=input_ids.device).expand(batch_size, -1)
|
224 |
+
image_mask = (position_indices >= img_start_positions.unsqueeze(1)) & (position_indices <= img_end_positions.unsqueeze(1))
|
225 |
+
|
226 |
+
masked_hidden_states = hidden_states * image_mask.unsqueeze(-1)
|
227 |
+
pooled_output = masked_hidden_states.sum(dim=1) / image_mask.sum(dim=1, keepdim=True)
|
|
|
|
|
228 |
|
229 |
else: # got query text
|
230 |
pooled_output = torch.sum(
|
231 |
hidden_states * attention_mask.unsqueeze(-1), dim=1
|
232 |
) / torch.sum(attention_mask, dim=1, keepdim=True)
|
233 |
+
|
234 |
single_vec_emb = self.single_vector_projector(pooled_output)
|
235 |
return torch.nn.functional.normalize(single_vec_emb, dim=-1)
|
236 |
|
|
|
316 |
embeddings = embeddings[:, :truncate_dim]
|
317 |
else:
|
318 |
embeddings = embeddings.multi_vec_emb
|
|
|
319 |
results.append(
|
320 |
embeddings.cpu()
|
321 |
if return_numpy
|