jupyterjazz commited on
Commit
9ad94d9
·
verified ·
1 Parent(s): d51390d

fix-task-setting-and-st-load (#46)

Browse files

- fix: load in new st, task setting (5a13d0f29dd4f13b4a0d82f530acd43d189d44fc)

Files changed (3) hide show
  1. README.md +1 -1
  2. custom_st.py +18 -5
  3. modeling_jina_embeddings_v4.py +23 -20
README.md CHANGED
@@ -155,7 +155,7 @@ from transformers import AutoModel
155
  import torch
156
 
157
  # Initialize the model
158
- model = AutoModel.from_pretrained("jinaai/jina-embeddings-v4", trust_remote_code=True)
159
 
160
  model.to("cuda")
161
 
 
155
  import torch
156
 
157
  # Initialize the model
158
+ model = AutoModel.from_pretrained("jinaai/jina-embeddings-v4", trust_remote_code=True, torch_dtype=torch.float16)
159
 
160
  model.to("cuda")
161
 
custom_st.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  from io import BytesIO
2
  from pathlib import Path
3
  from typing import Any, Dict, List, Literal, Optional, Union
@@ -104,7 +106,10 @@ class Transformer(nn.Module):
104
  return encoding
105
 
106
  def forward(
107
- self, features: Dict[str, torch.Tensor], task: Optional[str] = None, truncate_dim: Optional[int] = None
 
 
 
108
  ) -> Dict[str, torch.Tensor]:
109
  self.model.eval()
110
 
@@ -138,8 +143,10 @@ class Transformer(nn.Module):
138
  **text_batch, task_label=task
139
  ).single_vec_emb
140
  if truncate_dim:
141
- text_embeddings = text_embeddings[:, : truncate_dim]
142
- text_embeddings = torch.nn.functional.normalize(text_embeddings, p=2, dim=-1)
 
 
143
  for i, embedding in enumerate(text_embeddings):
144
  all_embeddings.append((text_indices[i], embedding))
145
 
@@ -156,8 +163,10 @@ class Transformer(nn.Module):
156
  **image_batch, task_label=task
157
  ).single_vec_emb
158
  if truncate_dim:
159
- img_embeddings = img_embeddings[:, : truncate_dim]
160
- img_embeddings = torch.nn.functional.normalize(img_embeddings, p=2, dim=-1)
 
 
161
 
162
  for i, embedding in enumerate(img_embeddings):
163
  all_embeddings.append((image_indices[i], embedding))
@@ -170,3 +179,7 @@ class Transformer(nn.Module):
170
  features["sentence_embedding"] = combined_embeddings
171
 
172
  return features
 
 
 
 
 
1
+ import json
2
+ import os
3
  from io import BytesIO
4
  from pathlib import Path
5
  from typing import Any, Dict, List, Literal, Optional, Union
 
106
  return encoding
107
 
108
  def forward(
109
+ self,
110
+ features: Dict[str, torch.Tensor],
111
+ task: Optional[str] = None,
112
+ truncate_dim: Optional[int] = None,
113
  ) -> Dict[str, torch.Tensor]:
114
  self.model.eval()
115
 
 
143
  **text_batch, task_label=task
144
  ).single_vec_emb
145
  if truncate_dim:
146
+ text_embeddings = text_embeddings[:, :truncate_dim]
147
+ text_embeddings = torch.nn.functional.normalize(
148
+ text_embeddings, p=2, dim=-1
149
+ )
150
  for i, embedding in enumerate(text_embeddings):
151
  all_embeddings.append((text_indices[i], embedding))
152
 
 
163
  **image_batch, task_label=task
164
  ).single_vec_emb
165
  if truncate_dim:
166
+ img_embeddings = img_embeddings[:, :truncate_dim]
167
+ img_embeddings = torch.nn.functional.normalize(
168
+ img_embeddings, p=2, dim=-1
169
+ )
170
 
171
  for i, embedding in enumerate(img_embeddings):
172
  all_embeddings.append((image_indices[i], embedding))
 
179
  features["sentence_embedding"] = combined_embeddings
180
 
181
  return features
182
+
183
+ @classmethod
184
+ def load(cls, input_path: str) -> "Transformer":
185
+ return cls(model_name_or_path=input_path)
modeling_jina_embeddings_v4.py CHANGED
@@ -242,7 +242,6 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
242
  pooled_output = masked_hidden_states.sum(dim=1) / image_mask.sum(
243
  dim=1, keepdim=True
244
  )
245
-
246
  else: # got query text
247
  pooled_output = torch.sum(
248
  hidden_states * attention_mask.unsqueeze(-1), dim=1
@@ -332,7 +331,9 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
332
  collate_fn=processor_fn,
333
  )
334
  if return_multivector and len(data) > 1:
335
- assert not return_numpy, "`return_numpy` is not supported when `return_multivector=True` and more than one data is encoded"
 
 
336
  results = []
337
  self.eval()
338
  for batch in tqdm(dataloader, desc=desc):
@@ -346,10 +347,12 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
346
  embeddings = embeddings.single_vec_emb
347
  if truncate_dim is not None:
348
  embeddings = embeddings[:, :truncate_dim]
349
- embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=-1)
 
 
350
  else:
351
  embeddings = embeddings.multi_vec_emb
352
-
353
  if return_multivector and not return_numpy:
354
  valid_tokens = batch["attention_mask"].bool()
