jupyterjazz commited on
Commit
baee517
·
verified ·
1 Parent(s): e9be62d

refine-the-codebase (#5)

Browse files

- feat: encode, prefixes, matryoshka, etc (df46e74a27635c8ed0e13a53e06588c2d1f933ff)

Files changed (3) hide show
  1. README.md +68 -8
  2. config.json +2 -1
  3. modeling_jina_embeddings_v4.py +116 -54
README.md CHANGED
@@ -1,24 +1,84 @@
1
  # Jina Embeddings V4
2
 
3
- Load the model:
 
 
 
4
 
5
  ```python
 
6
  from transformers import AutoModel
 
 
 
 
 
7
  model = AutoModel.from_pretrained('jinaai/jina-embeddings-v4', trust_remote_code=True)
8
- ```
9
 
10
- Encode Text:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- ```python
13
- text_embedding = model.encode_texts(['test'])
14
  ```
15
 
16
- Encode Image (very slow on CPU):
 
17
  ```python
 
 
18
  from PIL import Image
19
 
20
- img = Image.open('path/to/your/image.png')
21
- image_embedding = m.encode_images([img])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  ```
23
 
24
 
 
1
  # Jina Embeddings V4
2
 
3
+
4
+ ## Examples
5
+
6
+ Encode functions:
7
 
8
  ```python
9
+ import torch
10
  from transformers import AutoModel
11
+ from PIL import Image
12
+
13
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
+
15
+ # Load model
16
  model = AutoModel.from_pretrained('jinaai/jina-embeddings-v4', trust_remote_code=True)
17
+ model = model.to(device)
18
 
19
+ # Sample data
20
+ texts = ["Here is some sample code", "This is a matching text"]
21
+ image_paths = ['/<path_to_image>']
22
+ images = [Image.open(path) for path in image_paths]
23
+
24
+ # Example 1: Text matching task with single vector embeddings
25
+ model.set_task(task='text-matching')
26
+
27
+ # Generate embeddings with dimension truncation (256)
28
+ img_embeddings = model.encode_images(images=images, truncate_dim=256)
29
+ text_embeddings = model.encode_texts(texts=texts, truncate_dim=256, max_length=512)
30
+
31
+ # Example 2: Retrieval task with multi-vector embeddings
32
+ model.set_task(task='retrieval')
33
+
34
+ # Generate multi-vector embeddings
35
+ img_embeddings = model.encode_images(images=images, vector_type='multi_vector')
36
+ text_embeddings = model.encode_texts(texts=texts, vector_type='multi_vector', text_type='passage')
37
+
38
+ # Example 3: Code task with single vector embeddings
39
+ model.set_task(task='code')
40
+
41
+ code = ["def hello_world():\n print('Hello, World!')"]
42
+ code_embeddings = model.encode_texts(texts=code)
43
 
 
 
44
  ```
45
 
46
+ Using the model forward:
47
+
48
  ```python
49
+ import torch
50
+ from transformers import AutoModel, AutoProcessor
51
  from PIL import Image
52
 
53
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
54
+
55
+ # Load model and processor
56
+ model = AutoModel.from_pretrained('jinaai/jina-embeddings-v4', trust_remote_code=True)
57
+ model = model.to(device)
58
+ processor = AutoProcessor.from_pretrained('jinaai/jina-embeddings-v4', trust_remote_code=True)
59
+
60
+
61
+ # Sample data
62
+ texts = ["Here is some sample code", "This is a matching text"]
63
+ image_paths = ['/<path_to_image>']
64
+
65
+ # Process text and images
66
+ text_batch = processor.process_texts(texts=texts, prefix="Query", max_length=512)
67
+ images = [Image.open(path) for path in image_paths]
68
+ image_batch = processor.process_images(images=images)
69
+
70
+ # Forward pass
71
+ model.eval()
72
+ with torch.no_grad():
73
+ text_batch = {k: v.to(device) for k, v in text_batch.items()}
74
+ image_batch = {k: v.to(device) for k, v in image_batch.items()}
75
+
76
+ with torch.autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu'):
77
+ # Get embeddings
78
+ text_embeddings = model.model(**text_batch).single_vec_emb
79
+ img_embeddings = model.model(**image_batch).single_vec_emb
80
+
81
+
82
  ```
83
 
84
 
config.json CHANGED
@@ -53,5 +53,6 @@
53
  "vision_end_token_id": 151653,
54
  "vision_start_token_id": 151652,
55
  "vision_token_id": 151654,
56
- "vocab_size": 151936
 
57
  }
 
