#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Qwen3-Embedding-0.6B 推理测试代码 使用 RKLLM API 进行文本嵌入推理 """ import faulthandler faulthandler.enable() import os os.environ["RKLLM_LOG_LEVEL"] = "1" import numpy as np import time from typing import List, Dict, Any from rkllm_binding import * class Qwen3EmbeddingTester: def __init__(self, model_path: str, library_path: str = "./librkllmrt.so"): """ 初始化 Qwen3 嵌入模型测试器 Args: model_path: 模型文件路径(.rkllm 格式) library_path: RKLLM 库文件路径 """ self.model_path = model_path self.library_path = library_path self.runtime = None self.embeddings_buffer = [] self.current_result = None def callback_function(self, result_ptr, userdata_ptr, state_enum): """ 推理回调函数 Args: result_ptr: 结果指针 userdata_ptr: 用户数据指针 state_enum: 状态枚举 """ state = LLMCallState(state_enum) if state == LLMCallState.RKLLM_RUN_NORMAL: result = result_ptr.contents print(f"result: {result}") # 获取最后隐藏层输出作为嵌入 if result.last_hidden_layer.hidden_states and result.last_hidden_layer.embd_size > 0: embd_size = result.last_hidden_layer.embd_size num_tokens = result.last_hidden_layer.num_tokens print(f"获取到嵌入向量:维度={embd_size}, 令牌数={num_tokens}") # 将 C 数组转换为 numpy 数组 # 这里我们取最后一个 token 的隐藏状态作为句子嵌入 if num_tokens > 0: # 获取最后一个 token 的嵌入(shape: [embd_size]) last_token_embedding = np.array([ result.last_hidden_layer.hidden_states[(num_tokens-1) * embd_size + i] for i in range(embd_size) ]) self.current_result = { 'embedding': last_token_embedding, 'embd_size': embd_size, 'num_tokens': num_tokens } print(f"嵌入向量范数: {np.linalg.norm(last_token_embedding):.4f}") print(f"嵌入向量前10维: {last_token_embedding[:10]}") elif state == LLMCallState.RKLLM_RUN_ERROR: print("推理过程发生错误") def init_model(self): """初始化模型""" try: print(f"初始化 RKLLM 运行时,库路径: {self.library_path}") self.runtime = RKLLMRuntime(self.library_path) print("创建默认参数...") params = self.runtime.create_default_param() # 配置参数 params.model_path = self.model_path.encode('utf-8') params.max_context_len = 1024 # 设置上下文长度 params.max_new_tokens = 1 # 嵌入任务不需要生成新token params.temperature = 1.0 # 嵌入任务温度设置 params.top_k = 1 # 嵌入任务不需要采样 params.top_p = 1.0 # 嵌入任务不需要采样 # 扩展参数配置 params.extend_param.base_domain_id = 1 # 建议为 >1B 模型设置为1 params.extend_param.embed_flash = 0 # 是否使用flash存储Embedding params.extend_param.enabled_cpus_num = 4 # 启用的CPU核心数 params.extend_param.enabled_cpus_mask = 0x0F # CPU核心掩码 print(f"初始化模型: {self.model_path}") self.runtime.init(params, self.callback_function) self.runtime.set_chat_template("","","") print("模型初始化成功!") except Exception as e: print(f"模型初始化失败: {e}") raise def get_detailed_instruct(self, task_description: str, query: str) -> str: """ 构建指令提示词(参考 README 中的用法) Args: task_description: 任务描述 query: 查询文本 Returns: 格式化的指令提示词 """ return f'Instruct: {task_description}\nQuery: {query}' def encode_text(self, text: str, task_description: str = None) -> np.ndarray: """ 编码文本为嵌入向量 Args: text: 要编码的文本 task_description: 任务描述,如果提供则使用指令提示 Returns: 嵌入向量(numpy数组) """ try: # 如果提供了任务描述,则使用指令提示 if task_description: input_text = self.get_detailed_instruct(task_description, text) else: input_text = text print(f"编码文本: {input_text[:100]}{'...' if len(input_text) > 100 else ''}") # 准备输入 rk_input = RKLLMInput() rk_input.input_type = RKLLMInputType.RKLLM_INPUT_PROMPT c_prompt = input_text.encode('utf-8') rk_input._union_data.prompt_input = c_prompt # 准备推理参数 infer_params = RKLLMInferParam() infer_params.mode = RKLLMInferMode.RKLLM_INFER_GET_LAST_HIDDEN_LAYER # 获取隐藏层输出 infer_params.keep_history = 0 # 不保留历史 # 清空之前的结果 self.current_result = None self.runtime.clear_kv_cache(False) # 执行推理 start_time = time.time() self.runtime.run(rk_input, infer_params) end_time = time.time() print(f"推理耗时: {end_time - start_time:.3f}秒") if self.current_result and 'embedding' in self.current_result: # 对嵌入向量进行L2标准化 embedding = self.current_result['embedding'] normalized_embedding = embedding / np.linalg.norm(embedding) return normalized_embedding else: raise RuntimeError("未能获取到有效的嵌入向量") except Exception as e: print(f"编码文本时发生错误: {e}") raise def compute_similarity(self, emb1: np.ndarray, emb2: np.ndarray) -> float: """ 计算两个嵌入向量的余弦相似度 Args: emb1: 第一个嵌入向量 emb2: 第二个嵌入向量 Returns: 余弦相似度值 """ return np.dot(emb1, emb2) def test_embedding_similarity(self): """测试嵌入相似度计算""" print("\n" + "="*50) print("测试嵌入相似度计算") print("="*50) # 测试文本 task_description = "Given a web search query, retrieve relevant passages that answer the query" queries = [ "What is the capital of China?", "Explain gravity" ] documents = [ "The capital of China is Beijing.", "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun." ] # 编码查询(使用指令) print("\n编码查询文本:") query_embeddings = [] for i, query in enumerate(queries): print(f"\n查询 {i+1}: {query}") emb = self.encode_text(query, task_description) query_embeddings.append(emb) # 编码文档(不使用指令) print("\n编码文档文本:") doc_embeddings = [] for i, doc in enumerate(documents): print(f"\n文档 {i+1}: {doc}") emb = self.encode_text(doc) doc_embeddings.append(emb) # 计算相似度矩阵 print("\n计算相似度矩阵:") print("查询 vs 文档相似度:") print("-" * 30) similarities = [] for i, q_emb in enumerate(query_embeddings): row_similarities = [] for j, d_emb in enumerate(doc_embeddings): sim = self.compute_similarity(q_emb, d_emb) row_similarities.append(sim) print(f"查询{i+1} vs 文档{j+1}: {sim:.4f}") similarities.append(row_similarities) print() return similarities def test_multilingual_embedding(self): """测试多语言嵌入能力""" print("\n" + "="*50) print("测试多语言嵌入能力") print("="*50) # 多语言测试文本(相同含义的不同语言) texts = { "英语": "Hello, how are you?", "中文": "你好,你好吗?", "法语": "Bonjour, comment allez-vous?", "西班牙语": "Hola, ¿cómo estás?", "日语": "こんにちは、元気ですか?" } embeddings = {} print("\n编码多语言文本:") for lang, text in texts.items(): print(f"\n{lang}: {text}") emb = self.encode_text(text) embeddings[lang] = emb # 计算跨语言相似度 print("\n跨语言相似度:") print("-" * 30) languages = list(texts.keys()) for i, lang1 in enumerate(languages): for j, lang2 in enumerate(languages): if i <= j: sim = self.compute_similarity(embeddings[lang1], embeddings[lang2]) print(f"{lang1} vs {lang2}: {sim:.4f}") def test_code_embedding(self): """测试代码嵌入能力""" print("\n" + "="*50) print("测试代码嵌入能力") print("="*50) # 代码示例 codes = { "Python函数": """ def fibonacci(n): if n <= 1: return n return fibonacci(n-1) + fibonacci(n-2) """, "JavaScript函数": """ function fibonacci(n) { if (n <= 1) return n; return fibonacci(n-1) + fibonacci(n-2); } """, "C++函数": """ int fibonacci(int n) { if (n <= 1) return n; return fibonacci(n-1) + fibonacci(n-2); } """, "数组排序": """ def bubble_sort(arr): n = len(arr) for i in range(n): for j in range(0, n-i-1): if arr[j] > arr[j+1]: arr[j], arr[j+1] = arr[j+1], arr[j] """ } embeddings = {} print("\n编码代码文本:") for name, code in codes.items(): print(f"\n{name}:") print(code[:100] + "..." if len(code) > 100 else code) emb = self.encode_text(code) embeddings[name] = emb # 计算代码相似度 print("\n代码相似度:") print("-" * 30) code_names = list(codes.keys()) for i, name1 in enumerate(code_names): for j, name2 in enumerate(code_names): if i <= j: sim = self.compute_similarity(embeddings[name1], embeddings[name2]) print(f"{name1} vs {name2}: {sim:.4f}") def cleanup(self): """清理资源""" if self.runtime: try: self.runtime.destroy() print("模型资源已清理") except Exception as e: print(f"清理资源时发生错误: {e}") def main(): """主函数""" import argparse # 解析命令行参数 parser = argparse.ArgumentParser(description='Qwen3-Embedding-0.6B 推理测试') parser.add_argument('model_path', help='模型文件路径(.rkllm格式)') parser.add_argument('--library_path', default="./librkllmrt.so", help='RKLLM库文件路径(默认为./librkllmrt.so)') args = parser.parse_args() # 检查文件是否存在 if not os.path.exists(args.model_path): print(f"错误: 模型文件不存在: {args.model_path}") print("请确保:") print("1. 已下载 Qwen3-Embedding-0.6B 模型") print("2. 已使用 rkllm-convert.py 将模型转换为 .rkllm 格式") return if not os.path.exists(args.library_path): print(f"错误: RKLLM 库文件不存在: {args.library_path}") print("请确保 librkllmrt.so 在当前目录或 LD_LIBRARY_PATH 中") return print("Qwen3-Embedding-0.6B 推理测试") print("=" * 50) # 创建测试器 tester = Qwen3EmbeddingTester(args.model_path, args.library_path) try: # 初始化模型 tester.init_model() # 运行测试 print("\n开始运行嵌入测试...") # 测试基础嵌入相似度 tester.test_embedding_similarity() # 测试多语言嵌入 tester.test_multilingual_embedding() # 测试代码嵌入 tester.test_code_embedding() print("\n" + "="*50) print("所有测试完成!") print("="*50) except KeyboardInterrupt: print("\n测试被用户中断") except Exception as e: print(f"\n测试过程中发生错误: {e}") import traceback traceback.print_exc() finally: # 清理资源 tester.cleanup() if __name__ == "__main__": main()