jupyterjazz commited on
Commit
e9be62d
·
verified ·
1 Parent(s): 998398d

refactor-model-loading (#4)

Browse files

- feat: loading through jev4 class and stylistic changes (f7cb47c6b07716483dbf5fd311928026ce7cd27a)

Files changed (1) hide show
  1. modeling_jina_embeddings_v4.py +127 -87
modeling_jina_embeddings_v4.py CHANGED
@@ -1,25 +1,23 @@
1
- import os
2
  import math
3
- import numpy as np
4
-
5
  from dataclasses import dataclass
 
 
6
  from typing import Any, Callable, ClassVar, Dict, List, Optional, Union, cast
7
- from peft import PeftModel
 
8
  import torch
 
 
 
 
9
  from torch import nn
10
  from torch.utils.data import DataLoader
11
-
12
- from functools import partial
13
- from PIL import Image
14
  from tqdm import tqdm
15
- from enum import Enum
16
- from peft.utils.hotswap import hotswap_adapter
17
-
18
  from transformers import BatchFeature
19
-
20
- from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor, Qwen2_5_VLForConditionalGeneration
21
-
22
- from huggingface_hub import snapshot_download
23
 
24
  from .configuration_jina_embeddings_v4 import JinaEmbeddingsV4Config
25
 
@@ -28,6 +26,13 @@ class PromptType(str, Enum):
28
  query = "query"
29
  passage = "passage"
30
 
 
 
 
 
 
 
 
31
  class JinaEmbeddingsV4Processor(Qwen2_5_VLProcessor):
32
  def __init__(self, *args, **kwargs) -> None:
33
  Qwen2_5_VLProcessor.__init__(self, *args, **kwargs)
@@ -58,8 +63,12 @@ class JinaEmbeddingsV4Processor(Qwen2_5_VLProcessor):
58
  images = cast(List[List[Image.Image]], images)
59
  text_doc = []
60
  for i in range(len(images)):
61
- conversation = [{"role": "user", "content": [{"type": "image"}] * len(images[i])}]
62
- template = self.apply_chat_template(conversation, add_generation_prompt=False)
 
 
 
 
63
  text_doc.append(template[self.assistant_prefix_len :])
64
 
65
  else:
@@ -78,7 +87,16 @@ class JinaEmbeddingsV4Processor(Qwen2_5_VLProcessor):
78
  max_length = max([len(pv) for pv in pixel_values])
79
 
80
  pixel_values = [
81
- torch.cat([pv, torch.zeros((max_length - len(pv), pv.shape[1]), dtype=pv.dtype, device=pv.device)])
 
 
 
 
 
 
 
 
 
82
  for pv in pixel_values
83
  ]
84
 
@@ -93,7 +111,11 @@ class JinaEmbeddingsV4Processor(Qwen2_5_VLProcessor):
93
  padding: Optional[str] = None,
94
  ) -> BatchFeature:
95
 
96
- max_length = self.text_max_length if max_length is None else min(max_length, self.text_max_length)
 
 
 
 
97
  padded_texts: List[str] = []
98
 
99
  for text in texts:
@@ -127,7 +149,7 @@ class JinaEmbeddingsV4ModelOutput:
127
  multi_vec_emb: Optional[torch.Tensor] = None
128
 
129
 
130
- class QwenVL25Embeddings(Qwen2_5_VLForConditionalGeneration):
131
  config_class = JinaEmbeddingsV4Config
132
  main_input_name: ClassVar[str] = "doc_input_ids"
133
 
@@ -135,7 +157,9 @@ class QwenVL25Embeddings(Qwen2_5_VLForConditionalGeneration):
135
  Qwen2_5_VLForConditionalGeneration.__init__(self, config)
136
  self._init_projection_layers(config)
137
  self.post_init()
138
- self.processor = JinaEmbeddingsV4Processor.from_pretrained(self.name_or_path, trust_remote_code=True)
 
 
139
  self.single_vector_projector_dim = config.single_vector_projector_dim
140
  self.multi_vector_projector_dim = config.multi_vector_projector_dim
141
 
@@ -147,7 +171,9 @@ class QwenVL25Embeddings(Qwen2_5_VLForConditionalGeneration):
147
  ) -> torch.Tensor:
148
  if "pixel_values" in kwargs:
149
  offsets = kwargs["image_grid_thw"][:, 1] * kwargs["image_grid_thw"][:, 2]
150
- kwargs["pixel_values"] = torch.cat([pv[:o] for pv, o in zip(kwargs["pixel_values"], offsets)], dim=0)
 
 
151
 
152
  position_ids, rope_deltas = super().get_rope_index( # type: ignore
153
  input_ids=input_ids,
@@ -155,7 +181,7 @@ class QwenVL25Embeddings(Qwen2_5_VLForConditionalGeneration):
155
  attention_mask=attention_mask,
156
  )
