jupyterjazz commited on
Commit
7bf3b86
·
verified ·
1 Parent(s): 70044fb

fix-image-pooling (#9)

Browse files

- fix: image pooling (725b8ba6ba8cff17579843ca46e5eb21f7d5ea37)
- chore: remove prints (660fe4c4d743be00c7bdcb17c740414c21c53374)

Files changed (1) hide show
  1. modeling_jina_embeddings_v4.py +10 -12
modeling_jina_embeddings_v4.py CHANGED
@@ -216,22 +216,21 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
216
  Project the hidden states to single-vector embeddings.
217
  """
218
  if self._input_has_image(input_ids[0]): # got document image
219
- img_start_pos = torch.where(
220
- input_ids[0] == self.config.vision_start_token_id
221
- )[0][0]
222
- img_end_pos = torch.where(input_ids[0] == self.config.vision_end_token_id)[
223
- 0
224
- ][0]
225
- pooled_output = (
226
- hidden_states[0][img_start_pos : img_end_pos + 1]
227
- .mean(dim=0)
228
- .unsqueeze(0)
229
- )
230
 
231
  else: # got query text
232
  pooled_output = torch.sum(
233
  hidden_states * attention_mask.unsqueeze(-1), dim=1
234
  ) / torch.sum(attention_mask, dim=1, keepdim=True)
 
235
  single_vec_emb = self.single_vector_projector(pooled_output)
236
  return torch.nn.functional.normalize(single_vec_emb, dim=-1)
237
 
@@ -317,7 +316,6 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
317
  embeddings = embeddings[:, :truncate_dim]
318
  else:
319
  embeddings = embeddings.multi_vec_emb
320
-
321
  results.append(
322
  embeddings.cpu()
323
  if return_numpy
 
216
  Project the hidden states to single-vector embeddings.
217
  """
218
  if self._input_has_image(input_ids[0]): # got document image
219
+ img_start_positions = torch.where(input_ids == self.config.vision_start_token_id)[1]
220
+ img_end_positions = torch.where(input_ids == self.config.vision_end_token_id)[1]
221
+
222
+ batch_size, seq_len = input_ids.shape
223
+ position_indices = torch.arange(seq_len, device=input_ids.device).expand(batch_size, -1)
224
+ image_mask = (position_indices >= img_start_positions.unsqueeze(1)) & (position_indices <= img_end_positions.unsqueeze(1))
225
+
226
+ masked_hidden_states = hidden_states * image_mask.unsqueeze(-1)
227
+ pooled_output = masked_hidden_states.sum(dim=1) / image_mask.sum(dim=1, keepdim=True)
 
 
228
 
229
  else: # got query text
230
  pooled_output = torch.sum(
231
  hidden_states * attention_mask.unsqueeze(-1), dim=1
232
  ) / torch.sum(attention_mask, dim=1, keepdim=True)
233
+
234
  single_vec_emb = self.single_vector_projector(pooled_output)
235
  return torch.nn.functional.normalize(single_vec_emb, dim=-1)
236
 
 
316
  embeddings = embeddings[:, :truncate_dim]
317
  else:
318
  embeddings = embeddings.multi_vec_emb
 
319
  results.append(
320
  embeddings.cpu()
321
  if return_numpy