53
  "vision_end_token_id": 151653,
54
  "vision_start_token_id": 151652,
55
  "vision_token_id": 151654,
56
+ "vocab_size": 151936,
57
+ "truncate_dim": null
58
  }
modeling_jina_embeddings_v4.py CHANGED
@@ -1,4 +1,6 @@
1
- import math
 
 
2
  import os
3
  from dataclasses import dataclass
4
  from enum import Enum
@@ -15,7 +17,6 @@ from torch import nn
15
  from torch.utils.data import DataLoader
16
  from tqdm import tqdm
17
  from transformers import BatchFeature
18
- from transformers.modeling_utils import PreTrainedModel
19
  from transformers.models.qwen2_5_vl import (Qwen2_5_VLForConditionalGeneration,
20
  Qwen2_5_VLProcessor)
21
 
@@ -33,27 +34,17 @@ class TaskType(str, Enum):
33
  text_matching = "text-matching"
34
 
35
 
 
 
 
 
 
36
  class JinaEmbeddingsV4Processor(Qwen2_5_VLProcessor):
37
  def __init__(self, *args, **kwargs) -> None:
38
  Qwen2_5_VLProcessor.__init__(self, *args, **kwargs)
39
  self.assistant_prefix_len = 58
40
  self.text_max_length = 8192
41
 
42
- @staticmethod
43
- def round_by_factor(number: float, factor: int) -> int:
44
- """Returns the closest integer to 'number' that is divisible by 'factor'."""
45
- return round(number / factor) * factor
46
-
47
- @staticmethod
48
- def ceil_by_factor(number: float, factor: int) -> int:
49
- """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
50
- return math.ceil(number / factor) * factor
51
-
52
- @staticmethod
53
- def floor_by_factor(number: float, factor: int) -> int:
54
- """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
55
- return math.floor(number / factor) * factor
56
-
57
  def process_images(
58
  self,
59
  images: Union[List[Image.Image], List[List[Image.Image]]],
@@ -175,7 +166,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
175
  [pv[:o] for pv, o in zip(kwargs["pixel_values"], offsets)], dim=0
176
  )
177
 
178
- position_ids, rope_deltas = super().get_rope_index( # type: ignore
179
  input_ids=input_ids,
180
  image_grid_thw=kwargs.get("image_grid_thw", None),
181
  attention_mask=attention_mask,
@@ -267,10 +258,10 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
267
  **kwargs,
268
  ) -> JinaEmbeddingsV4ModelOutput:
