jupyterjazz commited on
Commit
bc22368
·
verified ·
1 Parent(s): 927f9b1

support-sentence-transformers (#12)

Browse files

- feat: sentencetransformer wrapper (13dd61de9e40f2dbdf9a4be82afc311d7e4042eb)

Files changed (4) hide show
  1. README.md +12 -0
  2. config_sentence_transformers.json +13 -0
  3. custom_st.py +134 -0
  4. modules.json +9 -0
README.md CHANGED
@@ -78,3 +78,15 @@ with torch.no_grad():
78
  ```
79
 
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  ```
79
 
80
 
81
+ Inference via the `SentenceTransformer` library:
82
+
83
+ ```python
84
+ from sentence_transformers import SentenceTransformer
85
+
86
+ model = SentenceTransformer(
87
+ 'jinaai/jina-embeddings-v4', trust_remote_code=True
88
+ )
89
+
90
+ emb = model.encode(['Khinkali is the best'], task='retrieval', prompt_name='query')
91
+
92
+ ```
config_sentence_transformers.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "__version__": {
3
+ "sentence_transformers": "4.1.0",
4
+ "transformers": "4.50.0",
5
+ "pytorch": "2.6.0"
6
+ },
7
+ "prompts":{
8
+ "query":"Query: ",
9
+ "passage":"Passage: "
10
+ },
11
+ "default_prompt_name": null,
12
+ "similarity_fn_name": "cosine"
13
+ }
custom_st.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Literal, Optional, Union
2
+
3
+ import torch
4
+ from PIL import Image
5
+ from torch import nn
6
+ from transformers import AutoConfig, AutoProcessor, AutoModel
7
+
8
+
9
+ class Transformer(nn.Module):
10
+
11
+ save_in_root: bool = True
12
+
13
+ def __init__(
14
+ self,
15
+ model_name_or_path: str = 'jinaai/jina-embeddings-v4',
16
+ max_seq_length: Optional[int] = None,
17
+ config_args: Optional[Dict[str, Any]] = None,
18
+ model_args: Optional[Dict[str, Any]] = None,
19
+ tokenizer_args: Optional[Dict[str, Any]] = None,
20
+ cache_dir: Optional[str] = None,
21
+ backend: Literal['torch', 'onnx', 'openvino'] = 'torch',
22
+ **kwargs,
23
+ ) -> None:
24
+ super(Transformer, self).__init__()
25
+ if backend != 'torch':
26
+ raise ValueError(
27
+ f'Backend \'{backend}\' is not supported, please use \'torch\' instead'
28
+ )
29
+
30
+ config_kwargs = config_args or {}
31
+ model_kwargs = model_args or {}
32
+ tokenizer_kwargs = tokenizer_args or {}
33
+
34
+ self.config = AutoConfig.from_pretrained(
35
+ model_name_or_path, cache_dir=cache_dir, **config_kwargs
36
+ )
37
+ self.default_task = model_args.pop('default_task', None)
38
+ if self.default_task and self.default_task not in self.config.task_names:
39
+ raise ValueError(f"Invalid task: {self.default_task}. Must be one of {self.config.task_names}.")
40
+
41
+ self.model = AutoModel.from_pretrained(
42
+ model_name_or_path, config=self.config, cache_dir=cache_dir, **model_kwargs
43
+ )
44
+
45
+ self.processor = AutoProcessor.from_pretrained(
46
+ model_name_or_path,
47
+ cache_dir=cache_dir,
48
+ **tokenizer_kwargs,
49
+ )
50
+ self.max_seq_length = max_seq_length or 8192
51
+
52
+ def tokenize(
53
+ self, texts: List[Union[str, Image.Image]], padding: Union[str, bool] = True
54
+ ) -> Dict[str, torch.Tensor]:
55
+ encoding = {}
56
+ text_indices = []
57
+ image_indices = []
58
+
59
+ for i, text in enumerate(texts):
60
+ if isinstance(text, str):
61
+ text_indices.append(i)
62
+ elif isinstance(text, Image.Image):
63
+ image_indices.append(i)
64
+ else:
65
+ raise ValueError(f'Invalid input type: {type(text)}')
66
+
67
+ if text_indices:
68
+ _texts = [texts[i] for i in text_indices]
69
+ text_features = self.processor.process_texts(_texts, max_length=self.max_seq_length)
70
+ for key, value in text_features.items():
71
+ encoding[f'text_{key}'] = value
72
+ encoding['text_indices'] = text_indices
73
+
74
+ if image_indices:
75
+ _images = [texts[i] for i in image_indices]
76
+ img_features = self.processor.process_images(_images)
77
+ for key, value in img_features.items():
78
+ encoding[f'image_{key}'] = value
79
+ encoding['image_indices'] = image_indices
80
+
81
+ return encoding
82
+
83
+
84
+ def forward(self, features: Dict[str, torch.Tensor], task: Optional[str] = None) -> Dict[str, torch.Tensor]:
85
+ self.model.eval()
86
+
87
+ if task is None:
88
+ if self.default_task is None:
89
+ raise ValueError(
90
+ "Task must be specified before encoding data. You can set it either during "
91
+ "loading the model (e.g., model_kwargs={'default_task': 'retrieval'}) or "
92
+ "pass it as an argument to the encode method (e.g., model.encode(texts, task='retrieval'))."
93
+ )
94
+ task = self.default_task
95
+ else:
96
+ if task not in self.config.task_names:
97
+ raise ValueError(f"Invalid task: {task}. Must be one of {self.config.task_names}.")
98
+
99
+ device = self.model.device.type
100
+ all_embeddings = []
101
+
102
+ with torch.no_grad():
103
+ if any(k.startswith('text_') for k in features.keys()):
104
+ text_batch = {k[len('text_'):]: v.to(device) for k, v in features.items() if k.startswith('text_') and k != 'text_indices'}
105
+ text_indices = features.get('text_indices', [])
106
+
107
+ with torch.autocast(device_type=device):
108
+ text_embeddings = self.model(**text_batch, task_label=task).single_vec_emb
109
+ if self.config.truncate_dim:
110
+ text_embeddings = text_embeddings[:, :self.config.truncate_dim]
111
+
112
+ for i, embedding in enumerate(text_embeddings):
113
+ all_embeddings.append((text_indices[i], embedding))
114
+
115
+ if any(k.startswith('image_') for k in features.keys()):
116
+ image_batch = {k[len('image_'):]: v.to(device) for k, v in features.items() if k.startswith('image_') and k != 'image_indices'}
117
+ image_indices = features.get('image_indices', [])
118
+
119
+ with torch.autocast(device_type=device):
120
+ img_embeddings = self.model(**image_batch, task_label=task).single_vec_emb
121
+ if self.config.truncate_dim:
122
+ img_embeddings = img_embeddings[:, :self.config.truncate_dim]
123
+
124
+ for i, embedding in enumerate(img_embeddings):
125
+ all_embeddings.append((image_indices[i], embedding))
126
+
127
+ if not all_embeddings:
128
+ raise RuntimeError('No embeddings were generated')
129
+
130
+ all_embeddings.sort(key=lambda x: x[0]) # sort by original index
131
+ combined_embeddings = torch.stack([emb for _, emb in all_embeddings])
132
+ features['sentence_embedding'] = combined_embeddings
133
+
134
+ return features
modules.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "idx": 0,
4
+ "name": "transformer",
5
+ "path": "",
6
+ "type": "custom_st.Transformer",
7
+ "kwargs": ["task"]
8
+ }
9
+ ]