jupyterjazz commited on
Commit
927f9b1
·
verified ·
1 Parent(s): e7f92e1

fix-dtype-casting (#13)

Browse files

- fix-dtype-casting (2011fa2b59111302c5deb3f0791ad31908120430)

Files changed (1) hide show
  1. modeling_jina_embeddings_v4.py +1 -1
modeling_jina_embeddings_v4.py CHANGED
@@ -345,7 +345,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
345
  for batch in tqdm(dataloader, desc=desc):
346
  with torch.no_grad():
347
  batch = {k: v.to(self.device) for k, v in batch.items()}
348
- with torch.autocast(device_type=torch.device(self.device).type):
349
  embeddings = self(**batch, task_label=task_label)
350
  if vector_type == "single_vector":
351
  embeddings = embeddings.single_vec_emb
 
345
  for batch in tqdm(dataloader, desc=desc):
346
  with torch.no_grad():
347
  batch = {k: v.to(self.device) for k, v in batch.items()}
348
+ with torch.autocast(device_type=torch.device(self.device).type, dtype=torch.bfloat16):
349
  embeddings = self(**batch, task_label=task_label)
350
  if vector_type == "single_vector":
351
  embeddings = embeddings.single_vec_emb