269
  """
270
- Forward pass through QwenVL25Embeddings. Returns both single-vector and multi-vector embeddings.
271
  Args:
272
- input_ids (torch.LongTensor): The input tokens tensor.
273
- attention_mask (torch.LongTensor): The attention mask tensor.
274
  Returns:
275
  JinaEmbeddingsV4ModelOutput:
276
  single_vector (torch.Tensor): Single-vector embeddings of shape (batch_size, dim).
@@ -302,17 +293,17 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
302
  data: List[Union[str, Image.Image]],
303
  processor_fn: Callable,
304
  desc: str,
305
- vector_type: Optional[str] = None,
306
  return_numpy: bool = False,
307
- **kwargs,
 
308
  ) -> Union[np.ndarray, List[torch.Tensor]]:
309
  dataloader = DataLoader(
310
  dataset=data,
311
- batch_size=kwargs.get("batch_size", 32),
312
  shuffle=False,
313
  collate_fn=processor_fn,
314
  )
315
- vector_type = vector_type or "single_vector"
316
  results = []
317
  self.eval()
318
  for batch in tqdm(dataloader, desc=desc):
@@ -322,8 +313,11 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
322
  embeddings = self(**batch)
323
  if vector_type == "single_vector":
324
  embeddings = embeddings.single_vec_emb
 
 
325
  else:
326
  embeddings = embeddings.multi_vec_emb
 
327
  results.append(
328
  embeddings.cpu()
329
  if return_numpy
@@ -333,44 +327,98 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
333
  return np.concatenate([result.numpy() for result in results], axis=0)
334
  return [item for sublist in results for item in sublist]
335
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
  def encode_texts(
337
  self,
338
- queries: List[str],
339
  max_length: int = 8192,
340
  batch_size: int = 8,
341
  vector_type: Optional[str] = None,
342
- desc: Optional[str] = None,
343
- **kwargs,
 
344
  ) -> List[torch.Tensor]:
 
 
 
 
 
345
  processor_fn = partial(
346
- self.processor.process_texts, max_length=max_length, prefix="Query"
 
 
347
  )
348
- return self._process_batches(
349
- data=queries,
 
 
350
  processor_fn=processor_fn,
351
- desc=desc or "Encode queries...",
352
- vector_type=vector_type,
353
  batch_size=batch_size,
354
- **kwargs,
355
  )
356
 
 
 
357
  def encode_images(
358
  self,
359
- documents: List[Image.Image],
360
  batch_size: int = 8,
361
  vector_type: Optional[str] = None,
362
- desc: Optional[str] = None,
363
- **kwargs,
364
  ) -> List[torch.Tensor]:
365
- return self._process_batches(
366
- data=documents,
 
 
 
367
  processor_fn=self.processor.process_images,
368
- desc=desc or "Encode documents...",
369
- vector_type=vector_type,
370
  batch_size=batch_size,
371
- **kwargs,
 
372
  )
373
 
 
 
374
  @classmethod
375
  def from_pretrained(
376
  cls,
@@ -381,9 +429,15 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
381
  if "torch_dtype" not in kwargs:
382
  kwargs["torch_dtype"] = "auto"
383
 
384
- task = kwargs.pop("task", TaskType.retrieval)
 
 
 
 
 
 
 
385
 
386
- # Get the base model first
387
  base_model = super().from_pretrained(
388
  pretrained_model_name_or_path, *args, **kwargs
389
  )
@@ -397,36 +451,44 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
397
  )
398
  adapter_dir = os.path.join(adapter_cache_path, "adapters")
399
 
400
- # Store adapter directory for later use with set_task
401
  base_model.adapter_dir = adapter_dir
 
402
 
403
  # Create the PEFT model with the requested task adapter
404
  peft_model = PeftModel.from_pretrained(
405
- base_model, os.path.join(adapter_dir, task)
406
  )
407
 
408
  # Add set_task method to the PEFT model instance
409
- def set_task_method(self, task_name: Union[str, TaskType]):
410
  """
411
  Set the task adapter for the model.
412
 
413
  Args:
414
- task_name (Union[str, TaskType]): The task name. Must be one of TaskType values or
415
  one of ['retrieval', 'text-matching', 'code']
416
  """
417
- if isinstance(task_name, str):
418
  try:
419
- task_name = TaskType(task_name)
420
  except ValueError:
421
  valid_tasks = [t.value for t in TaskType]
422
  raise ValueError(
423
- f"Invalid task: {task_name}. Must be one of {valid_tasks}"
424
  )
 
 
 
 
425
 
426
- adapter_path = os.path.join(self.adapter_dir, task_name.value)
427
- hotswap_adapter(self, adapter_path, adapter_name="default")
 
 
 
428
 
429
- # Bind the method to the instance
430
  peft_model.set_task = set_task_method.__get__(peft_model, type(peft_model))
 
431
 
432
  return peft_model
 
1
+ # Jina Embeddings V4 Model implementation was inspired by the ColPali codebase:
2
+ # https://github.com/illuin-tech/colpali
3
+
4
  import os
5
  from dataclasses import dataclass
6
  from enum import Enum
 
17
  from torch.utils.data import DataLoader
18
  from tqdm import tqdm
19
  from transformers import BatchFeature
 
20
  from transformers.models.qwen2_5_vl import (Qwen2_5_VLForConditionalGeneration,
21
  Qwen2_5_VLProcessor)
22
 
 
34
  text_matching = "text-matching"
35
 
36
 
37
+ PREFIX_DICT = {"query": "Query", "passage": "Passage"}
38
+ TRUNCATE_DIMS = [128, 256, 512, 1024]
39
+ VECTOR_TYPES = ["single_vector", "multi_vector"]
40
+
41
+
42
  class JinaEmbeddingsV4Processor(Qwen2_5_VLProcessor):
43
  def __init__(self, *args, **kwargs) -> None:
44
  Qwen2_5_VLProcessor.__init__(self, *args, **kwargs)
45
  self.assistant_prefix_len = 58
46
  self.text_max_length = 8192
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  def process_images(
49
  self,
50
  images: Union[List[Image.Image], List[List[Image.Image]]],
 
166
  [pv[:o] for pv, o in zip(kwargs["pixel_values"], offsets)], dim=0
167
  )
168
 
169
+ position_ids, rope_deltas = super().get_rope_index(
170
  input_ids=input_ids,
171
  image_grid_thw=kwargs.get("image_grid_thw", None),
172
  attention_mask=attention_mask,
 
258
  **kwargs,
259
  ) -> JinaEmbeddingsV4ModelOutput:
260
  """
