Qwen3-Embedding-RKLLM / test_embedding.py
happyme531's picture
Upload 6 files
c8faba2 verified
raw
history blame
14 kB
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Qwen3-Embedding-0.6B 推理测试代码
使用 RKLLM API 进行文本嵌入推理
"""
import faulthandler
faulthandler.enable()
import os
os.environ["RKLLM_LOG_LEVEL"] = "1"
import numpy as np
import time
from typing import List, Dict, Any
from rkllm_binding import *
class Qwen3EmbeddingTester:
def __init__(self, model_path: str, library_path: str = "./librkllmrt.so"):
"""
初始化 Qwen3 嵌入模型测试器
Args:
model_path: 模型文件路径(.rkllm 格式)
library_path: RKLLM 库文件路径
"""
self.model_path = model_path
self.library_path = library_path
self.runtime = None
self.embeddings_buffer = []
self.current_result = None
def callback_function(self, result_ptr, userdata_ptr, state_enum):
"""
推理回调函数
Args:
result_ptr: 结果指针
userdata_ptr: 用户数据指针
state_enum: 状态枚举
"""
state = LLMCallState(state_enum)
if state == LLMCallState.RKLLM_RUN_NORMAL:
result = result_ptr.contents
print(f"result: {result}")
# 获取最后隐藏层输出作为嵌入
if result.last_hidden_layer.hidden_states and result.last_hidden_layer.embd_size > 0:
embd_size = result.last_hidden_layer.embd_size
num_tokens = result.last_hidden_layer.num_tokens
print(f"获取到嵌入向量:维度={embd_size}, 令牌数={num_tokens}")
# 将 C 数组转换为 numpy 数组
# 这里我们取最后一个 token 的隐藏状态作为句子嵌入
if num_tokens > 0:
# 获取最后一个 token 的嵌入(shape: [embd_size])
last_token_embedding = np.array([
result.last_hidden_layer.hidden_states[(num_tokens-1) * embd_size + i]
for i in range(embd_size)
])
self.current_result = {
'embedding': last_token_embedding,
'embd_size': embd_size,
'num_tokens': num_tokens
}
print(f"嵌入向量范数: {np.linalg.norm(last_token_embedding):.4f}")
print(f"嵌入向量前10维: {last_token_embedding[:10]}")
elif state == LLMCallState.RKLLM_RUN_ERROR:
print("推理过程发生错误")
def init_model(self):
"""初始化模型"""
try:
print(f"初始化 RKLLM 运行时,库路径: {self.library_path}")
self.runtime = RKLLMRuntime(self.library_path)
print("创建默认参数...")
params = self.runtime.create_default_param()
# 配置参数
params.model_path = self.model_path.encode('utf-8')
params.max_context_len = 1024 # 设置上下文长度
params.max_new_tokens = 1 # 嵌入任务不需要生成新token
params.temperature = 1.0 # 嵌入任务温度设置
params.top_k = 1 # 嵌入任务不需要采样
params.top_p = 1.0 # 嵌入任务不需要采样
# 扩展参数配置
params.extend_param.base_domain_id = 1 # 建议为 >1B 模型设置为1
params.extend_param.embed_flash = 0 # 是否使用flash存储Embedding
params.extend_param.enabled_cpus_num = 4 # 启用的CPU核心数
params.extend_param.enabled_cpus_mask = 0x0F # CPU核心掩码
print(f"初始化模型: {self.model_path}")
self.runtime.init(params, self.callback_function)
self.runtime.set_chat_template("","","")
print("模型初始化成功!")
except Exception as e:
print(f"模型初始化失败: {e}")
raise
def get_detailed_instruct(self, task_description: str, query: str) -> str:
"""
构建指令提示词(参考 README 中的用法)
Args:
task_description: 任务描述
query: 查询文本
Returns:
格式化的指令提示词
"""
return f'Instruct: {task_description}\nQuery: {query}'
def encode_text(self, text: str, task_description: str = None) -> np.ndarray:
"""
编码文本为嵌入向量
Args:
text: 要编码的文本
task_description: 任务描述,如果提供则使用指令提示
Returns:
嵌入向量(numpy数组)
"""
try:
# 如果提供了任务描述,则使用指令提示
if task_description:
input_text = self.get_detailed_instruct(task_description, text)
else:
input_text = text
print(f"编码文本: {input_text[:100]}{'...' if len(input_text) > 100 else ''}")
# 准备输入
rk_input = RKLLMInput()
rk_input.input_type = RKLLMInputType.RKLLM_INPUT_PROMPT
c_prompt = input_text.encode('utf-8')
rk_input._union_data.prompt_input = c_prompt
# 准备推理参数
infer_params = RKLLMInferParam()
infer_params.mode = RKLLMInferMode.RKLLM_INFER_GET_LAST_HIDDEN_LAYER # 获取隐藏层输出
infer_params.keep_history = 0 # 不保留历史
# 清空之前的结果
self.current_result = None
self.runtime.clear_kv_cache(False)
# 执行推理
start_time = time.time()
self.runtime.run(rk_input, infer_params)
end_time = time.time()
print(f"推理耗时: {end_time - start_time:.3f}秒")
if self.current_result and 'embedding' in self.current_result:
# 对嵌入向量进行L2标准化
embedding = self.current_result['embedding']
normalized_embedding = embedding / np.linalg.norm(embedding)
return normalized_embedding
else:
raise RuntimeError("未能获取到有效的嵌入向量")
except Exception as e:
print(f"编码文本时发生错误: {e}")
raise
def compute_similarity(self, emb1: np.ndarray, emb2: np.ndarray) -> float:
"""
计算两个嵌入向量的余弦相似度
Args:
emb1: 第一个嵌入向量
emb2: 第二个嵌入向量
Returns:
余弦相似度值
"""
return np.dot(emb1, emb2)
def test_embedding_similarity(self):
"""测试嵌入相似度计算"""
print("\n" + "="*50)
print("测试嵌入相似度计算")
print("="*50)
# 测试文本
task_description = "Given a web search query, retrieve relevant passages that answer the query"
queries = [
"What is the capital of China?",
"Explain gravity"
]
documents = [
"The capital of China is Beijing.",
"Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun."
]
# 编码查询(使用指令)
print("\n编码查询文本:")
query_embeddings = []
for i, query in enumerate(queries):
print(f"\n查询 {i+1}: {query}")
emb = self.encode_text(query, task_description)
query_embeddings.append(emb)
# 编码文档(不使用指令)
print("\n编码文档文本:")
doc_embeddings = []
for i, doc in enumerate(documents):
print(f"\n文档 {i+1}: {doc}")
emb = self.encode_text(doc)
doc_embeddings.append(emb)
# 计算相似度矩阵
print("\n计算相似度矩阵:")
print("查询 vs 文档相似度:")
print("-" * 30)
similarities = []
for i, q_emb in enumerate(query_embeddings):
row_similarities = []
for j, d_emb in enumerate(doc_embeddings):
sim = self.compute_similarity(q_emb, d_emb)
row_similarities.append(sim)
print(f"查询{i+1} vs 文档{j+1}: {sim:.4f}")
similarities.append(row_similarities)
print()
return similarities
def test_multilingual_embedding(self):
"""测试多语言嵌入能力"""
print("\n" + "="*50)
print("测试多语言嵌入能力")
print("="*50)
# 多语言测试文本(相同含义的不同语言)
texts = {
"英语": "Hello, how are you?",
"中文": "你好,你好吗?",
"法语": "Bonjour, comment allez-vous?",
"西班牙语": "Hola, ¿cómo estás?",
"日语": "こんにちは、元気ですか?"
}
embeddings = {}
print("\n编码多语言文本:")
for lang, text in texts.items():
print(f"\n{lang}: {text}")
emb = self.encode_text(text)
embeddings[lang] = emb
# 计算跨语言相似度
print("\n跨语言相似度:")
print("-" * 30)
languages = list(texts.keys())
for i, lang1 in enumerate(languages):
for j, lang2 in enumerate(languages):
if i <= j:
sim = self.compute_similarity(embeddings[lang1], embeddings[lang2])
print(f"{lang1} vs {lang2}: {sim:.4f}")
def test_code_embedding(self):
"""测试代码嵌入能力"""
print("\n" + "="*50)
print("测试代码嵌入能力")
print("="*50)
# 代码示例
codes = {
"Python函数": """
def fibonacci(n):
if n <= 1:
return n
return fibonacci(n-1) + fibonacci(n-2)
""",
"JavaScript函数": """
function fibonacci(n) {
if (n <= 1) return n;
return fibonacci(n-1) + fibonacci(n-2);
}
""",
"C++函数": """
int fibonacci(int n) {
if (n <= 1) return n;
return fibonacci(n-1) + fibonacci(n-2);
}
""",
"数组排序": """
def bubble_sort(arr):
n = len(arr)
for i in range(n):
for j in range(0, n-i-1):
if arr[j] > arr[j+1]:
arr[j], arr[j+1] = arr[j+1], arr[j]
"""
}
embeddings = {}
print("\n编码代码文本:")
for name, code in codes.items():
print(f"\n{name}:")
print(code[:100] + "..." if len(code) > 100 else code)
emb = self.encode_text(code)
embeddings[name] = emb
# 计算代码相似度
print("\n代码相似度:")
print("-" * 30)
code_names = list(codes.keys())
for i, name1 in enumerate(code_names):
for j, name2 in enumerate(code_names):
if i <= j:
sim = self.compute_similarity(embeddings[name1], embeddings[name2])
print(f"{name1} vs {name2}: {sim:.4f}")
def cleanup(self):
"""清理资源"""
if self.runtime:
try:
self.runtime.destroy()
print("模型资源已清理")
except Exception as e:
print(f"清理资源时发生错误: {e}")
def main():
"""主函数"""
import argparse
# 解析命令行参数
parser = argparse.ArgumentParser(description='Qwen3-Embedding-0.6B 推理测试')
parser.add_argument('model_path', help='模型文件路径(.rkllm格式)')
parser.add_argument('--library_path', default="./librkllmrt.so", help='RKLLM库文件路径(默认为./librkllmrt.so)')
args = parser.parse_args()
# 检查文件是否存在
if not os.path.exists(args.model_path):
print(f"错误: 模型文件不存在: {args.model_path}")
print("请确保:")
print("1. 已下载 Qwen3-Embedding-0.6B 模型")
print("2. 已使用 rkllm-convert.py 将模型转换为 .rkllm 格式")
return
if not os.path.exists(args.library_path):
print(f"错误: RKLLM 库文件不存在: {args.library_path}")
print("请确保 librkllmrt.so 在当前目录或 LD_LIBRARY_PATH 中")
return
print("Qwen3-Embedding-0.6B 推理测试")
print("=" * 50)
# 创建测试器
tester = Qwen3EmbeddingTester(args.model_path, args.library_path)
try:
# 初始化模型
tester.init_model()
# 运行测试
print("\n开始运行嵌入测试...")
# 测试基础嵌入相似度
tester.test_embedding_similarity()
# 测试多语言嵌入
tester.test_multilingual_embedding()
# 测试代码嵌入
tester.test_code_embedding()
print("\n" + "="*50)
print("所有测试完成!")
print("="*50)
except KeyboardInterrupt:
print("\n测试被用户中断")
except Exception as e:
print(f"\n测试过程中发生错误: {e}")
import traceback
traceback.print_exc()
finally:
# 清理资源
tester.cleanup()
if __name__ == "__main__":
main()