jupyterjazz commited on
Commit
998398d
·
verified ·
1 Parent(s): f9712eb

refine-code (#3)

Browse files

- refactor: processor, config, model (0820da1aa617e7b945fe9a49a94a0b2f1d34dea0)

config.json CHANGED
@@ -1,11 +1,11 @@
1
  {
2
  "_name_or_path": "jinaai/jina-embeddings-v4",
3
  "architectures": [
4
- "ColQwen25Duo"
5
  ],
6
  "auto_map": {
7
- "AutoConfig": "configuration_colqwen_duo.ColQwen25DuoConfig",
8
- "AutoModel": "modeling_colqwen_duo.JinaEmbeddingsV4Model"
9
  },
10
  "attention_dropout": 0.0,
11
  "bos_token_id": 151643,
 
1
  {
2
  "_name_or_path": "jinaai/jina-embeddings-v4",
3
  "architectures": [
4
+ "JinaEmbeddingsV4Model"
5
  ],
6
  "auto_map": {
7
+ "AutoConfig": "configuration_jina_embeddings_v4.JinaEmbeddingsV4Config",
8
+ "AutoModel": "modeling_jina_embeddings_v4.JinaEmbeddingsV4Model"
9
  },
10
  "attention_dropout": 0.0,
11
  "bos_token_id": 151643,
configuration_colqwen_duo.py → configuration_jina_embeddings_v4.py RENAMED
@@ -2,9 +2,9 @@ from transformers.models.qwen2_5_vl import Qwen2_5_VLConfig
2
 
3
  from typing import Optional
4
 
5
- class ColQwen25DuoConfig(Qwen2_5_VLConfig):
6
  """
7
- Configuration for the ColQwenDuo model.
8
  """
9
 
10
  def __init__(
 
2
 
3
  from typing import Optional
4
 
5
+ class JinaEmbeddingsV4Config(Qwen2_5_VLConfig):
6
  """
7
+ Configuration for the JinaEmbeddingsV4 model.
8
  """
9
 
10
  def __init__(
modeling_colqwen_duo.py → modeling_jina_embeddings_v4.py RENAMED
@@ -2,11 +2,9 @@ import os
2
  import math
3
  import numpy as np
4
 
5
- from abc import ABC, abstractmethod
6
  from dataclasses import dataclass
7
  from typing import Any, Callable, ClassVar, Dict, List, Optional, Union, cast
8
- from typing_extensions import Unpack
9
- from peft import LoraConfig, PeftModel
10
  import torch
11
  from torch import nn
12
  from torch.utils.data import DataLoader
@@ -17,170 +15,24 @@ from tqdm import tqdm
17
  from enum import Enum
18
  from peft.utils.hotswap import hotswap_adapter
19
 
20
- from transformers import BatchEncoding, BatchFeature
21
-
22
- from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast
23
 
24
  from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor, Qwen2_5_VLForConditionalGeneration
25
 
26
- from transformers.processing_utils import (
27
- AllKwargsForChatTemplate,
28
- ImageInput,
29
- PreTokenizedInput,
30
- TextInput,
31
- VideoInput,
32
- )
33
-
34
  from huggingface_hub import snapshot_download
35
 
36
- from .configuration_colqwen_duo import ColQwen25DuoConfig
37
-
38
-
39
- def get_torch_device() -> str:
40
- """
41
- Returns the device (string) to be used by PyTorch.
42
-
43
- `device` arg defaults to "auto" which will use:
44
- - "cuda:0" if available
45
- - else "mps" if available
46
- - else "cpu".
47
- """
48
-
49
- if torch.cuda.is_available():
50
- device = "cuda:0"
51
- elif torch.backends.mps.is_available(): # for Apple Silicon
52
- device = "mps"
53
- else:
54
- device = "cpu"
55
-
56
- return device
57
 
58
 
59
  class PromptType(str, Enum):
60
  query = "query"
61
  passage = "passage"
62
 
63
-
64
- class BaseVisualRetrieverProcessor(ABC):
65
- """
66
- Base class for visual retriever processors.
67
- """
68
-
69
- @abstractmethod
70
- def process_images(
71
- self,
72
- images: List[Image.Image],
73
- ) -> Union[BatchFeature, BatchEncoding]:
74
- pass
75
-
76
- @abstractmethod
77
- def process_texts(
78
- self,
79
- texts: List[str],
80
- max_length: int = 50,
81
- suffix: Optional[str] = None,
82
- prefix: Optional[str] = None,
83
- ) -> Union[BatchFeature, BatchEncoding]:
84
- pass
85
-
86
- @abstractmethod
87
- def score(
88
- self,
89
- qs: List[torch.Tensor],
90
- ps: List[torch.Tensor],
91
- device: Optional[Union[str, torch.device]] = None,
92
- **kwargs,
93
- ) -> torch.Tensor:
94
- pass
95
-
96
- @staticmethod
97
- def score_single_vector(
98
- qs: List[torch.Tensor],
99
- ps: List[torch.Tensor],
100
- device: Optional[Union[str, torch.device]] = None,
101
- ) -> torch.Tensor:
102
- """
103
- Compute the dot product score for the given single-vector query and passage embeddings.
104
- """
105
- device = device or get_torch_device()
106
-
107
- if len(qs) == 0:
108
- raise ValueError("No queries provided")
109
- if len(ps) == 0:
110
- raise ValueError("No passages provided")
111
-
112
- qs_stacked = torch.stack(qs).to(device)
113
- ps_stacked = torch.stack(ps).to(device)
114
-
115
- scores = torch.einsum("bd,cd->bc", qs_stacked, ps_stacked)
116
- assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}"
117
-
118
- scores = scores.to(torch.float32)
119
- return scores
120
-
121
- @staticmethod
122
- def score_multi_vector(
123
- qs: List[torch.Tensor],
124
- ps: List[torch.Tensor],
125
- batch_size: int = 128,
126
- device: Optional[Union[str, torch.device]] = None,
127
- ) -> torch.Tensor:
128
- """
129
- Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings.
130
- """
131
- device = device or get_torch_device()
132
-
133
- if len(qs) == 0:
134
- raise ValueError("No queries provided")
135
- if len(ps) == 0:
136
- raise ValueError("No passages provided")
137
-
138
- scores_list: List[torch.Tensor] = []
139
-
140
- for i in range(0, len(qs), batch_size):
141
- scores_batch = []
142
- qs_batch = torch.nn.utils.rnn.pad_sequence(qs[i : i + batch_size], batch_first=True, padding_value=0).to(
143
- device
144
- )
145
- for j in range(0, len(ps), batch_size):
146
- ps_batch = torch.nn.utils.rnn.pad_sequence(
147
- ps[j : j + batch_size], batch_first=True, padding_value=0
148
- ).to(device)
149
- scores_batch.append(torch.einsum("bnd,csd->bcns", qs_batch, ps_batch).max(dim=3)[0].sum(dim=2))
150
- scores_batch = torch.cat(scores_batch, dim=1).cpu()
151
- scores_list.append(scores_batch)
152
-
153
- scores = torch.cat(scores_list, dim=0)
154
- assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}"
155
-
156
- scores = scores.to(torch.float32)
157
- return scores
158
-
159
-
160
- class QwenVLProcessor(ABC):
161
-
162
- def __call__(
163
- self,
164
- images: Optional[ImageInput] = None,
165
- text: Optional[Union[TextInput, PreTokenizedInput, List[PreTokenizedInput]]] = None,
166
- videos: Optional[VideoInput] = None,
167
- **kwargs,
168
- ) -> BatchFeature:
169
- return super().__call__(images=images, text=text, videos=videos, **kwargs) # type: ignore
170
-
171
- def apply_chat_template(
172
- self,
173
- conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]]],
174
- chat_template: Optional[str] = None,
175
- **kwargs: Unpack[AllKwargsForChatTemplate],
176
- ) -> str:
177
- return super().apply_chat_template(conversation=conversation, chat_template=chat_template, **kwargs) # type: ignore
178
-
179
-
180
- class QwenVLEmbeddingProcessorBase(BaseVisualRetrieverProcessor, QwenVLProcessor):
181
-
182
- assistant_prefix_len: int = 58 # length of prefix created by
183
- # super().apply_chat_template(conversation=conversation, chat_template=chat_template, **kwargs)
184
 
