Very bad performances (not gpu time, score)
Hello,
We evaluate embeddings in a needle in the haystack challenge, which is pretty much
You got a long text
you divide it into chunks of X characters (here 500)
You got a question--answer pair, hide the answer in of the chunks (so it's the needle), then ask the embedding model, using the question (needle magnet) rank the embeddings.
We expect the chunk containing the needle to be in the top ranked similarity
Using this kind of search we can evaluate the embedding model.
Gemma got terrible scores as it's on average ranked in the middle of the haystack, so worst possible case
here is my benchmarking code :
import os
import json
import gc
import random
import time
import copy
from typing import List, Tuple
from concurrent.futures import ProcessPoolExecutor, as_completed
import torch
import tqdm
import openpyxl
import fitz
from sentence_transformers import SentenceTransformer, util
from utils.hf_env import setup_hf_env
setup_hf_env()
# Global variables
embeddings_models = [
"BAAI/bge-m3",
"google/embeddinggemma-300m",
]
DEBUG = False
chunks_path = "./chunks/"
pdfs_path = "./pdfs/"
pdfs_to_process = [
"2407.12327v1.pdf",
"2409.05591v2.pdf",
"IDEXTEND_guide.pdf",
"Inner_Demons_by_E_Thornewell91-lwqpu3es.pdf",
"md2pdf - Markdown to PDF.pdf",
"ORANO-MAG-2021_205x275_FR_MEL.pdf",
"politique-hse-2024-2026-vf.pdf",
"SCP-3194_by_SCP-3194-rMxNQ8yJ.pdf",
"Walking_the_Wire_by_theMarvelousTolkienJob-p3fqb8tb.pdf",
]
os.makedirs(pdfs_path, exist_ok=True)
os.makedirs(chunks_path, exist_ok=True)
chunks_size = [512]
# Read PDFs and build text
text = ""
for pdf in pdfs_to_process:
full_path = os.path.join(pdfs_path, pdf)
with fitz.open(full_path) as doc:
for page in doc:
text += page.get_text()
dataset_crosslingual_easy = "./dataset_crosslingual_easy.json"
dataset_crosslingual_subtle = "./dataset_crosslingual_subtle.json"
dataset_multilingual_easy = "./dataset_multilingual_easy.json"
dataset_multilingual_subtle = "./dataset_multilingual_subtle.json"
needle_question_answers_pairs_crosslingual_easy = json.load(open(dataset_crosslingual_easy, "r", encoding="utf-8-sig"))
needle_question_answers_pairs_crosslingual_subtle = json.load(open(dataset_crosslingual_subtle, "r", encoding="utf-8-sig"))
needle_question_answers_pairs_multilingual_easy = json.load(open(dataset_multilingual_easy, "r", encoding="utf-8-sig"))
needle_question_answers_pairs_multilingual_subtle = json.load(open(dataset_multilingual_subtle, "r", encoding="utf-8-sig"))
if DEBUG:
needle_question_answers_pairs_crosslingual_easy = needle_question_answers_pairs_crosslingual_easy[:1]
needle_question_answers_pairs_crosslingual_subtle = needle_question_answers_pairs_crosslingual_subtle[:1]
needle_question_answers_pairs_multilingual_easy = needle_question_answers_pairs_multilingual_easy[:1]
needle_question_answers_pairs_multilingual_subtle = needle_question_answers_pairs_multilingual_subtle[:1]
langs = [
"crosslingual_easy",
"crosslingual_subtle",
"multilingual_easy",
"multilingual_subtle"]
NUM_REPEATS = 15
if DEBUG:
NUM_REPEATS = 1
def interpolate_color(color1, color2, value, max_value):
ratio = value / max_value
r = int(color1[0] + (color2[0] - color1[0]) * ratio)
g = int(color1[1] + (color2[1] - color1[1]) * ratio)
b = int(color1[2] + (color2[2] - color1[2]) * ratio)
return f"FF{r:02X}{g:02X}{b:02X}"
def get_color(value, max_value):
green = (0, 255, 0)
yellow = (255, 255, 0)
orange = (255, 165, 0)
red = (255, 0, 0)
if value <= max_value / 4:
return interpolate_color(green, yellow, value, max_value / 4)
elif value <= max_value / 2:
return interpolate_color(yellow, orange, value - max_value / 4, max_value / 4)
elif value <= 3 * max_value / 4:
return interpolate_color(orange, red, value - max_value / 2, max_value / 4)
else:
return interpolate_color(red, (0, 0, 0), value - 3 * max_value / 4, max_value / 4)
def inject_needle_in_text(needle_text: str, text: str) -> str:
random_index = random.randint(0, len(text))
if random_index == 0:
return needle_text + text
if random_index == len(text):
return text + needle_text
return text[:random_index] + needle_text + text[random_index:]
def inject_needle_in_chunks(needle_question_answer, chunks) -> Tuple[List[str], List[int]]:
cpy_chunks = copy.deepcopy(chunks)
n_chunks = len(cpy_chunks)
random_indexes = random.sample(range(n_chunks), len(needle_question_answer["needles"]))
needle_question_answer["expected_matches"] = random_indexes
for needle_index, needle_text in enumerate(needle_question_answer["needles"]):
cpy_chunks[random_indexes[needle_index]] = inject_needle_in_text(needle_text, cpy_chunks[random_indexes[needle_index]])
return cpy_chunks, random_indexes
def get_chunk_embeddings(chunks, model: SentenceTransformer, device: str):
return model.encode(chunks, convert_to_tensor=True, device=device, batch_size=2)
def update_chunk_embeddings(chunks, modified_indexes, model: SentenceTransformer, device: str):
modified_embeddings = model.encode([chunks[i] for i in modified_indexes], convert_to_tensor=True, device=device)
return modified_embeddings
def process_embedding_model(embedding_model: str, device: str):
rows = []
gc.collect()
torch.cuda.empty_cache()
try:
try:
model = SentenceTransformer(
embedding_model,
trust_remote_code=True,
device=device,
)
except Exception:
device = "cpu"
model = SentenceTransformer(
embedding_model,
trust_remote_code=True,
device=device,
)
for chunk_size in tqdm.tqdm(chunks_size, desc=f"Chunk sizes for {embedding_model}", leave=False):
chunks = [text[i: i + chunk_size] for i in range(0, len(text), chunk_size)]
original_chunks = copy.deepcopy(chunks)
original_chunk_embeddings = get_chunk_embeddings(original_chunks, model, device)
_time = time.time()
languages_averages = {}
for idx_pair_set, needle_pairs in enumerate([needle_question_answers_pairs_crosslingual_easy,
needle_question_answers_pairs_crosslingual_subtle,
needle_question_answers_pairs_multilingual_easy,
needle_question_answers_pairs_multilingual_subtle]):
stored_scores = []
for _ in range(NUM_REPEATS):
for pair in needle_pairs:
untouched_chunks = copy.deepcopy(original_chunks)
modified_chunks, ground_truth_indexes = inject_needle_in_chunks(pair, untouched_chunks)
modified_embeddings = update_chunk_embeddings(modified_chunks, ground_truth_indexes, model, device)
updated_chunk_embeddings = torch.Tensor(original_chunk_embeddings).clone()
for i, idx in enumerate(ground_truth_indexes):
updated_chunk_embeddings[idx] = modified_embeddings[i]
question_embedding = model.encode(pair["question"], convert_to_tensor=True, device=device)
sim = util.pytorch_cos_sim(question_embedding, updated_chunk_embeddings)
sorted_similarities, sorted_indexes = torch.sort(sim, descending=True)
predicted_index_sorted = sorted_indexes.tolist()[0]
for index in ground_truth_indexes:
stored_scores.append(predicted_index_sorted.index(index))
del updated_chunk_embeddings
del question_embedding
gc.collect()
average_index = sum(stored_scores) / len(stored_scores)
languages_averages[langs[idx_pair_set]] = average_index
time_spent = time.time() - _time
vram_used = torch.cuda.memory_allocated() / 1024 / 1024
row = [
embedding_model,
round(languages_averages["crosslingual_easy"], 1),
round(languages_averages["crosslingual_subtle"], 1),
round(languages_averages["multilingual_easy"], 1),
round(languages_averages["multilingual_subtle"], 1),
round((languages_averages["crosslingual_easy"] + languages_averages["crosslingual_subtle"] + languages_averages["multilingual_easy"] + languages_averages["multilingual_subtle"]) / 4, 1),
len(original_chunks),
time_spent,
vram_used,
chunk_size,
]
rows.append(row)
del model
except Exception as e:
#print stack
import traceback
traceback.print_exc()
print(f"Error with model {embedding_model} on device {device}: {e}")
for chunk_size in chunks_size:
row = [embedding_model, "Error", "Error", "Error", "Error", 0, len(original_chunks) if 'original_chunks' in locals() else 0, 0, 0, chunk_size]
rows.append(row)
gc.collect()
torch.cuda.empty_cache()
return rows
if __name__ == '__main__':
import multiprocessing
multiprocessing.freeze_support()
multiprocessing.set_start_method('spawn', force=True)
num_gpus = torch.cuda.device_count()
devices = [f"cuda:{i}" for i in range(num_gpus)]
print(f"Found {num_gpus} GPUs: {devices}")
all_rows = []
with ProcessPoolExecutor(max_workers=num_gpus) as executor:
futures = {}
for i, embedding_model in enumerate(embeddings_models):
dev = devices[i % num_gpus]
futures[executor.submit(process_embedding_model, embedding_model, dev)] = embedding_model
for fut in tqdm.tqdm(as_completed(futures), total=len(futures), desc="Models"):
all_rows.extend(fut.result())
xlsx_name = f"results_{time.time()}.xlsx"
wb = openpyxl.Workbook()
worksheet = wb.active
cols = ["Model",
"crosslingual_easy average index",
"crosslingual_subtle average index",
"multilingual_easy average index",
"multilingual_subtle average index",
"Average index",
"Total chunks",
"Time spent(s)",
"VRAM used (MB)",
"Chunk size"]
worksheet.append(cols)
for row in all_rows:
worksheet.append(row)
wb.save(xlsx_name)
for row in range(2, 2 + len(all_rows)):
max_value = worksheet.cell(row=row, column=7).value
for col in range(2, 7):
cell = worksheet.cell(row=row, column=col)
try:
value = float(cell.value if cell.value != "Error" else max_value)
except Exception:
value = 0
color = get_color(value, max_value if isinstance(max_value, (int, float)) and max_value else 1)
cell.fill = openpyxl.styles.PatternFill(start_color=color, end_color=color, fill_type="solid")
worksheet.cell(row=row, column=8).font = openpyxl.styles.Font(bold=True)
wb.save(xlsx_name)
print(f"Results saved to {xlsx_name}")
Hi, I'm having the same issue.
I'm testing this model on Korean retrieval benchmarks from MTEB (Ko-StrategyQA, AutoRAGRetrieval, PublicHealthQA, BelebeleRetrieval, MultiLongDocRetrieval, MIRACLRetrieval, MrTidyRetrieval) and this model shows very bad scores.
Also when evaluating especially on MultiLongDocRetrieval, MIRACLRetrieval, MrTidyRetrieval tasks (which contain long contexts), I see these message:
WARNING:mteb.evaluation.evaluators.RetrievalEvaluator:Found 33800 NaN values in the similarity scores. Replacing NaN values with -1.
I'm pretty much sure that I'm using the proper precision and model prompts. This is my code for model initialization:
model_prompts = {
PromptType.query.value: "task: search result | query: ",
PromptType.document.value: "title: none | text: ",
}
model = mteb.get_model(model_name, model_prompts=model_prompts, trust_remote_code=True, device=device)
And these logs are shown in the encoding step:
INFO:mteb.models.sentence_transformer_wrapper:Using prompt_name='query' for task=MultiLongDocRetrieval prompt_type=<PromptType.query: 'query'> with prompt='task: search result | query: '
INFO:mteb.models.sentence_transformer_wrapper:Using prompt_name='document' for task=MultiLongDocRetrieval prompt_type=<PromptType.document: 'document'> with prompt='title: none | text: '
Are there some problems?
- Youngjoon Jang
@LPN64 maybe your issue might be related to using
float16
because this is stated in the README:NOTE: EmbeddingGemma activations do not support float16. Please use float32 or bfloat16 as appropriate for your hardware.
Well shit, I actually saw this sentence but I mixed up the var types, will update my post with new scores
I updated the code to :
try:
model = SentenceTransformer(
embedding_model,
trust_remote_code=True,
device=device,
model_kwargs={"dtype": torch.bfloat16},
)
except Exception:
device = "cpu"
model = SentenceTransformer(
embedding_model,
trust_remote_code=True,
device=device,
model_kwargs={"dtype": torch.bfloat16},
)
but it doesn't help much
I updated the code to :
try: model = SentenceTransformer( embedding_model, trust_remote_code=True, device=device, model_kwargs={"dtype": torch.bfloat16}, ) except Exception: device = "cpu" model = SentenceTransformer( embedding_model, trust_remote_code=True, device=device, model_kwargs={"dtype": torch.bfloat16}, )
but it doesn't help much
- Did you remove
model.half()
? Because even if you initialize the model with bf16,.half()
would make it to fp16. reference:
>>> model = SentenceTransformer(model_name, model_kwargs={"torch_dtype": torch.bfloat16})
>>> model.dtype
torch.bfloat16
>>> model.half()
SentenceTransformer(
(0): Transformer({'max_seq_length': 2048, 'do_lower_case': False, 'architecture': 'Gemma3TextModel'})
(1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
(2): Dense({'in_features': 768, 'out_features': 3072, 'bias': False, 'activation_function': 'torch.nn.modules.linear.Identity'})
(3): Dense({'in_features': 3072, 'out_features': 768, 'bias': False, 'activation_function': 'torch.nn.modules.linear.Identity'})
(4): Normalize()
)
>>> model.dtype
torch.float16
>>>
- Are you using the prompts for retrieval ?
- "task: search result | query: " should be used for encoding queries
- "title: none | text: " should be used for encoding documents
Do you have any ideas,
@tomaarsen
?
This is my test code using mteb:
"""Benchmarking all datasets constituting the MTEB Korean leaderboard & average scores"""
from __future__ import annotations
import os
import logging
from multiprocessing import Process, current_process
import torch
import hashlib
from sentence_transformers import SentenceTransformer
from sentence_transformers.models import StaticEmbedding
import mteb
from mteb import MTEB, get_tasks
from mteb.encoder_interface import PromptType
from mteb.models.sentence_transformer_wrapper import SentenceTransformerWrapper
from mteb.models.instruct_wrapper import instruct_wrapper
import argparse
from dotenv import load_dotenv
from setproctitle import setproctitle
import traceback
import logging
load_dotenv() # for OPENAI
parser = argparse.ArgumentParser(description="Extract contexts")
parser.add_argument('--quantize', default=False, type=bool, help='quantize embeddings')
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("main")
TASK_LIST_RETRIEVAL_GPU_MAPPING = {
0: [
"Ko-StrategyQA",
"AutoRAGRetrieval",
"PublicHealthQA",
"BelebeleRetrieval",
"MultiLongDocRetrieval",
],
1: ["MIRACLRetrieval"],
2: ["MrTidyRetrieval"],
}
model_names = [
"google/embeddinggemma-300m"
]
save_path = "./RESULTS_DEV"
def evaluate_model(model_name, gpu_id, tasks):
import torch
try:
device = torch.device(f"cuda:{str(gpu_id)}")
torch.cuda.set_device(device)
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
model = None
if not os.path.exists(model_name):
if "m2v" in model_name:
static_embedding = StaticEmbedding.from_model2vec(model_name)
model = SentenceTransformer(modules=[static_embedding], model_kwargs={"attn_implementation": "sdpa"}, device=device)
else:
if model_name == "nlpai-lab/KoE5" or model_name == "KU-HIAI-ONTHEIT/ontheit-large-v1_1" or "KUKE" in model_name:
model_prompts = {
PromptType.query.value: "query: ",
PromptType.passage.value: "passage: ",
}
model = SentenceTransformerWrapper(model=model_name, model_prompts=model_prompts, model_kwargs={"attn_implementation": "sdpa"}, device=device)
elif "snowflake" in model_name.lower():
model_prompts = {
PromptType.query.value: "query: ",
}
model = SentenceTransformerWrapper(model=model_name, model_prompts=model_prompts, device=device)
elif "Qwen3" in model_name:
model = mteb.get_model(model_name, model_kwargs={"attn_implementation": "flash_attention_2", "torch_dtype": torch.bfloat16}, device=device, tokenizer_kwargs={"padding_side": "left"},)
elif "frony" in model_name:
model_prompts = {
# PromptType.query.value: "Instruct: Given a web search query, retrieve relevant passages that answer the query\nQuery:",
PromptType.query.value: "<Q>",
PromptType.passage.value: "<P>",
}
model = mteb.get_model(model_name, model_prompts=model_prompts, model_kwargs={"attn_implementation": "sdpa"}, device=device)
elif "gte-multilingual" in model_name or "nomic-embed" in model_name:
model = mteb.get_model(model_name, trust_remote_code=True, device=device)
elif "gemma" in model_name:
model_prompts = {
PromptType.query.value: "task: search result | query: ",
PromptType.document.value: "title: none | text: ",
}
model = mteb.get_model(model_name, model_prompts=model_prompts, trust_remote_code=True, device=device)
else:
model = mteb.get_model(model_name, device=device)
else: # 직접 학습한 모델의 경우
file_name = os.path.join(model_name, "model.safetensors")
if os.path.exists(file_name):
if "m2v" in model_name: # model2vec의 경우: 모델명에 m2v를 포함시켜주어야 model2vec 모델로 인식합니다.
static_embedding = StaticEmbedding.from_model2vec(model_name)
model = SentenceTransformer(modules=[static_embedding], model_kwargs={"attn_implementation": "sdpa"}, device=device)
else:
model = mteb.get_model(model_name, model_kwargs={"attn_implementation": "sdpa"}, device=device)
if model:
output_folder_name = os.path.basename(model_name)
if os.path.isdir(model_name) and len(output_folder_name) > 100:
model_hash = hashlib.md5(model_name.encode()).hexdigest()[:6]
output_folder_name = f"{output_folder_name[:93]}_{model_hash}"
if os.path.isdir(model_name):
try:
model.model_meta.name = output_folder_name
except AttributeError:
logger.warning("Could not override model_meta.name. Path might still be too long.")
setproctitle(f"{output_folder_name}-{gpu_id}")
print(f"Running tasks: {tasks} / {model_name} on GPU {gpu_id} in process {current_process().name}")
evaluation = MTEB(
tasks=get_tasks(tasks=tasks, languages=["kor-Kore", "kor-Hang", "kor_Hang"])
)
if "multilingual-e5" in model_name or "KoE5" in model_name or "ontheit" in model_name or "KUKE" in model_name:
batch_size = 512
elif "jina" in model_name:
batch_size = 8
elif "bge-m3" in model_name or "Snowflake" in model_name:
batch_size = 64
elif "gemma2" in model_name:
batch_size = 256
elif "Salesforce" in model_name:
batch_size = 8
else:
batch_size = 256
if args.quantize: # quantized model의 경우
evaluation.run(
model,
output_folder=f"{save_path}/{output_folder_name}-quantized",
encode_kwargs={"batch_size": batch_size, "precision": "binary"},
)
else:
evaluation.run(
model,
output_folder=f"{save_path}/{output_folder_name}",
encode_kwargs={"batch_size": batch_size},
)
except Exception as ex:
print(ex)
traceback.print_exc()
if __name__ == "__main__":
torch.multiprocessing.set_start_method('spawn')
for model_name in model_names:
print(f"Starting evaluation for model: {model_name}")
processes = []
for gpu_id, tasks in TASK_LIST_RETRIEVAL_GPU_MAPPING.items():
p = Process(target=evaluate_model, args=(model_name, gpu_id, tasks))
p.start()
processes.append(p)
for p in processes:
p.join()
print(f"Completed evaluation for model: {model_name}")
@yjoonjang
you should be able to simply load the model for MTEB evaluation with model = mteb.get_model("google/embeddinggemma-300m")
, it's defined here: https://github.com/embeddings-benchmark/mteb/blob/729f20adbbcda328ac38528d27c7136d79f946da/mteb/models/google_models.py#L241-L259, and the model itself already stores the prompts for retrieval: https://huggingface.co/google/embeddinggemma-300m/blob/main/config_sentence_transformers.json#L19-L20
@LPN64
The encode_query
and encode_document
that you started using are indeed recommended, those take care of the prompts for you. I still think it's rather weird that this model ranks after a literal static embedding model though. In my tests on e.g. MIRIAD, the non-finetuned EmbeddingGemma was the #2 for <500M parameters after only snowflake-arctic-embed-m-v2.0
: https://huggingface.co/blog/embeddinggemma#finetuned-evaluation
I also just ran some tests with https://huggingface.co/datasets/dwzhu/LongEmbed (i.e. needle tests) via MTEB, and results look strong there: #4 and #5 on LEMBSummScreenFDRetrieval and LEMBQMSumRetrieval out of all <500M models.
We did run into some potential issues in the implementation, so some questions that might help narrow it down:
- Are you using FA2? i.e. is
flash_attn
installed?pip show flash_attn
? - Are the chunk sizes indeed 512, never larger?
- Tom Aarsen
Thanks for your reply, @tomaarsen .
I've changed my code to mteb.get_model(model_name, device=device, model_kwargs={"attn_implementation": "flash_attention_2", "torch_dtype": torch.bfloat16})
, and I confirmed that FA2 is installed.
One good news is that I do not get any warnings about NaN
values anymore. However, I still have some retrieval performance problems.
To make it short, I'll report the nDCG@10 result for Korean task in BelebeleRetrieval (which is multilingual, human-translated benchmark)
Model | Score |
---|---|
snowflake-arctic-embed-l-v2.0 | 0.9271 |
bge-m3 | 0.92577 |
KURE-v1 (Korean fine-tuned version for bge-m3) | 0.95019 |
embeddinggemma-300m (using mteb.get_model, FA2, bf16) | 0.71893 |
embeddinggemma-300m (using mteb.get_model, fp32) | 0.72056 |
Other benchmark results also show similar patterns (low retrieval scores for embeddinggemma on Korean retrieval)
- Youngjoon Jang
Sorry for the delayed answer, I wanted to came back with mode results (see at the bottom)
Are the chunk sizes indeed 512, never larger?
Technically it's 512 + needle size (50 max) at very max but yes.
Chunk sizes are in characters not tokens, it's (laziness and) to ensure it's "fair" between two different models tokenizers.
flash attention is indeed installed.
my benchmark, as described in my first post doesn't seem really .... hard, it's a simple needle in the haystack challenge.
But it includes multiple challenges described here (+sample)
The "subtle" dataset might seems really hard, so hard it may even sound stupid, but if any of the embeddings manages to correctly rank the needle, it means all others models are wrong when they fail to do so.
From memory bge-m3 is one of the very rare model to have a very good "resilience" to the difference size of query->doc chunk size difference
For the following experiment I ran it in all those chunk sizes : chunks_size = [512, 1024, 2048, 4096]
Full table :
Model | crosslingual_easy average index | crosslingual_subtle average index | multilingual_easy average index | multilingual_subtle average index | Average index | Relative average index | Total chunks | Time spent(s) | Chunk size |
---|---|---|---|---|---|---|---|---|---|
HIT-TMG/KaLM-embedding-multilingual-mini-instruct-v2 | 64 | 340,2 | 8,6 | 113,2 | 131,5 | 0,125477099 | 1048 | 47,03376174 | 512 |
HIT-TMG/KaLM-embedding-multilingual-mini-instruct-v1.5 | 38,7 | 339,5 | 10 | 169,4 | 139,4 | 0,133015267 | 1048 | 48,50138378 | 512 |
nomic-ai/nomic-embed-text-v2-moe | 38,5 | 310,8 | 7,2 | 267,3 | 155,9 | 0,148759542 | 1048 | 61,92095351 | 512 |
BAAI/bge-m3 | 21 | 357,3 | 1,7 | 365,6 | 186,4 | 0,177862595 | 1048 | 28,41596007 | 512 |
HIT-TMG/KaLM-embedding-multilingual-mini-instruct-v1.5 | 45,9 | 202,8 | 13,8 | 114,7 | 94,3 | 0,179961832 | 524 | 68,54436135 | 1024 |
HIT-TMG/KaLM-embedding-multilingual-mini-instruct-v2 | 59,1 | 211,5 | 16,4 | 104,9 | 98 | 0,187022901 | 524 | 57,68682265 | 1024 |
Qwen/Qwen3-Embedding-0.6B | 52,3 | 392,8 | 30 | 311,3 | 196,6 | 0,18759542 | 1048 | 51,87343717 | 512 |
nomic-ai/nomic-embed-text-v2-moe | 44,4 | 176 | 12,1 | 172,7 | 101,3 | 0,193320611 | 524 | 66,20272422 | 1024 |
nomic-ai/nomic-embed-text-v1.5 | 261,1 | 401,2 | 61,4 | 122,6 | 211,6 | 0,201908397 | 1048 | 19,81428552 | 512 |
Alibaba-NLP/gte-multilingual-base | 44,7 | 390,6 | 31,3 | 389,9 | 214,1 | 0,204293893 | 1048 | 16,94352269 | 512 |
mixedbread-ai/mxbai-embed-large-v1 | 249,5 | 385,8 | 80,8 | 155,2 | 217,8 | 0,207824427 | 1048 | 29,9487319 | 512 |
BAAI/bge-m3 | 19,9 | 206,6 | 5,5 | 206 | 109,5 | 0,208969466 | 524 | 36,66027379 | 1024 |
nomic-ai/modernbert-embed-base | 264,5 | 381,8 | 88,3 | 165,4 | 225 | 0,214694656 | 1048 | 52,11879373 | 512 |
HIT-TMG/KaLM-embedding-multilingual-mini-instruct-v2 | 42 | 114,9 | 14,7 | 58,8 | 57,6 | 0,219847328 | 262 | 84,55695772 | 2048 |
HIT-TMG/KaLM-embedding-multilingual-mini-instruct-v1.5 | 39,9 | 107,8 | 17,3 | 78,8 | 61 | 0,232824427 | 262 | 111,6841481 | 2048 |
Alibaba-NLP/gte-multilingual-base | 48,7 | 212,4 | 39,1 | 205,8 | 126,5 | 0,241412214 | 524 | 19,15321803 | 1024 |
Qwen/Qwen3-Embedding-0.6B | 52,3 | 227,7 | 38,7 | 191,3 | 127,5 | 0,243320611 | 524 | 66,59457469 | 1024 |
BAAI/bge-m3 | 25,3 | 106,7 | 12,2 | 113 | 64,3 | 0,245419847 | 262 | 49,99724388 | 2048 |
Alibaba-NLP/gte-modernbert-base | 278,1 | 434 | 121,6 | 208,4 | 260,5 | 0,248568702 | 1048 | 46,6702168 | 512 |
nomic-ai/nomic-embed-text-v2-moe | 43,8 | 100,6 | 28,8 | 91,1 | 66,1 | 0,252290076 | 262 | 74,96642709 | 2048 |
nomic-ai/nomic-embed-text-v1.5 | 157,6 | 226 | 62,8 | 88,4 | 133,7 | 0,255152672 | 524 | 22,17950249 | 1024 |
HIT-TMG/KaLM-embedding-multilingual-mini-instruct-v2 | 28,3 | 59,1 | 13 | 34,7 | 33,8 | 0,258015267 | 131 | 142,5885003 | 4096 |
nomic-ai/modernbert-embed-base | 153,8 | 200,5 | 80 | 118,2 | 138,1 | 0,263549618 | 524 | 36,04254103 | 1024 |
multi-qa-mpnet-base-dot-v1 | 306 | 423,4 | 124,6 | 277,3 | 282,8 | 0,269847328 | 1048 | 81,86924815 | 512 |
HIT-TMG/KaLM-embedding-multilingual-mini-instruct-v1.5 | 30 | 56,9 | 15,3 | 42 | 36,1 | 0,275572519 | 131 | 154,172142 | 4096 |
sentence-transformers/static-retrieval-mrl-en-v1 | 379,8 | 515,3 | 24,5 | 262,2 | 295,5 | 0,281965649 | 1048 | 2,598180771 | 512 |
mixedbread-ai/mxbai-embed-large-v1 | 168 | 220,9 | 84,8 | 126,1 | 149,9 | 0,286068702 | 524 | 37,92282224 | 1024 |
BAAI/bge-m3 | 25,1 | 58,9 | 15,8 | 60,4 | 40,1 | 0,30610687 | 131 | 87,76587462 | 4096 |
Alibaba-NLP/gte-multilingual-base | 51,2 | 117 | 40,2 | 116,6 | 81,2 | 0,309923664 | 262 | 26,10993862 | 2048 |
Qwen/Qwen3-Embedding-0.6B | 51,9 | 122,2 | 46,3 | 105,5 | 81,5 | 0,311068702 | 262 | 105,2516272 | 2048 |
sentence-transformers/static-retrieval-mrl-en-v1 | 204,9 | 258,6 | 42,9 | 156,1 | 165,6 | 0,316030534 | 524 | 2,509407997 | 1024 |
multi-qa-mpnet-base-dot-v1 | 169,3 | 225,8 | 100,7 | 167,6 | 165,9 | 0,316603053 | 524 | 232,0425959 | 1024 |
google/embeddinggemma-300m | 238,1 | 475,6 | 178,3 | 480,8 | 343,2 | 0,327480916 | 1048 | 57,35304642 | 512 |
nomic-ai/nomic-embed-text-v1.5 | 98,9 | 121,4 | 57,9 | 65,8 | 86 | 0,328244275 | 262 | 29,56943297 | 2048 |
sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2 | 229,2 | 426 | 286,7 | 436 | 344,5 | 0,328721374 | 1048 | 14,14036393 | 512 |
tasksource/ModernBERT-base-embed | 460,6 | 485,5 | 189,2 | 245,5 | 345,2 | 0,329389313 | 1048 | 36,26382637 | 512 |
nomic-ai/modernbert-embed-base | 93,7 | 113,1 | 61,5 | 80 | 87,1 | 0,332442748 | 262 | 38,77631068 | 2048 |
Alibaba-NLP/gte-modernbert-base | 176,2 | 235,8 | 119,1 | 168,5 | 174,9 | 0,333778626 | 524 | 53,01872921 | 1024 |
paraphrase-multilingual-mpnet-base-v2 | 227,8 | 427,8 | 285,6 | 482,9 | 356,1 | 0,339790076 | 1048 | 23,47053266 | 512 |
mixedbread-ai/mxbai-embed-large-v1 | 101,7 | 121,6 | 62,1 | 72 | 89,4 | 0,341221374 | 262 | 44,94309378 | 2048 |
dangvantuan/sentence-camembert-base | 453,9 | 508,6 | 237,1 | 250,8 | 362,6 | 0,345992366 | 1048 | 14,90612245 | 512 |
Parallia/Fairly-Multilingual-ModernBERT-Embed-BE | 363,8 | 489,7 | 216,9 | 395,6 | 366,5 | 0,34971374 | 1048 | 34,70186949 | 512 |
sentence-transformers/static-retrieval-mrl-en-v1 | 112,1 | 128,3 | 38 | 93,3 | 92,9 | 0,354580153 | 262 | 3,002109051 | 2048 |
Alibaba-NLP/gte-multilingual-base | 37,3 | 60,9 | 33,9 | 58,1 | 47,5 | 0,36259542 | 131 | 44,10674095 | 4096 |
multi-qa-mpnet-base-dot-v1 | 95,2 | 117,2 | 70,4 | 99,8 | 95,7 | 0,365267176 | 262 | 270,1356633 | 2048 |
google/embeddinggemma-300m | 159 | 250,4 | 123,2 | 251,3 | 196 | 0,374045802 | 524 | 60,75665236 | 1024 |
nomic-ai/nomic-embed-text-v2-moe | 44,6 | 59,3 | 40,7 | 55,3 | 50 | 0,381679389 | 131 | 83,04587102 | 4096 |
nomic-ai/modernbert-embed-base | 51,7 | 61,1 | 40,7 | 47,5 | 50,3 | 0,383969466 | 131 | 47,37638664 | 4096 |
tasksource/ModernBERT-base-embed | 239,4 | 245,1 | 166,2 | 158,9 | 202,4 | 0,386259542 | 524 | 38,21582079 | 1024 |
nomic-ai/nomic-embed-text-v1.5 | 56 | 62,2 | 43,9 | 46,4 | 52,1 | 0,397709924 | 131 | 47,30568147 | 4096 |
Parallia/Fairly-Multilingual-ModernBERT-Embed-BE | 202,9 | 242,4 | 160,2 | 236,8 | 210,6 | 0,401908397 | 524 | 36,61860323 | 1024 |
google/embeddinggemma-300m | 90,1 | 127 | 82,6 | 125,4 | 106,3 | 0,405725191 | 262 | 73,31685472 | 2048 |
Alibaba-NLP/gte-modernbert-base | 106,8 | 124,3 | 90 | 106,2 | 106,8 | 0,407633588 | 262 | 68,65535808 | 2048 |
sentence-transformers/static-retrieval-mrl-en-v1 | 59 | 65,9 | 33,7 | 55,8 | 53,6 | 0,409160305 | 131 | 4,143317461 | 4096 |
mixedbread-ai/mxbai-embed-large-v1 | 56,1 | 62,2 | 47,9 | 48,9 | 53,8 | 0,410687023 | 131 | 45,95757341 | 4096 |
sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2 | 185,6 | 237 | 209,9 | 239,8 | 218,1 | 0,416221374 | 524 | 14,50295591 | 1024 |
paraphrase-multilingual-mpnet-base-v2 | 188,6 | 239 | 200,1 | 254,1 | 220,5 | 0,420801527 | 524 | 21,40634727 | 1024 |
dangvantuan/sentence-camembert-base | 237,4 | 254,2 | 184,3 | 207,9 | 220,9 | 0,421564885 | 524 | 15,13097763 | 1024 |
google/embeddinggemma-300m | 52,1 | 64,4 | 47 | 60,9 | 56,1 | 0,428244275 | 131 | 94,32970762 | 4096 |
multi-qa-mpnet-base-dot-v1 | 56,3 | 61,1 | 51,1 | 57,2 | 56,4 | 0,430534351 | 131 | 147,4030335 | 4096 |
Parallia/Fairly-Multilingual-ModernBERT-Embed-BE | 113,1 | 125,9 | 96,1 | 122 | 114,3 | 0,436259542 | 262 | 44,65966773 | 2048 |
tasksource/ModernBERT-base-embed | 127 | 128,5 | 104,7 | 100,2 | 115,1 | 0,439312977 | 262 | 47,17969394 | 2048 |
Alibaba-NLP/gte-modernbert-base | 60,4 | 61,7 | 53,7 | 60,2 | 59 | 0,450381679 | 131 | 84,02446485 | 4096 |
sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2 | 111,1 | 124 | 115 | 124,7 | 118,7 | 0,453053435 | 262 | 15,50453568 | 2048 |
dangvantuan/sentence-camembert-base | 123,6 | 130,9 | 109,1 | 112,5 | 119 | 0,454198473 | 262 | 16,11687374 | 2048 |
paraphrase-multilingual-mpnet-base-v2 | 114,6 | 123,1 | 118,6 | 120,3 | 119,1 | 0,454580153 | 262 | 27,78390026 | 2048 |
Parallia/Fairly-Multilingual-ModernBERT-Embed-BE | 58,9 | 63,9 | 53,8 | 63,6 | 60,1 | 0,458778626 | 131 | 62,07665706 | 4096 |
tasksource/ModernBERT-base-embed | 63,1 | 64,2 | 53,8 | 61 | 60,5 | 0,461832061 | 131 | 60,17870545 | 4096 |
dangvantuan/sentence-camembert-base | 64,7 | 66,5 | 59,2 | 59,5 | 62,5 | 0,477099237 | 131 | 17,60097528 | 4096 |
sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2 | 60,8 | 62,5 | 63,3 | 63,6 | 62,6 | 0,477862595 | 131 | 17,70088696 | 4096 |
paraphrase-multilingual-mpnet-base-v2 | 60,1 | 65,3 | 61,2 | 65,4 | 63 | 0,480916031 | 131 | 37,53072286 | 4096 |
Thank you for the excellent details! Given @yjoonjang 's results without fa2, I think it might not matter, but can you try without flash attention 2 @LPN64 ? We're investigating a bug here: https://github.com/huggingface/transformers/pull/40700
@yjoonjang , Google themselves have also evaluated on Belebele, you can find all results in: https://github.com/embeddings-benchmark/results/tree/main/results/google__embeddinggemma-300m/64614b0b8b64f0c6c1e52b07e4e9a4e8fe4d2da2
Here's Belebele: https://raw.githubusercontent.com/RyanMullins/mteb-results/c7359540dab61156ed875185c99df95a29e42e3f/results/google__embeddinggemma-300m/64614b0b8b64f0c6c1e52b07e4e9a4e8fe4d2da2/BelebeleRetrieval.json
Their reported results:
- kor_Hang-kor_Hang: 0.94142
- kor_Hang-eng_Latn: 0.95134
- eng_Latn-kor_Hang: 0.92925
If your scores is only kor_Hang-kor_Hang, then Google reports that EmbeddingGemma outperforms bge-m3 and snowflake-arctic-embed-l-v2.0 despite being much smaller. A finetuned EmbeddingGemma might then also outperform KURE-v1.
Somewhere, something is a bit off, it seems.
- Tom Aarsen
Oh, are you two on the correct transformers
version? See https://github.com/huggingface/transformers/releases/tag/v4.56.0-Embedding-Gemma-preview
pip install git+https://github.com/huggingface/[email protected]
If you're on older versions of transformers, the bidirectional configuration options in this model will be ignored, and it will use causal attention instead of bidirectional attention like what it was trained for. Using the latest transformers GitHub, with
pip install -U git+https://github.com/huggingface/transformers/
I get:
- BelebeleRetrieval, kor_Hang-kor_Hang: 0.94242
- BelebeleRetrieval, kor_Hang-eng_Latn: 0.95205
- BelebeleRetrieval, eng_Latn-kor_Hang: 0.92919
And with the potentially patched version, with
pip install -U git+https://github.com/vasqu/transformers@fix-gemma-embedding-fa
I get:
- BelebeleRetrieval, kor_Hang-kor_Hang: 0.94228
- BelebeleRetrieval, kor_Hang-eng_Latn: 0.95205
- BelebeleRetrieval, eng_Latn-kor_Hang: 0.92919
Using this CLI command:
mteb run -m "google/embeddinggemma-300m" -t BelebeleRetrieval --verbosity 3 --batch_size 8 --languages kor
So those already match Google's results very closely. I think you two just need to upgrade transformers
to the latest (or already to the upcoming patch. Note: the patch fixes some issues with FA2 and the sliding window on all attention implementations for inputs of 256+ tokens).
- Tom Aarsen
The transformers
version was indeed the issue. Now I'm able to reproduce the results successfully.
Thank you so much
@tomaarsen
!
It would be helpful if the Google authors could specify the recommended transformers
and torch
versions in their documentation.
Great to see embeddinggemma performing so well !
Glad it worked! The https://huggingface.co/blog/embeddinggemma blogpost is a bit clearer in this regard: https://huggingface.co/blog/embeddinggemma#sentence-transformers
For future readers: Weird performance? Try to install transformers
from source (pip install git+https://github.com/huggingface/transformers
), you might be on an incompatible version.
- Tom Aarsen
I would consider the issue fixed as it's performances makes more sense now (bfloat + transformers branch) but still like a LOT of embeddings model, bad crosslanguage performances.