refine-the-codebase (#5)
Browse files- feat: encode, prefixes, matryoshka, etc (df46e74a27635c8ed0e13a53e06588c2d1f933ff)
- README.md +68 -8
- config.json +2 -1
- modeling_jina_embeddings_v4.py +116 -54
README.md
CHANGED
@@ -1,24 +1,84 @@
|
|
1 |
# Jina Embeddings V4
|
2 |
|
3 |
-
|
|
|
|
|
|
|
4 |
|
5 |
```python
|
|
|
6 |
from transformers import AutoModel
|
|
|
|
|
|
|
|
|
|
|
7 |
model = AutoModel.from_pretrained('jinaai/jina-embeddings-v4', trust_remote_code=True)
|
8 |
-
|
9 |
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
-
```python
|
13 |
-
text_embedding = model.encode_texts(['test'])
|
14 |
```
|
15 |
|
16 |
-
|
|
|
17 |
```python
|
|
|
|
|
18 |
from PIL import Image
|
19 |
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
```
|
23 |
|
24 |
|
|
|
1 |
# Jina Embeddings V4
|
2 |
|
3 |
+
|
4 |
+
## Examples
|
5 |
+
|
6 |
+
Encode functions:
|
7 |
|
8 |
```python
|
9 |
+
import torch
|
10 |
from transformers import AutoModel
|
11 |
+
from PIL import Image
|
12 |
+
|
13 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
14 |
+
|
15 |
+
# Load model
|
16 |
model = AutoModel.from_pretrained('jinaai/jina-embeddings-v4', trust_remote_code=True)
|
17 |
+
model = model.to(device)
|
18 |
|
19 |
+
# Sample data
|
20 |
+
texts = ["Here is some sample code", "This is a matching text"]
|
21 |
+
image_paths = ['/<path_to_image>']
|
22 |
+
images = [Image.open(path) for path in image_paths]
|
23 |
+
|
24 |
+
# Example 1: Text matching task with single vector embeddings
|
25 |
+
model.set_task(task='text-matching')
|
26 |
+
|
27 |
+
# Generate embeddings with dimension truncation (256)
|
28 |
+
img_embeddings = model.encode_images(images=images, truncate_dim=256)
|
29 |
+
text_embeddings = model.encode_texts(texts=texts, truncate_dim=256, max_length=512)
|
30 |
+
|
31 |
+
# Example 2: Retrieval task with multi-vector embeddings
|
32 |
+
model.set_task(task='retrieval')
|
33 |
+
|
34 |
+
# Generate multi-vector embeddings
|
35 |
+
img_embeddings = model.encode_images(images=images, vector_type='multi_vector')
|
36 |
+
text_embeddings = model.encode_texts(texts=texts, vector_type='multi_vector', text_type='passage')
|
37 |
+
|
38 |
+
# Example 3: Code task with single vector embeddings
|
39 |
+
model.set_task(task='code')
|
40 |
+
|
41 |
+
code = ["def hello_world():\n print('Hello, World!')"]
|
42 |
+
code_embeddings = model.encode_texts(texts=code)
|
43 |
|
|
|
|
|
44 |
```
|
45 |
|
46 |
+
Using the model forward:
|
47 |
+
|
48 |
```python
|
49 |
+
import torch
|
50 |
+
from transformers import AutoModel, AutoProcessor
|
51 |
from PIL import Image
|
52 |
|
53 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
54 |
+
|
55 |
+
# Load model and processor
|
56 |
+
model = AutoModel.from_pretrained('jinaai/jina-embeddings-v4', trust_remote_code=True)
|
57 |
+
model = model.to(device)
|
58 |
+
processor = AutoProcessor.from_pretrained('jinaai/jina-embeddings-v4', trust_remote_code=True)
|
59 |
+
|
60 |
+
|
61 |
+
# Sample data
|
62 |
+
texts = ["Here is some sample code", "This is a matching text"]
|
63 |
+
image_paths = ['/<path_to_image>']
|
64 |
+
|
65 |
+
# Process text and images
|
66 |
+
text_batch = processor.process_texts(texts=texts, prefix="Query", max_length=512)
|
67 |
+
images = [Image.open(path) for path in image_paths]
|
68 |
+
image_batch = processor.process_images(images=images)
|
69 |
+
|
70 |
+
# Forward pass
|
71 |
+
model.eval()
|
72 |
+
with torch.no_grad():
|
73 |
+
text_batch = {k: v.to(device) for k, v in text_batch.items()}
|
74 |
+
image_batch = {k: v.to(device) for k, v in image_batch.items()}
|
75 |
+
|
76 |
+
with torch.autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu'):
|
77 |
+
# Get embeddings
|
78 |
+
text_embeddings = model.model(**text_batch).single_vec_emb
|
79 |
+
img_embeddings = model.model(**image_batch).single_vec_emb
|
80 |
+
|
81 |
+
|
82 |
```
|
83 |
|
84 |
|
config.json
CHANGED
@@ -53,5 +53,6 @@
|
|
53 |
"vision_end_token_id": 151653,
|
54 |
"vision_start_token_id": 151652,
|
55 |
"vision_token_id": 151654,
|
56 |
-
"vocab_size": 151936
|
|
|
57 |
}
|
|
|
53 |
"vision_end_token_id": 151653,
|
54 |
"vision_start_token_id": 151652,
|
55 |
"vision_token_id": 151654,
|
56 |
+
"vocab_size": 151936,
|
57 |
+
"truncate_dim": null
|
58 |
}
|
modeling_jina_embeddings_v4.py
CHANGED
@@ -1,4 +1,6 @@
|
|
1 |
-
|
|
|
|
|
2 |
import os
|
3 |
from dataclasses import dataclass
|
4 |
from enum import Enum
|
@@ -15,7 +17,6 @@ from torch import nn
|
|
15 |
from torch.utils.data import DataLoader
|
16 |
from tqdm import tqdm
|
17 |
from transformers import BatchFeature
|
18 |
-
from transformers.modeling_utils import PreTrainedModel
|
19 |
from transformers.models.qwen2_5_vl import (Qwen2_5_VLForConditionalGeneration,
|
20 |
Qwen2_5_VLProcessor)
|
21 |
|
@@ -33,27 +34,17 @@ class TaskType(str, Enum):
|
|
33 |
text_matching = "text-matching"
|
34 |
|
35 |
|
|
|
|
|
|
|
|
|
|
|
36 |
class JinaEmbeddingsV4Processor(Qwen2_5_VLProcessor):
|
37 |
def __init__(self, *args, **kwargs) -> None:
|
38 |
Qwen2_5_VLProcessor.__init__(self, *args, **kwargs)
|
39 |
self.assistant_prefix_len = 58
|
40 |
self.text_max_length = 8192
|
41 |
|
42 |
-
@staticmethod
|
43 |
-
def round_by_factor(number: float, factor: int) -> int:
|
44 |
-
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
|
45 |
-
return round(number / factor) * factor
|
46 |
-
|
47 |
-
@staticmethod
|
48 |
-
def ceil_by_factor(number: float, factor: int) -> int:
|
49 |
-
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
|
50 |
-
return math.ceil(number / factor) * factor
|
51 |
-
|
52 |
-
@staticmethod
|
53 |
-
def floor_by_factor(number: float, factor: int) -> int:
|
54 |
-
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
|
55 |
-
return math.floor(number / factor) * factor
|
56 |
-
|
57 |
def process_images(
|
58 |
self,
|
59 |
images: Union[List[Image.Image], List[List[Image.Image]]],
|
@@ -175,7 +166,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
175 |
[pv[:o] for pv, o in zip(kwargs["pixel_values"], offsets)], dim=0
|
176 |
)
|
177 |
|
178 |
-
position_ids, rope_deltas = super().get_rope_index(
|
179 |
input_ids=input_ids,
|
180 |
image_grid_thw=kwargs.get("image_grid_thw", None),
|
181 |
attention_mask=attention_mask,
|
@@ -267,10 +258,10 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
267 |
**kwargs,
|
268 |
) -> JinaEmbeddingsV4ModelOutput:
|
269 |
"""
|
270 |
-
Forward pass through
|
271 |
Args:
|
272 |
-
input_ids (torch.
|
273 |
-
attention_mask (torch.
|
274 |
Returns:
|
275 |
JinaEmbeddingsV4ModelOutput:
|
276 |
single_vector (torch.Tensor): Single-vector embeddings of shape (batch_size, dim).
|
@@ -302,17 +293,17 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
302 |
data: List[Union[str, Image.Image]],
|
303 |
processor_fn: Callable,
|
304 |
desc: str,
|
305 |
-
vector_type:
|
306 |
return_numpy: bool = False,
|
307 |
-
|
|
|
308 |
) -> Union[np.ndarray, List[torch.Tensor]]:
|
309 |
dataloader = DataLoader(
|
310 |
dataset=data,
|
311 |
-
batch_size=
|
312 |
shuffle=False,
|
313 |
collate_fn=processor_fn,
|
314 |
)
|
315 |
-
vector_type = vector_type or "single_vector"
|
316 |
results = []
|
317 |
self.eval()
|
318 |
for batch in tqdm(dataloader, desc=desc):
|
@@ -322,8 +313,11 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
322 |
embeddings = self(**batch)
|
323 |
if vector_type == "single_vector":
|
324 |
embeddings = embeddings.single_vec_emb
|
|
|
|
|
325 |
else:
|
326 |
embeddings = embeddings.multi_vec_emb
|
|
|
327 |
results.append(
|
328 |
embeddings.cpu()
|
329 |
if return_numpy
|
@@ -333,44 +327,98 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
333 |
return np.concatenate([result.numpy() for result in results], axis=0)
|
334 |
return [item for sublist in results for item in sublist]
|
335 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
336 |
def encode_texts(
|
337 |
self,
|
338 |
-
|
339 |
max_length: int = 8192,
|
340 |
batch_size: int = 8,
|
341 |
vector_type: Optional[str] = None,
|
342 |
-
|
343 |
-
|
|
|
344 |
) -> List[torch.Tensor]:
|
|
|
|
|
|
|
|
|
|
|
345 |
processor_fn = partial(
|
346 |
-
self.processor.process_texts,
|
|
|
|
|
347 |
)
|
348 |
-
|
349 |
-
|
|
|
|
|
350 |
processor_fn=processor_fn,
|
351 |
-
desc=
|
352 |
-
|
353 |
batch_size=batch_size,
|
354 |
-
**
|
355 |
)
|
356 |
|
|
|
|
|
357 |
def encode_images(
|
358 |
self,
|
359 |
-
|
360 |
batch_size: int = 8,
|
361 |
vector_type: Optional[str] = None,
|
362 |
-
|
363 |
-
|
364 |
) -> List[torch.Tensor]:
|
365 |
-
|
366 |
-
|
|
|
|
|
|
|
367 |
processor_fn=self.processor.process_images,
|
368 |
-
desc=
|
369 |
-
vector_type=vector_type,
|
370 |
batch_size=batch_size,
|
371 |
-
|
|
|
372 |
)
|
373 |
|
|
|
|
|
374 |
@classmethod
|
375 |
def from_pretrained(
|
376 |
cls,
|
@@ -381,9 +429,15 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
381 |
if "torch_dtype" not in kwargs:
|
382 |
kwargs["torch_dtype"] = "auto"
|
383 |
|
384 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
385 |
|
386 |
-
# Get the base model first
|
387 |
base_model = super().from_pretrained(
|
388 |
pretrained_model_name_or_path, *args, **kwargs
|
389 |
)
|
@@ -397,36 +451,44 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
397 |
)
|
398 |
adapter_dir = os.path.join(adapter_cache_path, "adapters")
|
399 |
|
400 |
-
# Store adapter directory for later use with set_task
|
401 |
base_model.adapter_dir = adapter_dir
|
|
|
402 |
|
403 |
# Create the PEFT model with the requested task adapter
|
404 |
peft_model = PeftModel.from_pretrained(
|
405 |
-
base_model, os.path.join(adapter_dir, task)
|
406 |
)
|
407 |
|
408 |
# Add set_task method to the PEFT model instance
|
409 |
-
def set_task_method(self,
|
410 |
"""
|
411 |
Set the task adapter for the model.
|
412 |
|
413 |
Args:
|
414 |
-
|
415 |
one of ['retrieval', 'text-matching', 'code']
|
416 |
"""
|
417 |
-
if isinstance(
|
418 |
try:
|
419 |
-
|
420 |
except ValueError:
|
421 |
valid_tasks = [t.value for t in TaskType]
|
422 |
raise ValueError(
|
423 |
-
f"Invalid task: {
|
424 |
)
|
|
|
|
|
|
|
|
|
425 |
|
426 |
-
|
427 |
-
|
|
|
|
|
|
|
428 |
|
429 |
-
# Bind the
|
430 |
peft_model.set_task = set_task_method.__get__(peft_model, type(peft_model))
|
|
|
431 |
|
432 |
return peft_model
|
|
|
1 |
+
# Jina Embeddings V4 Model implementation was inspired by the ColPali codebase:
|
2 |
+
# https://github.com/illuin-tech/colpali
|
3 |
+
|
4 |
import os
|
5 |
from dataclasses import dataclass
|
6 |
from enum import Enum
|
|
|
17 |
from torch.utils.data import DataLoader
|
18 |
from tqdm import tqdm
|
19 |
from transformers import BatchFeature
|
|
|
20 |
from transformers.models.qwen2_5_vl import (Qwen2_5_VLForConditionalGeneration,
|
21 |
Qwen2_5_VLProcessor)
|
22 |
|
|
|
34 |
text_matching = "text-matching"
|
35 |
|
36 |
|
37 |
+
PREFIX_DICT = {"query": "Query", "passage": "Passage"}
|
38 |
+
TRUNCATE_DIMS = [128, 256, 512, 1024]
|
39 |
+
VECTOR_TYPES = ["single_vector", "multi_vector"]
|
40 |
+
|
41 |
+
|
42 |
class JinaEmbeddingsV4Processor(Qwen2_5_VLProcessor):
|
43 |
def __init__(self, *args, **kwargs) -> None:
|
44 |
Qwen2_5_VLProcessor.__init__(self, *args, **kwargs)
|
45 |
self.assistant_prefix_len = 58
|
46 |
self.text_max_length = 8192
|
47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
def process_images(
|
49 |
self,
|
50 |
images: Union[List[Image.Image], List[List[Image.Image]]],
|
|
|
166 |
[pv[:o] for pv, o in zip(kwargs["pixel_values"], offsets)], dim=0
|
167 |
)
|
168 |
|
169 |
+
position_ids, rope_deltas = super().get_rope_index(
|
170 |
input_ids=input_ids,
|
171 |
image_grid_thw=kwargs.get("image_grid_thw", None),
|
172 |
attention_mask=attention_mask,
|
|
|
258 |
**kwargs,
|
259 |
) -> JinaEmbeddingsV4ModelOutput:
|
260 |
"""
|
261 |
+
Forward pass through the model. Returns both single-vector and multi-vector embeddings.
|
262 |
Args:
|
263 |
+
input_ids (torch.Tensor): The input tokens tensor.
|
264 |
+
attention_mask (torch.Tensor): The attention mask tensor.
|
265 |
Returns:
|
266 |
JinaEmbeddingsV4ModelOutput:
|
267 |
single_vector (torch.Tensor): Single-vector embeddings of shape (batch_size, dim).
|
|
|
293 |
data: List[Union[str, Image.Image]],
|
294 |
processor_fn: Callable,
|
295 |
desc: str,
|
296 |
+
vector_type: str = "single_vector",
|
297 |
return_numpy: bool = False,
|
298 |
+
batch_size: int = 32,
|
299 |
+
truncate_dim: Optional[int] = None,
|
300 |
) -> Union[np.ndarray, List[torch.Tensor]]:
|
301 |
dataloader = DataLoader(
|
302 |
dataset=data,
|
303 |
+
batch_size=batch_size,
|
304 |
shuffle=False,
|
305 |
collate_fn=processor_fn,
|
306 |
)
|
|
|
307 |
results = []
|
308 |
self.eval()
|
309 |
for batch in tqdm(dataloader, desc=desc):
|
|
|
313 |
embeddings = self(**batch)
|
314 |
if vector_type == "single_vector":
|
315 |
embeddings = embeddings.single_vec_emb
|
316 |
+
if truncate_dim is not None:
|
317 |
+
embeddings = embeddings[:, :truncate_dim]
|
318 |
else:
|
319 |
embeddings = embeddings.multi_vec_emb
|
320 |
+
|
321 |
results.append(
|
322 |
embeddings.cpu()
|
323 |
if return_numpy
|
|
|
327 |
return np.concatenate([result.numpy() for result in results], axis=0)
|
328 |
return [item for sublist in results for item in sublist]
|
329 |
|
330 |
+
def _validate_encoding_params(
|
331 |
+
self,
|
332 |
+
vector_type: Optional[str] = None,
|
333 |
+
truncate_dim: Optional[int] = None,
|
334 |
+
text_type: Optional[str] = None,
|
335 |
+
) -> Dict[str, Any]:
|
336 |
+
encode_kwargs = {}
|
337 |
+
if text_type is not None:
|
338 |
+
if text_type not in PREFIX_DICT:
|
339 |
+
raise ValueError(
|
340 |
+
f"Invalid text_type: {text_type}. Must be one of {list(PREFIX_DICT.keys())}."
|
341 |
+
)
|
342 |
+
else:
|
343 |
+
encode_kwargs["prefix"] = (
|
344 |
+
PREFIX_DICT[text_type]
|
345 |
+
if self.task != TaskType.text_matching
|
346 |
+
else PREFIX_DICT["query"]
|
347 |
+
)
|
348 |
+
|
349 |
+
vector_type = vector_type or "single_vector"
|
350 |
+
if vector_type not in VECTOR_TYPES:
|
351 |
+
raise ValueError(
|
352 |
+
f"Invalid vector_type: {vector_type}. Must be one of {VECTOR_TYPES}."
|
353 |
+
)
|
354 |
+
else:
|
355 |
+
encode_kwargs["vector_type"] = vector_type
|
356 |
+
|
357 |
+
truncate_dim = truncate_dim or self.config.truncate_dim
|
358 |
+
if truncate_dim is not None and truncate_dim not in TRUNCATE_DIMS:
|
359 |
+
raise ValueError(
|
360 |
+
f"Invalid truncate_dim: {truncate_dim}. Must be one of {TRUNCATE_DIMS}."
|
361 |
+
)
|
362 |
+
else:
|
363 |
+
encode_kwargs["truncate_dim"] = truncate_dim
|
364 |
+
|
365 |
+
return encode_kwargs
|
366 |
+
|
367 |
def encode_texts(
|
368 |
self,
|
369 |
+
texts: List[str],
|
370 |
max_length: int = 8192,
|
371 |
batch_size: int = 8,
|
372 |
vector_type: Optional[str] = None,
|
373 |
+
return_numpy: bool = False,
|
374 |
+
truncate_dim: Optional[int] = None,
|
375 |
+
text_type: Optional[str] = None,
|
376 |
) -> List[torch.Tensor]:
|
377 |
+
text_type = text_type or "query"
|
378 |
+
encode_kwargs = self._validate_encoding_params(
|
379 |
+
vector_type, truncate_dim, text_type
|
380 |
+
)
|
381 |
+
|
382 |
processor_fn = partial(
|
383 |
+
self.processor.process_texts,
|
384 |
+
max_length=max_length,
|
385 |
+
prefix=encode_kwargs.pop("prefix"),
|
386 |
)
|
387 |
+
|
388 |
+
is_single = len(texts) == 1
|
389 |
+
embeddings = self._process_batches(
|
390 |
+
data=texts,
|
391 |
processor_fn=processor_fn,
|
392 |
+
desc="Encoding texts...",
|
393 |
+
return_numpy=return_numpy,
|
394 |
batch_size=batch_size,
|
395 |
+
**encode_kwargs,
|
396 |
)
|
397 |
|
398 |
+
return embeddings[0] if is_single else embeddings
|
399 |
+
|
400 |
def encode_images(
|
401 |
self,
|
402 |
+
images: List[Image.Image],
|
403 |
batch_size: int = 8,
|
404 |
vector_type: Optional[str] = None,
|
405 |
+
return_numpy: bool = False,
|
406 |
+
truncate_dim: Optional[int] = None,
|
407 |
) -> List[torch.Tensor]:
|
408 |
+
encode_kwargs = self._validate_encoding_params(vector_type, truncate_dim)
|
409 |
+
|
410 |
+
is_single = len(images) == 1
|
411 |
+
embeddings = self._process_batches(
|
412 |
+
data=images,
|
413 |
processor_fn=self.processor.process_images,
|
414 |
+
desc="Encoding images...",
|
|
|
415 |
batch_size=batch_size,
|
416 |
+
return_numpy=return_numpy,
|
417 |
+
**encode_kwargs,
|
418 |
)
|
419 |
|
420 |
+
return embeddings[0] if is_single else embeddings
|
421 |
+
|
422 |
@classmethod
|
423 |
def from_pretrained(
|
424 |
cls,
|
|
|
429 |
if "torch_dtype" not in kwargs:
|
430 |
kwargs["torch_dtype"] = "auto"
|
431 |
|
432 |
+
task_value = kwargs.pop("task", "retrieval")
|
433 |
+
try:
|
434 |
+
task = TaskType(task_value)
|
435 |
+
except ValueError:
|
436 |
+
valid_tasks = [t.value for t in TaskType]
|
437 |
+
raise ValueError(
|
438 |
+
f"Invalid task: {task_value}. Must be one of {valid_tasks}."
|
439 |
+
)
|
440 |
|
|
|
441 |
base_model = super().from_pretrained(
|
442 |
pretrained_model_name_or_path, *args, **kwargs
|
443 |
)
|
|
|
451 |
)
|
452 |
adapter_dir = os.path.join(adapter_cache_path, "adapters")
|
453 |
|
|
|
454 |
base_model.adapter_dir = adapter_dir
|
455 |
+
base_model.task = task
|
456 |
|
457 |
# Create the PEFT model with the requested task adapter
|
458 |
peft_model = PeftModel.from_pretrained(
|
459 |
+
base_model, os.path.join(adapter_dir, task.value)
|
460 |
)
|
461 |
|
462 |
# Add set_task method to the PEFT model instance
|
463 |
+
def set_task_method(self, task: Union[str, TaskType]):
|
464 |
"""
|
465 |
Set the task adapter for the model.
|
466 |
|
467 |
Args:
|
468 |
+
task (Union[str, TaskType]): The task name. Must be one of TaskType values or
|
469 |
one of ['retrieval', 'text-matching', 'code']
|
470 |
"""
|
471 |
+
if isinstance(task, str):
|
472 |
try:
|
473 |
+
task = TaskType(task)
|
474 |
except ValueError:
|
475 |
valid_tasks = [t.value for t in TaskType]
|
476 |
raise ValueError(
|
477 |
+
f"Invalid task: {task}. Must be one of {valid_tasks}"
|
478 |
)
|
479 |
+
if self.model.task != task:
|
480 |
+
adapter_path = os.path.join(self.adapter_dir, task.value)
|
481 |
+
hotswap_adapter(self, adapter_path, adapter_name="default")
|
482 |
+
self.model.task = task
|
483 |
|
484 |
+
def get_task_method(self):
|
485 |
+
"""
|
486 |
+
Get the task adapter for the model.
|
487 |
+
"""
|
488 |
+
return self.model.task.value
|
489 |
|
490 |
+
# Bind the methods to the instance
|
491 |
peft_model.set_task = set_task_method.__get__(peft_model, type(peft_model))
|
492 |
+
peft_model.get_task = get_task_method.__get__(peft_model, type(peft_model))
|
493 |
|
494 |
return peft_model
|