EmbeddingGemma
Collection
7 items
•
Updated
•
2
The Model mlx-community/embeddinggemma-300m-qat-q4_0-unquantized-bf16 was converted to MLX format from google/embeddinggemma-300m-qat-q4_0-unquantized using mlx-lm version 0.0.4.
pip install mlx-embeddings
from mlx_embeddings import load, generate
import mlx.core as mx
model, tokenizer = load("mlx-community/embeddinggemma-300m-qat-q4_0-unquantized-bf16")
# For text embedding
sentences = [
"task: sentence similarity | query: Nothing really matters.",
"task: sentence similarity | query: The dog is barking.",
"task: sentence similarity | query: The dog is barking.",
]
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='mlx')
# Compute token embeddings
input_ids = encoded_input['input_ids']
attention_mask = encoded_input['attention_mask']
output = model(input_ids, attention_mask)
embeddings = output.text_embeds # Normalized embeddings
# Compute dot product between normalized embeddings
similarity_matrix = mx.matmul(embeddings, embeddings.T)
print("Similarity matrix between texts:")
print(similarity_matrix)
# You can use these task-specific prefixes for different tasks
task_prefixes = {
"BitextMining": "task: search result | query: ",
"Clustering": "task: clustering | query: ",
"Classification": "task: classification | query: ",
"MultilabelClassification": "task: classification | query: ",
"PairClassification": "task: sentence similarity | query: ",
"InstructionRetrieval": "task: code retrieval | query: ",
"Reranking": "task: search result | query: ",
"Retrieval": "task: search result | query: ",
"Retrieval-query": "task: search result | query: ",
"Retrieval-document": "title: none | text: ",
"STS": "task: sentence similarity | query: ",
"Summarization": "task: summarization | query: ",
"document": "title: none | text: "
}