SaHa Logo

SaHa-Qwen2-VL-2B-Instruct

Model Summary

SaHa-Qwen2-VL-2B-Instruct is a state-of-the-art universal multimodal embedding model based on the Qwen2-VL-2B-Instruct architecture. This model has been fine-tuned using our innovative Self-aware Hard Negative Sampling (SaHa) strategy, which is designed to efficiently adapt generative Multimodal Large Language Models (MLLMs) for discriminative embedding tasks.

Our approach leverages a hierarchical embedding prompt to unlock the powerful zero-shot capabilities of MLLMs and then fine-tunes the model with SaHa to achieve superior performance on universal multimodal retrieval benchmarks. This model significantly reduces the computational costs associated with traditional contrastive pre-training while delivering state-of-the-art results.

For more details, please refer to our paper: From Generator to Embedder: Harnessing Innate Abilities of Multimodal LLMs via Building Zero-Shot Discriminative Embedding Model and our GitHub repository.

How to Use

You can easily use this model with the transformers library for sentence and image similarity tasks. Make sure you have the latest version of transformers, torch, and Pillow installed.

pip install transformers>=4.46.1 torch pillow

Get Embeddings from Text or Image

Here's how to get embeddings for text or image inputs. The model uses a specific prompt structure to generate high-quality embeddings.

Load Model

import torch
from transformers import AutoProcessor, AutoConfig, Qwen2VLForConditionalGeneration

# Load the model and tokenizer
model_id = "Y-J-Ju/SaHa-Qwen2-VL-2B-Instruct"

config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
config._attn_implementation = "flash_attention_2"
config.vision_config._attn_implementation = "flash_attention_2"

model = Qwen2VLForConditionalGeneration.from_pretrained(
    model_id, torch_dtype=torch.bfloat16, config=config, device_map="cuda:0"
)
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True,
                                          min_pixels=256 * 28 * 28, max_pixels=1280 * 28 * 28)

Data Preparation and Prompting

texts = [
    "The Tesla Cybertruck is a battery electric pickup truck built by Tesla, Inc. since 2023.",
    "Korea University",
]
images = [
    'https://upload.wikimedia.org/wikipedia/commons/e/e9/Tesla_Cybertruck_damaged_window.jpg',
    'https://upload.wikimedia.org/wikipedia/commons/thumb/7/74/Korea_University.jpg/960px-Korea_University.jpg',
]
task_instruction = 'Find an image that matches the given text.'

system_prompt = "Given an image, summarize the provided image in one word. Given only text, describe the text in one word."
represent_prompt = "Represent the given text in one word."

query_form = '<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{task_instruction}\n{query}\n{represent_prompt}<|im_end|>\n<|im_start|>assistant\n'
candidate_form = '<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{cand}<|im_end|>\n<|im_start|>assistant\n'

queries = [
    query_form.format(system_prompt=system_prompt, task_instruction=task_instruction, query=text, represent_prompt=represent_prompt)
    for text in texts
]
candidates = [
    candidate_form.format(system_prompt=system_prompt, cand='<|image_pad|>')
    for _ in images
]

Get Embeddings

from PIL import Image
import io
from urllib import request
import torch.nn.functional as F

## Query (Text)
inputs = processor(text=queries, images=None, return_tensors="pt", padding=True)
model_input = {k: v if isinstance(v, list) else v.to(model.device) for k, v in inputs.items()}
outputs = model(**model_input, return_dict=True, output_hidden_states=True)
hidden_states = outputs.hidden_states[-1]
query_embed = hidden_states[:,-1]

## Candidate (Image)
pil_images = [Image.open(io.BytesIO(request.urlopen(url).read())) for url in images]
inputs = processor(text=candidates, images=pil_images, return_tensors="pt", padding=True)
model_input = {k: v if isinstance(v, list) else v.to(model.device) for k, v in inputs.items()}
outputs = model(**model_input, return_dict=True, output_hidden_states=True)
cand_embed = outputs.hidden_states[-1][:,-1]

query_embed = F.normalize(query_embed, p=2, dim=-1)
cand_embed = F.normalize(cand_embed, p=2, dim=-1)
print(query_embed @ cand_embed.T)

Outputs (Similarity)

tensor([[0.3984, 0.0254],
        [0.0092, 0.2930]], device='cuda:0', dtype=torch.bfloat16)

Training and Evaluation

Training Data

The model was fine-tuned on the Massive Multimodal Embedding Benchmark (MMEB) training set, which consists of approximately 829,000 pairs from 20 in-domain datasets.

Evaluation Data

The model's performance was evaluated on the MMEB evaluation set, which includes 36 datasets covering four meta-tasks: Classification, Visual Question Answering (VQA), Retrieval, and Visual Grounding.

Performance

The SaHa-Qwen2-VL-2B-Instruct model achieves state-of-the-art performance in its parameter class on the MMEB benchmark, outperforming methods that rely on large-scale contrastive pre-training.

Model Params Classification Retrieval VQA Grounding IND OND Overall Avg.
Ours (SaHa-Qwen2-VL-2B) 2.2B 65.4 70.0 59.1 83.0 71.2 62.1 67.1

Citation

If you find this model useful in your research, please cite our paper:

@misc{ju2025generatorembedder,
      title={From Generator to Embedder: Harnessing Innate Abilities of Multimodal LLMs via Building Zero-Shot Discriminative Embedding Model}, 
      author={Yeong-Joon Ju and Seong-Whan Lee},
      year={2025},
      eprint={2508.00955},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2508.00955}, 
}
Downloads last month
5
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for Y-J-Ju/SaHa-Qwen2-VL-2B-Instruct

Base model

Qwen/Qwen2-VL-2B
Finetuned
(236)
this model

Dataset used to train Y-J-Ju/SaHa-Qwen2-VL-2B-Instruct