fix-dtype-casting (#13)
Browse files- fix-dtype-casting (2011fa2b59111302c5deb3f0791ad31908120430)
modeling_jina_embeddings_v4.py
CHANGED
@@ -345,7 +345,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
345 |
for batch in tqdm(dataloader, desc=desc):
|
346 |
with torch.no_grad():
|
347 |
batch = {k: v.to(self.device) for k, v in batch.items()}
|
348 |
-
with torch.autocast(device_type=torch.device(self.device).type):
|
349 |
embeddings = self(**batch, task_label=task_label)
|
350 |
if vector_type == "single_vector":
|
351 |
embeddings = embeddings.single_vec_emb
|
|
|
345 |
for batch in tqdm(dataloader, desc=desc):
|
346 |
with torch.no_grad():
|
347 |
batch = {k: v.to(self.device) for k, v in batch.items()}
|
348 |
+
with torch.autocast(device_type=torch.device(self.device).type, dtype=torch.bfloat16):
|
349 |
embeddings = self(**batch, task_label=task_label)
|
350 |
if vector_type == "single_vector":
|
351 |
embeddings = embeddings.single_vec_emb
|