157
 
158
- kwargs['output_hidden_states'] = True
159
 
160
  outputs = super().forward(
161
  input_ids,
@@ -199,14 +225,22 @@ class QwenVL25Embeddings(Qwen2_5_VLForConditionalGeneration):
199
  Project the hidden states to single-vector embeddings.
200
  """
201
  if self._input_has_image(input_ids[0]): # got document image
202
- img_start_pos = torch.where(input_ids[0] == self.config.vision_start_token_id)[0][0]
203
- img_end_pos = torch.where(input_ids[0] == self.config.vision_end_token_id)[0][0]
204
- pooled_output = hidden_states[0][img_start_pos:img_end_pos + 1].mean(dim=0).unsqueeze(0)
 
 
 
 
 
 
 
 
205
 
206
  else: # got query text
207
- pooled_output = torch.sum(hidden_states * attention_mask.unsqueeze(-1), dim=1) / torch.sum(
208
- attention_mask, dim=1, keepdim=True
209
- )
210
  single_vec_emb = self.single_vector_projector(pooled_output)
211
  return torch.nn.functional.normalize(single_vec_emb, dim=-1)
212
 
@@ -248,15 +282,21 @@ class QwenVL25Embeddings(Qwen2_5_VLForConditionalGeneration):
248
  ) # (batch_size, seq_length, hidden_size)
249
 
250
  # Compute the embeddings
251
- single_vec_emb = self.project_to_single_vector_embeddings(hidden_states, attention_mask, input_ids=input_ids)
252
- multi_vec_emb = self.project_to_multi_vector_embeddings(hidden_states, attention_mask)
 
 
 
 
253
 
254
  return JinaEmbeddingsV4ModelOutput(
255
- vlm_last_hidden_states=hidden_states if output_vlm_last_hidden_states else None,
 
 
256
  single_vec_emb=single_vec_emb,
257
  multi_vec_emb=multi_vec_emb,
258
  )
259
-
260
  def _process_batches(
261
  self,
262
  data: List[Union[str, Image.Image]],
@@ -284,7 +324,11 @@ class QwenVL25Embeddings(Qwen2_5_VLForConditionalGeneration):
284
  embeddings = embeddings.single_vec_emb
285
  else:
286
  embeddings = embeddings.multi_vec_emb
287
- results.append(embeddings.cpu() if return_numpy else list(torch.unbind(embeddings)))
 
 
 
 
288
  if return_numpy:
289
  return np.concatenate([result.numpy() for result in results], axis=0)
290
  return [item for sublist in results for item in sublist]
@@ -298,7 +342,9 @@ class QwenVL25Embeddings(Qwen2_5_VLForConditionalGeneration):
298
  desc: Optional[str] = None,
299
  **kwargs,
300
  ) -> List[torch.Tensor]:
301
- processor_fn = partial(self.processor.process_texts, max_length=max_length, prefix="Query")
 
 
302
  return self._process_batches(
303
  data=queries,
304
  processor_fn=processor_fn,
@@ -325,17 +371,6 @@ class QwenVL25Embeddings(Qwen2_5_VLForConditionalGeneration):
325
  **kwargs,
326
  )
327
 
328
-
329
-
330
- class JinaEmbeddingsV4Model:
331
- """
332
- Wrapper class for QwenVL25Embeddings that handles the loading of models and adapters.
333
- """
334
-
335
- def __init__(self, model, adapter_dir):
336
- self.model = model
337
- self.adapter_dir = adapter_dir
338
-
339
  @classmethod
340
  def from_pretrained(
341
  cls,
@@ -345,48 +380,53 @@ class JinaEmbeddingsV4Model:
345
  ):
346
  if "torch_dtype" not in kwargs:
347
  kwargs["torch_dtype"] = "auto"
348
-
349
- task = kwargs.pop('task', 'retrieval')
350
-
351
- model = QwenVL25Embeddings.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
352
-
353
- if os.path.isdir(model.name_or_path):
354
- adapter_dir = os.path.join(model.name_or_path, 'adapters')
 
 
 
 
355
  else:
356
  adapter_cache_path = snapshot_download(
357
- repo_id=model.name_or_path,
358
- allow_patterns=['adapters/*']
359
  )
360
- adapter_dir = os.path.join(adapter_cache_path, 'adapters')
361
- model = PeftModel.from_pretrained(model, os.path.join(adapter_dir, task))
362
- je_v4_model = cls(model, adapter_dir)
363
-
364
- return je_v4_model
365
-
366
- def set_task(self, task: str):
367
- """
368
- Set the task adapter for the model.
369
-
370
- Args:
371
- task (str): The task name. Must be one of ['retrieval', 'text-matching', 'code']
372
- """
373
- if task not in ['retrieval', 'text-matching', 'code']:
374
- raise ValueError(f"Invalid task: {task}. Must be one of ['retrieval', 'text-matching', 'code']")
375
-
376
- adapter_path = os.path.join(self.adapter_dir, task)
377
- hotswap_adapter(self.model, adapter_path, adapter_name='default')
378
-
379
- def __getattr__(self, name):
380
- """
381
- Delegate attribute access to the underlying model.
382
- """
383
- if hasattr(self.model, name):
384
- return getattr(self.model, name)
385
- raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
386
-
387
- def __call__(self, *args, **kwargs):
388
- """
389
- Forward the call to the underlying model's forward method.
390
- """
391
- return self.model(*args, **kwargs)
392
-
 
 
 
 
1
  import math
2
+ import os
 
3
  from dataclasses import dataclass
4
+ from enum import Enum
5
+ from functools import partial
6
  from typing import Any, Callable, ClassVar, Dict, List, Optional, Union, cast
7
+
8
+ import numpy as np
9
  import torch
10
+ from huggingface_hub import snapshot_download
11
+ from peft import PeftModel
12
+ from peft.utils.hotswap import hotswap_adapter
13
+ from PIL import Image
14
  from torch import nn
15
  from torch.utils.data import DataLoader
 
 
 
16
  from tqdm import tqdm
 
 
 
17
  from transformers import BatchFeature
18
+ from transformers.modeling_utils import PreTrainedModel
19
+ from transformers.models.qwen2_5_vl import (Qwen2_5_VLForConditionalGeneration,
20
+ Qwen2_5_VLProcessor)
 
21
 
22
  from .configuration_jina_embeddings_v4 import JinaEmbeddingsV4Config
23
 
 
26
  query = "query"
27
  passage = "passage"
28
 
29
+
30
+ class TaskType(str, Enum):
31
+ retrieval = "retrieval"
32
+ code = "code"
33
+ text_matching = "text-matching"
34
+
35
+
36
  class JinaEmbeddingsV4Processor(Qwen2_5_VLProcessor):
37
  def __init__(self, *args, **kwargs) -> None:
38
  Qwen2_5_VLProcessor.__init__(self, *args, **kwargs)
 
63
  images = cast(List[List[Image.Image]], images)
64
  text_doc = []
65
  for i in range(len(images)):
66
+ conversation = [
67
+ {"role": "user", "content": [{"type": "image"}] * len(images[i])}
68
+ ]
69
+ template = self.apply_chat_template(
70
+ conversation, add_generation_prompt=False
71
+ )
72
  text_doc.append(template[self.assistant_prefix_len :])
73
 
74
  else:
 
87
  max_length = max([len(pv) for pv in pixel_values])
88
 
89
  pixel_values = [
90
+ torch.cat(
91
+ [
92
+ pv,
93
+ torch.zeros(
94
+ (max_length - len(pv), pv.shape[1]),
95
+ dtype=pv.dtype,
96
+ device=pv.device,
97
+ ),
98
+ ]
99
+ )
100
  for pv in pixel_values
101
  ]
102
 
 
111
  padding: Optional[str] = None,
112
  ) -> BatchFeature:
113
 
114
+ max_length = (
115
+ self.text_max_length
116
+ if max_length is None
117
+ else min(max_length, self.text_max_length)
118
+ )
119
  padded_texts: List[str] = []
120
 
121
  for text in texts:
 
149
  multi_vec_emb: Optional[torch.Tensor] = None
150
 
151
 
152
+ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
153
  config_class = JinaEmbeddingsV4Config
154
  main_input_name: ClassVar[str] = "doc_input_ids"
155
 
 
157
  Qwen2_5_VLForConditionalGeneration.__init__(self, config)
158
  self._init_projection_layers(config)
159
  self.post_init()
160
+ self.processor = JinaEmbeddingsV4Processor.from_pretrained(
161
+ self.name_or_path, trust_remote_code=True
162
+ )
163
  self.single_vector_projector_dim = config.single_vector_projector_dim
164
  self.multi_vector_projector_dim = config.multi_vector_projector_dim
165
 
 
171
  ) -> torch.Tensor:
172
  if "pixel_values" in kwargs:
173
  offsets = kwargs["image_grid_thw"][:, 1] * kwargs["image_grid_thw"][:, 2]
174
+ kwargs["pixel_values"] = torch.cat(
175
+ [pv[:o] for pv, o in zip(kwargs["pixel_values"], offsets)], dim=0
176
+ )
177
 
178
  position_ids, rope_deltas = super().get_rope_index( # type: ignore
179
  input_ids=input_ids,
 
181
  attention_mask=attention_mask,
182
  )
183
 
184
+ kwargs["output_hidden_states"] = True
185
 
186
  outputs = super().forward(
187
  input_ids,
 
225
  Project the hidden states to single-vector embeddings.
226
  """
227
  if self._input_has_image(input_ids[0]): # got document image
228
+ img_start_pos = torch.where(
229
+ input_ids[0] == self.config.vision_start_token_id
230
+ )[0][0]
231
+ img_end_pos = torch.where(input_ids[0] == self.config.vision_end_token_id)[
232
+ 0
233
+ ][0]
234
+ pooled_output = (
235
+ hidden_states[0][img_start_pos : img_end_pos + 1]
236
+ .mean(dim=0)
237
+ .unsqueeze(0)
238
+ )
239
 
240
  else: # got query text
241
+ pooled_output = torch.sum(
242
+ hidden_states * attention_mask.unsqueeze(-1), dim=1
243
+ ) / torch.sum(attention_mask, dim=1, keepdim=True)
244
  single_vec_emb = self.single_vector_projector(pooled_output)
245
  return torch.nn.functional.normalize(single_vec_emb, dim=-1)
246
 
 
282
  ) # (batch_size, seq_length, hidden_size)