261
+ Forward pass through the model. Returns both single-vector and multi-vector embeddings.
262
  Args:
263
+ input_ids (torch.Tensor): The input tokens tensor.
264
+ attention_mask (torch.Tensor): The attention mask tensor.
265
  Returns:
266
  JinaEmbeddingsV4ModelOutput:
267
  single_vector (torch.Tensor): Single-vector embeddings of shape (batch_size, dim).
 
293
  data: List[Union[str, Image.Image]],
294
  processor_fn: Callable,
295
  desc: str,
296
+ vector_type: str = "single_vector",
297
  return_numpy: bool = False,
298
+ batch_size: int = 32,
299
+ truncate_dim: Optional[int] = None,
300
  ) -> Union[np.ndarray, List[torch.Tensor]]:
301
  dataloader = DataLoader(
302
  dataset=data,
303
+ batch_size=batch_size,
304
  shuffle=False,
305
  collate_fn=processor_fn,
306
  )
 
307
  results = []
308
  self.eval()
309
  for batch in tqdm(dataloader, desc=desc):
 
313
  embeddings = self(**batch)
314
  if vector_type == "single_vector":
315
  embeddings = embeddings.single_vec_emb
316
+ if truncate_dim is not None:
317
+ embeddings = embeddings[:, :truncate_dim]
318
  else:
319
  embeddings = embeddings.multi_vec_emb
