|
|
|
|
|
""" |
|
Qwen3-Embedding-0.6B 推理测试代码 |
|
使用 RKLLM API 进行文本嵌入推理 |
|
""" |
|
import faulthandler |
|
faulthandler.enable() |
|
import os |
|
os.environ["RKLLM_LOG_LEVEL"] = "1" |
|
import numpy as np |
|
import time |
|
from typing import List, Dict, Any |
|
from rkllm_binding import * |
|
|
|
|
|
class Qwen3EmbeddingTester: |
|
def __init__(self, model_path: str, library_path: str = "./librkllmrt.so"): |
|
""" |
|
初始化 Qwen3 嵌入模型测试器 |
|
|
|
Args: |
|
model_path: 模型文件路径(.rkllm 格式) |
|
library_path: RKLLM 库文件路径 |
|
""" |
|
self.model_path = model_path |
|
self.library_path = library_path |
|
self.runtime = None |
|
self.embeddings_buffer = [] |
|
self.current_result = None |
|
|
|
def callback_function(self, result_ptr, userdata_ptr, state_enum): |
|
""" |
|
推理回调函数 |
|
|
|
Args: |
|
result_ptr: 结果指针 |
|
userdata_ptr: 用户数据指针 |
|
state_enum: 状态枚举 |
|
""" |
|
state = LLMCallState(state_enum) |
|
|
|
if state == LLMCallState.RKLLM_RUN_NORMAL: |
|
result = result_ptr.contents |
|
print(f"result: {result}") |
|
|
|
if result.last_hidden_layer.hidden_states and result.last_hidden_layer.embd_size > 0: |
|
embd_size = result.last_hidden_layer.embd_size |
|
num_tokens = result.last_hidden_layer.num_tokens |
|
|
|
print(f"获取到嵌入向量:维度={embd_size}, 令牌数={num_tokens}") |
|
|
|
|
|
|
|
if num_tokens > 0: |
|
|
|
last_token_embedding = np.array([ |
|
result.last_hidden_layer.hidden_states[(num_tokens-1) * embd_size + i] |
|
for i in range(embd_size) |
|
]) |
|
|
|
self.current_result = { |
|
'embedding': last_token_embedding, |
|
'embd_size': embd_size, |
|
'num_tokens': num_tokens |
|
} |
|
|
|
print(f"嵌入向量范数: {np.linalg.norm(last_token_embedding):.4f}") |
|
print(f"嵌入向量前10维: {last_token_embedding[:10]}") |
|
|
|
elif state == LLMCallState.RKLLM_RUN_ERROR: |
|
print("推理过程发生错误") |
|
|
|
def init_model(self): |
|
"""初始化模型""" |
|
try: |
|
print(f"初始化 RKLLM 运行时,库路径: {self.library_path}") |
|
self.runtime = RKLLMRuntime(self.library_path) |
|
|
|
print("创建默认参数...") |
|
params = self.runtime.create_default_param() |
|
|
|
|
|
params.model_path = self.model_path.encode('utf-8') |
|
params.max_context_len = 1024 |
|
params.max_new_tokens = 1 |
|
params.temperature = 1.0 |
|
params.top_k = 1 |
|
params.top_p = 1.0 |
|
|
|
|
|
params.extend_param.base_domain_id = 1 |
|
params.extend_param.embed_flash = 0 |
|
params.extend_param.enabled_cpus_num = 4 |
|
params.extend_param.enabled_cpus_mask = 0x0F |
|
|
|
print(f"初始化模型: {self.model_path}") |
|
self.runtime.init(params, self.callback_function) |
|
self.runtime.set_chat_template("","","") |
|
print("模型初始化成功!") |
|
|
|
except Exception as e: |
|
print(f"模型初始化失败: {e}") |
|
raise |
|
|
|
def get_detailed_instruct(self, task_description: str, query: str) -> str: |
|
""" |
|
构建指令提示词(参考 README 中的用法) |
|
|
|
Args: |
|
task_description: 任务描述 |
|
query: 查询文本 |
|
|
|
Returns: |
|
格式化的指令提示词 |
|
""" |
|
return f'Instruct: {task_description}\nQuery: {query}' |
|
|
|
def encode_text(self, text: str, task_description: str = None) -> np.ndarray: |
|
""" |
|
编码文本为嵌入向量 |
|
|
|
Args: |
|
text: 要编码的文本 |
|
task_description: 任务描述,如果提供则使用指令提示 |
|
|
|
Returns: |
|
嵌入向量(numpy数组) |
|
""" |
|
try: |
|
|
|
if task_description: |
|
input_text = self.get_detailed_instruct(task_description, text) |
|
else: |
|
input_text = text |
|
|
|
print(f"编码文本: {input_text[:100]}{'...' if len(input_text) > 100 else ''}") |
|
|
|
|
|
rk_input = RKLLMInput() |
|
rk_input.input_type = RKLLMInputType.RKLLM_INPUT_PROMPT |
|
c_prompt = input_text.encode('utf-8') |
|
rk_input._union_data.prompt_input = c_prompt |
|
|
|
|
|
infer_params = RKLLMInferParam() |
|
infer_params.mode = RKLLMInferMode.RKLLM_INFER_GET_LAST_HIDDEN_LAYER |
|
infer_params.keep_history = 0 |
|
|
|
|
|
self.current_result = None |
|
self.runtime.clear_kv_cache(False) |
|
|
|
|
|
start_time = time.time() |
|
self.runtime.run(rk_input, infer_params) |
|
end_time = time.time() |
|
|
|
print(f"推理耗时: {end_time - start_time:.3f}秒") |
|
|
|
if self.current_result and 'embedding' in self.current_result: |
|
|
|
embedding = self.current_result['embedding'] |
|
normalized_embedding = embedding / np.linalg.norm(embedding) |
|
return normalized_embedding |
|
else: |
|
raise RuntimeError("未能获取到有效的嵌入向量") |
|
|
|
except Exception as e: |
|
print(f"编码文本时发生错误: {e}") |
|
raise |
|
|
|
def compute_similarity(self, emb1: np.ndarray, emb2: np.ndarray) -> float: |
|
""" |
|
计算两个嵌入向量的余弦相似度 |
|
|
|
Args: |
|
emb1: 第一个嵌入向量 |
|
emb2: 第二个嵌入向量 |
|
|
|
Returns: |
|
余弦相似度值 |
|
""" |
|
return np.dot(emb1, emb2) |
|
|
|
def test_embedding_similarity(self): |
|
"""测试嵌入相似度计算""" |
|
print("\n" + "="*50) |
|
print("测试嵌入相似度计算") |
|
print("="*50) |
|
|
|
|
|
task_description = "Given a web search query, retrieve relevant passages that answer the query" |
|
|
|
queries = [ |
|
"What is the capital of China?", |
|
"Explain gravity" |
|
] |
|
|
|
documents = [ |
|
"The capital of China is Beijing.", |
|
"Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun." |
|
] |
|
|
|
|
|
print("\n编码查询文本:") |
|
query_embeddings = [] |
|
for i, query in enumerate(queries): |
|
print(f"\n查询 {i+1}: {query}") |
|
emb = self.encode_text(query, task_description) |
|
query_embeddings.append(emb) |
|
|
|
|
|
print("\n编码文档文本:") |
|
doc_embeddings = [] |
|
for i, doc in enumerate(documents): |
|
print(f"\n文档 {i+1}: {doc}") |
|
emb = self.encode_text(doc) |
|
doc_embeddings.append(emb) |
|
|
|
|
|
print("\n计算相似度矩阵:") |
|
print("查询 vs 文档相似度:") |
|
print("-" * 30) |
|
|
|
similarities = [] |
|
for i, q_emb in enumerate(query_embeddings): |
|
row_similarities = [] |
|
for j, d_emb in enumerate(doc_embeddings): |
|
sim = self.compute_similarity(q_emb, d_emb) |
|
row_similarities.append(sim) |
|
print(f"查询{i+1} vs 文档{j+1}: {sim:.4f}") |
|
similarities.append(row_similarities) |
|
print() |
|
|
|
return similarities |
|
|
|
def test_multilingual_embedding(self): |
|
"""测试多语言嵌入能力""" |
|
print("\n" + "="*50) |
|
print("测试多语言嵌入能力") |
|
print("="*50) |
|
|
|
|
|
texts = { |
|
"英语": "Hello, how are you?", |
|
"中文": "你好,你好吗?", |
|
"法语": "Bonjour, comment allez-vous?", |
|
"西班牙语": "Hola, ¿cómo estás?", |
|
"日语": "こんにちは、元気ですか?" |
|
} |
|
|
|
embeddings = {} |
|
print("\n编码多语言文本:") |
|
for lang, text in texts.items(): |
|
print(f"\n{lang}: {text}") |
|
emb = self.encode_text(text) |
|
embeddings[lang] = emb |
|
|
|
|
|
print("\n跨语言相似度:") |
|
print("-" * 30) |
|
|
|
languages = list(texts.keys()) |
|
for i, lang1 in enumerate(languages): |
|
for j, lang2 in enumerate(languages): |
|
if i <= j: |
|
sim = self.compute_similarity(embeddings[lang1], embeddings[lang2]) |
|
print(f"{lang1} vs {lang2}: {sim:.4f}") |
|
|
|
def test_code_embedding(self): |
|
"""测试代码嵌入能力""" |
|
print("\n" + "="*50) |
|
print("测试代码嵌入能力") |
|
print("="*50) |
|
|
|
|
|
codes = { |
|
"Python函数": """ |
|
def fibonacci(n): |
|
if n <= 1: |
|
return n |
|
return fibonacci(n-1) + fibonacci(n-2) |
|
""", |
|
"JavaScript函数": """ |
|
function fibonacci(n) { |
|
if (n <= 1) return n; |
|
return fibonacci(n-1) + fibonacci(n-2); |
|
} |
|
""", |
|
"C++函数": """ |
|
int fibonacci(int n) { |
|
if (n <= 1) return n; |
|
return fibonacci(n-1) + fibonacci(n-2); |
|
} |
|
""", |
|
"数组排序": """ |
|
def bubble_sort(arr): |
|
n = len(arr) |
|
for i in range(n): |
|
for j in range(0, n-i-1): |
|
if arr[j] > arr[j+1]: |
|
arr[j], arr[j+1] = arr[j+1], arr[j] |
|
""" |
|
} |
|
|
|
embeddings = {} |
|
print("\n编码代码文本:") |
|
for name, code in codes.items(): |
|
print(f"\n{name}:") |
|
print(code[:100] + "..." if len(code) > 100 else code) |
|
emb = self.encode_text(code) |
|
embeddings[name] = emb |
|
|
|
|
|
print("\n代码相似度:") |
|
print("-" * 30) |
|
|
|
code_names = list(codes.keys()) |
|
for i, name1 in enumerate(code_names): |
|
for j, name2 in enumerate(code_names): |
|
if i <= j: |
|
sim = self.compute_similarity(embeddings[name1], embeddings[name2]) |
|
print(f"{name1} vs {name2}: {sim:.4f}") |
|
|
|
def cleanup(self): |
|
"""清理资源""" |
|
if self.runtime: |
|
try: |
|
self.runtime.destroy() |
|
print("模型资源已清理") |
|
except Exception as e: |
|
print(f"清理资源时发生错误: {e}") |
|
|
|
def main(): |
|
"""主函数""" |
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser(description='Qwen3-Embedding-0.6B 推理测试') |
|
parser.add_argument('model_path', help='模型文件路径(.rkllm格式)') |
|
parser.add_argument('--library_path', default="./librkllmrt.so", help='RKLLM库文件路径(默认为./librkllmrt.so)') |
|
args = parser.parse_args() |
|
|
|
|
|
if not os.path.exists(args.model_path): |
|
print(f"错误: 模型文件不存在: {args.model_path}") |
|
print("请确保:") |
|
print("1. 已下载 Qwen3-Embedding-0.6B 模型") |
|
print("2. 已使用 rkllm-convert.py 将模型转换为 .rkllm 格式") |
|
return |
|
|
|
if not os.path.exists(args.library_path): |
|
print(f"错误: RKLLM 库文件不存在: {args.library_path}") |
|
print("请确保 librkllmrt.so 在当前目录或 LD_LIBRARY_PATH 中") |
|
return |
|
|
|
print("Qwen3-Embedding-0.6B 推理测试") |
|
print("=" * 50) |
|
|
|
|
|
tester = Qwen3EmbeddingTester(args.model_path, args.library_path) |
|
|
|
try: |
|
|
|
tester.init_model() |
|
|
|
|
|
print("\n开始运行嵌入测试...") |
|
|
|
|
|
tester.test_embedding_similarity() |
|
|
|
|
|
tester.test_multilingual_embedding() |
|
|
|
|
|
tester.test_code_embedding() |
|
|
|
print("\n" + "="*50) |
|
print("所有测试完成!") |
|
print("="*50) |
|
|
|
except KeyboardInterrupt: |
|
print("\n测试被用户中断") |
|
except Exception as e: |
|
print(f"\n测试过程中发生错误: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
finally: |
|
|
|
tester.cleanup() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|