refine-code (#3)
Browse files- refactor: processor, config, model (0820da1aa617e7b945fe9a49a94a0b2f1d34dea0)
config.json
CHANGED
@@ -1,11 +1,11 @@
|
|
1 |
{
|
2 |
"_name_or_path": "jinaai/jina-embeddings-v4",
|
3 |
"architectures": [
|
4 |
-
"
|
5 |
],
|
6 |
"auto_map": {
|
7 |
-
"AutoConfig": "
|
8 |
-
"AutoModel": "
|
9 |
},
|
10 |
"attention_dropout": 0.0,
|
11 |
"bos_token_id": 151643,
|
|
|
1 |
{
|
2 |
"_name_or_path": "jinaai/jina-embeddings-v4",
|
3 |
"architectures": [
|
4 |
+
"JinaEmbeddingsV4Model"
|
5 |
],
|
6 |
"auto_map": {
|
7 |
+
"AutoConfig": "configuration_jina_embeddings_v4.JinaEmbeddingsV4Config",
|
8 |
+
"AutoModel": "modeling_jina_embeddings_v4.JinaEmbeddingsV4Model"
|
9 |
},
|
10 |
"attention_dropout": 0.0,
|
11 |
"bos_token_id": 151643,
|
configuration_colqwen_duo.py → configuration_jina_embeddings_v4.py
RENAMED
@@ -2,9 +2,9 @@ from transformers.models.qwen2_5_vl import Qwen2_5_VLConfig
|
|
2 |
|
3 |
from typing import Optional
|
4 |
|
5 |
-
class
|
6 |
"""
|
7 |
-
Configuration for the
|
8 |
"""
|
9 |
|
10 |
def __init__(
|
|
|
2 |
|
3 |
from typing import Optional
|
4 |
|
5 |
+
class JinaEmbeddingsV4Config(Qwen2_5_VLConfig):
|
6 |
"""
|
7 |
+
Configuration for the JinaEmbeddingsV4 model.
|
8 |
"""
|
9 |
|
10 |
def __init__(
|
modeling_colqwen_duo.py → modeling_jina_embeddings_v4.py
RENAMED
@@ -2,11 +2,9 @@ import os
|
|
2 |
import math
|
3 |
import numpy as np
|
4 |
|
5 |
-
from abc import ABC, abstractmethod
|
6 |
from dataclasses import dataclass
|
7 |
from typing import Any, Callable, ClassVar, Dict, List, Optional, Union, cast
|
8 |
-
from
|
9 |
-
from peft import LoraConfig, PeftModel
|
10 |
import torch
|
11 |
from torch import nn
|
12 |
from torch.utils.data import DataLoader
|
@@ -17,170 +15,24 @@ from tqdm import tqdm
|
|
17 |
from enum import Enum
|
18 |
from peft.utils.hotswap import hotswap_adapter
|
19 |
|
20 |
-
from transformers import
|
21 |
-
|
22 |
-
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast
|
23 |
|
24 |
from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor, Qwen2_5_VLForConditionalGeneration
|
25 |
|
26 |
-
from transformers.processing_utils import (
|
27 |
-
AllKwargsForChatTemplate,
|
28 |
-
ImageInput,
|
29 |
-
PreTokenizedInput,
|
30 |
-
TextInput,
|
31 |
-
VideoInput,
|
32 |
-
)
|
33 |
-
|
34 |
from huggingface_hub import snapshot_download
|
35 |
|
36 |
-
from .
|
37 |
-
|
38 |
-
|
39 |
-
def get_torch_device() -> str:
|
40 |
-
"""
|
41 |
-
Returns the device (string) to be used by PyTorch.
|
42 |
-
|
43 |
-
`device` arg defaults to "auto" which will use:
|
44 |
-
- "cuda:0" if available
|
45 |
-
- else "mps" if available
|
46 |
-
- else "cpu".
|
47 |
-
"""
|
48 |
-
|
49 |
-
if torch.cuda.is_available():
|
50 |
-
device = "cuda:0"
|
51 |
-
elif torch.backends.mps.is_available(): # for Apple Silicon
|
52 |
-
device = "mps"
|
53 |
-
else:
|
54 |
-
device = "cpu"
|
55 |
-
|
56 |
-
return device
|
57 |
|
58 |
|
59 |
class PromptType(str, Enum):
|
60 |
query = "query"
|
61 |
passage = "passage"
|
62 |
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
@abstractmethod
|
70 |
-
def process_images(
|
71 |
-
self,
|
72 |
-
images: List[Image.Image],
|
73 |
-
) -> Union[BatchFeature, BatchEncoding]:
|
74 |
-
pass
|
75 |
-
|
76 |
-
@abstractmethod
|
77 |
-
def process_texts(
|
78 |
-
self,
|
79 |
-
texts: List[str],
|
80 |
-
max_length: int = 50,
|
81 |
-
suffix: Optional[str] = None,
|
82 |
-
prefix: Optional[str] = None,
|
83 |
-
) -> Union[BatchFeature, BatchEncoding]:
|
84 |
-
pass
|
85 |
-
|
86 |
-
@abstractmethod
|
87 |
-
def score(
|
88 |
-
self,
|
89 |
-
qs: List[torch.Tensor],
|
90 |
-
ps: List[torch.Tensor],
|
91 |
-
device: Optional[Union[str, torch.device]] = None,
|
92 |
-
**kwargs,
|
93 |
-
) -> torch.Tensor:
|
94 |
-
pass
|
95 |
-
|
96 |
-
@staticmethod
|
97 |
-
def score_single_vector(
|
98 |
-
qs: List[torch.Tensor],
|
99 |
-
ps: List[torch.Tensor],
|
100 |
-
device: Optional[Union[str, torch.device]] = None,
|
101 |
-
) -> torch.Tensor:
|
102 |
-
"""
|
103 |
-
Compute the dot product score for the given single-vector query and passage embeddings.
|
104 |
-
"""
|
105 |
-
device = device or get_torch_device()
|
106 |
-
|
107 |
-
if len(qs) == 0:
|
108 |
-
raise ValueError("No queries provided")
|
109 |
-
if len(ps) == 0:
|
110 |
-
raise ValueError("No passages provided")
|
111 |
-
|
112 |
-
qs_stacked = torch.stack(qs).to(device)
|
113 |
-
ps_stacked = torch.stack(ps).to(device)
|
114 |
-
|
115 |
-
scores = torch.einsum("bd,cd->bc", qs_stacked, ps_stacked)
|
116 |
-
assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}"
|
117 |
-
|
118 |
-
scores = scores.to(torch.float32)
|
119 |
-
return scores
|
120 |
-
|
121 |
-
@staticmethod
|
122 |
-
def score_multi_vector(
|
123 |
-
qs: List[torch.Tensor],
|
124 |
-
ps: List[torch.Tensor],
|
125 |
-
batch_size: int = 128,
|
126 |
-
device: Optional[Union[str, torch.device]] = None,
|
127 |
-
) -> torch.Tensor:
|
128 |
-
"""
|
129 |
-
Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings.
|
130 |
-
"""
|
131 |
-
device = device or get_torch_device()
|
132 |
-
|
133 |
-
if len(qs) == 0:
|
134 |
-
raise ValueError("No queries provided")
|
135 |
-
if len(ps) == 0:
|
136 |
-
raise ValueError("No passages provided")
|
137 |
-
|
138 |
-
scores_list: List[torch.Tensor] = []
|
139 |
-
|
140 |
-
for i in range(0, len(qs), batch_size):
|
141 |
-
scores_batch = []
|
142 |
-
qs_batch = torch.nn.utils.rnn.pad_sequence(qs[i : i + batch_size], batch_first=True, padding_value=0).to(
|
143 |
-
device
|
144 |
-
)
|
145 |
-
for j in range(0, len(ps), batch_size):
|
146 |
-
ps_batch = torch.nn.utils.rnn.pad_sequence(
|
147 |
-
ps[j : j + batch_size], batch_first=True, padding_value=0
|
148 |
-
).to(device)
|
149 |
-
scores_batch.append(torch.einsum("bnd,csd->bcns", qs_batch, ps_batch).max(dim=3)[0].sum(dim=2))
|
150 |
-
scores_batch = torch.cat(scores_batch, dim=1).cpu()
|
151 |
-
scores_list.append(scores_batch)
|
152 |
-
|
153 |
-
scores = torch.cat(scores_list, dim=0)
|
154 |
-
assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}"
|
155 |
-
|
156 |
-
scores = scores.to(torch.float32)
|
157 |
-
return scores
|
158 |
-
|
159 |
-
|
160 |
-
class QwenVLProcessor(ABC):
|
161 |
-
|
162 |
-
def __call__(
|
163 |
-
self,
|
164 |
-
images: Optional[ImageInput] = None,
|
165 |
-
text: Optional[Union[TextInput, PreTokenizedInput, List[PreTokenizedInput]]] = None,
|
166 |
-
videos: Optional[VideoInput] = None,
|
167 |
-
**kwargs,
|
168 |
-
) -> BatchFeature:
|
169 |
-
return super().__call__(images=images, text=text, videos=videos, **kwargs) # type: ignore
|
170 |
-
|
171 |
-
def apply_chat_template(
|
172 |
-
self,
|
173 |
-
conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]]],
|
174 |
-
chat_template: Optional[str] = None,
|
175 |
-
**kwargs: Unpack[AllKwargsForChatTemplate],
|
176 |
-
) -> str:
|
177 |
-
return super().apply_chat_template(conversation=conversation, chat_template=chat_template, **kwargs) # type: ignore
|
178 |
-
|
179 |
-
|
180 |
-
class QwenVLEmbeddingProcessorBase(BaseVisualRetrieverProcessor, QwenVLProcessor):
|
181 |
-
|
182 |
-
assistant_prefix_len: int = 58 # length of prefix created by
|
183 |
-
# super().apply_chat_template(conversation=conversation, chat_template=chat_template, **kwargs)
|
184 |
|
185 |
@staticmethod
|
186 |
def round_by_factor(number: float, factor: int) -> int:
|
@@ -236,12 +88,12 @@ class QwenVLEmbeddingProcessorBase(BaseVisualRetrieverProcessor, QwenVLProcessor
|
|
236 |
def process_texts(
|
237 |
self,
|
238 |
texts: List[str],
|
239 |
-
max_length: int =
|
240 |
-
suffix: Optional[str] = None,
|
241 |
prefix: Optional[str] = None,
|
242 |
padding: Optional[str] = None,
|
243 |
) -> BatchFeature:
|
244 |
|
|
|
245 |
padded_texts: List[str] = []
|
246 |
|
247 |
for text in texts:
|
@@ -260,42 +112,8 @@ class QwenVLEmbeddingProcessorBase(BaseVisualRetrieverProcessor, QwenVLProcessor
|
|
260 |
return text_batch
|
261 |
|
262 |
|
263 |
-
class ColQwenDuoProcessorBase(QwenVLEmbeddingProcessorBase):
|
264 |
-
"""
|
265 |
-
Processor for ColQwenDuo. Mirrors the `ColQwen2Processor` class.
|
266 |
-
"""
|
267 |
-
|
268 |
-
def score(
|
269 |
-
self,
|
270 |
-
qs: List[torch.Tensor],
|
271 |
-
ps: List[torch.Tensor],
|
272 |
-
vector_type: str,
|
273 |
-
device: Optional[Union[str, torch.device]] = None,
|
274 |
-
truncate: Optional[int] = None,
|
275 |
-
**kwargs,
|
276 |
-
) -> torch.Tensor:
|
277 |
-
"""
|
278 |
-
Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings.
|
279 |
-
"""
|
280 |
-
if truncate:
|
281 |
-
qs = [q[..., :truncate] for q in qs]
|
282 |
-
ps = [p[..., :truncate] for p in ps]
|
283 |
-
|
284 |
-
if vector_type == "single_vector":
|
285 |
-
return self.score_single_vector(qs, ps, device=device)
|
286 |
-
elif vector_type == "multi_vector":
|
287 |
-
return self.score_multi_vector(qs, ps, device=device, **kwargs)
|
288 |
-
else:
|
289 |
-
raise ValueError('vector_type must be one of the following: [`single_vector`, `multi_vector`]')
|
290 |
-
|
291 |
-
|
292 |
-
class ColQwen25DuoProcessor(ColQwenDuoProcessorBase, Qwen2_5_VLProcessor):
|
293 |
-
def __init__(self, *args, **kwargs) -> None:
|
294 |
-
Qwen2_5_VLProcessor.__init__(self, *args, **kwargs)
|
295 |
-
|
296 |
-
|
297 |
@dataclass
|
298 |
-
class
|
299 |
"""
|
300 |
Base class for the Hybrid Model outputs.
|
301 |
Args:
|
@@ -308,149 +126,20 @@ class HybridModelOutput:
|
|
308 |
single_vec_emb: Optional[torch.Tensor] = None
|
309 |
multi_vec_emb: Optional[torch.Tensor] = None
|
310 |
|
311 |
-
class EncodeMixin:
|
312 |
-
"""
|
313 |
-
Interface to encode data for MTEB and ViDoRe evaluations.
|
314 |
-
"""
|
315 |
-
|
316 |
-
def _process_batches(
|
317 |
-
self,
|
318 |
-
data: List[Union[str, Image.Image]],
|
319 |
-
processor_fn: Callable,
|
320 |
-
desc: str,
|
321 |
-
vector_type: Optional[str] = None,
|
322 |
-
return_numpy: bool = False,
|
323 |
-
**kwargs,
|
324 |
-
) -> Union[np.ndarray, List[torch.Tensor]]:
|
325 |
-
dataloader = DataLoader(
|
326 |
-
dataset=data,
|
327 |
-
batch_size=kwargs.get("batch_size", 32),
|
328 |
-
shuffle=False,
|
329 |
-
collate_fn=processor_fn,
|
330 |
-
)
|
331 |
-
results = []
|
332 |
-
self.eval()
|
333 |
-
for batch in tqdm(dataloader, desc=desc):
|
334 |
-
with torch.no_grad():
|
335 |
-
batch = {k: v.to(self.device) for k, v in batch.items()}
|
336 |
-
with torch.autocast(device_type=torch.device(self.device).type):
|
337 |
-
embeddings = self(**batch)
|
338 |
-
if isinstance(embeddings, HybridModelOutput) and (vector_type == "single_vector"):
|
339 |
-
embeddings = embeddings.single_vec_emb
|
340 |
-
elif isinstance(embeddings, HybridModelOutput) and (vector_type == "multi_vector"):
|
341 |
-
embeddings = embeddings.multi_vec_emb
|
342 |
-
elif not vector_type and isinstance(embeddings, HybridModelOutput):
|
343 |
-
embeddings = embeddings.single_vec_emb # get single-vectors for text2text tasks by default
|
344 |
-
results.append(embeddings.cpu() if return_numpy else list(torch.unbind(embeddings)))
|
345 |
-
if return_numpy:
|
346 |
-
return np.concatenate([result.numpy() for result in results], axis=0)
|
347 |
-
return [item for sublist in results for item in sublist]
|
348 |
-
|
349 |
-
def encode(
|
350 |
-
self,
|
351 |
-
sentences: List[str],
|
352 |
-
max_length: int = 8192,
|
353 |
-
batch_size: int = 8,
|
354 |
-
prefixes: Optional[List[str]] = None,
|
355 |
-
desc: Optional[str] = None,
|
356 |
-
vector_type: Optional[str] = None,
|
357 |
-
padding: Optional[str] = None,
|
358 |
-
prompt_type: Optional[PromptType] = None,
|
359 |
-
**kwargs,
|
360 |
-
) -> np.ndarray:
|
361 |
-
prefix = None
|
362 |
-
if isinstance(prefixes, list) and len(prefixes) > 0:
|
363 |
-
if prompt_type:
|
364 |
-
desc = f"MTEB: Encode {prompt_type.value}..."
|
365 |
-
prefix = prefixes[0] if prompt_type.value == "query" else prefixes[1]
|
366 |
-
else:
|
367 |
-
prefix = prefixes[0]
|
368 |
-
processor_fn = partial(self.processor.process_texts, max_length=max_length, prefix=prefix, padding=padding)
|
369 |
-
desc = desc or "MTEB: Encode texts..."
|
370 |
-
return self._process_batches(
|
371 |
-
data=sentences,
|
372 |
-
processor_fn=processor_fn,
|
373 |
-
desc=desc,
|
374 |
-
vector_type=vector_type,
|
375 |
-
batch_size=batch_size,
|
376 |
-
**kwargs,
|
377 |
-
)
|
378 |
-
|
379 |
-
def encode_texts(
|
380 |
-
self,
|
381 |
-
queries: List[str],
|
382 |
-
max_length: int = 8192,
|
383 |
-
batch_size: int = 8,
|
384 |
-
vector_type: Optional[str] = None,
|
385 |
-
desc: Optional[str] = None,
|
386 |
-
**kwargs,
|
387 |
-
) -> List[torch.Tensor]:
|
388 |
-
processor_fn = partial(self.processor.process_texts, max_length=max_length, prefix="Query")
|
389 |
-
return self._process_batches(
|
390 |
-
data=queries,
|
391 |
-
processor_fn=processor_fn,
|
392 |
-
desc=desc or "Encode queries...",
|
393 |
-
vector_type=vector_type,
|
394 |
-
batch_size=batch_size,
|
395 |
-
**kwargs,
|
396 |
-
)
|
397 |
-
|
398 |
-
def encode_images(
|
399 |
-
self,
|
400 |
-
documents: List[Image.Image],
|
401 |
-
batch_size: int = 8,
|
402 |
-
vector_type: Optional[str] = None,
|
403 |
-
desc: Optional[str] = None,
|
404 |
-
**kwargs,
|
405 |
-
) -> List[torch.Tensor]:
|
406 |
-
return self._process_batches(
|
407 |
-
data=documents,
|
408 |
-
processor_fn=self.processor.process_images,
|
409 |
-
desc=desc or "Encode documents...",
|
410 |
-
vector_type=vector_type,
|
411 |
-
batch_size=batch_size,
|
412 |
-
**kwargs,
|
413 |
-
)
|
414 |
-
|
415 |
-
class QwenVLModel(ABC):
|
416 |
|
417 |
-
|
418 |
-
|
419 |
-
input_ids: torch.LongTensor,
|
420 |
-
image_grid_thw: Union[torch.LongTensor, None],
|
421 |
-
attention_mask: torch.Tensor,
|
422 |
-
) -> tuple[torch.LongTensor, torch.Tensor]:
|
423 |
-
return super().get_rope_index( # type: ignore
|
424 |
-
input_ids=input_ids,
|
425 |
-
image_grid_thw=image_grid_thw,
|
426 |
-
attention_mask=attention_mask,
|
427 |
-
)
|
428 |
-
|
429 |
-
def forward(
|
430 |
-
self,
|
431 |
-
input_ids: torch.LongTensor,
|
432 |
-
attention_mask: torch.Tensor,
|
433 |
-
position_ids: torch.LongTensor,
|
434 |
-
rope_deltas: torch.Tensor,
|
435 |
-
output_hidden_states: bool,
|
436 |
-
use_cache: bool,
|
437 |
-
**kwargs,
|
438 |
-
) -> Qwen2VLCausalLMOutputWithPast:
|
439 |
-
return super().forward( # type: ignore
|
440 |
-
input_ids=input_ids,
|
441 |
-
attention_mask=attention_mask,
|
442 |
-
position_ids=position_ids,
|
443 |
-
rope_deltas=rope_deltas,
|
444 |
-
output_hidden_states=output_hidden_states,
|
445 |
-
use_cache=use_cache,
|
446 |
-
**kwargs,
|
447 |
-
)
|
448 |
-
|
449 |
-
|
450 |
-
class QwenVLEmbeddingBase(EncodeMixin, QwenVLModel):
|
451 |
main_input_name: ClassVar[str] = "doc_input_ids"
|
452 |
|
453 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
454 |
self,
|
455 |
input_ids: torch.LongTensor,
|
456 |
attention_mask: torch.Tensor,
|
@@ -460,19 +149,20 @@ class QwenVLEmbeddingBase(EncodeMixin, QwenVLModel):
|
|
460 |
offsets = kwargs["image_grid_thw"][:, 1] * kwargs["image_grid_thw"][:, 2]
|
461 |
kwargs["pixel_values"] = torch.cat([pv[:o] for pv, o in zip(kwargs["pixel_values"], offsets)], dim=0)
|
462 |
|
463 |
-
position_ids, rope_deltas =
|
464 |
input_ids=input_ids,
|
465 |
image_grid_thw=kwargs.get("image_grid_thw", None),
|
466 |
attention_mask=attention_mask,
|
467 |
)
|
468 |
|
|
|
|
|
469 |
outputs = super().forward(
|
470 |
input_ids,
|
471 |
attention_mask,
|
472 |
**kwargs,
|
473 |
position_ids=position_ids,
|
474 |
rope_deltas=rope_deltas,
|
475 |
-
output_hidden_states=True,
|
476 |
use_cache=False,
|
477 |
)
|
478 |
|
@@ -482,35 +172,6 @@ class QwenVLEmbeddingBase(EncodeMixin, QwenVLModel):
|
|
482 |
|
483 |
return hidden_states[-1]
|
484 |
|
485 |
-
|
486 |
-
class AbstractHybridModel(ABC):
|
487 |
-
"""
|
488 |
-
Abstract class for a hybrid model (single-vector and multi-vector embeddings).
|
489 |
-
"""
|
490 |
-
|
491 |
-
@property
|
492 |
-
def single_vector_projector_dim(self) -> int:
|
493 |
-
return self.config.single_vector_projector_dim
|
494 |
-
|
495 |
-
@property
|
496 |
-
def multi_vector_projector_dim(self) -> int:
|
497 |
-
return self.config.multi_vector_projector_dim
|
498 |
-
|
499 |
-
@abstractmethod
|
500 |
-
def forward(
|
501 |
-
self,
|
502 |
-
input_ids: torch.LongTensor,
|
503 |
-
attention_mask: torch.Tensor,
|
504 |
-
output_vlm_last_hidden_states: bool = False,
|
505 |
-
*args,
|
506 |
-
**kwargs,
|
507 |
-
) -> HybridModelOutput:
|
508 |
-
"""
|
509 |
-
Forward pass through the model. Returns both single-vector and multi-vector embeddings.
|
510 |
-
Must be implemented by subclasses.
|
511 |
-
"""
|
512 |
-
pass
|
513 |
-
|
514 |
def _init_projection_layers(self, config) -> None:
|
515 |
"""
|
516 |
Initializes projection layers.
|
@@ -528,14 +189,6 @@ class AbstractHybridModel(ABC):
|
|
528 |
out_features=self.config.multi_vector_projector_dim,
|
529 |
)
|
530 |
|
531 |
-
@staticmethod
|
532 |
-
def _delete_redundant_forward_kwargs(kwargs: Dict[str, Any]) -> None:
|
533 |
-
"""
|
534 |
-
Delete redundant kwargs before passing them to the forward method. In-place operation.
|
535 |
-
"""
|
536 |
-
for key in ["input_ids", "attention_mask", "output_hidden_states"]:
|
537 |
-
kwargs.pop(key, None)
|
538 |
-
|
539 |
def project_to_single_vector_embeddings(
|
540 |
self,
|
541 |
hidden_states: torch.Tensor,
|
@@ -545,48 +198,15 @@ class AbstractHybridModel(ABC):
|
|
545 |
"""
|
546 |
Project the hidden states to single-vector embeddings.
|
547 |
"""
|
|
|
|
|
|
|
|
|
548 |
|
549 |
-
|
550 |
-
|
551 |
-
if pooling_method == "mean" and input_ids is None:
|
552 |
-
print("Warning: `input_ids` is None. Using `legacy-mean` pooling strategy instead.")
|
553 |
-
pooling_method = "legacy-mean"
|
554 |
-
|
555 |
-
if pooling_method == "last-token":
|
556 |
-
pooled_output = hidden_states[:, -1, :]
|
557 |
-
elif pooling_method == "mean":
|
558 |
-
if self._input_has_image(input_ids[0]): # got document image(s)
|
559 |
-
# getting start and end positions of image tokens; torch.where returns
|
560 |
-
# (1) a list of indices of input sequences
|
561 |
-
# (shape corresponds to the total number of images in the batch)
|
562 |
-
# (2) a list of positions of image tokens in the input sequence
|
563 |
-
# (shape corresponds to the total number of images in the batch)
|
564 |
-
input_seq_idx, img_start_pos = torch.where(
|
565 |
-
input_ids == self.config.vision_start_token_id
|
566 |
-
) # (total number of images), (total number of images)
|
567 |
-
_, img_end_pos = torch.where(
|
568 |
-
input_ids == self.config.vision_end_token_id
|
569 |
-
) # (total number of images), (total number of images)
|
570 |
-
means = []
|
571 |
-
for i in range(input_seq_idx.shape[0]):
|
572 |
-
vector_pos = input_seq_idx[i]
|
573 |
-
start = img_start_pos[i]
|
574 |
-
end = img_end_pos[i]
|
575 |
-
mean_value = hidden_states[vector_pos][start : end + 1].mean(dim=0)
|
576 |
-
means.append(mean_value)
|
577 |
-
pooled_output = torch.stack(means)
|
578 |
-
|
579 |
-
else: # got query text
|
580 |
-
pooled_output = torch.sum(hidden_states * attention_mask.unsqueeze(-1), dim=1) / torch.sum(
|
581 |
-
attention_mask, dim=1, keepdim=True
|
582 |
-
)
|
583 |
-
|
584 |
-
elif pooling_method == "legacy-mean":
|
585 |
pooled_output = torch.sum(hidden_states * attention_mask.unsqueeze(-1), dim=1) / torch.sum(
|
586 |
attention_mask, dim=1, keepdim=True
|
587 |
)
|
588 |
-
else:
|
589 |
-
raise ValueError(f"Invalid pooling strategy: {pooling_method}")
|
590 |
single_vec_emb = self.single_vector_projector(pooled_output)
|
591 |
return torch.nn.functional.normalize(single_vec_emb, dim=-1)
|
592 |
|
@@ -605,30 +225,25 @@ class AbstractHybridModel(ABC):
|
|
605 |
def _input_has_image(self, input_ids):
|
606 |
return self.config.vision_start_token_id in input_ids
|
607 |
|
608 |
-
class ColQwenDuoBase(AbstractHybridModel, QwenVLEmbeddingBase):
|
609 |
-
|
610 |
def forward(
|
611 |
self,
|
612 |
input_ids: torch.LongTensor,
|
613 |
attention_mask: torch.Tensor,
|
614 |
output_vlm_last_hidden_states: bool = False,
|
615 |
**kwargs,
|
616 |
-
) ->
|
617 |
"""
|
618 |
-
Forward pass through
|
619 |
Args:
|
620 |
input_ids (torch.LongTensor): The input tokens tensor.
|
621 |
attention_mask (torch.LongTensor): The attention mask tensor.
|
622 |
Returns:
|
623 |
-
|
624 |
single_vector (torch.Tensor): Single-vector embeddings of shape (batch_size, dim).
|
625 |
multi_vector (torch.Tensor): Multi-vector embeddings of shape (batch_size, num_tokens, dim).
|
626 |
"""
|
627 |
-
# Delete redundant kwargs
|
628 |
-
self._delete_redundant_forward_kwargs(kwargs)
|
629 |
-
|
630 |
# Forward pass through the VLM
|
631 |
-
hidden_states = self.
|
632 |
input_ids=input_ids, attention_mask=attention_mask, **kwargs
|
633 |
) # (batch_size, seq_length, hidden_size)
|
634 |
|
@@ -636,16 +251,85 @@ class ColQwenDuoBase(AbstractHybridModel, QwenVLEmbeddingBase):
|
|
636 |
single_vec_emb = self.project_to_single_vector_embeddings(hidden_states, attention_mask, input_ids=input_ids)
|
637 |
multi_vec_emb = self.project_to_multi_vector_embeddings(hidden_states, attention_mask)
|
638 |
|
639 |
-
return
|
640 |
vlm_last_hidden_states=hidden_states if output_vlm_last_hidden_states else None,
|
641 |
single_vec_emb=single_vec_emb,
|
642 |
multi_vec_emb=multi_vec_emb,
|
643 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
644 |
|
645 |
|
646 |
class JinaEmbeddingsV4Model:
|
647 |
"""
|
648 |
-
Wrapper class for
|
649 |
"""
|
650 |
|
651 |
def __init__(self, model, adapter_dir):
|
@@ -664,7 +348,7 @@ class JinaEmbeddingsV4Model:
|
|
664 |
|
665 |
task = kwargs.pop('task', 'retrieval')
|
666 |
|
667 |
-
model =
|
668 |
|
669 |
if os.path.isdir(model.name_or_path):
|
670 |
adapter_dir = os.path.join(model.name_or_path, 'adapters')
|
@@ -705,13 +389,4 @@ class JinaEmbeddingsV4Model:
|
|
705 |
Forward the call to the underlying model's forward method.
|
706 |
"""
|
707 |
return self.model(*args, **kwargs)
|
708 |
-
|
709 |
-
|
710 |
-
class ColQwen25Duo(ColQwenDuoBase, Qwen2_5_VLForConditionalGeneration):
|
711 |
-
config_class = ColQwen25DuoConfig
|
712 |
-
def __init__(self, config: ColQwen25DuoConfig):
|
713 |
-
Qwen2_5_VLForConditionalGeneration.__init__(self, config)
|
714 |
-
self._init_projection_layers(config)
|
715 |
-
self.post_init()
|
716 |
-
self.processor = ColQwen25DuoProcessor.from_pretrained(self.name_or_path, trust_remote_code=True)
|
717 |
|
|
|
2 |
import math
|
3 |
import numpy as np
|
4 |
|
|
|
5 |
from dataclasses import dataclass
|
6 |
from typing import Any, Callable, ClassVar, Dict, List, Optional, Union, cast
|
7 |
+
from peft import PeftModel
|
|
|
8 |
import torch
|
9 |
from torch import nn
|
10 |
from torch.utils.data import DataLoader
|
|
|
15 |
from enum import Enum
|
16 |
from peft.utils.hotswap import hotswap_adapter
|
17 |
|
18 |
+
from transformers import BatchFeature
|
|
|
|
|
19 |
|
20 |
from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor, Qwen2_5_VLForConditionalGeneration
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
from huggingface_hub import snapshot_download
|
23 |
|
24 |
+
from .configuration_jina_embeddings_v4 import JinaEmbeddingsV4Config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
|
27 |
class PromptType(str, Enum):
|
28 |
query = "query"
|
29 |
passage = "passage"
|
30 |
|
31 |
+
class JinaEmbeddingsV4Processor(Qwen2_5_VLProcessor):
|
32 |
+
def __init__(self, *args, **kwargs) -> None:
|
33 |
+
Qwen2_5_VLProcessor.__init__(self, *args, **kwargs)
|
34 |
+
self.assistant_prefix_len = 58
|
35 |
+
self.text_max_length = 8192
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
@staticmethod
|
38 |
def round_by_factor(number: float, factor: int) -> int:
|
|
|
88 |
def process_texts(
|
89 |
self,
|
90 |
texts: List[str],
|
91 |
+
max_length: Optional[int] = None,
|
|
|
92 |
prefix: Optional[str] = None,
|
93 |
padding: Optional[str] = None,
|
94 |
) -> BatchFeature:
|
95 |
|
96 |
+
max_length = self.text_max_length if max_length is None else min(max_length, self.text_max_length)
|
97 |
padded_texts: List[str] = []
|
98 |
|
99 |
for text in texts:
|
|
|
112 |
return text_batch
|
113 |
|
114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
@dataclass
|
116 |
+
class JinaEmbeddingsV4ModelOutput:
|
117 |
"""
|
118 |
Base class for the Hybrid Model outputs.
|
119 |
Args:
|
|
|
126 |
single_vec_emb: Optional[torch.Tensor] = None
|
127 |
multi_vec_emb: Optional[torch.Tensor] = None
|
128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
|
130 |
+
class QwenVL25Embeddings(Qwen2_5_VLForConditionalGeneration):
|
131 |
+
config_class = JinaEmbeddingsV4Config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
main_input_name: ClassVar[str] = "doc_input_ids"
|
133 |
|
134 |
+
def __init__(self, config: JinaEmbeddingsV4Config):
|
135 |
+
Qwen2_5_VLForConditionalGeneration.__init__(self, config)
|
136 |
+
self._init_projection_layers(config)
|
137 |
+
self.post_init()
|
138 |
+
self.processor = JinaEmbeddingsV4Processor.from_pretrained(self.name_or_path, trust_remote_code=True)
|
139 |
+
self.single_vector_projector_dim = config.single_vector_projector_dim
|
140 |
+
self.multi_vector_projector_dim = config.multi_vector_projector_dim
|
141 |
+
|
142 |
+
def get_last_hidden_states(
|
143 |
self,
|
144 |
input_ids: torch.LongTensor,
|
145 |
attention_mask: torch.Tensor,
|
|
|
149 |
offsets = kwargs["image_grid_thw"][:, 1] * kwargs["image_grid_thw"][:, 2]
|
150 |
kwargs["pixel_values"] = torch.cat([pv[:o] for pv, o in zip(kwargs["pixel_values"], offsets)], dim=0)
|
151 |
|
152 |
+
position_ids, rope_deltas = super().get_rope_index( # type: ignore
|
153 |
input_ids=input_ids,
|
154 |
image_grid_thw=kwargs.get("image_grid_thw", None),
|
155 |
attention_mask=attention_mask,
|
156 |
)
|
157 |
|
158 |
+
kwargs['output_hidden_states'] = True
|
159 |
+
|
160 |
outputs = super().forward(
|
161 |
input_ids,
|
162 |
attention_mask,
|
163 |
**kwargs,
|
164 |
position_ids=position_ids,
|
165 |
rope_deltas=rope_deltas,
|
|
|
166 |
use_cache=False,
|
167 |
)
|
168 |
|
|
|
172 |
|
173 |
return hidden_states[-1]
|
174 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
def _init_projection_layers(self, config) -> None:
|
176 |
"""
|
177 |
Initializes projection layers.
|
|
|
189 |
out_features=self.config.multi_vector_projector_dim,
|
190 |
)
|
191 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
def project_to_single_vector_embeddings(
|
193 |
self,
|
194 |
hidden_states: torch.Tensor,
|
|
|
198 |
"""
|
199 |
Project the hidden states to single-vector embeddings.
|
200 |
"""
|
201 |
+
if self._input_has_image(input_ids[0]): # got document image
|
202 |
+
img_start_pos = torch.where(input_ids[0] == self.config.vision_start_token_id)[0][0]
|
203 |
+
img_end_pos = torch.where(input_ids[0] == self.config.vision_end_token_id)[0][0]
|
204 |
+
pooled_output = hidden_states[0][img_start_pos:img_end_pos + 1].mean(dim=0).unsqueeze(0)
|
205 |
|
206 |
+
else: # got query text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
207 |
pooled_output = torch.sum(hidden_states * attention_mask.unsqueeze(-1), dim=1) / torch.sum(
|
208 |
attention_mask, dim=1, keepdim=True
|
209 |
)
|
|
|
|
|
210 |
single_vec_emb = self.single_vector_projector(pooled_output)
|
211 |
return torch.nn.functional.normalize(single_vec_emb, dim=-1)
|
212 |
|
|
|
225 |
def _input_has_image(self, input_ids):
|
226 |
return self.config.vision_start_token_id in input_ids
|
227 |
|
|
|
|
|
228 |
def forward(
|
229 |
self,
|
230 |
input_ids: torch.LongTensor,
|
231 |
attention_mask: torch.Tensor,
|
232 |
output_vlm_last_hidden_states: bool = False,
|
233 |
**kwargs,
|
234 |
+
) -> JinaEmbeddingsV4ModelOutput:
|
235 |
"""
|
236 |
+
Forward pass through QwenVL25Embeddings. Returns both single-vector and multi-vector embeddings.
|
237 |
Args:
|
238 |
input_ids (torch.LongTensor): The input tokens tensor.
|
239 |
attention_mask (torch.LongTensor): The attention mask tensor.
|
240 |
Returns:
|
241 |
+
JinaEmbeddingsV4ModelOutput:
|
242 |
single_vector (torch.Tensor): Single-vector embeddings of shape (batch_size, dim).
|
243 |
multi_vector (torch.Tensor): Multi-vector embeddings of shape (batch_size, num_tokens, dim).
|
244 |
"""
|
|
|
|
|
|
|
245 |
# Forward pass through the VLM
|
246 |
+
hidden_states = self.get_last_hidden_states(
|
247 |
input_ids=input_ids, attention_mask=attention_mask, **kwargs
|
248 |
) # (batch_size, seq_length, hidden_size)
|
249 |
|
|
|
251 |
single_vec_emb = self.project_to_single_vector_embeddings(hidden_states, attention_mask, input_ids=input_ids)
|
252 |
multi_vec_emb = self.project_to_multi_vector_embeddings(hidden_states, attention_mask)
|
253 |
|
254 |
+
return JinaEmbeddingsV4ModelOutput(
|
255 |
vlm_last_hidden_states=hidden_states if output_vlm_last_hidden_states else None,
|
256 |
single_vec_emb=single_vec_emb,
|
257 |
multi_vec_emb=multi_vec_emb,
|
258 |
)
|
259 |
+
|
260 |
+
def _process_batches(
|
261 |
+
self,
|
262 |
+
data: List[Union[str, Image.Image]],
|
263 |
+
processor_fn: Callable,
|
264 |
+
desc: str,
|
265 |
+
vector_type: Optional[str] = None,
|
266 |
+
return_numpy: bool = False,
|
267 |
+
**kwargs,
|
268 |
+
) -> Union[np.ndarray, List[torch.Tensor]]:
|
269 |
+
dataloader = DataLoader(
|
270 |
+
dataset=data,
|
271 |
+
batch_size=kwargs.get("batch_size", 32),
|
272 |
+
shuffle=False,
|
273 |
+
collate_fn=processor_fn,
|
274 |
+
)
|
275 |
+
vector_type = vector_type or "single_vector"
|
276 |
+
results = []
|
277 |
+
self.eval()
|
278 |
+
for batch in tqdm(dataloader, desc=desc):
|
279 |
+
with torch.no_grad():
|
280 |
+
batch = {k: v.to(self.device) for k, v in batch.items()}
|
281 |
+
with torch.autocast(device_type=torch.device(self.device).type):
|
282 |
+
embeddings = self(**batch)
|
283 |
+
if vector_type == "single_vector":
|
284 |
+
embeddings = embeddings.single_vec_emb
|
285 |
+
else:
|
286 |
+
embeddings = embeddings.multi_vec_emb
|
287 |
+
results.append(embeddings.cpu() if return_numpy else list(torch.unbind(embeddings)))
|
288 |
+
if return_numpy:
|
289 |
+
return np.concatenate([result.numpy() for result in results], axis=0)
|
290 |
+
return [item for sublist in results for item in sublist]
|
291 |
+
|
292 |
+
def encode_texts(
|
293 |
+
self,
|
294 |
+
queries: List[str],
|
295 |
+
max_length: int = 8192,
|
296 |
+
batch_size: int = 8,
|
297 |
+
vector_type: Optional[str] = None,
|
298 |
+
desc: Optional[str] = None,
|
299 |
+
**kwargs,
|
300 |
+
) -> List[torch.Tensor]:
|
301 |
+
processor_fn = partial(self.processor.process_texts, max_length=max_length, prefix="Query")
|
302 |
+
return self._process_batches(
|
303 |
+
data=queries,
|
304 |
+
processor_fn=processor_fn,
|
305 |
+
desc=desc or "Encode queries...",
|
306 |
+
vector_type=vector_type,
|
307 |
+
batch_size=batch_size,
|
308 |
+
**kwargs,
|
309 |
+
)
|
310 |
+
|
311 |
+
def encode_images(
|
312 |
+
self,
|
313 |
+
documents: List[Image.Image],
|
314 |
+
batch_size: int = 8,
|
315 |
+
vector_type: Optional[str] = None,
|
316 |
+
desc: Optional[str] = None,
|
317 |
+
**kwargs,
|
318 |
+
) -> List[torch.Tensor]:
|
319 |
+
return self._process_batches(
|
320 |
+
data=documents,
|
321 |
+
processor_fn=self.processor.process_images,
|
322 |
+
desc=desc or "Encode documents...",
|
323 |
+
vector_type=vector_type,
|
324 |
+
batch_size=batch_size,
|
325 |
+
**kwargs,
|
326 |
+
)
|
327 |
+
|
328 |
|
329 |
|
330 |
class JinaEmbeddingsV4Model:
|
331 |
"""
|
332 |
+
Wrapper class for QwenVL25Embeddings that handles the loading of models and adapters.
|
333 |
"""
|
334 |
|
335 |
def __init__(self, model, adapter_dir):
|
|
|
348 |
|
349 |
task = kwargs.pop('task', 'retrieval')
|
350 |
|
351 |
+
model = QwenVL25Embeddings.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
|
352 |
|
353 |
if os.path.isdir(model.name_or_path):
|
354 |
adapter_dir = os.path.join(model.name_or_path, 'adapters')
|
|
|
389 |
Forward the call to the underlying model's forward method.
|
390 |
"""
|
391 |
return self.model(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
392 |
|
preprocessor_config.json
CHANGED
@@ -18,7 +18,7 @@
|
|
18 |
"merge_size": 2,
|
19 |
"min_pixels": 3136,
|
20 |
"patch_size": 14,
|
21 |
-
"processor_class": "
|
22 |
"resample": 3,
|
23 |
"rescale_factor": 0.00392156862745098,
|
24 |
"size": {
|
@@ -27,6 +27,6 @@
|
|
27 |
},
|
28 |
"temporal_patch_size": 2,
|
29 |
"auto_map": {
|
30 |
-
"AutoProcessor": "
|
31 |
}
|
32 |
}
|
|
|
18 |
"merge_size": 2,
|
19 |
"min_pixels": 3136,
|
20 |
"patch_size": 14,
|
21 |
+
"processor_class": "JinaEmbeddingsV4Processor",
|
22 |
"resample": 3,
|
23 |
"rescale_factor": 0.00392156862745098,
|
24 |
"size": {
|
|
|
27 |
},
|
28 |
"temporal_patch_size": 2,
|
29 |
"auto_map": {
|
30 |
+
"AutoProcessor": "modeling_jina_embeddings_v4.JinaEmbeddingsV4Processor"
|
31 |
}
|
32 |
}
|