Mismatch between vllm and non-vllm versions on image embedding

#3
by onecatperson - opened

Trying to audit results from this vllm repo and non-vllm original repo. Used exactly same code as the huggingface page suggests. The text embedding (both query and retrieval) aligns with each other, but image embedding does NOT. (vector from transformers is totally different from vllm)

Env: H20
vllm==0.9.2
transformers==4.52.4

Transformers code (almost direct copied from huggingface)

from transformers import AutoModel
import torch

# Initialize the model
model = AutoModel.from_pretrained("path", trust_remote_code=True, torch_dtype=torch.float16)

model.to("cuda")

# Encode image/document
image_embeddings = model.encode_image(
    images=["sample.jpg"],
    task="retrieval",
)

VLLM code (almost direct copied from huggingface)

import base64
import json
import logging
import time
from datetime import datetime
from io import BytesIO
from typing import Dict, Any, List, Optional

from flask import Flask, request, jsonify
from PIL import Image

import torch
from transformers import AutoModel, AutoProcessor

from vllm import LLM
from vllm.config import PoolerConfig
from vllm.inputs.data import TextPrompt

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# Jina Embeddings v4: load the model, processor
ckpt = "path"
model_jina_embeddings_v4 = LLM(
                                    model=ckpt,
                                    task="embed",
                                    override_pooler_config=PoolerConfig(pooling_type="ALL", normalize=False),
                                    dtype="float16",
                                    gpu_memory_utilization = 0.5,
                                    trust_remote_code=True,
                                )


# Create image prompt
image = Image.open("sample.jpg")
image_prompt = TextPrompt(
    prompt="<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>\n",
    multi_modal_data={"image": image},
)

# Encode all prompts
prompts = [image_prompt]
outputs = model_jina_embeddings_v4.encode(prompts)

def get_embeddings(outputs):
    VISION_START_TOKEN_ID, VISION_END_TOKEN_ID = 151652, 151653

    embeddings = []
    for output in outputs:
        if VISION_START_TOKEN_ID in output.prompt_token_ids:
            # Gather only vision tokens
            img_start_pos = torch.where(
                torch.tensor(output.prompt_token_ids) == VISION_START_TOKEN_ID
            )[0][0]
            img_end_pos = torch.where(
                torch.tensor(output.prompt_token_ids) == VISION_END_TOKEN_ID
            )[0][0]
            embeddings_tensor = output.outputs.data.detach().clone()[
                img_start_pos : img_end_pos + 1
            ]
        else:
            # Use all tokens for text-only prompts
            embeddings_tensor = output.outputs.data.detach().clone()
        
        # Pool and normalize embeddings
        pooled_output = (
            embeddings_tensor.sum(dim=0, dtype=torch.float32)
            / embeddings_tensor.shape[0]
        )
        embeddings.append(torch.nn.functional.normalize(pooled_output, dim=-1))
    return embeddings

start_time = time.time()
embeddings = get_embeddings(outputs)
end_time = time.time()

Sign up or log in to comment