283
 
284
  # Compute the embeddings
285
+ single_vec_emb = self.project_to_single_vector_embeddings(
286
+ hidden_states, attention_mask, input_ids=input_ids
287
+ )
288
+ multi_vec_emb = self.project_to_multi_vector_embeddings(
289
+ hidden_states, attention_mask
290
+ )
291
 
292
  return JinaEmbeddingsV4ModelOutput(
293
+ vlm_last_hidden_states=(
294
+ hidden_states if output_vlm_last_hidden_states else None
295
+ ),
296
  single_vec_emb=single_vec_emb,
297
  multi_vec_emb=multi_vec_emb,
298
  )
299
+
300
  def _process_batches(
301
  self,
302
  data: List[Union[str, Image.Image]],
 
324
  embeddings = embeddings.single_vec_emb
325
  else:
326
  embeddings = embeddings.multi_vec_emb
327
+ results.append(
328
+ embeddings.cpu()
329
+ if return_numpy
330
+ else list(torch.unbind(embeddings))
331
+ )
332
  if return_numpy:
333
  return np.concatenate([result.numpy() for result in results], axis=0)
334
  return [item for sublist in results for item in sublist]
 
342
  desc: Optional[str] = None,
343
  **kwargs,
344
  ) -> List[torch.Tensor]:
