refactor-model-loading (#4)
Browse files- feat: loading through jev4 class and stylistic changes (f7cb47c6b07716483dbf5fd311928026ce7cd27a)
- modeling_jina_embeddings_v4.py +127 -87
modeling_jina_embeddings_v4.py
CHANGED
@@ -1,25 +1,23 @@
|
|
1 |
-
import os
|
2 |
import math
|
3 |
-
import
|
4 |
-
|
5 |
from dataclasses import dataclass
|
|
|
|
|
6 |
from typing import Any, Callable, ClassVar, Dict, List, Optional, Union, cast
|
7 |
-
|
|
|
8 |
import torch
|
|
|
|
|
|
|
|
|
9 |
from torch import nn
|
10 |
from torch.utils.data import DataLoader
|
11 |
-
|
12 |
-
from functools import partial
|
13 |
-
from PIL import Image
|
14 |
from tqdm import tqdm
|
15 |
-
from enum import Enum
|
16 |
-
from peft.utils.hotswap import hotswap_adapter
|
17 |
-
|
18 |
from transformers import BatchFeature
|
19 |
-
|
20 |
-
from transformers.models.qwen2_5_vl import
|
21 |
-
|
22 |
-
from huggingface_hub import snapshot_download
|
23 |
|
24 |
from .configuration_jina_embeddings_v4 import JinaEmbeddingsV4Config
|
25 |
|
@@ -28,6 +26,13 @@ class PromptType(str, Enum):
|
|
28 |
query = "query"
|
29 |
passage = "passage"
|
30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
class JinaEmbeddingsV4Processor(Qwen2_5_VLProcessor):
|
32 |
def __init__(self, *args, **kwargs) -> None:
|
33 |
Qwen2_5_VLProcessor.__init__(self, *args, **kwargs)
|
@@ -58,8 +63,12 @@ class JinaEmbeddingsV4Processor(Qwen2_5_VLProcessor):
|
|
58 |
images = cast(List[List[Image.Image]], images)
|
59 |
text_doc = []
|
60 |
for i in range(len(images)):
|
61 |
-
conversation = [
|
62 |
-
|
|
|
|
|
|
|
|
|
63 |
text_doc.append(template[self.assistant_prefix_len :])
|
64 |
|
65 |
else:
|
@@ -78,7 +87,16 @@ class JinaEmbeddingsV4Processor(Qwen2_5_VLProcessor):
|
|
78 |
max_length = max([len(pv) for pv in pixel_values])
|
79 |
|
80 |
pixel_values = [
|
81 |
-
torch.cat(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
for pv in pixel_values
|
83 |
]
|
84 |
|
@@ -93,7 +111,11 @@ class JinaEmbeddingsV4Processor(Qwen2_5_VLProcessor):
|
|
93 |
padding: Optional[str] = None,
|
94 |
) -> BatchFeature:
|
95 |
|
96 |
-
max_length =
|
|
|
|
|
|
|
|
|
97 |
padded_texts: List[str] = []
|
98 |
|
99 |
for text in texts:
|
@@ -127,7 +149,7 @@ class JinaEmbeddingsV4ModelOutput:
|
|
127 |
multi_vec_emb: Optional[torch.Tensor] = None
|
128 |
|
129 |
|
130 |
-
class
|
131 |
config_class = JinaEmbeddingsV4Config
|
132 |
main_input_name: ClassVar[str] = "doc_input_ids"
|
133 |
|
@@ -135,7 +157,9 @@ class QwenVL25Embeddings(Qwen2_5_VLForConditionalGeneration):
|
|
135 |
Qwen2_5_VLForConditionalGeneration.__init__(self, config)
|
136 |
self._init_projection_layers(config)
|
137 |
self.post_init()
|
138 |
-
self.processor = JinaEmbeddingsV4Processor.from_pretrained(
|
|
|
|
|
139 |
self.single_vector_projector_dim = config.single_vector_projector_dim
|
140 |
self.multi_vector_projector_dim = config.multi_vector_projector_dim
|
141 |
|
@@ -147,7 +171,9 @@ class QwenVL25Embeddings(Qwen2_5_VLForConditionalGeneration):
|
|
147 |
) -> torch.Tensor:
|
148 |
if "pixel_values" in kwargs:
|
149 |
offsets = kwargs["image_grid_thw"][:, 1] * kwargs["image_grid_thw"][:, 2]
|
150 |
-
kwargs["pixel_values"] = torch.cat(
|
|
|
|
|
151 |
|
152 |
position_ids, rope_deltas = super().get_rope_index( # type: ignore
|
153 |
input_ids=input_ids,
|
@@ -155,7 +181,7 @@ class QwenVL25Embeddings(Qwen2_5_VLForConditionalGeneration):
|
|
155 |
attention_mask=attention_mask,
|
156 |
)
|
157 |
|
158 |
-
kwargs[
|
159 |
|
160 |
outputs = super().forward(
|
161 |
input_ids,
|
@@ -199,14 +225,22 @@ class QwenVL25Embeddings(Qwen2_5_VLForConditionalGeneration):
|
|
199 |
Project the hidden states to single-vector embeddings.
|
200 |
"""
|
201 |
if self._input_has_image(input_ids[0]): # got document image
|
202 |
-
img_start_pos = torch.where(
|
203 |
-
|
204 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
|
206 |
else: # got query text
|
207 |
-
pooled_output = torch.sum(
|
208 |
-
attention_mask, dim=1
|
209 |
-
)
|
210 |
single_vec_emb = self.single_vector_projector(pooled_output)
|
211 |
return torch.nn.functional.normalize(single_vec_emb, dim=-1)
|
212 |
|
@@ -248,15 +282,21 @@ class QwenVL25Embeddings(Qwen2_5_VLForConditionalGeneration):
|
|
248 |
) # (batch_size, seq_length, hidden_size)
|
249 |
|
250 |
# Compute the embeddings
|
251 |
-
single_vec_emb = self.project_to_single_vector_embeddings(
|
252 |
-
|
|
|
|
|
|
|
|
|
253 |
|
254 |
return JinaEmbeddingsV4ModelOutput(
|
255 |
-
vlm_last_hidden_states=
|
|
|
|
|
256 |
single_vec_emb=single_vec_emb,
|
257 |
multi_vec_emb=multi_vec_emb,
|
258 |
)
|
259 |
-
|
260 |
def _process_batches(
|
261 |
self,
|
262 |
data: List[Union[str, Image.Image]],
|
@@ -284,7 +324,11 @@ class QwenVL25Embeddings(Qwen2_5_VLForConditionalGeneration):
|
|
284 |
embeddings = embeddings.single_vec_emb
|
285 |
else:
|
286 |
embeddings = embeddings.multi_vec_emb
|
287 |
-
results.append(
|
|
|
|
|
|
|
|
|
288 |
if return_numpy:
|
289 |
return np.concatenate([result.numpy() for result in results], axis=0)
|
290 |
return [item for sublist in results for item in sublist]
|
@@ -298,7 +342,9 @@ class QwenVL25Embeddings(Qwen2_5_VLForConditionalGeneration):
|
|
298 |
desc: Optional[str] = None,
|
299 |
**kwargs,
|
300 |
) -> List[torch.Tensor]:
|
301 |
-
processor_fn = partial(
|
|
|
|
|
302 |
return self._process_batches(
|
303 |
data=queries,
|
304 |
processor_fn=processor_fn,
|
@@ -325,17 +371,6 @@ class QwenVL25Embeddings(Qwen2_5_VLForConditionalGeneration):
|
|
325 |
**kwargs,
|
326 |
)
|
327 |
|
328 |
-
|
329 |
-
|
330 |
-
class JinaEmbeddingsV4Model:
|
331 |
-
"""
|
332 |
-
Wrapper class for QwenVL25Embeddings that handles the loading of models and adapters.
|
333 |
-
"""
|
334 |
-
|
335 |
-
def __init__(self, model, adapter_dir):
|
336 |
-
self.model = model
|
337 |
-
self.adapter_dir = adapter_dir
|
338 |
-
|
339 |
@classmethod
|
340 |
def from_pretrained(
|
341 |
cls,
|
@@ -345,48 +380,53 @@ class JinaEmbeddingsV4Model:
|
|
345 |
):
|
346 |
if "torch_dtype" not in kwargs:
|
347 |
kwargs["torch_dtype"] = "auto"
|
348 |
-
|
349 |
-
task = kwargs.pop(
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
|
|
|
|
|
|
|
|
355 |
else:
|
356 |
adapter_cache_path = snapshot_download(
|
357 |
-
repo_id=
|
358 |
-
allow_patterns=['adapters/*']
|
359 |
)
|
360 |
-
adapter_dir = os.path.join(adapter_cache_path,
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
|
|
|
|
|
|
|
1 |
import math
|
2 |
+
import os
|
|
|
3 |
from dataclasses import dataclass
|
4 |
+
from enum import Enum
|
5 |
+
from functools import partial
|
6 |
from typing import Any, Callable, ClassVar, Dict, List, Optional, Union, cast
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
import torch
|
10 |
+
from huggingface_hub import snapshot_download
|
11 |
+
from peft import PeftModel
|
12 |
+
from peft.utils.hotswap import hotswap_adapter
|
13 |
+
from PIL import Image
|
14 |
from torch import nn
|
15 |
from torch.utils.data import DataLoader
|
|
|
|
|
|
|
16 |
from tqdm import tqdm
|
|
|
|
|
|
|
17 |
from transformers import BatchFeature
|
18 |
+
from transformers.modeling_utils import PreTrainedModel
|
19 |
+
from transformers.models.qwen2_5_vl import (Qwen2_5_VLForConditionalGeneration,
|
20 |
+
Qwen2_5_VLProcessor)
|
|
|
21 |
|
22 |
from .configuration_jina_embeddings_v4 import JinaEmbeddingsV4Config
|
23 |
|
|
|
26 |
query = "query"
|
27 |
passage = "passage"
|
28 |
|
29 |
+
|
30 |
+
class TaskType(str, Enum):
|
31 |
+
retrieval = "retrieval"
|
32 |
+
code = "code"
|
33 |
+
text_matching = "text-matching"
|
34 |
+
|
35 |
+
|
36 |
class JinaEmbeddingsV4Processor(Qwen2_5_VLProcessor):
|
37 |
def __init__(self, *args, **kwargs) -> None:
|
38 |
Qwen2_5_VLProcessor.__init__(self, *args, **kwargs)
|
|
|
63 |
images = cast(List[List[Image.Image]], images)
|
64 |
text_doc = []
|
65 |
for i in range(len(images)):
|
66 |
+
conversation = [
|
67 |
+
{"role": "user", "content": [{"type": "image"}] * len(images[i])}
|
68 |
+
]
|
69 |
+
template = self.apply_chat_template(
|
70 |
+
conversation, add_generation_prompt=False
|
71 |
+
)
|
72 |
text_doc.append(template[self.assistant_prefix_len :])
|
73 |
|
74 |
else:
|
|
|
87 |
max_length = max([len(pv) for pv in pixel_values])
|
88 |
|
89 |
pixel_values = [
|
90 |
+
torch.cat(
|
91 |
+
[
|
92 |
+
pv,
|
93 |
+
torch.zeros(
|
94 |
+
(max_length - len(pv), pv.shape[1]),
|
95 |
+
dtype=pv.dtype,
|
96 |
+
device=pv.device,
|
97 |
+
),
|
98 |
+
]
|
99 |
+
)
|
100 |
for pv in pixel_values
|
101 |
]
|
102 |
|
|
|
111 |
padding: Optional[str] = None,
|
112 |
) -> BatchFeature:
|
113 |
|
114 |
+
max_length = (
|
115 |
+
self.text_max_length
|
116 |
+
if max_length is None
|
117 |
+
else min(max_length, self.text_max_length)
|
118 |
+
)
|
119 |
padded_texts: List[str] = []
|
120 |
|
121 |
for text in texts:
|
|
|
149 |
multi_vec_emb: Optional[torch.Tensor] = None
|
150 |
|
151 |
|
152 |
+
class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
153 |
config_class = JinaEmbeddingsV4Config
|
154 |
main_input_name: ClassVar[str] = "doc_input_ids"
|
155 |
|
|
|
157 |
Qwen2_5_VLForConditionalGeneration.__init__(self, config)
|
158 |
self._init_projection_layers(config)
|
159 |
self.post_init()
|
160 |
+
self.processor = JinaEmbeddingsV4Processor.from_pretrained(
|
161 |
+
self.name_or_path, trust_remote_code=True
|
162 |
+
)
|
163 |
self.single_vector_projector_dim = config.single_vector_projector_dim
|
164 |
self.multi_vector_projector_dim = config.multi_vector_projector_dim
|
165 |
|
|
|
171 |
) -> torch.Tensor:
|
172 |
if "pixel_values" in kwargs:
|
173 |
offsets = kwargs["image_grid_thw"][:, 1] * kwargs["image_grid_thw"][:, 2]
|
174 |
+
kwargs["pixel_values"] = torch.cat(
|
175 |
+
[pv[:o] for pv, o in zip(kwargs["pixel_values"], offsets)], dim=0
|
176 |
+
)
|
177 |
|
178 |
position_ids, rope_deltas = super().get_rope_index( # type: ignore
|
179 |
input_ids=input_ids,
|
|
|
181 |
attention_mask=attention_mask,
|
182 |
)
|
183 |
|
184 |
+
kwargs["output_hidden_states"] = True
|
185 |
|
186 |
outputs = super().forward(
|
187 |
input_ids,
|
|
|
225 |
Project the hidden states to single-vector embeddings.
|
226 |
"""
|
227 |
if self._input_has_image(input_ids[0]): # got document image
|
228 |
+
img_start_pos = torch.where(
|
229 |
+
input_ids[0] == self.config.vision_start_token_id
|
230 |
+
)[0][0]
|
231 |
+
img_end_pos = torch.where(input_ids[0] == self.config.vision_end_token_id)[
|
232 |
+
0
|
233 |
+
][0]
|
234 |
+
pooled_output = (
|
235 |
+
hidden_states[0][img_start_pos : img_end_pos + 1]
|
236 |
+
.mean(dim=0)
|
237 |
+
.unsqueeze(0)
|
238 |
+
)
|
239 |
|
240 |
else: # got query text
|
241 |
+
pooled_output = torch.sum(
|
242 |
+
hidden_states * attention_mask.unsqueeze(-1), dim=1
|
243 |
+
) / torch.sum(attention_mask, dim=1, keepdim=True)
|
244 |
single_vec_emb = self.single_vector_projector(pooled_output)
|
245 |
return torch.nn.functional.normalize(single_vec_emb, dim=-1)
|
246 |
|
|
|
282 |
) # (batch_size, seq_length, hidden_size)
|
283 |
|
284 |
# Compute the embeddings
|
285 |
+
single_vec_emb = self.project_to_single_vector_embeddings(
|
286 |
+
hidden_states, attention_mask, input_ids=input_ids
|
287 |
+
)
|
288 |
+
multi_vec_emb = self.project_to_multi_vector_embeddings(
|
289 |
+
hidden_states, attention_mask
|
290 |
+
)
|
291 |
|
292 |
return JinaEmbeddingsV4ModelOutput(
|
293 |
+
vlm_last_hidden_states=(
|
294 |
+
hidden_states if output_vlm_last_hidden_states else None
|
295 |
+
),
|
296 |
single_vec_emb=single_vec_emb,
|
297 |
multi_vec_emb=multi_vec_emb,
|
298 |
)
|
299 |
+
|
300 |
def _process_batches(
|
301 |
self,
|
302 |
data: List[Union[str, Image.Image]],
|
|
|
324 |
embeddings = embeddings.single_vec_emb
|
325 |
else:
|
326 |
embeddings = embeddings.multi_vec_emb
|
327 |
+
results.append(
|
328 |
+
embeddings.cpu()
|
329 |
+
if return_numpy
|
330 |
+
else list(torch.unbind(embeddings))
|
331 |
+
)
|
332 |
if return_numpy:
|
333 |
return np.concatenate([result.numpy() for result in results], axis=0)
|
334 |
return [item for sublist in results for item in sublist]
|
|
|
342 |
desc: Optional[str] = None,
|
343 |
**kwargs,
|
344 |
) -> List[torch.Tensor]:
|
345 |
+
processor_fn = partial(
|
346 |
+
self.processor.process_texts, max_length=max_length, prefix="Query"
|
347 |
+
)
|
348 |
return self._process_batches(
|
349 |
data=queries,
|
350 |
processor_fn=processor_fn,
|
|
|
371 |
**kwargs,
|
372 |
)
|
373 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
374 |
@classmethod
|
375 |
def from_pretrained(
|
376 |
cls,
|
|
|
380 |
):
|
381 |
if "torch_dtype" not in kwargs:
|
382 |
kwargs["torch_dtype"] = "auto"
|
383 |
+
|
384 |
+
task = kwargs.pop("task", TaskType.retrieval)
|
385 |
+
|
386 |
+
# Get the base model first
|
387 |
+
base_model = super().from_pretrained(
|
388 |
+
pretrained_model_name_or_path, *args, **kwargs
|
389 |
+
)
|
390 |
+
|
391 |
+
# Configure adapter directory
|
392 |
+
if os.path.isdir(base_model.name_or_path):
|
393 |
+
adapter_dir = os.path.join(base_model.name_or_path, "adapters")
|
394 |
else:
|
395 |
adapter_cache_path = snapshot_download(
|
396 |
+
repo_id=base_model.name_or_path, allow_patterns=["adapters/*"]
|
|
|
397 |
)
|
398 |
+
adapter_dir = os.path.join(adapter_cache_path, "adapters")
|
399 |
+
|
400 |
+
# Store adapter directory for later use with set_task
|
401 |
+
base_model.adapter_dir = adapter_dir
|
402 |
+
|
403 |
+
# Create the PEFT model with the requested task adapter
|
404 |
+
peft_model = PeftModel.from_pretrained(
|
405 |
+
base_model, os.path.join(adapter_dir, task)
|
406 |
+
)
|
407 |
+
|
408 |
+
# Add set_task method to the PEFT model instance
|
409 |
+
def set_task_method(self, task_name: Union[str, TaskType]):
|
410 |
+
"""
|
411 |
+
Set the task adapter for the model.
|
412 |
+
|
413 |
+
Args:
|
414 |
+
task_name (Union[str, TaskType]): The task name. Must be one of TaskType values or
|
415 |
+
one of ['retrieval', 'text-matching', 'code']
|
416 |
+
"""
|
417 |
+
if isinstance(task_name, str):
|
418 |
+
try:
|
419 |
+
task_name = TaskType(task_name)
|
420 |
+
except ValueError:
|
421 |
+
valid_tasks = [t.value for t in TaskType]
|
422 |
+
raise ValueError(
|
423 |
+
f"Invalid task: {task_name}. Must be one of {valid_tasks}"
|
424 |
+
)
|
425 |
+
|
426 |
+
adapter_path = os.path.join(self.adapter_dir, task_name.value)
|
427 |
+
hotswap_adapter(self, adapter_path, adapter_name="default")
|
428 |
+
|
429 |
+
# Bind the method to the instance
|
430 |
+
peft_model.set_task = set_task_method.__get__(peft_model, type(peft_model))
|
431 |
+
|
432 |
+
return peft_model
|