fix-task-setting-and-st-load (#46)
Browse files- fix: load in new st, task setting (5a13d0f29dd4f13b4a0d82f530acd43d189d44fc)
- README.md +1 -1
- custom_st.py +18 -5
- modeling_jina_embeddings_v4.py +23 -20
README.md
CHANGED
@@ -155,7 +155,7 @@ from transformers import AutoModel
|
|
155 |
import torch
|
156 |
|
157 |
# Initialize the model
|
158 |
-
model = AutoModel.from_pretrained("jinaai/jina-embeddings-v4", trust_remote_code=True)
|
159 |
|
160 |
model.to("cuda")
|
161 |
|
|
|
155 |
import torch
|
156 |
|
157 |
# Initialize the model
|
158 |
+
model = AutoModel.from_pretrained("jinaai/jina-embeddings-v4", trust_remote_code=True, torch_dtype=torch.float16)
|
159 |
|
160 |
model.to("cuda")
|
161 |
|
custom_st.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
from io import BytesIO
|
2 |
from pathlib import Path
|
3 |
from typing import Any, Dict, List, Literal, Optional, Union
|
@@ -104,7 +106,10 @@ class Transformer(nn.Module):
|
|
104 |
return encoding
|
105 |
|
106 |
def forward(
|
107 |
-
self,
|
|
|
|
|
|
|
108 |
) -> Dict[str, torch.Tensor]:
|
109 |
self.model.eval()
|
110 |
|
@@ -138,8 +143,10 @@ class Transformer(nn.Module):
|
|
138 |
**text_batch, task_label=task
|
139 |
).single_vec_emb
|
140 |
if truncate_dim:
|
141 |
-
text_embeddings = text_embeddings[:, :
|
142 |
-
text_embeddings = torch.nn.functional.normalize(
|
|
|
|
|
143 |
for i, embedding in enumerate(text_embeddings):
|
144 |
all_embeddings.append((text_indices[i], embedding))
|
145 |
|
@@ -156,8 +163,10 @@ class Transformer(nn.Module):
|
|
156 |
**image_batch, task_label=task
|
157 |
).single_vec_emb
|
158 |
if truncate_dim:
|
159 |
-
img_embeddings = img_embeddings[:, :
|
160 |
-
img_embeddings = torch.nn.functional.normalize(
|
|
|
|
|
161 |
|
162 |
for i, embedding in enumerate(img_embeddings):
|
163 |
all_embeddings.append((image_indices[i], embedding))
|
@@ -170,3 +179,7 @@ class Transformer(nn.Module):
|
|
170 |
features["sentence_embedding"] = combined_embeddings
|
171 |
|
172 |
return features
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
from io import BytesIO
|
4 |
from pathlib import Path
|
5 |
from typing import Any, Dict, List, Literal, Optional, Union
|
|
|
106 |
return encoding
|
107 |
|
108 |
def forward(
|
109 |
+
self,
|
110 |
+
features: Dict[str, torch.Tensor],
|
111 |
+
task: Optional[str] = None,
|
112 |
+
truncate_dim: Optional[int] = None,
|
113 |
) -> Dict[str, torch.Tensor]:
|
114 |
self.model.eval()
|
115 |
|
|
|
143 |
**text_batch, task_label=task
|
144 |
).single_vec_emb
|
145 |
if truncate_dim:
|
146 |
+
text_embeddings = text_embeddings[:, :truncate_dim]
|
147 |
+
text_embeddings = torch.nn.functional.normalize(
|
148 |
+
text_embeddings, p=2, dim=-1
|
149 |
+
)
|
150 |
for i, embedding in enumerate(text_embeddings):
|
151 |
all_embeddings.append((text_indices[i], embedding))
|
152 |
|
|
|
163 |
**image_batch, task_label=task
|
164 |
).single_vec_emb
|
165 |
if truncate_dim:
|
166 |
+
img_embeddings = img_embeddings[:, :truncate_dim]
|
167 |
+
img_embeddings = torch.nn.functional.normalize(
|
168 |
+
img_embeddings, p=2, dim=-1
|
169 |
+
)
|
170 |
|
171 |
for i, embedding in enumerate(img_embeddings):
|
172 |
all_embeddings.append((image_indices[i], embedding))
|
|
|
179 |
features["sentence_embedding"] = combined_embeddings
|
180 |
|
181 |
return features
|
182 |
+
|
183 |
+
@classmethod
|
184 |
+
def load(cls, input_path: str) -> "Transformer":
|
185 |
+
return cls(model_name_or_path=input_path)
|
modeling_jina_embeddings_v4.py
CHANGED
@@ -242,7 +242,6 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
242 |
pooled_output = masked_hidden_states.sum(dim=1) / image_mask.sum(
|
243 |
dim=1, keepdim=True
|
244 |
)
|
245 |
-
|
246 |
else: # got query text
|
247 |
pooled_output = torch.sum(
|
248 |
hidden_states * attention_mask.unsqueeze(-1), dim=1
|
@@ -332,7 +331,9 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
332 |
collate_fn=processor_fn,
|
333 |
)
|
334 |
if return_multivector and len(data) > 1:
|
335 |
-
assert
|
|
|
|
|
336 |
results = []
|
337 |
self.eval()
|
338 |
for batch in tqdm(dataloader, desc=desc):
|
@@ -346,10 +347,12 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
346 |
embeddings = embeddings.single_vec_emb
|
347 |
if truncate_dim is not None:
|
348 |
embeddings = embeddings[:, :truncate_dim]
|
349 |
-
embeddings = torch.nn.functional.normalize(
|
|
|
|
|
350 |
else:
|
351 |
embeddings = embeddings.multi_vec_emb
|
352 |
-
|
353 |
if return_multivector and not return_numpy:
|
354 |
valid_tokens = batch["attention_mask"].bool()
|
355 |
embeddings = [
|
@@ -436,7 +439,9 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
436 |
List of text embeddings as tensors or numpy arrays when encoding multiple texts, or single text embedding as tensor when encoding a single text
|
437 |
"""
|
438 |
prompt_name = prompt_name or "query"
|
439 |
-
encode_kwargs = self._validate_encoding_params(
|
|
|
|
|
440 |
|
441 |
task = self._validate_task(task)
|
442 |
|
@@ -451,9 +456,11 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
451 |
# If return_multivector is True and encoding multiple texts, ignore return_numpy
|
452 |
if return_multivector and return_list and len(texts) > 1:
|
453 |
if return_numpy:
|
454 |
-
print(
|
|
|
|
|
455 |
return_numpy = False
|
456 |
-
|
457 |
if isinstance(texts, str):
|
458 |
texts = [texts]
|
459 |
|
@@ -468,7 +475,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
468 |
**encode_kwargs,
|
469 |
)
|
470 |
|
471 |
-
return embeddings if return_list else embeddings[0]
|
472 |
|
473 |
def _load_images_if_needed(
|
474 |
self, images: List[Union[str, Image.Image]]
|
@@ -515,19 +522,21 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
515 |
)
|
516 |
encode_kwargs = self._validate_encoding_params(truncate_dim=truncate_dim)
|
517 |
task = self._validate_task(task)
|
518 |
-
|
519 |
return_list = isinstance(images, list)
|
520 |
|
521 |
# If return_multivector is True and encoding multiple images, ignore return_numpy
|
522 |
if return_multivector and return_list and len(images) > 1:
|
523 |
if return_numpy:
|
524 |
-
print(
|
|
|
|
|
525 |
return_numpy = False
|
526 |
|
527 |
# Convert single image to list
|
528 |
if isinstance(images, (str, Image.Image)):
|
529 |
images = [images]
|
530 |
-
|
531 |
images = self._load_images_if_needed(images)
|
532 |
embeddings = self._process_batches(
|
533 |
data=images,
|
@@ -588,18 +597,12 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
588 |
config=lora_config,
|
589 |
)
|
590 |
|
591 |
-
|
592 |
-
def task(self):
|
593 |
return self.model.task
|
594 |
|
595 |
-
|
596 |
-
def task(self, value):
|
597 |
self.model.task = value
|
598 |
|
599 |
-
peft_model.task = property(
|
600 |
-
peft_model.__class__.task = property(
|
601 |
-
lambda self: self.model.task,
|
602 |
-
lambda self, value: setattr(self.model, "task", value),
|
603 |
-
)
|
604 |
|
605 |
return peft_model
|
|
|
242 |
pooled_output = masked_hidden_states.sum(dim=1) / image_mask.sum(
|
243 |
dim=1, keepdim=True
|
244 |
)
|
|
|
245 |
else: # got query text
|
246 |
pooled_output = torch.sum(
|
247 |
hidden_states * attention_mask.unsqueeze(-1), dim=1
|
|
|
331 |
collate_fn=processor_fn,
|
332 |
)
|
333 |
if return_multivector and len(data) > 1:
|
334 |
+
assert (
|
335 |
+
not return_numpy
|
336 |
+
), "`return_numpy` is not supported when `return_multivector=True` and more than one data is encoded"
|
337 |
results = []
|
338 |
self.eval()
|
339 |
for batch in tqdm(dataloader, desc=desc):
|
|
|
347 |
embeddings = embeddings.single_vec_emb
|
348 |
if truncate_dim is not None:
|
349 |
embeddings = embeddings[:, :truncate_dim]
|
350 |
+
embeddings = torch.nn.functional.normalize(
|
351 |
+
embeddings, p=2, dim=-1
|
352 |
+
)
|
353 |
else:
|
354 |
embeddings = embeddings.multi_vec_emb
|
355 |
+
|
356 |
if return_multivector and not return_numpy:
|
357 |
valid_tokens = batch["attention_mask"].bool()
|
358 |
embeddings = [
|
|
|
439 |
List of text embeddings as tensors or numpy arrays when encoding multiple texts, or single text embedding as tensor when encoding a single text
|
440 |
"""
|
441 |
prompt_name = prompt_name or "query"
|
442 |
+
encode_kwargs = self._validate_encoding_params(
|
443 |
+
truncate_dim=truncate_dim, prompt_name=prompt_name
|
444 |
+
)
|
445 |
|
446 |
task = self._validate_task(task)
|
447 |
|
|
|
456 |
# If return_multivector is True and encoding multiple texts, ignore return_numpy
|
457 |
if return_multivector and return_list and len(texts) > 1:
|
458 |
if return_numpy:
|
459 |
+
print(
|
460 |
+
"Warning: `return_numpy` is ignored when `return_multivector=True` and `len(texts) > 1`"
|
461 |
+
)
|
462 |
return_numpy = False
|
463 |
+
|
464 |
if isinstance(texts, str):
|
465 |
texts = [texts]
|
466 |
|
|
|
475 |
**encode_kwargs,
|
476 |
)
|
477 |
|
478 |
+
return embeddings if return_list else embeddings[0]
|
479 |
|
480 |
def _load_images_if_needed(
|
481 |
self, images: List[Union[str, Image.Image]]
|
|
|
522 |
)
|
523 |
encode_kwargs = self._validate_encoding_params(truncate_dim=truncate_dim)
|
524 |
task = self._validate_task(task)
|
525 |
+
|
526 |
return_list = isinstance(images, list)
|
527 |
|
528 |
# If return_multivector is True and encoding multiple images, ignore return_numpy
|
529 |
if return_multivector and return_list and len(images) > 1:
|
530 |
if return_numpy:
|
531 |
+
print(
|
532 |
+
"Warning: `return_numpy` is ignored when `return_multivector=True` and `len(images) > 1`"
|
533 |
+
)
|
534 |
return_numpy = False
|
535 |
|
536 |
# Convert single image to list
|
537 |
if isinstance(images, (str, Image.Image)):
|
538 |
images = [images]
|
539 |
+
|
540 |
images = self._load_images_if_needed(images)
|
541 |
embeddings = self._process_batches(
|
542 |
data=images,
|
|
|
597 |
config=lora_config,
|
598 |
)
|
599 |
|
600 |
+
def task_getter(self):
|
|
|
601 |
return self.model.task
|
602 |
|
603 |
+
def task_setter(self, value):
|
|
|
604 |
self.model.task = value
|
605 |
|
606 |
+
peft_model.__class__.task = property(task_getter, task_setter)
|
|
|
|
|
|
|
|
|
607 |
|
608 |
return peft_model
|