320
+
321
  results.append(
322
  embeddings.cpu()
323
  if return_numpy
 
327
  return np.concatenate([result.numpy() for result in results], axis=0)
328
  return [item for sublist in results for item in sublist]
329
 
330
+ def _validate_encoding_params(
331
+ self,
332
+ vector_type: Optional[str] = None,
333
+ truncate_dim: Optional[int] = None,
334
+ text_type: Optional[str] = None,
335
+ ) -> Dict[str, Any]:
336
+ encode_kwargs = {}
337
+ if text_type is not None:
338
+ if text_type not in PREFIX_DICT:
339
+ raise ValueError(
340
+ f"Invalid text_type: {text_type}. Must be one of {list(PREFIX_DICT.keys())}."
341
+ )
342
+ else:
343
+ encode_kwargs["prefix"] = (
344
+ PREFIX_DICT[text_type]
345
+ if self.task != TaskType.text_matching
346
+ else PREFIX_DICT["query"]
347
+ )
348
+
349
+ vector_type = vector_type or "single_vector"
350
+ if vector_type not in VECTOR_TYPES:
351
+ raise ValueError(
352
+ f"Invalid vector_type: {vector_type}. Must be one of {VECTOR_TYPES}."
353
+ )
354
+ else:
355
+ encode_kwargs["vector_type"] = vector_type
356
+
357
+ truncate_dim = truncate_dim or self.config.truncate_dim
358
+ if truncate_dim is not None and truncate_dim not in TRUNCATE_DIMS:
359
+ raise ValueError(
360
+ f"Invalid truncate_dim: {truncate_dim}. Must be one of {TRUNCATE_DIMS}."
361
+ )
362
+ else:
363
+ encode_kwargs["truncate_dim"] = truncate_dim
364
+
365
+ return encode_kwargs
366
+
367
  def encode_texts(
368
  self,
369
+ texts: List[str],
370
  max_length: int = 8192,
371
  batch_size: int = 8,
372
  vector_type: Optional[str] = None,
373
+ return_numpy: bool = False,
374
+ truncate_dim: Optional[int] = None,
375
+ text_type: Optional[str] = None,
376
  ) -> List[torch.Tensor]:
377
+ text_type = text_type or "query"
378
+ encode_kwargs = self._validate_encoding_params(
379
+ vector_type, truncate_dim, text_type
380
+ )
381
+
382
  processor_fn = partial(
383
+ self.processor.process_texts,
384
+ max_length=max_length,
385
+ prefix=encode_kwargs.pop("prefix"),
386
  )
387
+
388
+ is_single = len(texts) == 1
389
+ embeddings = self._process_batches(
390
+ data=texts,
391
  processor_fn=processor_fn,
392
+ desc="Encoding texts...",
393
+ return_numpy=return_numpy,
394
  batch_size=batch_size,
395
+ **encode_kwargs,
396
  )
397
 
398
+ return embeddings[0] if is_single else embeddings
399
+
400
  def encode_images(
401
  self,
402
+ images: List[Image.Image],
403
  batch_size: int = 8,
404
  vector_type: Optional[str] = None,
405
+ return_numpy: bool = False,
406
+ truncate_dim: Optional[int] = None,
407
  ) -> List[torch.Tensor]:
408
+ encode_kwargs = self._validate_encoding_params(vector_type, truncate_dim)
409
+
410
+ is_single = len(images) == 1
411
+ embeddings = self._process_batches(
412
+ data=images,
413
  processor_fn=self.processor.process_images,
414
+ desc="Encoding images...",
 
415
  batch_size=batch_size,
416
+ return_numpy=return_numpy,
417
+ **encode_kwargs,
418
  )
419
 
420
+ return embeddings[0] if is_single else embeddings
421
+
422
  @classmethod
423
  def from_pretrained(
424
  cls,
 
429
  if "torch_dtype" not in kwargs:
430
  kwargs["torch_dtype"] = "auto"
431
 
432
+ task_value = kwargs.pop("task", "retrieval")
433
+ try:
434
+ task = TaskType(task_value)
435
+ except ValueError:
436
+ valid_tasks = [t.value for t in TaskType]
437
+ raise ValueError(
438
+ f"Invalid task: {task_value}. Must be one of {valid_tasks}."
439
+ )
440
 
 
441
  base_model = super().from_pretrained(
442
  pretrained_model_name_or_path, *args, **kwargs
443
  )
 
451
  )
452
  adapter_dir = os.path.join(adapter_cache_path, "adapters")
453
 
 
454
  base_model.adapter_dir = adapter_dir
455
+ base_model.task = task
456
 
457
  # Create the PEFT model with the requested task adapter
458
  peft_model = PeftModel.from_pretrained(
459
+ base_model, os.path.join(adapter_dir, task.value)
460
  )
461
 
462
  # Add set_task method to the PEFT model instance
463
+ def set_task_method(self, task: Union[str, TaskType]):
464
  """
465
  Set the task adapter for the model.
466
 
467
  Args:
468
+ task (Union[str, TaskType]): The task name. Must be one of TaskType values or
469
  one of ['retrieval', 'text-matching', 'code']
470
  """
471
+ if isinstance(task, str):
472
  try:
473
+ task = TaskType(task)
474
  except ValueError:
475
  valid_tasks = [t.value for t in TaskType]
476
  raise ValueError(
477
+ f"Invalid task: {task}. Must be one of {valid_tasks}"
478
  )
479
+ if self.model.task != task:
480
+ adapter_path = os.path.join(self.adapter_dir, task.value)
481
+ hotswap_adapter(self, adapter_path, adapter_name="default")
482
+ self.model.task = task
483
 
484
+ def get_task_method(self):
485
+ """
486
+ Get the task adapter for the model.
487
+ """
488
+ return self.model.task.value
489
 
490
+ # Bind the methods to the instance
491
  peft_model.set_task = set_task_method.__get__(peft_model, type(peft_model))
492
+ peft_model.get_task = get_task_method.__get__(peft_model, type(peft_model))
493
 
494
  return peft_model