345
+ processor_fn = partial(
346
+ self.processor.process_texts, max_length=max_length, prefix="Query"
347
+ )
348
  return self._process_batches(
349
  data=queries,
350
  processor_fn=processor_fn,
 
371
  **kwargs,
372
  )
373
 
 
 
 
 
 
 
 
 
 
 
 
374
  @classmethod
375
  def from_pretrained(
376
  cls,
 
380
  ):
381
  if "torch_dtype" not in kwargs:
382
  kwargs["torch_dtype"] = "auto"
383
+
384
+ task = kwargs.pop("task", TaskType.retrieval)
385
+
386
+ # Get the base model first
387
+ base_model = super().from_pretrained(
388
+ pretrained_model_name_or_path, *args, **kwargs
389
+ )
390
+
391
+ # Configure adapter directory
392
+ if os.path.isdir(base_model.name_or_path):
393
+ adapter_dir = os.path.join(base_model.name_or_path, "adapters")
394
  else:
395
  adapter_cache_path = snapshot_download(
396
+ repo_id=base_model.name_or_path, allow_patterns=["adapters/*"]
 
397
  )
398
+ adapter_dir = os.path.join(adapter_cache_path, "adapters")
399
+
400
+ # Store adapter directory for later use with set_task
401
+ base_model.adapter_dir = adapter_dir
402
+
403
+ # Create the PEFT model with the requested task adapter
404
+ peft_model = PeftModel.from_pretrained(
405
+ base_model, os.path.join(adapter_dir, task)
406
+ )
407
+
408
+ # Add set_task method to the PEFT model instance
409
+ def set_task_method(self, task_name: Union[str, TaskType]):
410
+ """
411
+ Set the task adapter for the model.
412
+
413
+ Args:
414
+ task_name (Union[str, TaskType]): The task name. Must be one of TaskType values or
415
+ one of ['retrieval', 'text-matching', 'code']
416
+ """
417
+ if isinstance(task_name, str):
418
+ try:
419
+ task_name = TaskType(task_name)
420
+ except ValueError:
421
+ valid_tasks = [t.value for t in TaskType]
422
+ raise ValueError(
423
+ f"Invalid task: {task_name}. Must be one of {valid_tasks}"
424
+ )
425
+
426
+ adapter_path = os.path.join(self.adapter_dir, task_name.value)
427
+ hotswap_adapter(self, adapter_path, adapter_name="default")
428
+
429
+ # Bind the method to the instance
430
+ peft_model.set_task = set_task_method.__get__(peft_model, type(peft_model))
431
+
432
+ return peft_model