185
  @staticmethod
186
  def round_by_factor(number: float, factor: int) -> int:
@@ -236,12 +88,12 @@ class QwenVLEmbeddingProcessorBase(BaseVisualRetrieverProcessor, QwenVLProcessor
236
  def process_texts(
237
  self,
238
  texts: List[str],
239
- max_length: int = 8192,
240
- suffix: Optional[str] = None,
241
  prefix: Optional[str] = None,
242
  padding: Optional[str] = None,
243
  ) -> BatchFeature:
244
 
 
245
  padded_texts: List[str] = []
246
 
247
  for text in texts:
@@ -260,42 +112,8 @@ class QwenVLEmbeddingProcessorBase(BaseVisualRetrieverProcessor, QwenVLProcessor
260
  return text_batch
261
 
262
 
263
- class ColQwenDuoProcessorBase(QwenVLEmbeddingProcessorBase):
264
- """
265
- Processor for ColQwenDuo. Mirrors the `ColQwen2Processor` class.
266
- """
267
-
268
- def score(
269
- self,
270
- qs: List[torch.Tensor],
271
- ps: List[torch.Tensor],
272
- vector_type: str,
273
- device: Optional[Union[str, torch.device]] = None,
274
- truncate: Optional[int] = None,
275
- **kwargs,
276
- ) -> torch.Tensor:
277
- """
278
- Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings.
279
- """
280
- if truncate:
281
- qs = [q[..., :truncate] for q in qs]
282
- ps = [p[..., :truncate] for p in ps]
283
-
284
- if vector_type == "single_vector":
285
- return self.score_single_vector(qs, ps, device=device)
286
- elif vector_type == "multi_vector":
287
- return self.score_multi_vector(qs, ps, device=device, **kwargs)
288
- else:
289
- raise ValueError('vector_type must be one of the following: [`single_vector`, `multi_vector`]')
290
-
291
-
292
- class ColQwen25DuoProcessor(ColQwenDuoProcessorBase, Qwen2_5_VLProcessor):
293
- def __init__(self, *args, **kwargs) -> None:
294
- Qwen2_5_VLProcessor.__init__(self, *args, **kwargs)
295
-
296
-
297
  @dataclass
298
- class HybridModelOutput:
299
  """
300
  Base class for the Hybrid Model outputs.
301
  Args:
@@ -308,149 +126,20 @@ class HybridModelOutput:
308
  single_vec_emb: Optional[torch.Tensor] = None
309
  multi_vec_emb: Optional[torch.Tensor] = None
310
 
311
- class EncodeMixin:
312
- """
313
- Interface to encode data for MTEB and ViDoRe evaluations.
314
- """
315
-
316
- def _process_batches(
317
- self,
318
- data: List[Union[str, Image.Image]],
319
- processor_fn: Callable,
320
- desc: str,
321
- vector_type: Optional[str] = None,
322
- return_numpy: bool = False,
323
- **kwargs,
324
- ) -> Union[np.ndarray, List[torch.Tensor]]:
325
- dataloader = DataLoader(
326
- dataset=data,
327
- batch_size=kwargs.get("batch_size", 32),
328
- shuffle=False,
329
- collate_fn=processor_fn,
330
- )
331
- results = []
332
- self.eval()
333
- for batch in tqdm(dataloader, desc=desc):
334
- with torch.no_grad():
335
- batch = {k: v.to(self.device) for k, v in batch.items()}
336
- with torch.autocast(device_type=torch.device(self.device).type):
337
- embeddings = self(**batch)
338
- if isinstance(embeddings, HybridModelOutput) and (vector_type == "single_vector"):
339
- embeddings = embeddings.single_vec_emb
340
- elif isinstance(embeddings, HybridModelOutput) and (vector_type == "multi_vector"):
341
- embeddings = embeddings.multi_vec_emb
342
- elif not vector_type and isinstance(embeddings, HybridModelOutput):
343
- embeddings = embeddings.single_vec_emb # get single-vectors for text2text tasks by default
344
- results.append(embeddings.cpu() if return_numpy else list(torch.unbind(embeddings)))
345
- if return_numpy:
346
- return np.concatenate([result.numpy() for result in results], axis=0)
347
- return [item for sublist in results for item in sublist]
348
-
349
- def encode(
350
- self,
351
- sentences: List[str],
352
- max_length: int = 8192,
353
- batch_size: int = 8,
354
- prefixes: Optional[List[str]] = None,
355
- desc: Optional[str] = None,
356
- vector_type: Optional[str] = None,
357
- padding: Optional[str] = None,
358
- prompt_type: Optional[PromptType] = None,
359
- **kwargs,
360
- ) -> np.ndarray:
361
- prefix = None
362
- if isinstance(prefixes, list) and len(prefixes) > 0:
363
- if prompt_type:
364
- desc = f"MTEB: Encode {prompt_type.value}..."
365
- prefix = prefixes[0] if prompt_type.value == "query" else prefixes[1]
366
- else:
367
- prefix = prefixes[0]
368
- processor_fn = partial(self.processor.process_texts, max_length=max_length, prefix=prefix, padding=padding)
369
- desc = desc or "MTEB: Encode texts..."
370
- return self._process_batches(
371
- data=sentences,
372
- processor_fn=processor_fn,
373
- desc=desc,
374
- vector_type=vector_type,
375
- batch_size=batch_size,
376
- **kwargs,
377
- )
378
-
379
- def encode_texts(
380
- self,
381
- queries: List[str],
382
- max_length: int = 8192,
383
- batch_size: int = 8,
384
- vector_type: Optional[str] = None,
385
- desc: Optional[str] = None,
386
- **kwargs,
387
- ) -> List[torch.Tensor]:
388
- processor_fn = partial(self.processor.process_texts, max_length=max_length, prefix="Query")
389
- return self._process_batches(
390
- data=queries,
391
- processor_fn=processor_fn,
392
- desc=desc or "Encode queries...",
393
- vector_type=vector_type,
394
- batch_size=batch_size,
395
- **kwargs,
396
- )
397
-
398
- def encode_images(
399
- self,
400
- documents: List[Image.Image],
401
- batch_size: int = 8,
402
- vector_type: Optional[str] = None,
403
- desc: Optional[str] = None,
404
- **kwargs,
405
- ) -> List[torch.Tensor]:
406
- return self._process_batches(
407
- data=documents,
408
- processor_fn=self.processor.process_images,
409
- desc=desc or "Encode documents...",
410
- vector_type=vector_type,
411
- batch_size=batch_size,
412
- **kwargs,
413
- )
414
-
415
- class QwenVLModel(ABC):
416
 
417
- def get_rope_index(
418
- self,
419
- input_ids: torch.LongTensor,
420
- image_grid_thw: Union[torch.LongTensor, None],
421
- attention_mask: torch.Tensor,
422
- ) -> tuple[torch.LongTensor, torch.Tensor]:
423
- return super().get_rope_index( # type: ignore
424
- input_ids=input_ids,
425
- image_grid_thw=image_grid_thw,
426
- attention_mask=attention_mask,
427
- )
428
-
429
- def forward(
430
- self,
431
- input_ids: torch.LongTensor,
432
- attention_mask: torch.Tensor,
433
- position_ids: torch.LongTensor,
434
- rope_deltas: torch.Tensor,
435
- output_hidden_states: bool,
436
- use_cache: bool,
437
- **kwargs,
438
- ) -> Qwen2VLCausalLMOutputWithPast:
439
- return super().forward( # type: ignore
440
- input_ids=input_ids,
441
- attention_mask=attention_mask,
442
- position_ids=position_ids,
443
- rope_deltas=rope_deltas,
444
- output_hidden_states=output_hidden_states,
445
- use_cache=use_cache,
446
- **kwargs,
447
- )
448
-
449
-
450
- class QwenVLEmbeddingBase(EncodeMixin, QwenVLModel):
451
  main_input_name: ClassVar[str] = "doc_input_ids"
452
 
453
- def get_vlm_last_hidden_states(
 
 
 
 
 
 
 
 
454
  self,
455
  input_ids: torch.LongTensor,
456
  attention_mask: torch.Tensor,
@@ -460,19 +149,20 @@ class QwenVLEmbeddingBase(EncodeMixin, QwenVLModel):
460
  offsets = kwargs["image_grid_thw"][:, 1] * kwargs["image_grid_thw"][:, 2]
461
  kwargs["pixel_values"] = torch.cat([pv[:o] for pv, o in zip(kwargs["pixel_values"], offsets)], dim=0)
462
 
463
- position_ids, rope_deltas = self.get_rope_index(
464
  input_ids=input_ids,
465
  image_grid_thw=kwargs.get("image_grid_thw", None),
466
  attention_mask=attention_mask,
467
  )
468
 
 
 
469
  outputs = super().forward(
470
  input_ids,
471
  attention_mask,
472
  **kwargs,
473
  position_ids=position_ids,
474
  rope_deltas=rope_deltas,
475
- output_hidden_states=True,
476
  use_cache=False,
477
  )
478
 
@@ -482,35 +172,6 @@ class QwenVLEmbeddingBase(EncodeMixin, QwenVLModel):
482
 
483
  return hidden_states[-1]
484
 
485
-
486
- class AbstractHybridModel(ABC):
487
- """
488
- Abstract class for a hybrid model (single-vector and multi-vector embeddings).
489
- """
490
-
491
- @property
492
- def single_vector_projector_dim(self) -> int:
493
- return self.config.single_vector_projector_dim
494
-
495
- @property
496
- def multi_vector_projector_dim(self) -> int:
497
- return self.config.multi_vector_projector_dim
498
-
499
- @abstractmethod
500
- def forward(
501
- self,
502
- input_ids: torch.LongTensor,
503
- attention_mask: torch.Tensor,
504
- output_vlm_last_hidden_states: bool = False,
505
- *args,
506
- **kwargs,
507
- ) -> HybridModelOutput:
508
- """
509
- Forward pass through the model. Returns both single-vector and multi-vector embeddings.
510
- Must be implemented by subclasses.
511
- """
512
- pass
513
-
514
  def _init_projection_layers(self, config) -> None:
515
  """
516
  Initializes projection layers.
@@ -528,14 +189,6 @@ class AbstractHybridModel(ABC):
528
  out_features=self.config.multi_vector_projector_dim,
529
  )
530
 
531
- @staticmethod
532
- def _delete_redundant_forward_kwargs(kwargs: Dict[str, Any]) -> None:
533
- """
534
- Delete redundant kwargs before passing them to the forward method. In-place operation.
535
- """
536
- for key in ["input_ids", "attention_mask", "output_hidden_states"]:
537
- kwargs.pop(key, None)
538
-
539
  def project_to_single_vector_embeddings(
540
  self,
541
  hidden_states: torch.Tensor,
@@ -545,48 +198,15 @@ class AbstractHybridModel(ABC):
545
  """
546
  Project the hidden states to single-vector embeddings.
547
  """
 
 
 
 
548
 
549
- pooling_method = self.config.single_vector_pool_strategy
550
-
551
- if pooling_method == "mean" and input_ids is None:
552
- print("Warning: `input_ids` is None. Using `legacy-mean` pooling strategy instead.")
553
- pooling_method = "legacy-mean"
554
-
555
- if pooling_method == "last-token":
556
- pooled_output = hidden_states[:, -1, :]
557
- elif pooling_method == "mean":
558
- if self._input_has_image(input_ids[0]): # got document image(s)
559
- # getting start and end positions of image tokens; torch.where returns
560
- # (1) a list of indices of input sequences
561
- # (shape corresponds to the total number of images in the batch)
562
- # (2) a list of positions of image tokens in the input sequence
563
- # (shape corresponds to the total number of images in the batch)
564
- input_seq_idx, img_start_pos = torch.where(
565
- input_ids == self.config.vision_start_token_id
566
- ) # (total number of images), (total number of images)
567
- _, img_end_pos = torch.where(
568
- input_ids == self.config.vision_end_token_id
569
- ) # (total number of images), (total number of images)
570
- means = []
571
- for i in range(input_seq_idx.shape[0]):
572
- vector_pos = input_seq_idx[i]
573
- start = img_start_pos[i]
574
- end = img_end_pos[i]
575
- mean_value = hidden_states[vector_pos][start : end + 1].mean(dim=0)
576
- means.append(mean_value)
577
- pooled_output = torch.stack(means)
578
-
579
- else: # got query text
580
- pooled_output = torch.sum(hidden_states * attention_mask.unsqueeze(-1), dim=1) / torch.sum(
581
- attention_mask, dim=1, keepdim=True
582
- )
583
-
584
- elif pooling_method == "legacy-mean":
585
  pooled_output = torch.sum(hidden_states * attention_mask.unsqueeze(-1), dim=1) / torch.sum(
586
  attention_mask, dim=1, keepdim=True
587
  )
588
- else:
589
- raise ValueError(f"Invalid pooling strategy: {pooling_method}")
590
  single_vec_emb = self.single_vector_projector(pooled_output)
591
  return torch.nn.functional.normalize(single_vec_emb, dim=-1)
592
 
@@ -605,30 +225,25 @@ class AbstractHybridModel(ABC):
605
  def _input_has_image(self, input_ids):
606
  return self.config.vision_start_token_id in input_ids
607
 
608
- class ColQwenDuoBase(AbstractHybridModel, QwenVLEmbeddingBase):
609
-
610
  def forward(
611
  self,
612
  input_ids: torch.LongTensor,
613
  attention_mask: torch.Tensor,
614
  output_vlm_last_hidden_states: bool = False,
615
  **kwargs,
616
- ) -> HybridModelOutput:
617
  """
618
- Forward pass through ColQwenDuo. Returns both single-vector and multi-vector embeddings.
619
  Args:
620
  input_ids (torch.LongTensor): The input tokens tensor.
621
  attention_mask (torch.LongTensor): The attention mask tensor.
622
  Returns:
623
- HybridModelOutput:
624
  single_vector (torch.Tensor): Single-vector embeddings of shape (batch_size, dim).
625
  multi_vector (torch.Tensor): Multi-vector embeddings of shape (batch_size, num_tokens, dim).
626
  """
627
- # Delete redundant kwargs
628
- self._delete_redundant_forward_kwargs(kwargs)
629
-
630
  # Forward pass through the VLM
631
- hidden_states = self.get_vlm_last_hidden_states(
632
  input_ids=input_ids, attention_mask=attention_mask, **kwargs
633
  ) # (batch_size, seq_length, hidden_size)
634
 
@@ -636,16 +251,85 @@ class ColQwenDuoBase(AbstractHybridModel, QwenVLEmbeddingBase):
636
  single_vec_emb = self.project_to_single_vector_embeddings(hidden_states, attention_mask, input_ids=input_ids)
637
  multi_vec_emb = self.project_to_multi_vector_embeddings(hidden_states, attention_mask)
638
 
639
- return HybridModelOutput(
640
  vlm_last_hidden_states=hidden_states if output_vlm_last_hidden_states else None,
641
  single_vec_emb=single_vec_emb,
642
  multi_vec_emb=multi_vec_emb,
643
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
644
 
645
 
646
  class JinaEmbeddingsV4Model:
647
  """
648
- Wrapper class for ColQwen25Duo that handles the loading of models and adapters.
649
  """
650
 
651
  def __init__(self, model, adapter_dir):
@@ -664,7 +348,7 @@ class JinaEmbeddingsV4Model:
664
 
665
  task = kwargs.pop('task', 'retrieval')
666
 
667
- model = ColQwen25Duo.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
668
 
669
  if os.path.isdir(model.name_or_path):
670
  adapter_dir = os.path.join(model.name_or_path, 'adapters')
@@ -705,13 +389,4 @@ class JinaEmbeddingsV4Model:
705
  Forward the call to the underlying model's forward method.
706
  """
707
  return self.model(*args, **kwargs)
708
-
709
-
710
- class ColQwen25Duo(ColQwenDuoBase, Qwen2_5_VLForConditionalGeneration):
711
- config_class = ColQwen25DuoConfig
712
- def __init__(self, config: ColQwen25DuoConfig):
713
- Qwen2_5_VLForConditionalGeneration.__init__(self, config)
714
- self._init_projection_layers(config)
715
- self.post_init()
716
- self.processor = ColQwen25DuoProcessor.from_pretrained(self.name_or_path, trust_remote_code=True)
717
 
 
2
  import math
3
  import numpy as np
4
 
 
5
  from dataclasses import dataclass
6
  from typing import Any, Callable, ClassVar, Dict, List, Optional, Union, cast
7
+ from peft import PeftModel
 
8
  import torch
9
  from torch import nn
10
  from torch.utils.data import DataLoader
 
15
  from enum import Enum
16
  from peft.utils.hotswap import hotswap_adapter
17
 
18
+ from transformers import BatchFeature
 
 
19
 
20
  from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor, Qwen2_5_VLForConditionalGeneration
21
 
 
 
 
 
 
 
 
 
22
  from huggingface_hub import snapshot_download
23
 
24
+ from .configuration_jina_embeddings_v4 import JinaEmbeddingsV4Config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
 
27
  class PromptType(str, Enum):
28
  query = "query"
29
  passage = "passage"
30
 
31
+ class JinaEmbeddingsV4Processor(Qwen2_5_VLProcessor):
32
+ def __init__(self, *args, **kwargs) -> None:
33
+ Qwen2_5_VLProcessor.__init__(self, *args, **kwargs)
34
+ self.assistant_prefix_len = 58
35
+ self.text_max_length = 8192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  @staticmethod
38
  def round_by_factor(number: float, factor: int) -> int:
 
88
  def process_texts(
89
  self,
90
  texts: List[str],
91
+ max_length: Optional[int] = None,
 
92
  prefix: Optional[str] = None,
93
  padding: Optional[str] = None,
94
  ) -> BatchFeature:
95
 
96
+ max_length = self.text_max_length if max_length is None else min(max_length, self.text_max_length)
97
  padded_texts: List[str] = []
98
 
99
  for text in texts:
 
112
  return text_batch
113
 
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  @dataclass
116
+ class JinaEmbeddingsV4ModelOutput:
117
  """
118
  Base class for the Hybrid Model outputs.
119
  Args:
 
126
  single_vec_emb: Optional[torch.Tensor] = None
127
  multi_vec_emb: Optional[torch.Tensor] = None
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
+ class QwenVL25Embeddings(Qwen2_5_VLForConditionalGeneration):
131
+ config_class = JinaEmbeddingsV4Config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  main_input_name: ClassVar[str] = "doc_input_ids"
133
 
134
+ def __init__(self, config: JinaEmbeddingsV4Config):
135
+ Qwen2_5_VLForConditionalGeneration.__init__(self, config)
136
+ self._init_projection_layers(config)
137
+ self.post_init()
138
+ self.processor = JinaEmbeddingsV4Processor.from_pretrained(self.name_or_path, trust_remote_code=True)
139
+ self.single_vector_projector_dim = config.single_vector_projector_dim
140
+ self.multi_vector_projector_dim = config.multi_vector_projector_dim
141
+
142
+ def get_last_hidden_states(
143
  self,
144
  input_ids: torch.LongTensor,
145
  attention_mask: torch.Tensor,
 
149
  offsets = kwargs["image_grid_thw"][:, 1] * kwargs["image_grid_thw"][:, 2]
150
  kwargs["pixel_values"] = torch.cat([pv[:o] for pv, o in zip(kwargs["pixel_values"], offsets)], dim=0)
151
 
152
+ position_ids, rope_deltas = super().get_rope_index( # type: ignore
153
  input_ids=input_ids,
154
  image_grid_thw=kwargs.get("image_grid_thw", None),
155
  attention_mask=attention_mask,
156
  )
157
 
158
+ kwargs['output_hidden_states'] = True
159
+
160
  outputs = super().forward(
161
  input_ids,
162
  attention_mask,
163
  **kwargs,
164
  position_ids=position_ids,
165
  rope_deltas=rope_deltas,
 
166
  use_cache=False,
167
  )
168
 
 
172
 
173
  return hidden_states[-1]
174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  def _init_projection_layers(self, config) -> None:
176
  """
177
  Initializes projection layers.
 
189
  out_features=self.config.multi_vector_projector_dim,
190
  )
191
 
 
 
 
 
 
 
 
 
192
  def project_to_single_vector_embeddings(
193
  self,
194
  hidden_states: torch.Tensor,
 
198
  """
199
  Project the hidden states to single-vector embeddings.
200
  """
201
+ if self._input_has_image(input_ids[0]): # got document image
202
+ img_start_pos = torch.where(input_ids[0] == self.config.vision_start_token_id)[0][0]
203
+ img_end_pos = torch.where(input_ids[0] == self.config.vision_end_token_id)[0][0]
204
+ pooled_output = hidden_states[0][img_start_pos:img_end_pos + 1].mean(dim=0).unsqueeze(0)
205
 
206
+ else: # got query text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  pooled_output = torch.sum(hidden_states * attention_mask.unsqueeze(-1), dim=1) / torch.sum(
208
  attention_mask, dim=1, keepdim=True
209
  )
 
 
210
  single_vec_emb = self.single_vector_projector(pooled_output)
211
  return torch.nn.functional.normalize(single_vec_emb, dim=-1)
212
 
 
225
  def _input_has_image(self, input_ids):
226
  return self.config.vision_start_token_id in input_ids
227
 
 
 
228
  def forward(
229
  self,
230
  input_ids: torch.LongTensor,
231
  attention_mask: torch.Tensor,
232
  output_vlm_last_hidden_states: bool = False,
233
  **kwargs,
234
+ ) -> JinaEmbeddingsV4ModelOutput:
235
  """
236
+ Forward pass through QwenVL25Embeddings. Returns both single-vector and multi-vector embeddings.
237
  Args:
238
  input_ids (torch.LongTensor): The input tokens tensor.
239
  attention_mask (torch.LongTensor): The attention mask tensor.
240
  Returns:
241
+ JinaEmbeddingsV4ModelOutput:
242
  single_vector (torch.Tensor): Single-vector embeddings of shape (batch_size, dim).
243
  multi_vector (torch.Tensor): Multi-vector embeddings of shape (batch_size, num_tokens, dim).
244
  """
 
 
 
245
  # Forward pass through the VLM
246
+ hidden_states = self.get_last_hidden_states(
247
  input_ids=input_ids, attention_mask=attention_mask, **kwargs
248
  ) # (batch_size, seq_length, hidden_size)
249
 
 
251
  single_vec_emb = self.project_to_single_vector_embeddings(hidden_states, attention_mask, input_ids=input_ids)
252
  multi_vec_emb = self.project_to_multi_vector_embeddings(hidden_states, attention_mask)
253
 
254
+ return JinaEmbeddingsV4ModelOutput(
255
  vlm_last_hidden_states=hidden_states if output_vlm_last_hidden_states else None,
256
  single_vec_emb=single_vec_emb,
257
  multi_vec_emb=multi_vec_emb,
258
  )
259
+
260
+ def _process_batches(
261
+ self,
262
+ data: List[Union[str, Image.Image]],
263
+ processor_fn: Callable,
264
+ desc: str,
265
+ vector_type: Optional[str] = None,
266
+ return_numpy: bool = False,
267
+ **kwargs,
268
+ ) -> Union[np.ndarray, List[torch.Tensor]]:
269
+ dataloader = DataLoader(
270
+ dataset=data,
271
+ batch_size=kwargs.get("batch_size", 32),
272
+ shuffle=False,
273
+ collate_fn=processor_fn,
274
+ )
275
+ vector_type = vector_type or "single_vector"
276
+ results = []
277
+ self.eval()
278
+ for batch in tqdm(dataloader, desc=desc):
279
+ with torch.no_grad():
280
+ batch = {k: v.to(self.device) for k, v in batch.items()}
281
+ with torch.autocast(device_type=torch.device(self.device).type):
282
+ embeddings = self(**batch)
283
+ if vector_type == "single_vector":
284
+ embeddings = embeddings.single_vec_emb
285
+ else:
286
+ embeddings = embeddings.multi_vec_emb
287
+ results.append(embeddings.cpu() if return_numpy else list(torch.unbind(embeddings)))
288
+ if return_numpy:
289
+ return np.concatenate([result.numpy() for result in results], axis=0)
290
+ return [item for sublist in results for item in sublist]
291
+
292
+ def encode_texts(
293
+ self,
294
+ queries: List[str],
295
+ max_length: int = 8192,
296
+ batch_size: int = 8,
297
+ vector_type: Optional[str] = None,
298
+ desc: Optional[str] = None,
299
+ **kwargs,
300
+ ) -> List[torch.Tensor]:
301
+ processor_fn = partial(self.processor.process_texts, max_length=max_length, prefix="Query")
302
+ return self._process_batches(
303
+ data=queries,
304
+ processor_fn=processor_fn,
305
+ desc=desc or "Encode queries...",
306
+ vector_type=vector_type,
307
+ batch_size=batch_size,
308
+ **kwargs,
309
+ )
310
+
311
+ def encode_images(
312
+ self,
313
+ documents: List[Image.Image],
314
+ batch_size: int = 8,
315
+ vector_type: Optional[str] = None,
316
+ desc: Optional[str] = None,
317
+ **kwargs,
318
+ ) -> List[torch.Tensor]:
319
+ return self._process_batches(
320
+ data=documents,
321
+ processor_fn=self.processor.process_images,
322
+ desc=desc or "Encode documents...",
323
+ vector_type=vector_type,
324
+ batch_size=batch_size,
325
+ **kwargs,
326
+ )
327
+
328
 
329
 
330
  class JinaEmbeddingsV4Model:
331
  """
332
+ Wrapper class for QwenVL25Embeddings that handles the loading of models and adapters.
333
  """
334
 
335
  def __init__(self, model, adapter_dir):
 
348
 
349
  task = kwargs.pop('task', 'retrieval')
350
 
351
+ model = QwenVL25Embeddings.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
352
 
353
  if os.path.isdir(model.name_or_path):
354
  adapter_dir = os.path.join(model.name_or_path, 'adapters')
 
389
  Forward the call to the underlying model's forward method.
390
  """
391
  return self.model(*args, **kwargs)
 
 
 
 
 
 
 
 
 
392
 
preprocessor_config.json CHANGED
@@ -18,7 +18,7 @@
18
  "merge_size": 2,
19
  "min_pixels": 3136,
20
  "patch_size": 14,
21
- "processor_class": "ColQwen25DuoProcessor",
22
  "resample": 3,
23
  "rescale_factor": 0.00392156862745098,
24
  "size": {
@@ -27,6 +27,6 @@
27
  },
28
  "temporal_patch_size": 2,
29
  "auto_map": {
30
- "AutoProcessor": "modeling_colqwen_duo.ColQwen25DuoProcessor"
31
  }
32
  }
 
18
  "merge_size": 2,
19
  "min_pixels": 3136,
20
  "patch_size": 14,
21
+ "processor_class": "JinaEmbeddingsV4Processor",
22
  "resample": 3,
23
  "rescale_factor": 0.00392156862745098,
24
  "size": {
 
27
  },
28
  "temporal_patch_size": 2,
29
  "auto_map": {
30
+ "AutoProcessor": "modeling_jina_embeddings_v4.JinaEmbeddingsV4Processor"
31
  }
32
  }