355
  embeddings = [
@@ -436,7 +439,9 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
436
  List of text embeddings as tensors or numpy arrays when encoding multiple texts, or single text embedding as tensor when encoding a single text
437
  """
438
  prompt_name = prompt_name or "query"
439
- encode_kwargs = self._validate_encoding_params(truncate_dim=truncate_dim, prompt_name=prompt_name)
 
 
440
 
441
  task = self._validate_task(task)
442
 
@@ -451,9 +456,11 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
451
  # If return_multivector is True and encoding multiple texts, ignore return_numpy
452
  if return_multivector and return_list and len(texts) > 1:
453
  if return_numpy:
454
- print("Warning: `return_numpy` is ignored when `return_multivector=True` and `len(texts) > 1`")
 
 
455
  return_numpy = False
456
-
457
  if isinstance(texts, str):
458
  texts = [texts]
459
 
@@ -468,7 +475,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
468
  **encode_kwargs,
469
  )
470
 
471
- return embeddings if return_list else embeddings[0]
472
 
473
  def _load_images_if_needed(
474
  self, images: List[Union[str, Image.Image]]
@@ -515,19 +522,21 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
515
  )
516
  encode_kwargs = self._validate_encoding_params(truncate_dim=truncate_dim)
517
  task = self._validate_task(task)
518
-
519
  return_list = isinstance(images, list)
520
 
521
  # If return_multivector is True and encoding multiple images, ignore return_numpy
522
  if return_multivector and return_list and len(images) > 1:
523
  if return_numpy:
524
- print("Warning: `return_numpy` is ignored when `return_multivector=True` and `len(images) > 1`")
 
 
525
  return_numpy = False
526
 
527
  # Convert single image to list
528
  if isinstance(images, (str, Image.Image)):
529
  images = [images]
530
-
531
  images = self._load_images_if_needed(images)
532
  embeddings = self._process_batches(
533
  data=images,
@@ -588,18 +597,12 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
588
  config=lora_config,
589
  )
590
 
591
- @property
592
- def task(self):
593
  return self.model.task
594
 
595
- @task.setter
596
- def task(self, value):
597
  self.model.task = value
598
 
599
- peft_model.task = property(task.fget, task.fset)
600
- peft_model.__class__.task = property(
601
- lambda self: self.model.task,
602
- lambda self, value: setattr(self.model, "task", value),
603
- )
604
 
605
  return peft_model
 
242
  pooled_output = masked_hidden_states.sum(dim=1) / image_mask.sum(
243
  dim=1, keepdim=True
244
  )
 
245
  else: # got query text
246
  pooled_output = torch.sum(
247
  hidden_states * attention_mask.unsqueeze(-1), dim=1
 
331
  collate_fn=processor_fn,
332
  )
333
  if return_multivector and len(data) > 1:
334
+ assert (
335
+ not return_numpy
336
+ ), "`return_numpy` is not supported when `return_multivector=True` and more than one data is encoded"
337
  results = []
338
  self.eval()
339
  for batch in tqdm(dataloader, desc=desc):
 
347
  embeddings = embeddings.single_vec_emb
348
  if truncate_dim is not None:
349
  embeddings = embeddings[:, :truncate_dim]
350
+ embeddings = torch.nn.functional.normalize(
351
+ embeddings, p=2, dim=-1
352
+ )
353
  else:
354
  embeddings = embeddings.multi_vec_emb
355
+
356
  if return_multivector and not return_numpy:
357
  valid_tokens = batch["attention_mask"].bool()
358
  embeddings = [
 
439
  List of text embeddings as tensors or numpy arrays when encoding multiple texts, or single text embedding as tensor when encoding a single text
440
  """
441
  prompt_name = prompt_name or "query"
442
+ encode_kwargs = self._validate_encoding_params(
443
+ truncate_dim=truncate_dim, prompt_name=prompt_name
444
+ )
445
 
446
  task = self._validate_task(task)
447
 
 
456
  # If return_multivector is True and encoding multiple texts, ignore return_numpy
457
  if return_multivector and return_list and len(texts) > 1:
458
  if return_numpy:
459
+ print(
460
+ "Warning: `return_numpy` is ignored when `return_multivector=True` and `len(texts) > 1`"
461
+ )
462
  return_numpy = False
463
+
464
  if isinstance(texts, str):
465
  texts = [texts]
466
 
 
475
  **encode_kwargs,
476
  )
477
 
478
+ return embeddings if return_list else embeddings[0]
479
 
480
  def _load_images_if_needed(
481
  self, images: List[Union[str, Image.Image]]
 
522
  )
523
  encode_kwargs = self._validate_encoding_params(truncate_dim=truncate_dim)
524
  task = self._validate_task(task)
525
+
526
  return_list = isinstance(images, list)
527
 
528
  # If return_multivector is True and encoding multiple images, ignore return_numpy
529
  if return_multivector and return_list and len(images) > 1:
530
  if return_numpy:
531
+ print(
532
+ "Warning: `return_numpy` is ignored when `return_multivector=True` and `len(images) > 1`"
533
+ )
534
  return_numpy = False
535
 
536
  # Convert single image to list
537
  if isinstance(images, (str, Image.Image)):
538
  images = [images]
539
+
540
  images = self._load_images_if_needed(images)
541
  embeddings = self._process_batches(
542
  data=images,
 
597
  config=lora_config,
598
  )
599
 
600
+ def task_getter(self):
 
601
  return self.model.task
602
 
603
+ def task_setter(self, value):
 
604
  self.model.task = value
605
 
606
+ peft_model.__class__.task = property(task_getter, task_setter)
 
 
 
 
607
 
608
  return peft_model