Upload 6 files
Browse files- .gitattributes +3 -0
- Qwen3-Embedding-0.6B_f16.rkllm +3 -0
- Qwen3-Embedding-0.6B_w8a8.rkllm +3 -0
- librkllmrt.so +3 -0
- rkllm-convert.py +74 -0
- rkllm_binding.py +658 -0
- test_embedding.py +396 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
librkllmrt.so filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
Qwen3-Embedding-0.6B_f16.rkllm filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
Qwen3-Embedding-0.6B_w8a8.rkllm filter=lfs diff=lfs merge=lfs -text
|
Qwen3-Embedding-0.6B_f16.rkllm
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4a383a715837f432a8af06798e236ab70a30b80d55cdbe4d7e4d489de639310b
|
| 3 |
+
size 1524801182
|
Qwen3-Embedding-0.6B_w8a8.rkllm
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:be4ffc5c46b6246c71459081c5fdb8b8984c52936fa23083e342f4a843945806
|
| 3 |
+
size 931372078
|
librkllmrt.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f6a9c2de93cf94bb524eb071c27190ad4c83401e01b562534f265dff4cb40da2
|
| 3 |
+
size 6710712
|
rkllm-convert.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from rkllm.api import RKLLM
|
| 3 |
+
|
| 4 |
+
def convert_model(model_path, output_name, do_quantization=False):
|
| 5 |
+
"""转换单个模型"""
|
| 6 |
+
llm = RKLLM()
|
| 7 |
+
|
| 8 |
+
print(f"正在加载模型: {model_path}")
|
| 9 |
+
ret = llm.load_huggingface(model=model_path, model_lora=None, device='cpu')
|
| 10 |
+
if ret != 0:
|
| 11 |
+
print(f'加载模型失败: {model_path}')
|
| 12 |
+
return ret
|
| 13 |
+
|
| 14 |
+
print(f"正在构建模型: {output_name} (量化: {do_quantization})")
|
| 15 |
+
qparams = None
|
| 16 |
+
ret = llm.build(do_quantization=do_quantization, optimization_level=1, quantized_dtype='w8a8',
|
| 17 |
+
quantized_algorithm='normal', target_platform='rk3588', num_npu_core=3, extra_qparams=qparams)
|
| 18 |
+
|
| 19 |
+
if ret != 0:
|
| 20 |
+
print(f'构建模型失败: {output_name}')
|
| 21 |
+
return ret
|
| 22 |
+
|
| 23 |
+
# 导出rkllm模型
|
| 24 |
+
print(f"正在导出模型: {output_name}")
|
| 25 |
+
ret = llm.export_rkllm(output_name)
|
| 26 |
+
if ret != 0:
|
| 27 |
+
print(f'导出模型失败: {output_name}')
|
| 28 |
+
return ret
|
| 29 |
+
|
| 30 |
+
print(f"成功转换: {output_name}")
|
| 31 |
+
return 0
|
| 32 |
+
|
| 33 |
+
def main():
|
| 34 |
+
"""主函数:遍历所有子文件夹并转换模型"""
|
| 35 |
+
current_dir = '.'
|
| 36 |
+
|
| 37 |
+
# 获取所有子文件夹
|
| 38 |
+
subdirs = [d for d in os.listdir(current_dir)
|
| 39 |
+
if os.path.isdir(os.path.join(current_dir, d)) and not d.startswith('.')]
|
| 40 |
+
|
| 41 |
+
print(f"找到 {len(subdirs)} 个模型文件夹: {subdirs}")
|
| 42 |
+
|
| 43 |
+
for subdir in subdirs:
|
| 44 |
+
model_path = os.path.join(current_dir, subdir)
|
| 45 |
+
|
| 46 |
+
# 生成输出文件名
|
| 47 |
+
base_name = subdir.replace('/', '_').replace('\\', '_')
|
| 48 |
+
quantized_output = f"{base_name}_w8a8.rkllm"
|
| 49 |
+
unquantized_output = f"{base_name}_f16.rkllm"
|
| 50 |
+
|
| 51 |
+
print(f"\n{'='*50}")
|
| 52 |
+
print(f"处理模型文件夹: {subdir}")
|
| 53 |
+
print(f"{'='*50}")
|
| 54 |
+
|
| 55 |
+
# 转换非量化版本
|
| 56 |
+
print(f"\n--- 转换非量化版本 ---")
|
| 57 |
+
ret = convert_model(model_path, unquantized_output, do_quantization=False)
|
| 58 |
+
if ret != 0:
|
| 59 |
+
print(f"非量化版本转换失败: {subdir}")
|
| 60 |
+
continue
|
| 61 |
+
|
| 62 |
+
# 转换量化版本
|
| 63 |
+
print(f"\n--- 转换量化版本 ---")
|
| 64 |
+
ret = convert_model(model_path, quantized_output, do_quantization=True)
|
| 65 |
+
if ret != 0:
|
| 66 |
+
print(f"量化版本转换失败: {subdir}")
|
| 67 |
+
continue
|
| 68 |
+
|
| 69 |
+
print(f"\n✓ {subdir} 模型转换完成!")
|
| 70 |
+
print(f" - 非量化版本: {unquantized_output}")
|
| 71 |
+
print(f" - 量化版本: {quantized_output}")
|
| 72 |
+
|
| 73 |
+
if __name__ == "__main__":
|
| 74 |
+
main()
|
rkllm_binding.py
ADDED
|
@@ -0,0 +1,658 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ctypes
|
| 2 |
+
import enum
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
# Define constants from the header
|
| 6 |
+
CPU0 = (1 << 0) # 0x01
|
| 7 |
+
CPU1 = (1 << 1) # 0x02
|
| 8 |
+
CPU2 = (1 << 2) # 0x04
|
| 9 |
+
CPU3 = (1 << 3) # 0x08
|
| 10 |
+
CPU4 = (1 << 4) # 0x10
|
| 11 |
+
CPU5 = (1 << 5) # 0x20
|
| 12 |
+
CPU6 = (1 << 6) # 0x40
|
| 13 |
+
CPU7 = (1 << 7) # 0x80
|
| 14 |
+
|
| 15 |
+
# --- Enums ---
|
| 16 |
+
class LLMCallState(enum.IntEnum):
|
| 17 |
+
RKLLM_RUN_NORMAL = 0
|
| 18 |
+
RKLLM_RUN_WAITING = 1
|
| 19 |
+
RKLLM_RUN_FINISH = 2
|
| 20 |
+
RKLLM_RUN_ERROR = 3
|
| 21 |
+
|
| 22 |
+
class RKLLMInputType(enum.IntEnum):
|
| 23 |
+
RKLLM_INPUT_PROMPT = 0
|
| 24 |
+
RKLLM_INPUT_TOKEN = 1
|
| 25 |
+
RKLLM_INPUT_EMBED = 2
|
| 26 |
+
RKLLM_INPUT_MULTIMODAL = 3
|
| 27 |
+
|
| 28 |
+
class RKLLMInferMode(enum.IntEnum):
|
| 29 |
+
RKLLM_INFER_GENERATE = 0
|
| 30 |
+
RKLLM_INFER_GET_LAST_HIDDEN_LAYER = 1
|
| 31 |
+
RKLLM_INFER_GET_LOGITS = 2
|
| 32 |
+
|
| 33 |
+
# --- Structures ---
|
| 34 |
+
class RKLLMExtendParam(ctypes.Structure):
|
| 35 |
+
# 基础iommu domain ID, 对>1b的模型建议设置为1
|
| 36 |
+
base_domain_id: ctypes.c_int32
|
| 37 |
+
# 是否使用flash存储Embedding
|
| 38 |
+
embed_flash: ctypes.c_int8
|
| 39 |
+
# 启用的cpu核心数
|
| 40 |
+
enabled_cpus_num: ctypes.c_int8
|
| 41 |
+
# 启用的cpu核心掩码
|
| 42 |
+
enabled_cpus_mask: ctypes.c_uint32
|
| 43 |
+
reserved: ctypes.c_uint8 * 106
|
| 44 |
+
|
| 45 |
+
_fields_ = [
|
| 46 |
+
("base_domain_id", ctypes.c_int32),
|
| 47 |
+
("embed_flash", ctypes.c_int8),
|
| 48 |
+
("enabled_cpus_num", ctypes.c_int8),
|
| 49 |
+
("enabled_cpus_mask", ctypes.c_uint32),
|
| 50 |
+
("reserved", ctypes.c_uint8 * 106)
|
| 51 |
+
]
|
| 52 |
+
|
| 53 |
+
class RKLLMParam(ctypes.Structure):
|
| 54 |
+
# 模型文件路径
|
| 55 |
+
model_path: ctypes.c_char_p
|
| 56 |
+
# 上下文窗口最大token数
|
| 57 |
+
max_context_len: ctypes.c_int32
|
| 58 |
+
# 最大生成新token数
|
| 59 |
+
max_new_tokens: ctypes.c_int32
|
| 60 |
+
# Top-K采样参数
|
| 61 |
+
top_k: ctypes.c_int32
|
| 62 |
+
# 上下文窗口移动时保留的kv缓存数量
|
| 63 |
+
n_keep: ctypes.c_int32
|
| 64 |
+
# Top-P采样参数
|
| 65 |
+
top_p: ctypes.c_float
|
| 66 |
+
# 采样温度,影响token选择的随机性
|
| 67 |
+
temperature: ctypes.c_float
|
| 68 |
+
# 重复token惩罚
|
| 69 |
+
repeat_penalty: ctypes.c_float
|
| 70 |
+
# 频繁token惩罚
|
| 71 |
+
frequency_penalty: ctypes.c_float
|
| 72 |
+
# 输入中已存在token的惩罚
|
| 73 |
+
presence_penalty: ctypes.c_float
|
| 74 |
+
# Mirostat采样策略标志(0表示禁用)
|
| 75 |
+
mirostat: ctypes.c_int32
|
| 76 |
+
# Mirostat采样Tau参数
|
| 77 |
+
mirostat_tau: ctypes.c_float
|
| 78 |
+
# Mirostat采样Eta参数
|
| 79 |
+
mirostat_eta: ctypes.c_float
|
| 80 |
+
# 是否跳过特殊token
|
| 81 |
+
skip_special_token: ctypes.c_bool
|
| 82 |
+
# 是否异步推理
|
| 83 |
+
is_async: ctypes.c_bool
|
| 84 |
+
# 多模态输入中图像的起始Token
|
| 85 |
+
img_start: ctypes.c_char_p
|
| 86 |
+
# 多模态输入中图像的结束Token
|
| 87 |
+
img_end: ctypes.c_char_p
|
| 88 |
+
# 图像内容指针
|
| 89 |
+
img_content: ctypes.c_char_p
|
| 90 |
+
# 扩展参数
|
| 91 |
+
extend_param: RKLLMExtendParam
|
| 92 |
+
|
| 93 |
+
_fields_ = [
|
| 94 |
+
("model_path", ctypes.c_char_p), # 模型文件路径
|
| 95 |
+
("max_context_len", ctypes.c_int32), # 上下文窗口最大token数
|
| 96 |
+
("max_new_tokens", ctypes.c_int32), # 最大生成新token数
|
| 97 |
+
("top_k", ctypes.c_int32), # Top-K采样参数
|
| 98 |
+
("n_keep", ctypes.c_int32), # 上下文窗口移动时保留的kv缓存数量
|
| 99 |
+
("top_p", ctypes.c_float), # Top-P(nucleus)采样参数
|
| 100 |
+
("temperature", ctypes.c_float), # 采样温度,影响token选择的随机性
|
| 101 |
+
("repeat_penalty", ctypes.c_float), # 重复token惩罚
|
| 102 |
+
("frequency_penalty", ctypes.c_float), # 频繁token惩罚
|
| 103 |
+
("presence_penalty", ctypes.c_float), # 输入中已存在token的惩罚
|
| 104 |
+
("mirostat", ctypes.c_int32), # Mirostat采样策略标志(0表示禁用)
|
| 105 |
+
("mirostat_tau", ctypes.c_float), # Mirostat采样Tau参数
|
| 106 |
+
("mirostat_eta", ctypes.c_float), # Mirostat采样Eta参数
|
| 107 |
+
("skip_special_token", ctypes.c_bool), # 是否跳过特殊token
|
| 108 |
+
("is_async", ctypes.c_bool), # 是否异步推理
|
| 109 |
+
("img_start", ctypes.c_char_p), # 多模态输入中图像的起始Token
|
| 110 |
+
("img_end", ctypes.c_char_p), # 多模态输入中图像的结束Token
|
| 111 |
+
("img_content", ctypes.c_char_p), # 图像内容指针
|
| 112 |
+
("extend_param", RKLLMExtendParam) # 扩展参数
|
| 113 |
+
]
|
| 114 |
+
|
| 115 |
+
class RKLLMLoraAdapter(ctypes.Structure):
|
| 116 |
+
lora_adapter_path: ctypes.c_char_p
|
| 117 |
+
lora_adapter_name: ctypes.c_char_p
|
| 118 |
+
scale: ctypes.c_float
|
| 119 |
+
|
| 120 |
+
_fields_ = [
|
| 121 |
+
("lora_adapter_path", ctypes.c_char_p),
|
| 122 |
+
("lora_adapter_name", ctypes.c_char_p),
|
| 123 |
+
("scale", ctypes.c_float)
|
| 124 |
+
]
|
| 125 |
+
|
| 126 |
+
class RKLLMEmbedInput(ctypes.Structure):
|
| 127 |
+
# Shape: [n_tokens, embed_size]
|
| 128 |
+
embed: ctypes.POINTER(ctypes.c_float)
|
| 129 |
+
n_tokens: ctypes.c_size_t
|
| 130 |
+
|
| 131 |
+
_fields_ = [
|
| 132 |
+
("embed", ctypes.POINTER(ctypes.c_float)),
|
| 133 |
+
("n_tokens", ctypes.c_size_t)
|
| 134 |
+
]
|
| 135 |
+
|
| 136 |
+
class RKLLMTokenInput(ctypes.Structure):
|
| 137 |
+
# Shape: [n_tokens]
|
| 138 |
+
input_ids: ctypes.POINTER(ctypes.c_int32)
|
| 139 |
+
n_tokens: ctypes.c_size_t
|
| 140 |
+
|
| 141 |
+
_fields_ = [
|
| 142 |
+
("input_ids", ctypes.POINTER(ctypes.c_int32)),
|
| 143 |
+
("n_tokens", ctypes.c_size_t)
|
| 144 |
+
]
|
| 145 |
+
|
| 146 |
+
class RKLLMMultiModelInput(ctypes.Structure):
|
| 147 |
+
prompt: ctypes.c_char_p
|
| 148 |
+
image_embed: ctypes.POINTER(ctypes.c_float)
|
| 149 |
+
n_image_tokens: ctypes.c_size_t
|
| 150 |
+
n_image: ctypes.c_size_t
|
| 151 |
+
image_width: ctypes.c_size_t
|
| 152 |
+
image_height: ctypes.c_size_t
|
| 153 |
+
|
| 154 |
+
_fields_ = [
|
| 155 |
+
("prompt", ctypes.c_char_p),
|
| 156 |
+
("image_embed", ctypes.POINTER(ctypes.c_float)),
|
| 157 |
+
("n_image_tokens", ctypes.c_size_t),
|
| 158 |
+
("n_image", ctypes.c_size_t),
|
| 159 |
+
("image_width", ctypes.c_size_t),
|
| 160 |
+
("image_height", ctypes.c_size_t)
|
| 161 |
+
]
|
| 162 |
+
|
| 163 |
+
class _RKLLMInputUnion(ctypes.Union):
|
| 164 |
+
prompt_input: ctypes.c_char_p
|
| 165 |
+
embed_input: RKLLMEmbedInput
|
| 166 |
+
token_input: RKLLMTokenInput
|
| 167 |
+
multimodal_input: RKLLMMultiModelInput
|
| 168 |
+
|
| 169 |
+
_fields_ = [
|
| 170 |
+
("prompt_input", ctypes.c_char_p),
|
| 171 |
+
("embed_input", RKLLMEmbedInput),
|
| 172 |
+
("token_input", RKLLMTokenInput),
|
| 173 |
+
("multimodal_input", RKLLMMultiModelInput)
|
| 174 |
+
]
|
| 175 |
+
|
| 176 |
+
class RKLLMInput(ctypes.Structure):
|
| 177 |
+
input_type: ctypes.c_int
|
| 178 |
+
_union_data: _RKLLMInputUnion
|
| 179 |
+
|
| 180 |
+
_fields_ = [
|
| 181 |
+
("input_type", ctypes.c_int), # Enum will be passed as int, changed RKLLMInputType to ctypes.c_int
|
| 182 |
+
("_union_data", _RKLLMInputUnion)
|
| 183 |
+
]
|
| 184 |
+
# Properties to make accessing union members easier
|
| 185 |
+
@property
|
| 186 |
+
def prompt_input(self) -> bytes: # Assuming c_char_p maps to bytes
|
| 187 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_PROMPT:
|
| 188 |
+
return self._union_data.prompt_input
|
| 189 |
+
raise AttributeError("Not a prompt input")
|
| 190 |
+
@prompt_input.setter
|
| 191 |
+
def prompt_input(self, value: bytes): # Assuming c_char_p maps to bytes
|
| 192 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_PROMPT:
|
| 193 |
+
self._union_data.prompt_input = value
|
| 194 |
+
else:
|
| 195 |
+
raise AttributeError("Not a prompt input")
|
| 196 |
+
@property
|
| 197 |
+
def embed_input(self) -> RKLLMEmbedInput:
|
| 198 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_EMBED:
|
| 199 |
+
return self._union_data.embed_input
|
| 200 |
+
raise AttributeError("Not an embed input")
|
| 201 |
+
@embed_input.setter
|
| 202 |
+
def embed_input(self, value: RKLLMEmbedInput):
|
| 203 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_EMBED:
|
| 204 |
+
self._union_data.embed_input = value
|
| 205 |
+
else:
|
| 206 |
+
raise AttributeError("Not an embed input")
|
| 207 |
+
|
| 208 |
+
@property
|
| 209 |
+
def token_input(self) -> RKLLMTokenInput:
|
| 210 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_TOKEN:
|
| 211 |
+
return self._union_data.token_input
|
| 212 |
+
raise AttributeError("Not a token input")
|
| 213 |
+
@token_input.setter
|
| 214 |
+
def token_input(self, value: RKLLMTokenInput):
|
| 215 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_TOKEN:
|
| 216 |
+
self._union_data.token_input = value
|
| 217 |
+
else:
|
| 218 |
+
raise AttributeError("Not a token input")
|
| 219 |
+
|
| 220 |
+
@property
|
| 221 |
+
def multimodal_input(self) -> RKLLMMultiModelInput:
|
| 222 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_MULTIMODAL:
|
| 223 |
+
return self._union_data.multimodal_input
|
| 224 |
+
raise AttributeError("Not a multimodal input")
|
| 225 |
+
@multimodal_input.setter
|
| 226 |
+
def multimodal_input(self, value: RKLLMMultiModelInput):
|
| 227 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_MULTIMODAL:
|
| 228 |
+
self._union_data.multimodal_input = value
|
| 229 |
+
else:
|
| 230 |
+
raise AttributeError("Not a multimodal input")
|
| 231 |
+
|
| 232 |
+
class RKLLMLoraParam(ctypes.Structure): # For inference
|
| 233 |
+
lora_adapter_name: ctypes.c_char_p
|
| 234 |
+
|
| 235 |
+
_fields_ = [
|
| 236 |
+
("lora_adapter_name", ctypes.c_char_p)
|
| 237 |
+
]
|
| 238 |
+
|
| 239 |
+
class RKLLMPromptCacheParam(ctypes.Structure): # For inference
|
| 240 |
+
save_prompt_cache: ctypes.c_int # bool-like
|
| 241 |
+
prompt_cache_path: ctypes.c_char_p
|
| 242 |
+
|
| 243 |
+
_fields_ = [
|
| 244 |
+
("save_prompt_cache", ctypes.c_int), # bool-like
|
| 245 |
+
("prompt_cache_path", ctypes.c_char_p)
|
| 246 |
+
]
|
| 247 |
+
|
| 248 |
+
class RKLLMInferParam(ctypes.Structure):
|
| 249 |
+
mode: ctypes.c_int
|
| 250 |
+
lora_params: ctypes.POINTER(RKLLMLoraParam)
|
| 251 |
+
prompt_cache_params: ctypes.POINTER(RKLLMPromptCacheParam)
|
| 252 |
+
keep_history: ctypes.c_int # bool-like
|
| 253 |
+
|
| 254 |
+
_fields_ = [
|
| 255 |
+
("mode", ctypes.c_int), # Enum will be passed as int, changed RKLLMInferMode to ctypes.c_int
|
| 256 |
+
("lora_params", ctypes.POINTER(RKLLMLoraParam)),
|
| 257 |
+
("prompt_cache_params", ctypes.POINTER(RKLLMPromptCacheParam)),
|
| 258 |
+
("keep_history", ctypes.c_int) # bool-like
|
| 259 |
+
]
|
| 260 |
+
|
| 261 |
+
class RKLLMResultLastHiddenLayer(ctypes.Structure):
|
| 262 |
+
# Shape: [num_tokens, embd_size]
|
| 263 |
+
hidden_states: ctypes.POINTER(ctypes.c_float)
|
| 264 |
+
# 隐藏层大小
|
| 265 |
+
embd_size: ctypes.c_int
|
| 266 |
+
# 输出token数
|
| 267 |
+
num_tokens: ctypes.c_int
|
| 268 |
+
|
| 269 |
+
_fields_ = [
|
| 270 |
+
("hidden_states", ctypes.POINTER(ctypes.c_float)),
|
| 271 |
+
("embd_size", ctypes.c_int),
|
| 272 |
+
("num_tokens", ctypes.c_int)
|
| 273 |
+
]
|
| 274 |
+
|
| 275 |
+
class RKLLMResultLogits(ctypes.Structure):
|
| 276 |
+
# Shape: [num_tokens, vocab_size]
|
| 277 |
+
logits: ctypes.POINTER(ctypes.c_float)
|
| 278 |
+
# 词汇表大小
|
| 279 |
+
vocab_size: ctypes.c_int
|
| 280 |
+
# 输出token数
|
| 281 |
+
num_tokens: ctypes.c_int
|
| 282 |
+
|
| 283 |
+
_fields_ = [
|
| 284 |
+
("logits", ctypes.POINTER(ctypes.c_float)),
|
| 285 |
+
("vocab_size", ctypes.c_int),
|
| 286 |
+
("num_tokens", ctypes.c_int)
|
| 287 |
+
]
|
| 288 |
+
|
| 289 |
+
class RKLLMResult(ctypes.Structure):
|
| 290 |
+
text: ctypes.c_char_p
|
| 291 |
+
token_id: ctypes.c_int32
|
| 292 |
+
last_hidden_layer: RKLLMResultLastHiddenLayer
|
| 293 |
+
logits: RKLLMResultLogits
|
| 294 |
+
|
| 295 |
+
_fields_ = [
|
| 296 |
+
("text", ctypes.c_char_p),
|
| 297 |
+
("token_id", ctypes.c_int32),
|
| 298 |
+
("last_hidden_layer", RKLLMResultLastHiddenLayer),
|
| 299 |
+
("logits", RKLLMResultLogits)
|
| 300 |
+
]
|
| 301 |
+
|
| 302 |
+
# --- Typedefs ---
|
| 303 |
+
LLMHandle = ctypes.c_void_p
|
| 304 |
+
|
| 305 |
+
# --- Callback Function Type ---
|
| 306 |
+
LLMResultCallback = ctypes.CFUNCTYPE(
|
| 307 |
+
None, # return type: void
|
| 308 |
+
ctypes.POINTER(RKLLMResult),
|
| 309 |
+
ctypes.c_void_p, # userdata
|
| 310 |
+
ctypes.c_int # enum, will be passed as int. Changed LLMCallState to ctypes.c_int
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
class RKLLMRuntime:
|
| 315 |
+
def __init__(self, library_path="./librkllmrt.so"):
|
| 316 |
+
try:
|
| 317 |
+
self.lib = ctypes.CDLL(library_path)
|
| 318 |
+
except OSError as e:
|
| 319 |
+
raise OSError(f"Failed to load RKLLM library from {library_path}. "
|
| 320 |
+
f"Ensure it's in your LD_LIBRARY_PATH or provide the full path. Error: {e}")
|
| 321 |
+
self._setup_functions()
|
| 322 |
+
self.llm_handle = LLMHandle()
|
| 323 |
+
self._c_callback = None # To keep the callback object alive
|
| 324 |
+
|
| 325 |
+
def _setup_functions(self):
|
| 326 |
+
# RKLLMParam rkllm_createDefaultParam();
|
| 327 |
+
self.lib.rkllm_createDefaultParam.restype = RKLLMParam
|
| 328 |
+
self.lib.rkllm_createDefaultParam.argtypes = []
|
| 329 |
+
|
| 330 |
+
# int rkllm_init(LLMHandle* handle, RKLLMParam* param, LLMResultCallback callback);
|
| 331 |
+
self.lib.rkllm_init.restype = ctypes.c_int
|
| 332 |
+
self.lib.rkllm_init.argtypes = [
|
| 333 |
+
ctypes.POINTER(LLMHandle),
|
| 334 |
+
ctypes.POINTER(RKLLMParam),
|
| 335 |
+
LLMResultCallback
|
| 336 |
+
]
|
| 337 |
+
|
| 338 |
+
# int rkllm_load_lora(LLMHandle handle, RKLLMLoraAdapter* lora_adapter);
|
| 339 |
+
self.lib.rkllm_load_lora.restype = ctypes.c_int
|
| 340 |
+
self.lib.rkllm_load_lora.argtypes = [LLMHandle, ctypes.POINTER(RKLLMLoraAdapter)]
|
| 341 |
+
|
| 342 |
+
# int rkllm_load_prompt_cache(LLMHandle handle, const char* prompt_cache_path);
|
| 343 |
+
self.lib.rkllm_load_prompt_cache.restype = ctypes.c_int
|
| 344 |
+
self.lib.rkllm_load_prompt_cache.argtypes = [LLMHandle, ctypes.c_char_p]
|
| 345 |
+
|
| 346 |
+
# int rkllm_release_prompt_cache(LLMHandle handle);
|
| 347 |
+
self.lib.rkllm_release_prompt_cache.restype = ctypes.c_int
|
| 348 |
+
self.lib.rkllm_release_prompt_cache.argtypes = [LLMHandle]
|
| 349 |
+
|
| 350 |
+
# int rkllm_destroy(LLMHandle handle);
|
| 351 |
+
self.lib.rkllm_destroy.restype = ctypes.c_int
|
| 352 |
+
self.lib.rkllm_destroy.argtypes = [LLMHandle]
|
| 353 |
+
|
| 354 |
+
# int rkllm_run(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata);
|
| 355 |
+
self.lib.rkllm_run.restype = ctypes.c_int
|
| 356 |
+
self.lib.rkllm_run.argtypes = [
|
| 357 |
+
LLMHandle,
|
| 358 |
+
ctypes.POINTER(RKLLMInput),
|
| 359 |
+
ctypes.POINTER(RKLLMInferParam),
|
| 360 |
+
ctypes.c_void_p # userdata
|
| 361 |
+
]
|
| 362 |
+
|
| 363 |
+
# int rkllm_run_async(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata);
|
| 364 |
+
# Assuming async also takes userdata for the callback context
|
| 365 |
+
self.lib.rkllm_run_async.restype = ctypes.c_int
|
| 366 |
+
self.lib.rkllm_run_async.argtypes = [
|
| 367 |
+
LLMHandle,
|
| 368 |
+
ctypes.POINTER(RKLLMInput),
|
| 369 |
+
ctypes.POINTER(RKLLMInferParam),
|
| 370 |
+
ctypes.c_void_p # userdata
|
| 371 |
+
]
|
| 372 |
+
|
| 373 |
+
# int rkllm_abort(LLMHandle handle);
|
| 374 |
+
self.lib.rkllm_abort.restype = ctypes.c_int
|
| 375 |
+
self.lib.rkllm_abort.argtypes = [LLMHandle]
|
| 376 |
+
|
| 377 |
+
# int rkllm_is_running(LLMHandle handle);
|
| 378 |
+
self.lib.rkllm_is_running.restype = ctypes.c_int # 0 if running, non-zero otherwise
|
| 379 |
+
self.lib.rkllm_is_running.argtypes = [LLMHandle]
|
| 380 |
+
|
| 381 |
+
# int rkllm_clear_kv_cache(LLMHandle handle, int keep_system_prompt);
|
| 382 |
+
self.lib.rkllm_clear_kv_cache.restype = ctypes.c_int
|
| 383 |
+
self.lib.rkllm_clear_kv_cache.argtypes = [LLMHandle, ctypes.c_int]
|
| 384 |
+
|
| 385 |
+
# int rkllm_set_chat_template(LLMHandle handle, const char* system_prompt, const char* prompt_prefix, const char* prompt_postfix);
|
| 386 |
+
self.lib.rkllm_set_chat_template.restype = ctypes.c_int
|
| 387 |
+
self.lib.rkllm_set_chat_template.argtypes = [
|
| 388 |
+
LLMHandle,
|
| 389 |
+
ctypes.c_char_p,
|
| 390 |
+
ctypes.c_char_p,
|
| 391 |
+
ctypes.c_char_p
|
| 392 |
+
]
|
| 393 |
+
|
| 394 |
+
def create_default_param(self) -> RKLLMParam:
|
| 395 |
+
"""Creates a default RKLLMParam structure."""
|
| 396 |
+
return self.lib.rkllm_createDefaultParam()
|
| 397 |
+
|
| 398 |
+
def init(self, param: RKLLMParam, callback_func) -> int:
|
| 399 |
+
"""
|
| 400 |
+
Initializes the LLM.
|
| 401 |
+
:param param: RKLLMParam structure.
|
| 402 |
+
:param callback_func: A Python function that matches the signature:
|
| 403 |
+
def my_callback(result_ptr, userdata_ptr, state_enum):
|
| 404 |
+
result = result_ptr.contents # RKLLMResult
|
| 405 |
+
# Process result
|
| 406 |
+
# userdata can be retrieved if passed during run, or ignored
|
| 407 |
+
# state = LLMCallState(state_enum)
|
| 408 |
+
:return: 0 for success, non-zero for failure.
|
| 409 |
+
"""
|
| 410 |
+
if not callable(callback_func):
|
| 411 |
+
raise ValueError("callback_func must be a callable Python function.")
|
| 412 |
+
|
| 413 |
+
# Keep a reference to the ctypes callback object to prevent it from being garbage collected
|
| 414 |
+
self._c_callback = LLMResultCallback(callback_func)
|
| 415 |
+
|
| 416 |
+
ret = self.lib.rkllm_init(ctypes.byref(self.llm_handle), ctypes.byref(param), self._c_callback)
|
| 417 |
+
if ret != 0:
|
| 418 |
+
raise RuntimeError(f"rkllm_init failed with error code {ret}")
|
| 419 |
+
return ret
|
| 420 |
+
|
| 421 |
+
def load_lora(self, lora_adapter: RKLLMLoraAdapter) -> int:
|
| 422 |
+
"""Loads a Lora adapter."""
|
| 423 |
+
ret = self.lib.rkllm_load_lora(self.llm_handle, ctypes.byref(lora_adapter))
|
| 424 |
+
if ret != 0:
|
| 425 |
+
raise RuntimeError(f"rkllm_load_lora failed with error code {ret}")
|
| 426 |
+
return ret
|
| 427 |
+
|
| 428 |
+
def load_prompt_cache(self, prompt_cache_path: str) -> int:
|
| 429 |
+
"""Loads a prompt cache from a file."""
|
| 430 |
+
c_path = prompt_cache_path.encode('utf-8')
|
| 431 |
+
ret = self.lib.rkllm_load_prompt_cache(self.llm_handle, c_path)
|
| 432 |
+
if ret != 0:
|
| 433 |
+
raise RuntimeError(f"rkllm_load_prompt_cache failed for {prompt_cache_path} with error code {ret}")
|
| 434 |
+
return ret
|
| 435 |
+
|
| 436 |
+
def release_prompt_cache(self) -> int:
|
| 437 |
+
"""Releases the prompt cache from memory."""
|
| 438 |
+
ret = self.lib.rkllm_release_prompt_cache(self.llm_handle)
|
| 439 |
+
if ret != 0:
|
| 440 |
+
raise RuntimeError(f"rkllm_release_prompt_cache failed with error code {ret}")
|
| 441 |
+
return ret
|
| 442 |
+
|
| 443 |
+
def destroy(self) -> int:
|
| 444 |
+
"""Destroys the LLM instance and releases resources."""
|
| 445 |
+
if self.llm_handle and self.llm_handle.value: # Check if handle is not NULL
|
| 446 |
+
ret = self.lib.rkllm_destroy(self.llm_handle)
|
| 447 |
+
self.llm_handle = LLMHandle() # Reset handle
|
| 448 |
+
if ret != 0:
|
| 449 |
+
# Don't raise here as it might be called in __del__
|
| 450 |
+
print(f"Warning: rkllm_destroy failed with error code {ret}")
|
| 451 |
+
return ret
|
| 452 |
+
return 0 # Already destroyed or not initialized
|
| 453 |
+
|
| 454 |
+
def run(self, rkllm_input: RKLLMInput, rkllm_infer_params: RKLLMInferParam, userdata=None) -> int:
|
| 455 |
+
"""Runs an LLM inference task synchronously."""
|
| 456 |
+
# userdata can be a ctypes.py_object if you want to pass Python objects,
|
| 457 |
+
# then cast to c_void_p. Or simply None.
|
| 458 |
+
if userdata is not None:
|
| 459 |
+
# Store the userdata object to keep it alive during the call
|
| 460 |
+
self._userdata_ref = userdata
|
| 461 |
+
c_userdata = ctypes.cast(ctypes.pointer(ctypes.py_object(userdata)), ctypes.c_void_p)
|
| 462 |
+
else:
|
| 463 |
+
c_userdata = None
|
| 464 |
+
ret = self.lib.rkllm_run(self.llm_handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), c_userdata)
|
| 465 |
+
if ret != 0:
|
| 466 |
+
raise RuntimeError(f"rkllm_run failed with error code {ret}")
|
| 467 |
+
return ret
|
| 468 |
+
|
| 469 |
+
def run_async(self, rkllm_input: RKLLMInput, rkllm_infer_params: RKLLMInferParam, userdata=None) -> int:
|
| 470 |
+
"""Runs an LLM inference task asynchronously."""
|
| 471 |
+
if userdata is not None:
|
| 472 |
+
# Store the userdata object to keep it alive during the call
|
| 473 |
+
self._userdata_ref = userdata
|
| 474 |
+
c_userdata = ctypes.cast(ctypes.pointer(ctypes.py_object(userdata)), ctypes.c_void_p)
|
| 475 |
+
else:
|
| 476 |
+
c_userdata = None
|
| 477 |
+
ret = self.lib.rkllm_run_async(self.llm_handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), c_userdata)
|
| 478 |
+
if ret != 0:
|
| 479 |
+
raise RuntimeError(f"rkllm_run_async failed with error code {ret}")
|
| 480 |
+
return ret
|
| 481 |
+
|
| 482 |
+
def abort(self) -> int:
|
| 483 |
+
"""Aborts an ongoing LLM task."""
|
| 484 |
+
ret = self.lib.rkllm_abort(self.llm_handle)
|
| 485 |
+
if ret != 0:
|
| 486 |
+
raise RuntimeError(f"rkllm_abort failed with error code {ret}")
|
| 487 |
+
return ret
|
| 488 |
+
|
| 489 |
+
def is_running(self) -> bool:
|
| 490 |
+
"""Checks if an LLM task is currently running. Returns True if running."""
|
| 491 |
+
# The C API returns 0 if running, non-zero otherwise.
|
| 492 |
+
# This is a bit counter-intuitive for a boolean "is_running".
|
| 493 |
+
return self.lib.rkllm_is_running(self.llm_handle) == 0
|
| 494 |
+
|
| 495 |
+
def clear_kv_cache(self, keep_system_prompt: bool) -> int:
|
| 496 |
+
"""Clears the key-value cache."""
|
| 497 |
+
ret = self.lib.rkllm_clear_kv_cache(self.llm_handle, ctypes.c_int(1 if keep_system_prompt else 0))
|
| 498 |
+
if ret != 0:
|
| 499 |
+
raise RuntimeError(f"rkllm_clear_kv_cache failed with error code {ret}")
|
| 500 |
+
return ret
|
| 501 |
+
|
| 502 |
+
def set_chat_template(self, system_prompt: str, prompt_prefix: str, prompt_postfix: str) -> int:
|
| 503 |
+
"""Sets the chat template for the LLM."""
|
| 504 |
+
c_system = system_prompt.encode('utf-8') if system_prompt else b""
|
| 505 |
+
c_prefix = prompt_prefix.encode('utf-8') if prompt_prefix else b""
|
| 506 |
+
c_postfix = prompt_postfix.encode('utf-8') if prompt_postfix else b""
|
| 507 |
+
|
| 508 |
+
ret = self.lib.rkllm_set_chat_template(self.llm_handle, c_system, c_prefix, c_postfix)
|
| 509 |
+
if ret != 0:
|
| 510 |
+
raise RuntimeError(f"rkllm_set_chat_template failed with error code {ret}")
|
| 511 |
+
return ret
|
| 512 |
+
|
| 513 |
+
def __enter__(self):
|
| 514 |
+
return self
|
| 515 |
+
|
| 516 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 517 |
+
self.destroy()
|
| 518 |
+
|
| 519 |
+
def __del__(self):
|
| 520 |
+
self.destroy() # Ensure resources are freed if object is garbage collected
|
| 521 |
+
|
| 522 |
+
# --- Example Usage (Illustrative) ---
|
| 523 |
+
if __name__ == "__main__":
|
| 524 |
+
# This is a placeholder for how you might use it.
|
| 525 |
+
# You'll need a valid .rkllm model and librkllmrt.so in your path.
|
| 526 |
+
|
| 527 |
+
# Global list to store results from callback for demonstration
|
| 528 |
+
results_buffer = []
|
| 529 |
+
|
| 530 |
+
def my_python_callback(result_ptr, userdata_ptr, state_enum):
|
| 531 |
+
"""
|
| 532 |
+
Callback function to be called by the C library.
|
| 533 |
+
"""
|
| 534 |
+
global results_buffer
|
| 535 |
+
state = LLMCallState(state_enum)
|
| 536 |
+
result = result_ptr.contents
|
| 537 |
+
|
| 538 |
+
current_text = ""
|
| 539 |
+
if result.text: # Check if the char_p is not NULL
|
| 540 |
+
current_text = result.text.decode('utf-8', errors='ignore')
|
| 541 |
+
|
| 542 |
+
print(f"Callback: State={state.name}, TokenID={result.token_id}, Text='{current_text}'")
|
| 543 |
+
results_buffer.append(current_text)
|
| 544 |
+
|
| 545 |
+
if state == LLMCallState.RKLLM_RUN_FINISH:
|
| 546 |
+
print("Inference finished.")
|
| 547 |
+
elif state == LLMCallState.RKLLM_RUN_ERROR:
|
| 548 |
+
print("Inference error.")
|
| 549 |
+
|
| 550 |
+
# Example: Accessing logits if available (and if mode was set to get logits)
|
| 551 |
+
# if result.logits.logits and result.logits.vocab_size > 0:
|
| 552 |
+
# print(f" Logits (first 5 of vocab_size {result.logits.vocab_size}):")
|
| 553 |
+
# for i in range(min(5, result.logits.vocab_size)):
|
| 554 |
+
# print(f" {result.logits.logits[i]:.4f}", end=" ")
|
| 555 |
+
# print()
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
# --- Attempt to use the wrapper ---
|
| 559 |
+
try:
|
| 560 |
+
print("Initializing RKLLMRuntime...")
|
| 561 |
+
# Adjust library_path if librkllmrt.so is not in default search paths
|
| 562 |
+
# e.g., library_path="./path/to/librkllmrt.so"
|
| 563 |
+
rk_llm = RKLLMRuntime()
|
| 564 |
+
|
| 565 |
+
print("Creating default parameters...")
|
| 566 |
+
params = rk_llm.create_default_param()
|
| 567 |
+
|
| 568 |
+
# --- Configure parameters ---
|
| 569 |
+
# THIS IS CRITICAL: model_path must point to an actual .rkllm file
|
| 570 |
+
# For this example to run, you need a model file.
|
| 571 |
+
# Let's assume a dummy path for now, this will fail at init if not valid.
|
| 572 |
+
model_file = "dummy_model.rkllm"
|
| 573 |
+
if not os.path.exists(model_file):
|
| 574 |
+
print(f"Warning: Model file '{model_file}' does not exist. Init will likely fail.")
|
| 575 |
+
# Create a dummy file for the example to proceed further, though init will still fail
|
| 576 |
+
# with a real library unless it's a valid model.
|
| 577 |
+
with open(model_file, "w") as f:
|
| 578 |
+
f.write("dummy content")
|
| 579 |
+
|
| 580 |
+
params.model_path = model_file.encode('utf-8')
|
| 581 |
+
params.max_context_len = 512
|
| 582 |
+
params.max_new_tokens = 128
|
| 583 |
+
params.top_k = 1 # Greedy
|
| 584 |
+
params.temperature = 0.7
|
| 585 |
+
params.repeat_penalty = 1.1
|
| 586 |
+
# ... set other params as needed
|
| 587 |
+
|
| 588 |
+
print(f"Initializing LLM with model: {params.model_path.decode()}...")
|
| 589 |
+
# This will likely fail if dummy_model.rkllm is not a valid model recognized by the library
|
| 590 |
+
try:
|
| 591 |
+
rk_llm.init(params, my_python_callback)
|
| 592 |
+
print("LLM Initialized.")
|
| 593 |
+
except RuntimeError as e:
|
| 594 |
+
print(f"Error during LLM initialization: {e}")
|
| 595 |
+
print("This is expected if 'dummy_model.rkllm' is not a valid model.")
|
| 596 |
+
print("Replace 'dummy_model.rkllm' with a real model path to test further.")
|
| 597 |
+
exit()
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
# --- Prepare input ---
|
| 601 |
+
print("Preparing input...")
|
| 602 |
+
rk_input = RKLLMInput()
|
| 603 |
+
rk_input.input_type = RKLLMInputType.RKLLM_INPUT_PROMPT
|
| 604 |
+
|
| 605 |
+
prompt_text = "Translate the following English text to French: 'Hello, world!'"
|
| 606 |
+
c_prompt = prompt_text.encode('utf-8')
|
| 607 |
+
rk_input._union_data.prompt_input = c_prompt # Accessing union member directly
|
| 608 |
+
|
| 609 |
+
# --- Prepare inference parameters ---
|
| 610 |
+
print("Preparing inference parameters...")
|
| 611 |
+
infer_params = RKLLMInferParam()
|
| 612 |
+
infer_params.mode = RKLLMInferMode.RKLLM_INFER_GENERATE
|
| 613 |
+
infer_params.keep_history = 1 # True
|
| 614 |
+
# infer_params.lora_params = None # or set up RKLLMLoraParam if using LoRA
|
| 615 |
+
# infer_params.prompt_cache_params = None # or set up RKLLMPromptCacheParam
|
| 616 |
+
|
| 617 |
+
# --- Run inference ---
|
| 618 |
+
print(f"Running inference with prompt: '{prompt_text}'")
|
| 619 |
+
results_buffer.clear()
|
| 620 |
+
try:
|
| 621 |
+
rk_llm.run(rk_input, infer_params) # Userdata is None by default
|
| 622 |
+
print("\n--- Full Response ---")
|
| 623 |
+
print("".join(results_buffer))
|
| 624 |
+
print("---------------------\n")
|
| 625 |
+
except RuntimeError as e:
|
| 626 |
+
print(f"Error during LLM run: {e}")
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
# --- Example: Set chat template (if model supports it) ---
|
| 630 |
+
# print("Setting chat template...")
|
| 631 |
+
# try:
|
| 632 |
+
# rk_llm.set_chat_template("You are a helpful assistant.", "<user>: ", "<assistant>: ")
|
| 633 |
+
# print("Chat template set.")
|
| 634 |
+
# except RuntimeError as e:
|
| 635 |
+
# print(f"Error setting chat template: {e}")
|
| 636 |
+
|
| 637 |
+
# --- Example: Clear KV Cache ---
|
| 638 |
+
# print("Clearing KV cache (keeping system prompt if any)...")
|
| 639 |
+
# try:
|
| 640 |
+
# rk_llm.clear_kv_cache(keep_system_prompt=True)
|
| 641 |
+
# print("KV cache cleared.")
|
| 642 |
+
# except RuntimeError as e:
|
| 643 |
+
# print(f"Error clearing KV cache: {e}")
|
| 644 |
+
|
| 645 |
+
except OSError as e:
|
| 646 |
+
print(f"OSError: {e}. Could not load the RKLLM library.")
|
| 647 |
+
print("Please ensure 'librkllmrt.so' is in your LD_LIBRARY_PATH or provide the full path.")
|
| 648 |
+
except Exception as e:
|
| 649 |
+
print(f"An unexpected error occurred: {e}")
|
| 650 |
+
finally:
|
| 651 |
+
if 'rk_llm' in locals() and rk_llm.llm_handle and rk_llm.llm_handle.value:
|
| 652 |
+
print("Destroying LLM instance...")
|
| 653 |
+
rk_llm.destroy()
|
| 654 |
+
print("LLM instance destroyed.")
|
| 655 |
+
if os.path.exists(model_file) and model_file == "dummy_model.rkllm":
|
| 656 |
+
os.remove(model_file) # Clean up dummy file
|
| 657 |
+
|
| 658 |
+
print("Example finished.")
|
test_embedding.py
ADDED
|
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
Qwen3-Embedding-0.6B 推理测试代码
|
| 5 |
+
使用 RKLLM API 进行文本嵌入推理
|
| 6 |
+
"""
|
| 7 |
+
import faulthandler
|
| 8 |
+
faulthandler.enable()
|
| 9 |
+
import os
|
| 10 |
+
os.environ["RKLLM_LOG_LEVEL"] = "1"
|
| 11 |
+
import numpy as np
|
| 12 |
+
import time
|
| 13 |
+
from typing import List, Dict, Any
|
| 14 |
+
from rkllm_binding import *
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Qwen3EmbeddingTester:
|
| 18 |
+
def __init__(self, model_path: str, library_path: str = "./librkllmrt.so"):
|
| 19 |
+
"""
|
| 20 |
+
初始化 Qwen3 嵌入模型测试器
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
model_path: 模型文件路径(.rkllm 格式)
|
| 24 |
+
library_path: RKLLM 库文件路径
|
| 25 |
+
"""
|
| 26 |
+
self.model_path = model_path
|
| 27 |
+
self.library_path = library_path
|
| 28 |
+
self.runtime = None
|
| 29 |
+
self.embeddings_buffer = []
|
| 30 |
+
self.current_result = None
|
| 31 |
+
|
| 32 |
+
def callback_function(self, result_ptr, userdata_ptr, state_enum):
|
| 33 |
+
"""
|
| 34 |
+
推理回调函数
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
result_ptr: 结果指针
|
| 38 |
+
userdata_ptr: 用户数据指针
|
| 39 |
+
state_enum: 状态枚举
|
| 40 |
+
"""
|
| 41 |
+
state = LLMCallState(state_enum)
|
| 42 |
+
|
| 43 |
+
if state == LLMCallState.RKLLM_RUN_NORMAL:
|
| 44 |
+
result = result_ptr.contents
|
| 45 |
+
print(f"result: {result}")
|
| 46 |
+
# 获取最后隐藏层输出作为嵌入
|
| 47 |
+
if result.last_hidden_layer.hidden_states and result.last_hidden_layer.embd_size > 0:
|
| 48 |
+
embd_size = result.last_hidden_layer.embd_size
|
| 49 |
+
num_tokens = result.last_hidden_layer.num_tokens
|
| 50 |
+
|
| 51 |
+
print(f"获取到嵌入向量:维度={embd_size}, 令牌数={num_tokens}")
|
| 52 |
+
|
| 53 |
+
# 将 C 数组转换为 numpy 数组
|
| 54 |
+
# 这里我们取最后一个 token 的隐藏状态作为句子嵌入
|
| 55 |
+
if num_tokens > 0:
|
| 56 |
+
# 获取最后一个 token 的嵌入(shape: [embd_size])
|
| 57 |
+
last_token_embedding = np.array([
|
| 58 |
+
result.last_hidden_layer.hidden_states[(num_tokens-1) * embd_size + i]
|
| 59 |
+
for i in range(embd_size)
|
| 60 |
+
])
|
| 61 |
+
|
| 62 |
+
self.current_result = {
|
| 63 |
+
'embedding': last_token_embedding,
|
| 64 |
+
'embd_size': embd_size,
|
| 65 |
+
'num_tokens': num_tokens
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
print(f"嵌入向量范数: {np.linalg.norm(last_token_embedding):.4f}")
|
| 69 |
+
print(f"嵌入向量前10维: {last_token_embedding[:10]}")
|
| 70 |
+
|
| 71 |
+
elif state == LLMCallState.RKLLM_RUN_ERROR:
|
| 72 |
+
print("推理过程发生错误")
|
| 73 |
+
|
| 74 |
+
def init_model(self):
|
| 75 |
+
"""初始化模型"""
|
| 76 |
+
try:
|
| 77 |
+
print(f"初始化 RKLLM 运行时,库路径: {self.library_path}")
|
| 78 |
+
self.runtime = RKLLMRuntime(self.library_path)
|
| 79 |
+
|
| 80 |
+
print("创建默认参数...")
|
| 81 |
+
params = self.runtime.create_default_param()
|
| 82 |
+
|
| 83 |
+
# 配置参数
|
| 84 |
+
params.model_path = self.model_path.encode('utf-8')
|
| 85 |
+
params.max_context_len = 1024 # 设置上下文长度
|
| 86 |
+
params.max_new_tokens = 1 # 嵌入任务不需要生成新token
|
| 87 |
+
params.temperature = 1.0 # 嵌入任务温度设置
|
| 88 |
+
params.top_k = 1 # 嵌入任务不需要采样
|
| 89 |
+
params.top_p = 1.0 # 嵌入任务不需要采样
|
| 90 |
+
|
| 91 |
+
# 扩展参数配置
|
| 92 |
+
params.extend_param.base_domain_id = 1 # 建议为 >1B 模型设置为1
|
| 93 |
+
params.extend_param.embed_flash = 0 # 是否使用flash存储Embedding
|
| 94 |
+
params.extend_param.enabled_cpus_num = 4 # 启用的CPU核心数
|
| 95 |
+
params.extend_param.enabled_cpus_mask = 0x0F # CPU核心掩码
|
| 96 |
+
|
| 97 |
+
print(f"初始化模型: {self.model_path}")
|
| 98 |
+
self.runtime.init(params, self.callback_function)
|
| 99 |
+
self.runtime.set_chat_template("","","")
|
| 100 |
+
print("模型初始化成功!")
|
| 101 |
+
|
| 102 |
+
except Exception as e:
|
| 103 |
+
print(f"模型初始化失败: {e}")
|
| 104 |
+
raise
|
| 105 |
+
|
| 106 |
+
def get_detailed_instruct(self, task_description: str, query: str) -> str:
|
| 107 |
+
"""
|
| 108 |
+
构建指令提示词(参考 README 中的用法)
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
task_description: 任务描述
|
| 112 |
+
query: 查询文本
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
格式化的指令提示词
|
| 116 |
+
"""
|
| 117 |
+
return f'Instruct: {task_description}\nQuery: {query}'
|
| 118 |
+
|
| 119 |
+
def encode_text(self, text: str, task_description: str = None) -> np.ndarray:
|
| 120 |
+
"""
|
| 121 |
+
编码文本为嵌入向量
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
text: 要编码的文本
|
| 125 |
+
task_description: 任务描述,如果提供则使用指令提示
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
嵌入向量(numpy数组)
|
| 129 |
+
"""
|
| 130 |
+
try:
|
| 131 |
+
# 如果提供了任务描述,则使用指令提示
|
| 132 |
+
if task_description:
|
| 133 |
+
input_text = self.get_detailed_instruct(task_description, text)
|
| 134 |
+
else:
|
| 135 |
+
input_text = text
|
| 136 |
+
|
| 137 |
+
print(f"编码文本: {input_text[:100]}{'...' if len(input_text) > 100 else ''}")
|
| 138 |
+
|
| 139 |
+
# 准备输入
|
| 140 |
+
rk_input = RKLLMInput()
|
| 141 |
+
rk_input.input_type = RKLLMInputType.RKLLM_INPUT_PROMPT
|
| 142 |
+
c_prompt = input_text.encode('utf-8')
|
| 143 |
+
rk_input._union_data.prompt_input = c_prompt
|
| 144 |
+
|
| 145 |
+
# 准备推理参数
|
| 146 |
+
infer_params = RKLLMInferParam()
|
| 147 |
+
infer_params.mode = RKLLMInferMode.RKLLM_INFER_GET_LAST_HIDDEN_LAYER # 获取隐藏层输出
|
| 148 |
+
infer_params.keep_history = 0 # 不保留历史
|
| 149 |
+
|
| 150 |
+
# 清空之前的结果
|
| 151 |
+
self.current_result = None
|
| 152 |
+
self.runtime.clear_kv_cache(False)
|
| 153 |
+
|
| 154 |
+
# 执行推理
|
| 155 |
+
start_time = time.time()
|
| 156 |
+
self.runtime.run(rk_input, infer_params)
|
| 157 |
+
end_time = time.time()
|
| 158 |
+
|
| 159 |
+
print(f"推理耗时: {end_time - start_time:.3f}秒")
|
| 160 |
+
|
| 161 |
+
if self.current_result and 'embedding' in self.current_result:
|
| 162 |
+
# 对嵌入向量进行L2标准化
|
| 163 |
+
embedding = self.current_result['embedding']
|
| 164 |
+
normalized_embedding = embedding / np.linalg.norm(embedding)
|
| 165 |
+
return normalized_embedding
|
| 166 |
+
else:
|
| 167 |
+
raise RuntimeError("未能获取到有效的嵌入向量")
|
| 168 |
+
|
| 169 |
+
except Exception as e:
|
| 170 |
+
print(f"编码文本时发生错误: {e}")
|
| 171 |
+
raise
|
| 172 |
+
|
| 173 |
+
def compute_similarity(self, emb1: np.ndarray, emb2: np.ndarray) -> float:
|
| 174 |
+
"""
|
| 175 |
+
计算两个嵌入向量的余弦相似度
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
emb1: 第一个嵌入向量
|
| 179 |
+
emb2: 第二个嵌入向量
|
| 180 |
+
|
| 181 |
+
Returns:
|
| 182 |
+
余弦相似度值
|
| 183 |
+
"""
|
| 184 |
+
return np.dot(emb1, emb2)
|
| 185 |
+
|
| 186 |
+
def test_embedding_similarity(self):
|
| 187 |
+
"""测试嵌入相似度计算"""
|
| 188 |
+
print("\n" + "="*50)
|
| 189 |
+
print("测试嵌入相似度计算")
|
| 190 |
+
print("="*50)
|
| 191 |
+
|
| 192 |
+
# 测试文本
|
| 193 |
+
task_description = "Given a web search query, retrieve relevant passages that answer the query"
|
| 194 |
+
|
| 195 |
+
queries = [
|
| 196 |
+
"What is the capital of China?",
|
| 197 |
+
"Explain gravity"
|
| 198 |
+
]
|
| 199 |
+
|
| 200 |
+
documents = [
|
| 201 |
+
"The capital of China is Beijing.",
|
| 202 |
+
"Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun."
|
| 203 |
+
]
|
| 204 |
+
|
| 205 |
+
# 编码查询(使用指令)
|
| 206 |
+
print("\n编码查询文本:")
|
| 207 |
+
query_embeddings = []
|
| 208 |
+
for i, query in enumerate(queries):
|
| 209 |
+
print(f"\n查询 {i+1}: {query}")
|
| 210 |
+
emb = self.encode_text(query, task_description)
|
| 211 |
+
query_embeddings.append(emb)
|
| 212 |
+
|
| 213 |
+
# 编码文档(不使用指令)
|
| 214 |
+
print("\n编码文档文本:")
|
| 215 |
+
doc_embeddings = []
|
| 216 |
+
for i, doc in enumerate(documents):
|
| 217 |
+
print(f"\n文档 {i+1}: {doc}")
|
| 218 |
+
emb = self.encode_text(doc)
|
| 219 |
+
doc_embeddings.append(emb)
|
| 220 |
+
|
| 221 |
+
# 计算相似度矩阵
|
| 222 |
+
print("\n计算相似度矩阵:")
|
| 223 |
+
print("查询 vs 文档相似度:")
|
| 224 |
+
print("-" * 30)
|
| 225 |
+
|
| 226 |
+
similarities = []
|
| 227 |
+
for i, q_emb in enumerate(query_embeddings):
|
| 228 |
+
row_similarities = []
|
| 229 |
+
for j, d_emb in enumerate(doc_embeddings):
|
| 230 |
+
sim = self.compute_similarity(q_emb, d_emb)
|
| 231 |
+
row_similarities.append(sim)
|
| 232 |
+
print(f"查询{i+1} vs 文档{j+1}: {sim:.4f}")
|
| 233 |
+
similarities.append(row_similarities)
|
| 234 |
+
print()
|
| 235 |
+
|
| 236 |
+
return similarities
|
| 237 |
+
|
| 238 |
+
def test_multilingual_embedding(self):
|
| 239 |
+
"""测试多语言嵌入能力"""
|
| 240 |
+
print("\n" + "="*50)
|
| 241 |
+
print("测试多语言嵌入能力")
|
| 242 |
+
print("="*50)
|
| 243 |
+
|
| 244 |
+
# 多语言测试文本(相同含义的不同语言)
|
| 245 |
+
texts = {
|
| 246 |
+
"英语": "Hello, how are you?",
|
| 247 |
+
"中文": "你好,你好吗?",
|
| 248 |
+
"法语": "Bonjour, comment allez-vous?",
|
| 249 |
+
"西班牙语": "Hola, ¿cómo estás?",
|
| 250 |
+
"日语": "こんにちは、元気ですか?"
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
embeddings = {}
|
| 254 |
+
print("\n编码多语言文本:")
|
| 255 |
+
for lang, text in texts.items():
|
| 256 |
+
print(f"\n{lang}: {text}")
|
| 257 |
+
emb = self.encode_text(text)
|
| 258 |
+
embeddings[lang] = emb
|
| 259 |
+
|
| 260 |
+
# 计算跨语言相似度
|
| 261 |
+
print("\n跨语言相似度:")
|
| 262 |
+
print("-" * 30)
|
| 263 |
+
|
| 264 |
+
languages = list(texts.keys())
|
| 265 |
+
for i, lang1 in enumerate(languages):
|
| 266 |
+
for j, lang2 in enumerate(languages):
|
| 267 |
+
if i <= j:
|
| 268 |
+
sim = self.compute_similarity(embeddings[lang1], embeddings[lang2])
|
| 269 |
+
print(f"{lang1} vs {lang2}: {sim:.4f}")
|
| 270 |
+
|
| 271 |
+
def test_code_embedding(self):
|
| 272 |
+
"""测试代码嵌入能力"""
|
| 273 |
+
print("\n" + "="*50)
|
| 274 |
+
print("测试代码嵌入能力")
|
| 275 |
+
print("="*50)
|
| 276 |
+
|
| 277 |
+
# 代码示例
|
| 278 |
+
codes = {
|
| 279 |
+
"Python函数": """
|
| 280 |
+
def fibonacci(n):
|
| 281 |
+
if n <= 1:
|
| 282 |
+
return n
|
| 283 |
+
return fibonacci(n-1) + fibonacci(n-2)
|
| 284 |
+
""",
|
| 285 |
+
"JavaScript函数": """
|
| 286 |
+
function fibonacci(n) {
|
| 287 |
+
if (n <= 1) return n;
|
| 288 |
+
return fibonacci(n-1) + fibonacci(n-2);
|
| 289 |
+
}
|
| 290 |
+
""",
|
| 291 |
+
"C++函数": """
|
| 292 |
+
int fibonacci(int n) {
|
| 293 |
+
if (n <= 1) return n;
|
| 294 |
+
return fibonacci(n-1) + fibonacci(n-2);
|
| 295 |
+
}
|
| 296 |
+
""",
|
| 297 |
+
"数组排序": """
|
| 298 |
+
def bubble_sort(arr):
|
| 299 |
+
n = len(arr)
|
| 300 |
+
for i in range(n):
|
| 301 |
+
for j in range(0, n-i-1):
|
| 302 |
+
if arr[j] > arr[j+1]:
|
| 303 |
+
arr[j], arr[j+1] = arr[j+1], arr[j]
|
| 304 |
+
"""
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
embeddings = {}
|
| 308 |
+
print("\n编码代码文本:")
|
| 309 |
+
for name, code in codes.items():
|
| 310 |
+
print(f"\n{name}:")
|
| 311 |
+
print(code[:100] + "..." if len(code) > 100 else code)
|
| 312 |
+
emb = self.encode_text(code)
|
| 313 |
+
embeddings[name] = emb
|
| 314 |
+
|
| 315 |
+
# 计算代码相似度
|
| 316 |
+
print("\n代码相似度:")
|
| 317 |
+
print("-" * 30)
|
| 318 |
+
|
| 319 |
+
code_names = list(codes.keys())
|
| 320 |
+
for i, name1 in enumerate(code_names):
|
| 321 |
+
for j, name2 in enumerate(code_names):
|
| 322 |
+
if i <= j:
|
| 323 |
+
sim = self.compute_similarity(embeddings[name1], embeddings[name2])
|
| 324 |
+
print(f"{name1} vs {name2}: {sim:.4f}")
|
| 325 |
+
|
| 326 |
+
def cleanup(self):
|
| 327 |
+
"""清理资源"""
|
| 328 |
+
if self.runtime:
|
| 329 |
+
try:
|
| 330 |
+
self.runtime.destroy()
|
| 331 |
+
print("模型资源已清理")
|
| 332 |
+
except Exception as e:
|
| 333 |
+
print(f"清理资源时发生错误: {e}")
|
| 334 |
+
|
| 335 |
+
def main():
|
| 336 |
+
"""主函数"""
|
| 337 |
+
import argparse
|
| 338 |
+
|
| 339 |
+
# 解析命令行参数
|
| 340 |
+
parser = argparse.ArgumentParser(description='Qwen3-Embedding-0.6B 推理测试')
|
| 341 |
+
parser.add_argument('model_path', help='模型文件路径(.rkllm格式)')
|
| 342 |
+
parser.add_argument('--library_path', default="./librkllmrt.so", help='RKLLM库文件路径(默认为./librkllmrt.so)')
|
| 343 |
+
args = parser.parse_args()
|
| 344 |
+
|
| 345 |
+
# 检查文件是否存在
|
| 346 |
+
if not os.path.exists(args.model_path):
|
| 347 |
+
print(f"错误: 模型文件不存在: {args.model_path}")
|
| 348 |
+
print("请确保:")
|
| 349 |
+
print("1. 已下载 Qwen3-Embedding-0.6B 模型")
|
| 350 |
+
print("2. 已使用 rkllm-convert.py 将模型转换为 .rkllm 格式")
|
| 351 |
+
return
|
| 352 |
+
|
| 353 |
+
if not os.path.exists(args.library_path):
|
| 354 |
+
print(f"错误: RKLLM 库文件不存在: {args.library_path}")
|
| 355 |
+
print("请确保 librkllmrt.so 在当前目录或 LD_LIBRARY_PATH 中")
|
| 356 |
+
return
|
| 357 |
+
|
| 358 |
+
print("Qwen3-Embedding-0.6B 推理测试")
|
| 359 |
+
print("=" * 50)
|
| 360 |
+
|
| 361 |
+
# 创建测试器
|
| 362 |
+
tester = Qwen3EmbeddingTester(args.model_path, args.library_path)
|
| 363 |
+
|
| 364 |
+
try:
|
| 365 |
+
# 初始化模型
|
| 366 |
+
tester.init_model()
|
| 367 |
+
|
| 368 |
+
# 运行测试
|
| 369 |
+
print("\n开始运行嵌入测试...")
|
| 370 |
+
|
| 371 |
+
# 测试基础嵌入相似度
|
| 372 |
+
tester.test_embedding_similarity()
|
| 373 |
+
|
| 374 |
+
# 测试多语言嵌入
|
| 375 |
+
tester.test_multilingual_embedding()
|
| 376 |
+
|
| 377 |
+
# 测试代码嵌入
|
| 378 |
+
tester.test_code_embedding()
|
| 379 |
+
|
| 380 |
+
print("\n" + "="*50)
|
| 381 |
+
print("所有测试完成!")
|
| 382 |
+
print("="*50)
|
| 383 |
+
|
| 384 |
+
except KeyboardInterrupt:
|
| 385 |
+
print("\n测试被用户中断")
|
| 386 |
+
except Exception as e:
|
| 387 |
+
print(f"\n测试过程中发生错误: {e}")
|
| 388 |
+
import traceback
|
| 389 |
+
traceback.print_exc()
|
| 390 |
+
finally:
|
| 391 |
+
# 清理资源
|
| 392 |
+
tester.cleanup()
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
if __name__ == "__main__":
|
| 396 |
+
main()
|