Upload 6 files
Browse files- .gitattributes +3 -0
- Qwen3-Embedding-0.6B_f16.rkllm +3 -0
- Qwen3-Embedding-0.6B_w8a8.rkllm +3 -0
- librkllmrt.so +3 -0
- rkllm-convert.py +74 -0
- rkllm_binding.py +658 -0
- test_embedding.py +396 -0
.gitattributes
CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
librkllmrt.so filter=lfs diff=lfs merge=lfs -text
|
37 |
+
Qwen3-Embedding-0.6B_f16.rkllm filter=lfs diff=lfs merge=lfs -text
|
38 |
+
Qwen3-Embedding-0.6B_w8a8.rkllm filter=lfs diff=lfs merge=lfs -text
|
Qwen3-Embedding-0.6B_f16.rkllm
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4a383a715837f432a8af06798e236ab70a30b80d55cdbe4d7e4d489de639310b
|
3 |
+
size 1524801182
|
Qwen3-Embedding-0.6B_w8a8.rkllm
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:be4ffc5c46b6246c71459081c5fdb8b8984c52936fa23083e342f4a843945806
|
3 |
+
size 931372078
|
librkllmrt.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f6a9c2de93cf94bb524eb071c27190ad4c83401e01b562534f265dff4cb40da2
|
3 |
+
size 6710712
|
rkllm-convert.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from rkllm.api import RKLLM
|
3 |
+
|
4 |
+
def convert_model(model_path, output_name, do_quantization=False):
|
5 |
+
"""转换单个模型"""
|
6 |
+
llm = RKLLM()
|
7 |
+
|
8 |
+
print(f"正在加载模型: {model_path}")
|
9 |
+
ret = llm.load_huggingface(model=model_path, model_lora=None, device='cpu')
|
10 |
+
if ret != 0:
|
11 |
+
print(f'加载模型失败: {model_path}')
|
12 |
+
return ret
|
13 |
+
|
14 |
+
print(f"正在构建模型: {output_name} (量化: {do_quantization})")
|
15 |
+
qparams = None
|
16 |
+
ret = llm.build(do_quantization=do_quantization, optimization_level=1, quantized_dtype='w8a8',
|
17 |
+
quantized_algorithm='normal', target_platform='rk3588', num_npu_core=3, extra_qparams=qparams)
|
18 |
+
|
19 |
+
if ret != 0:
|
20 |
+
print(f'构建模型失败: {output_name}')
|
21 |
+
return ret
|
22 |
+
|
23 |
+
# 导出rkllm模型
|
24 |
+
print(f"正在导出模型: {output_name}")
|
25 |
+
ret = llm.export_rkllm(output_name)
|
26 |
+
if ret != 0:
|
27 |
+
print(f'导出模型失败: {output_name}')
|
28 |
+
return ret
|
29 |
+
|
30 |
+
print(f"成功转换: {output_name}")
|
31 |
+
return 0
|
32 |
+
|
33 |
+
def main():
|
34 |
+
"""主函数:遍历所有子文件夹并转换模型"""
|
35 |
+
current_dir = '.'
|
36 |
+
|
37 |
+
# 获取所有子文件夹
|
38 |
+
subdirs = [d for d in os.listdir(current_dir)
|
39 |
+
if os.path.isdir(os.path.join(current_dir, d)) and not d.startswith('.')]
|
40 |
+
|
41 |
+
print(f"找到 {len(subdirs)} 个模型文件夹: {subdirs}")
|
42 |
+
|
43 |
+
for subdir in subdirs:
|
44 |
+
model_path = os.path.join(current_dir, subdir)
|
45 |
+
|
46 |
+
# 生成输出文件名
|
47 |
+
base_name = subdir.replace('/', '_').replace('\\', '_')
|
48 |
+
quantized_output = f"{base_name}_w8a8.rkllm"
|
49 |
+
unquantized_output = f"{base_name}_f16.rkllm"
|
50 |
+
|
51 |
+
print(f"\n{'='*50}")
|
52 |
+
print(f"处理模型文件夹: {subdir}")
|
53 |
+
print(f"{'='*50}")
|
54 |
+
|
55 |
+
# 转换非量化版本
|
56 |
+
print(f"\n--- 转换非量化版本 ---")
|
57 |
+
ret = convert_model(model_path, unquantized_output, do_quantization=False)
|
58 |
+
if ret != 0:
|
59 |
+
print(f"非量化版本转换失败: {subdir}")
|
60 |
+
continue
|
61 |
+
|
62 |
+
# 转换量化版本
|
63 |
+
print(f"\n--- 转换量化版本 ---")
|
64 |
+
ret = convert_model(model_path, quantized_output, do_quantization=True)
|
65 |
+
if ret != 0:
|
66 |
+
print(f"量化版本转换失败: {subdir}")
|
67 |
+
continue
|
68 |
+
|
69 |
+
print(f"\n✓ {subdir} 模型转换完成!")
|
70 |
+
print(f" - 非量化版本: {unquantized_output}")
|
71 |
+
print(f" - 量化版本: {quantized_output}")
|
72 |
+
|
73 |
+
if __name__ == "__main__":
|
74 |
+
main()
|
rkllm_binding.py
ADDED
@@ -0,0 +1,658 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ctypes
|
2 |
+
import enum
|
3 |
+
import os
|
4 |
+
|
5 |
+
# Define constants from the header
|
6 |
+
CPU0 = (1 << 0) # 0x01
|
7 |
+
CPU1 = (1 << 1) # 0x02
|
8 |
+
CPU2 = (1 << 2) # 0x04
|
9 |
+
CPU3 = (1 << 3) # 0x08
|
10 |
+
CPU4 = (1 << 4) # 0x10
|
11 |
+
CPU5 = (1 << 5) # 0x20
|
12 |
+
CPU6 = (1 << 6) # 0x40
|
13 |
+
CPU7 = (1 << 7) # 0x80
|
14 |
+
|
15 |
+
# --- Enums ---
|
16 |
+
class LLMCallState(enum.IntEnum):
|
17 |
+
RKLLM_RUN_NORMAL = 0
|
18 |
+
RKLLM_RUN_WAITING = 1
|
19 |
+
RKLLM_RUN_FINISH = 2
|
20 |
+
RKLLM_RUN_ERROR = 3
|
21 |
+
|
22 |
+
class RKLLMInputType(enum.IntEnum):
|
23 |
+
RKLLM_INPUT_PROMPT = 0
|
24 |
+
RKLLM_INPUT_TOKEN = 1
|
25 |
+
RKLLM_INPUT_EMBED = 2
|
26 |
+
RKLLM_INPUT_MULTIMODAL = 3
|
27 |
+
|
28 |
+
class RKLLMInferMode(enum.IntEnum):
|
29 |
+
RKLLM_INFER_GENERATE = 0
|
30 |
+
RKLLM_INFER_GET_LAST_HIDDEN_LAYER = 1
|
31 |
+
RKLLM_INFER_GET_LOGITS = 2
|
32 |
+
|
33 |
+
# --- Structures ---
|
34 |
+
class RKLLMExtendParam(ctypes.Structure):
|
35 |
+
# 基础iommu domain ID, 对>1b的模型建议设置为1
|
36 |
+
base_domain_id: ctypes.c_int32
|
37 |
+
# 是否使用flash存储Embedding
|
38 |
+
embed_flash: ctypes.c_int8
|
39 |
+
# 启用的cpu核心数
|
40 |
+
enabled_cpus_num: ctypes.c_int8
|
41 |
+
# 启用的cpu核心掩码
|
42 |
+
enabled_cpus_mask: ctypes.c_uint32
|
43 |
+
reserved: ctypes.c_uint8 * 106
|
44 |
+
|
45 |
+
_fields_ = [
|
46 |
+
("base_domain_id", ctypes.c_int32),
|
47 |
+
("embed_flash", ctypes.c_int8),
|
48 |
+
("enabled_cpus_num", ctypes.c_int8),
|
49 |
+
("enabled_cpus_mask", ctypes.c_uint32),
|
50 |
+
("reserved", ctypes.c_uint8 * 106)
|
51 |
+
]
|
52 |
+
|
53 |
+
class RKLLMParam(ctypes.Structure):
|
54 |
+
# 模型文件路径
|
55 |
+
model_path: ctypes.c_char_p
|
56 |
+
# 上下文窗口最大token数
|
57 |
+
max_context_len: ctypes.c_int32
|
58 |
+
# 最大生成新token数
|
59 |
+
max_new_tokens: ctypes.c_int32
|
60 |
+
# Top-K采样参数
|
61 |
+
top_k: ctypes.c_int32
|
62 |
+
# 上下文窗口移动时保留的kv缓存数量
|
63 |
+
n_keep: ctypes.c_int32
|
64 |
+
# Top-P采样参数
|
65 |
+
top_p: ctypes.c_float
|
66 |
+
# 采样温度,影响token选择的随机性
|
67 |
+
temperature: ctypes.c_float
|
68 |
+
# 重复token惩罚
|
69 |
+
repeat_penalty: ctypes.c_float
|
70 |
+
# 频繁token惩罚
|
71 |
+
frequency_penalty: ctypes.c_float
|
72 |
+
# 输入中已存在token的惩罚
|
73 |
+
presence_penalty: ctypes.c_float
|
74 |
+
# Mirostat采样策略标志(0表示禁用)
|
75 |
+
mirostat: ctypes.c_int32
|
76 |
+
# Mirostat采样Tau参数
|
77 |
+
mirostat_tau: ctypes.c_float
|
78 |
+
# Mirostat采样Eta参数
|
79 |
+
mirostat_eta: ctypes.c_float
|
80 |
+
# 是否跳过特殊token
|
81 |
+
skip_special_token: ctypes.c_bool
|
82 |
+
# 是否异步推理
|
83 |
+
is_async: ctypes.c_bool
|
84 |
+
# 多模态输入中图像的起始Token
|
85 |
+
img_start: ctypes.c_char_p
|
86 |
+
# 多模态输入中图像的结束Token
|
87 |
+
img_end: ctypes.c_char_p
|
88 |
+
# 图像内容指针
|
89 |
+
img_content: ctypes.c_char_p
|
90 |
+
# 扩展参数
|
91 |
+
extend_param: RKLLMExtendParam
|
92 |
+
|
93 |
+
_fields_ = [
|
94 |
+
("model_path", ctypes.c_char_p), # 模型文件路径
|
95 |
+
("max_context_len", ctypes.c_int32), # 上下文窗口最大token数
|
96 |
+
("max_new_tokens", ctypes.c_int32), # 最大生成新token数
|
97 |
+
("top_k", ctypes.c_int32), # Top-K采样参数
|
98 |
+
("n_keep", ctypes.c_int32), # 上下文窗口移动时保留的kv缓存数量
|
99 |
+
("top_p", ctypes.c_float), # Top-P(nucleus)采样参数
|
100 |
+
("temperature", ctypes.c_float), # 采样温度,影响token选择的随机性
|
101 |
+
("repeat_penalty", ctypes.c_float), # 重复token惩罚
|
102 |
+
("frequency_penalty", ctypes.c_float), # 频繁token惩罚
|
103 |
+
("presence_penalty", ctypes.c_float), # 输入中已存在token的惩罚
|
104 |
+
("mirostat", ctypes.c_int32), # Mirostat采样策略标志(0表示禁用)
|
105 |
+
("mirostat_tau", ctypes.c_float), # Mirostat采样Tau参数
|
106 |
+
("mirostat_eta", ctypes.c_float), # Mirostat采样Eta参数
|
107 |
+
("skip_special_token", ctypes.c_bool), # 是否跳过特殊token
|
108 |
+
("is_async", ctypes.c_bool), # 是否异步推理
|
109 |
+
("img_start", ctypes.c_char_p), # 多模态输入中图像的起始Token
|
110 |
+
("img_end", ctypes.c_char_p), # 多模态输入中图像的结束Token
|
111 |
+
("img_content", ctypes.c_char_p), # 图像内容指针
|
112 |
+
("extend_param", RKLLMExtendParam) # 扩展参数
|
113 |
+
]
|
114 |
+
|
115 |
+
class RKLLMLoraAdapter(ctypes.Structure):
|
116 |
+
lora_adapter_path: ctypes.c_char_p
|
117 |
+
lora_adapter_name: ctypes.c_char_p
|
118 |
+
scale: ctypes.c_float
|
119 |
+
|
120 |
+
_fields_ = [
|
121 |
+
("lora_adapter_path", ctypes.c_char_p),
|
122 |
+
("lora_adapter_name", ctypes.c_char_p),
|
123 |
+
("scale", ctypes.c_float)
|
124 |
+
]
|
125 |
+
|
126 |
+
class RKLLMEmbedInput(ctypes.Structure):
|
127 |
+
# Shape: [n_tokens, embed_size]
|
128 |
+
embed: ctypes.POINTER(ctypes.c_float)
|
129 |
+
n_tokens: ctypes.c_size_t
|
130 |
+
|
131 |
+
_fields_ = [
|
132 |
+
("embed", ctypes.POINTER(ctypes.c_float)),
|
133 |
+
("n_tokens", ctypes.c_size_t)
|
134 |
+
]
|
135 |
+
|
136 |
+
class RKLLMTokenInput(ctypes.Structure):
|
137 |
+
# Shape: [n_tokens]
|
138 |
+
input_ids: ctypes.POINTER(ctypes.c_int32)
|
139 |
+
n_tokens: ctypes.c_size_t
|
140 |
+
|
141 |
+
_fields_ = [
|
142 |
+
("input_ids", ctypes.POINTER(ctypes.c_int32)),
|
143 |
+
("n_tokens", ctypes.c_size_t)
|
144 |
+
]
|
145 |
+
|
146 |
+
class RKLLMMultiModelInput(ctypes.Structure):
|
147 |
+
prompt: ctypes.c_char_p
|
148 |
+
image_embed: ctypes.POINTER(ctypes.c_float)
|
149 |
+
n_image_tokens: ctypes.c_size_t
|
150 |
+
n_image: ctypes.c_size_t
|
151 |
+
image_width: ctypes.c_size_t
|
152 |
+
image_height: ctypes.c_size_t
|
153 |
+
|
154 |
+
_fields_ = [
|
155 |
+
("prompt", ctypes.c_char_p),
|
156 |
+
("image_embed", ctypes.POINTER(ctypes.c_float)),
|
157 |
+
("n_image_tokens", ctypes.c_size_t),
|
158 |
+
("n_image", ctypes.c_size_t),
|
159 |
+
("image_width", ctypes.c_size_t),
|
160 |
+
("image_height", ctypes.c_size_t)
|
161 |
+
]
|
162 |
+
|
163 |
+
class _RKLLMInputUnion(ctypes.Union):
|
164 |
+
prompt_input: ctypes.c_char_p
|
165 |
+
embed_input: RKLLMEmbedInput
|
166 |
+
token_input: RKLLMTokenInput
|
167 |
+
multimodal_input: RKLLMMultiModelInput
|
168 |
+
|
169 |
+
_fields_ = [
|
170 |
+
("prompt_input", ctypes.c_char_p),
|
171 |
+
("embed_input", RKLLMEmbedInput),
|
172 |
+
("token_input", RKLLMTokenInput),
|
173 |
+
("multimodal_input", RKLLMMultiModelInput)
|
174 |
+
]
|
175 |
+
|
176 |
+
class RKLLMInput(ctypes.Structure):
|
177 |
+
input_type: ctypes.c_int
|
178 |
+
_union_data: _RKLLMInputUnion
|
179 |
+
|
180 |
+
_fields_ = [
|
181 |
+
("input_type", ctypes.c_int), # Enum will be passed as int, changed RKLLMInputType to ctypes.c_int
|
182 |
+
("_union_data", _RKLLMInputUnion)
|
183 |
+
]
|
184 |
+
# Properties to make accessing union members easier
|
185 |
+
@property
|
186 |
+
def prompt_input(self) -> bytes: # Assuming c_char_p maps to bytes
|
187 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_PROMPT:
|
188 |
+
return self._union_data.prompt_input
|
189 |
+
raise AttributeError("Not a prompt input")
|
190 |
+
@prompt_input.setter
|
191 |
+
def prompt_input(self, value: bytes): # Assuming c_char_p maps to bytes
|
192 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_PROMPT:
|
193 |
+
self._union_data.prompt_input = value
|
194 |
+
else:
|
195 |
+
raise AttributeError("Not a prompt input")
|
196 |
+
@property
|
197 |
+
def embed_input(self) -> RKLLMEmbedInput:
|
198 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_EMBED:
|
199 |
+
return self._union_data.embed_input
|
200 |
+
raise AttributeError("Not an embed input")
|
201 |
+
@embed_input.setter
|
202 |
+
def embed_input(self, value: RKLLMEmbedInput):
|
203 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_EMBED:
|
204 |
+
self._union_data.embed_input = value
|
205 |
+
else:
|
206 |
+
raise AttributeError("Not an embed input")
|
207 |
+
|
208 |
+
@property
|
209 |
+
def token_input(self) -> RKLLMTokenInput:
|
210 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_TOKEN:
|
211 |
+
return self._union_data.token_input
|
212 |
+
raise AttributeError("Not a token input")
|
213 |
+
@token_input.setter
|
214 |
+
def token_input(self, value: RKLLMTokenInput):
|
215 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_TOKEN:
|
216 |
+
self._union_data.token_input = value
|
217 |
+
else:
|
218 |
+
raise AttributeError("Not a token input")
|
219 |
+
|
220 |
+
@property
|
221 |
+
def multimodal_input(self) -> RKLLMMultiModelInput:
|
222 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_MULTIMODAL:
|
223 |
+
return self._union_data.multimodal_input
|
224 |
+
raise AttributeError("Not a multimodal input")
|
225 |
+
@multimodal_input.setter
|
226 |
+
def multimodal_input(self, value: RKLLMMultiModelInput):
|
227 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_MULTIMODAL:
|
228 |
+
self._union_data.multimodal_input = value
|
229 |
+
else:
|
230 |
+
raise AttributeError("Not a multimodal input")
|
231 |
+
|
232 |
+
class RKLLMLoraParam(ctypes.Structure): # For inference
|
233 |
+
lora_adapter_name: ctypes.c_char_p
|
234 |
+
|
235 |
+
_fields_ = [
|
236 |
+
("lora_adapter_name", ctypes.c_char_p)
|
237 |
+
]
|
238 |
+
|
239 |
+
class RKLLMPromptCacheParam(ctypes.Structure): # For inference
|
240 |
+
save_prompt_cache: ctypes.c_int # bool-like
|
241 |
+
prompt_cache_path: ctypes.c_char_p
|
242 |
+
|
243 |
+
_fields_ = [
|
244 |
+
("save_prompt_cache", ctypes.c_int), # bool-like
|
245 |
+
("prompt_cache_path", ctypes.c_char_p)
|
246 |
+
]
|
247 |
+
|
248 |
+
class RKLLMInferParam(ctypes.Structure):
|
249 |
+
mode: ctypes.c_int
|
250 |
+
lora_params: ctypes.POINTER(RKLLMLoraParam)
|
251 |
+
prompt_cache_params: ctypes.POINTER(RKLLMPromptCacheParam)
|
252 |
+
keep_history: ctypes.c_int # bool-like
|
253 |
+
|
254 |
+
_fields_ = [
|
255 |
+
("mode", ctypes.c_int), # Enum will be passed as int, changed RKLLMInferMode to ctypes.c_int
|
256 |
+
("lora_params", ctypes.POINTER(RKLLMLoraParam)),
|
257 |
+
("prompt_cache_params", ctypes.POINTER(RKLLMPromptCacheParam)),
|
258 |
+
("keep_history", ctypes.c_int) # bool-like
|
259 |
+
]
|
260 |
+
|
261 |
+
class RKLLMResultLastHiddenLayer(ctypes.Structure):
|
262 |
+
# Shape: [num_tokens, embd_size]
|
263 |
+
hidden_states: ctypes.POINTER(ctypes.c_float)
|
264 |
+
# 隐藏层大小
|
265 |
+
embd_size: ctypes.c_int
|
266 |
+
# 输出token数
|
267 |
+
num_tokens: ctypes.c_int
|
268 |
+
|
269 |
+
_fields_ = [
|
270 |
+
("hidden_states", ctypes.POINTER(ctypes.c_float)),
|
271 |
+
("embd_size", ctypes.c_int),
|
272 |
+
("num_tokens", ctypes.c_int)
|
273 |
+
]
|
274 |
+
|
275 |
+
class RKLLMResultLogits(ctypes.Structure):
|
276 |
+
# Shape: [num_tokens, vocab_size]
|
277 |
+
logits: ctypes.POINTER(ctypes.c_float)
|
278 |
+
# 词汇表大小
|
279 |
+
vocab_size: ctypes.c_int
|
280 |
+
# 输出token数
|
281 |
+
num_tokens: ctypes.c_int
|
282 |
+
|
283 |
+
_fields_ = [
|
284 |
+
("logits", ctypes.POINTER(ctypes.c_float)),
|
285 |
+
("vocab_size", ctypes.c_int),
|
286 |
+
("num_tokens", ctypes.c_int)
|
287 |
+
]
|
288 |
+
|
289 |
+
class RKLLMResult(ctypes.Structure):
|
290 |
+
text: ctypes.c_char_p
|
291 |
+
token_id: ctypes.c_int32
|
292 |
+
last_hidden_layer: RKLLMResultLastHiddenLayer
|
293 |
+
logits: RKLLMResultLogits
|
294 |
+
|
295 |
+
_fields_ = [
|
296 |
+
("text", ctypes.c_char_p),
|
297 |
+
("token_id", ctypes.c_int32),
|
298 |
+
("last_hidden_layer", RKLLMResultLastHiddenLayer),
|
299 |
+
("logits", RKLLMResultLogits)
|
300 |
+
]
|
301 |
+
|
302 |
+
# --- Typedefs ---
|
303 |
+
LLMHandle = ctypes.c_void_p
|
304 |
+
|
305 |
+
# --- Callback Function Type ---
|
306 |
+
LLMResultCallback = ctypes.CFUNCTYPE(
|
307 |
+
None, # return type: void
|
308 |
+
ctypes.POINTER(RKLLMResult),
|
309 |
+
ctypes.c_void_p, # userdata
|
310 |
+
ctypes.c_int # enum, will be passed as int. Changed LLMCallState to ctypes.c_int
|
311 |
+
)
|
312 |
+
|
313 |
+
|
314 |
+
class RKLLMRuntime:
|
315 |
+
def __init__(self, library_path="./librkllmrt.so"):
|
316 |
+
try:
|
317 |
+
self.lib = ctypes.CDLL(library_path)
|
318 |
+
except OSError as e:
|
319 |
+
raise OSError(f"Failed to load RKLLM library from {library_path}. "
|
320 |
+
f"Ensure it's in your LD_LIBRARY_PATH or provide the full path. Error: {e}")
|
321 |
+
self._setup_functions()
|
322 |
+
self.llm_handle = LLMHandle()
|
323 |
+
self._c_callback = None # To keep the callback object alive
|
324 |
+
|
325 |
+
def _setup_functions(self):
|
326 |
+
# RKLLMParam rkllm_createDefaultParam();
|
327 |
+
self.lib.rkllm_createDefaultParam.restype = RKLLMParam
|
328 |
+
self.lib.rkllm_createDefaultParam.argtypes = []
|
329 |
+
|
330 |
+
# int rkllm_init(LLMHandle* handle, RKLLMParam* param, LLMResultCallback callback);
|
331 |
+
self.lib.rkllm_init.restype = ctypes.c_int
|
332 |
+
self.lib.rkllm_init.argtypes = [
|
333 |
+
ctypes.POINTER(LLMHandle),
|
334 |
+
ctypes.POINTER(RKLLMParam),
|
335 |
+
LLMResultCallback
|
336 |
+
]
|
337 |
+
|
338 |
+
# int rkllm_load_lora(LLMHandle handle, RKLLMLoraAdapter* lora_adapter);
|
339 |
+
self.lib.rkllm_load_lora.restype = ctypes.c_int
|
340 |
+
self.lib.rkllm_load_lora.argtypes = [LLMHandle, ctypes.POINTER(RKLLMLoraAdapter)]
|
341 |
+
|
342 |
+
# int rkllm_load_prompt_cache(LLMHandle handle, const char* prompt_cache_path);
|
343 |
+
self.lib.rkllm_load_prompt_cache.restype = ctypes.c_int
|
344 |
+
self.lib.rkllm_load_prompt_cache.argtypes = [LLMHandle, ctypes.c_char_p]
|
345 |
+
|
346 |
+
# int rkllm_release_prompt_cache(LLMHandle handle);
|
347 |
+
self.lib.rkllm_release_prompt_cache.restype = ctypes.c_int
|
348 |
+
self.lib.rkllm_release_prompt_cache.argtypes = [LLMHandle]
|
349 |
+
|
350 |
+
# int rkllm_destroy(LLMHandle handle);
|
351 |
+
self.lib.rkllm_destroy.restype = ctypes.c_int
|
352 |
+
self.lib.rkllm_destroy.argtypes = [LLMHandle]
|
353 |
+
|
354 |
+
# int rkllm_run(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata);
|
355 |
+
self.lib.rkllm_run.restype = ctypes.c_int
|
356 |
+
self.lib.rkllm_run.argtypes = [
|
357 |
+
LLMHandle,
|
358 |
+
ctypes.POINTER(RKLLMInput),
|
359 |
+
ctypes.POINTER(RKLLMInferParam),
|
360 |
+
ctypes.c_void_p # userdata
|
361 |
+
]
|
362 |
+
|
363 |
+
# int rkllm_run_async(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata);
|
364 |
+
# Assuming async also takes userdata for the callback context
|
365 |
+
self.lib.rkllm_run_async.restype = ctypes.c_int
|
366 |
+
self.lib.rkllm_run_async.argtypes = [
|
367 |
+
LLMHandle,
|
368 |
+
ctypes.POINTER(RKLLMInput),
|
369 |
+
ctypes.POINTER(RKLLMInferParam),
|
370 |
+
ctypes.c_void_p # userdata
|
371 |
+
]
|
372 |
+
|
373 |
+
# int rkllm_abort(LLMHandle handle);
|
374 |
+
self.lib.rkllm_abort.restype = ctypes.c_int
|
375 |
+
self.lib.rkllm_abort.argtypes = [LLMHandle]
|
376 |
+
|
377 |
+
# int rkllm_is_running(LLMHandle handle);
|
378 |
+
self.lib.rkllm_is_running.restype = ctypes.c_int # 0 if running, non-zero otherwise
|
379 |
+
self.lib.rkllm_is_running.argtypes = [LLMHandle]
|
380 |
+
|
381 |
+
# int rkllm_clear_kv_cache(LLMHandle handle, int keep_system_prompt);
|
382 |
+
self.lib.rkllm_clear_kv_cache.restype = ctypes.c_int
|
383 |
+
self.lib.rkllm_clear_kv_cache.argtypes = [LLMHandle, ctypes.c_int]
|
384 |
+
|
385 |
+
# int rkllm_set_chat_template(LLMHandle handle, const char* system_prompt, const char* prompt_prefix, const char* prompt_postfix);
|
386 |
+
self.lib.rkllm_set_chat_template.restype = ctypes.c_int
|
387 |
+
self.lib.rkllm_set_chat_template.argtypes = [
|
388 |
+
LLMHandle,
|
389 |
+
ctypes.c_char_p,
|
390 |
+
ctypes.c_char_p,
|
391 |
+
ctypes.c_char_p
|
392 |
+
]
|
393 |
+
|
394 |
+
def create_default_param(self) -> RKLLMParam:
|
395 |
+
"""Creates a default RKLLMParam structure."""
|
396 |
+
return self.lib.rkllm_createDefaultParam()
|
397 |
+
|
398 |
+
def init(self, param: RKLLMParam, callback_func) -> int:
|
399 |
+
"""
|
400 |
+
Initializes the LLM.
|
401 |
+
:param param: RKLLMParam structure.
|
402 |
+
:param callback_func: A Python function that matches the signature:
|
403 |
+
def my_callback(result_ptr, userdata_ptr, state_enum):
|
404 |
+
result = result_ptr.contents # RKLLMResult
|
405 |
+
# Process result
|
406 |
+
# userdata can be retrieved if passed during run, or ignored
|
407 |
+
# state = LLMCallState(state_enum)
|
408 |
+
:return: 0 for success, non-zero for failure.
|
409 |
+
"""
|
410 |
+
if not callable(callback_func):
|
411 |
+
raise ValueError("callback_func must be a callable Python function.")
|
412 |
+
|
413 |
+
# Keep a reference to the ctypes callback object to prevent it from being garbage collected
|
414 |
+
self._c_callback = LLMResultCallback(callback_func)
|
415 |
+
|
416 |
+
ret = self.lib.rkllm_init(ctypes.byref(self.llm_handle), ctypes.byref(param), self._c_callback)
|
417 |
+
if ret != 0:
|
418 |
+
raise RuntimeError(f"rkllm_init failed with error code {ret}")
|
419 |
+
return ret
|
420 |
+
|
421 |
+
def load_lora(self, lora_adapter: RKLLMLoraAdapter) -> int:
|
422 |
+
"""Loads a Lora adapter."""
|
423 |
+
ret = self.lib.rkllm_load_lora(self.llm_handle, ctypes.byref(lora_adapter))
|
424 |
+
if ret != 0:
|
425 |
+
raise RuntimeError(f"rkllm_load_lora failed with error code {ret}")
|
426 |
+
return ret
|
427 |
+
|
428 |
+
def load_prompt_cache(self, prompt_cache_path: str) -> int:
|
429 |
+
"""Loads a prompt cache from a file."""
|
430 |
+
c_path = prompt_cache_path.encode('utf-8')
|
431 |
+
ret = self.lib.rkllm_load_prompt_cache(self.llm_handle, c_path)
|
432 |
+
if ret != 0:
|
433 |
+
raise RuntimeError(f"rkllm_load_prompt_cache failed for {prompt_cache_path} with error code {ret}")
|
434 |
+
return ret
|
435 |
+
|
436 |
+
def release_prompt_cache(self) -> int:
|
437 |
+
"""Releases the prompt cache from memory."""
|
438 |
+
ret = self.lib.rkllm_release_prompt_cache(self.llm_handle)
|
439 |
+
if ret != 0:
|
440 |
+
raise RuntimeError(f"rkllm_release_prompt_cache failed with error code {ret}")
|
441 |
+
return ret
|
442 |
+
|
443 |
+
def destroy(self) -> int:
|
444 |
+
"""Destroys the LLM instance and releases resources."""
|
445 |
+
if self.llm_handle and self.llm_handle.value: # Check if handle is not NULL
|
446 |
+
ret = self.lib.rkllm_destroy(self.llm_handle)
|
447 |
+
self.llm_handle = LLMHandle() # Reset handle
|
448 |
+
if ret != 0:
|
449 |
+
# Don't raise here as it might be called in __del__
|
450 |
+
print(f"Warning: rkllm_destroy failed with error code {ret}")
|
451 |
+
return ret
|
452 |
+
return 0 # Already destroyed or not initialized
|
453 |
+
|
454 |
+
def run(self, rkllm_input: RKLLMInput, rkllm_infer_params: RKLLMInferParam, userdata=None) -> int:
|
455 |
+
"""Runs an LLM inference task synchronously."""
|
456 |
+
# userdata can be a ctypes.py_object if you want to pass Python objects,
|
457 |
+
# then cast to c_void_p. Or simply None.
|
458 |
+
if userdata is not None:
|
459 |
+
# Store the userdata object to keep it alive during the call
|
460 |
+
self._userdata_ref = userdata
|
461 |
+
c_userdata = ctypes.cast(ctypes.pointer(ctypes.py_object(userdata)), ctypes.c_void_p)
|
462 |
+
else:
|
463 |
+
c_userdata = None
|
464 |
+
ret = self.lib.rkllm_run(self.llm_handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), c_userdata)
|
465 |
+
if ret != 0:
|
466 |
+
raise RuntimeError(f"rkllm_run failed with error code {ret}")
|
467 |
+
return ret
|
468 |
+
|
469 |
+
def run_async(self, rkllm_input: RKLLMInput, rkllm_infer_params: RKLLMInferParam, userdata=None) -> int:
|
470 |
+
"""Runs an LLM inference task asynchronously."""
|
471 |
+
if userdata is not None:
|
472 |
+
# Store the userdata object to keep it alive during the call
|
473 |
+
self._userdata_ref = userdata
|
474 |
+
c_userdata = ctypes.cast(ctypes.pointer(ctypes.py_object(userdata)), ctypes.c_void_p)
|
475 |
+
else:
|
476 |
+
c_userdata = None
|
477 |
+
ret = self.lib.rkllm_run_async(self.llm_handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), c_userdata)
|
478 |
+
if ret != 0:
|
479 |
+
raise RuntimeError(f"rkllm_run_async failed with error code {ret}")
|
480 |
+
return ret
|
481 |
+
|
482 |
+
def abort(self) -> int:
|
483 |
+
"""Aborts an ongoing LLM task."""
|
484 |
+
ret = self.lib.rkllm_abort(self.llm_handle)
|
485 |
+
if ret != 0:
|
486 |
+
raise RuntimeError(f"rkllm_abort failed with error code {ret}")
|
487 |
+
return ret
|
488 |
+
|
489 |
+
def is_running(self) -> bool:
|
490 |
+
"""Checks if an LLM task is currently running. Returns True if running."""
|
491 |
+
# The C API returns 0 if running, non-zero otherwise.
|
492 |
+
# This is a bit counter-intuitive for a boolean "is_running".
|
493 |
+
return self.lib.rkllm_is_running(self.llm_handle) == 0
|
494 |
+
|
495 |
+
def clear_kv_cache(self, keep_system_prompt: bool) -> int:
|
496 |
+
"""Clears the key-value cache."""
|
497 |
+
ret = self.lib.rkllm_clear_kv_cache(self.llm_handle, ctypes.c_int(1 if keep_system_prompt else 0))
|
498 |
+
if ret != 0:
|
499 |
+
raise RuntimeError(f"rkllm_clear_kv_cache failed with error code {ret}")
|
500 |
+
return ret
|
501 |
+
|
502 |
+
def set_chat_template(self, system_prompt: str, prompt_prefix: str, prompt_postfix: str) -> int:
|
503 |
+
"""Sets the chat template for the LLM."""
|
504 |
+
c_system = system_prompt.encode('utf-8') if system_prompt else b""
|
505 |
+
c_prefix = prompt_prefix.encode('utf-8') if prompt_prefix else b""
|
506 |
+
c_postfix = prompt_postfix.encode('utf-8') if prompt_postfix else b""
|
507 |
+
|
508 |
+
ret = self.lib.rkllm_set_chat_template(self.llm_handle, c_system, c_prefix, c_postfix)
|
509 |
+
if ret != 0:
|
510 |
+
raise RuntimeError(f"rkllm_set_chat_template failed with error code {ret}")
|
511 |
+
return ret
|
512 |
+
|
513 |
+
def __enter__(self):
|
514 |
+
return self
|
515 |
+
|
516 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
517 |
+
self.destroy()
|
518 |
+
|
519 |
+
def __del__(self):
|
520 |
+
self.destroy() # Ensure resources are freed if object is garbage collected
|
521 |
+
|
522 |
+
# --- Example Usage (Illustrative) ---
|
523 |
+
if __name__ == "__main__":
|
524 |
+
# This is a placeholder for how you might use it.
|
525 |
+
# You'll need a valid .rkllm model and librkllmrt.so in your path.
|
526 |
+
|
527 |
+
# Global list to store results from callback for demonstration
|
528 |
+
results_buffer = []
|
529 |
+
|
530 |
+
def my_python_callback(result_ptr, userdata_ptr, state_enum):
|
531 |
+
"""
|
532 |
+
Callback function to be called by the C library.
|
533 |
+
"""
|
534 |
+
global results_buffer
|
535 |
+
state = LLMCallState(state_enum)
|
536 |
+
result = result_ptr.contents
|
537 |
+
|
538 |
+
current_text = ""
|
539 |
+
if result.text: # Check if the char_p is not NULL
|
540 |
+
current_text = result.text.decode('utf-8', errors='ignore')
|
541 |
+
|
542 |
+
print(f"Callback: State={state.name}, TokenID={result.token_id}, Text='{current_text}'")
|
543 |
+
results_buffer.append(current_text)
|
544 |
+
|
545 |
+
if state == LLMCallState.RKLLM_RUN_FINISH:
|
546 |
+
print("Inference finished.")
|
547 |
+
elif state == LLMCallState.RKLLM_RUN_ERROR:
|
548 |
+
print("Inference error.")
|
549 |
+
|
550 |
+
# Example: Accessing logits if available (and if mode was set to get logits)
|
551 |
+
# if result.logits.logits and result.logits.vocab_size > 0:
|
552 |
+
# print(f" Logits (first 5 of vocab_size {result.logits.vocab_size}):")
|
553 |
+
# for i in range(min(5, result.logits.vocab_size)):
|
554 |
+
# print(f" {result.logits.logits[i]:.4f}", end=" ")
|
555 |
+
# print()
|
556 |
+
|
557 |
+
|
558 |
+
# --- Attempt to use the wrapper ---
|
559 |
+
try:
|
560 |
+
print("Initializing RKLLMRuntime...")
|
561 |
+
# Adjust library_path if librkllmrt.so is not in default search paths
|
562 |
+
# e.g., library_path="./path/to/librkllmrt.so"
|
563 |
+
rk_llm = RKLLMRuntime()
|
564 |
+
|
565 |
+
print("Creating default parameters...")
|
566 |
+
params = rk_llm.create_default_param()
|
567 |
+
|
568 |
+
# --- Configure parameters ---
|
569 |
+
# THIS IS CRITICAL: model_path must point to an actual .rkllm file
|
570 |
+
# For this example to run, you need a model file.
|
571 |
+
# Let's assume a dummy path for now, this will fail at init if not valid.
|
572 |
+
model_file = "dummy_model.rkllm"
|
573 |
+
if not os.path.exists(model_file):
|
574 |
+
print(f"Warning: Model file '{model_file}' does not exist. Init will likely fail.")
|
575 |
+
# Create a dummy file for the example to proceed further, though init will still fail
|
576 |
+
# with a real library unless it's a valid model.
|
577 |
+
with open(model_file, "w") as f:
|
578 |
+
f.write("dummy content")
|
579 |
+
|
580 |
+
params.model_path = model_file.encode('utf-8')
|
581 |
+
params.max_context_len = 512
|
582 |
+
params.max_new_tokens = 128
|
583 |
+
params.top_k = 1 # Greedy
|
584 |
+
params.temperature = 0.7
|
585 |
+
params.repeat_penalty = 1.1
|
586 |
+
# ... set other params as needed
|
587 |
+
|
588 |
+
print(f"Initializing LLM with model: {params.model_path.decode()}...")
|
589 |
+
# This will likely fail if dummy_model.rkllm is not a valid model recognized by the library
|
590 |
+
try:
|
591 |
+
rk_llm.init(params, my_python_callback)
|
592 |
+
print("LLM Initialized.")
|
593 |
+
except RuntimeError as e:
|
594 |
+
print(f"Error during LLM initialization: {e}")
|
595 |
+
print("This is expected if 'dummy_model.rkllm' is not a valid model.")
|
596 |
+
print("Replace 'dummy_model.rkllm' with a real model path to test further.")
|
597 |
+
exit()
|
598 |
+
|
599 |
+
|
600 |
+
# --- Prepare input ---
|
601 |
+
print("Preparing input...")
|
602 |
+
rk_input = RKLLMInput()
|
603 |
+
rk_input.input_type = RKLLMInputType.RKLLM_INPUT_PROMPT
|
604 |
+
|
605 |
+
prompt_text = "Translate the following English text to French: 'Hello, world!'"
|
606 |
+
c_prompt = prompt_text.encode('utf-8')
|
607 |
+
rk_input._union_data.prompt_input = c_prompt # Accessing union member directly
|
608 |
+
|
609 |
+
# --- Prepare inference parameters ---
|
610 |
+
print("Preparing inference parameters...")
|
611 |
+
infer_params = RKLLMInferParam()
|
612 |
+
infer_params.mode = RKLLMInferMode.RKLLM_INFER_GENERATE
|
613 |
+
infer_params.keep_history = 1 # True
|
614 |
+
# infer_params.lora_params = None # or set up RKLLMLoraParam if using LoRA
|
615 |
+
# infer_params.prompt_cache_params = None # or set up RKLLMPromptCacheParam
|
616 |
+
|
617 |
+
# --- Run inference ---
|
618 |
+
print(f"Running inference with prompt: '{prompt_text}'")
|
619 |
+
results_buffer.clear()
|
620 |
+
try:
|
621 |
+
rk_llm.run(rk_input, infer_params) # Userdata is None by default
|
622 |
+
print("\n--- Full Response ---")
|
623 |
+
print("".join(results_buffer))
|
624 |
+
print("---------------------\n")
|
625 |
+
except RuntimeError as e:
|
626 |
+
print(f"Error during LLM run: {e}")
|
627 |
+
|
628 |
+
|
629 |
+
# --- Example: Set chat template (if model supports it) ---
|
630 |
+
# print("Setting chat template...")
|
631 |
+
# try:
|
632 |
+
# rk_llm.set_chat_template("You are a helpful assistant.", "<user>: ", "<assistant>: ")
|
633 |
+
# print("Chat template set.")
|
634 |
+
# except RuntimeError as e:
|
635 |
+
# print(f"Error setting chat template: {e}")
|
636 |
+
|
637 |
+
# --- Example: Clear KV Cache ---
|
638 |
+
# print("Clearing KV cache (keeping system prompt if any)...")
|
639 |
+
# try:
|
640 |
+
# rk_llm.clear_kv_cache(keep_system_prompt=True)
|
641 |
+
# print("KV cache cleared.")
|
642 |
+
# except RuntimeError as e:
|
643 |
+
# print(f"Error clearing KV cache: {e}")
|
644 |
+
|
645 |
+
except OSError as e:
|
646 |
+
print(f"OSError: {e}. Could not load the RKLLM library.")
|
647 |
+
print("Please ensure 'librkllmrt.so' is in your LD_LIBRARY_PATH or provide the full path.")
|
648 |
+
except Exception as e:
|
649 |
+
print(f"An unexpected error occurred: {e}")
|
650 |
+
finally:
|
651 |
+
if 'rk_llm' in locals() and rk_llm.llm_handle and rk_llm.llm_handle.value:
|
652 |
+
print("Destroying LLM instance...")
|
653 |
+
rk_llm.destroy()
|
654 |
+
print("LLM instance destroyed.")
|
655 |
+
if os.path.exists(model_file) and model_file == "dummy_model.rkllm":
|
656 |
+
os.remove(model_file) # Clean up dummy file
|
657 |
+
|
658 |
+
print("Example finished.")
|
test_embedding.py
ADDED
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
Qwen3-Embedding-0.6B 推理测试代码
|
5 |
+
使用 RKLLM API 进行文本嵌入推理
|
6 |
+
"""
|
7 |
+
import faulthandler
|
8 |
+
faulthandler.enable()
|
9 |
+
import os
|
10 |
+
os.environ["RKLLM_LOG_LEVEL"] = "1"
|
11 |
+
import numpy as np
|
12 |
+
import time
|
13 |
+
from typing import List, Dict, Any
|
14 |
+
from rkllm_binding import *
|
15 |
+
|
16 |
+
|
17 |
+
class Qwen3EmbeddingTester:
|
18 |
+
def __init__(self, model_path: str, library_path: str = "./librkllmrt.so"):
|
19 |
+
"""
|
20 |
+
初始化 Qwen3 嵌入模型测试器
|
21 |
+
|
22 |
+
Args:
|
23 |
+
model_path: 模型文件路径(.rkllm 格式)
|
24 |
+
library_path: RKLLM 库文件路径
|
25 |
+
"""
|
26 |
+
self.model_path = model_path
|
27 |
+
self.library_path = library_path
|
28 |
+
self.runtime = None
|
29 |
+
self.embeddings_buffer = []
|
30 |
+
self.current_result = None
|
31 |
+
|
32 |
+
def callback_function(self, result_ptr, userdata_ptr, state_enum):
|
33 |
+
"""
|
34 |
+
推理回调函数
|
35 |
+
|
36 |
+
Args:
|
37 |
+
result_ptr: 结果指针
|
38 |
+
userdata_ptr: 用户数据指针
|
39 |
+
state_enum: 状态枚举
|
40 |
+
"""
|
41 |
+
state = LLMCallState(state_enum)
|
42 |
+
|
43 |
+
if state == LLMCallState.RKLLM_RUN_NORMAL:
|
44 |
+
result = result_ptr.contents
|
45 |
+
print(f"result: {result}")
|
46 |
+
# 获取最后隐藏层输出作为嵌入
|
47 |
+
if result.last_hidden_layer.hidden_states and result.last_hidden_layer.embd_size > 0:
|
48 |
+
embd_size = result.last_hidden_layer.embd_size
|
49 |
+
num_tokens = result.last_hidden_layer.num_tokens
|
50 |
+
|
51 |
+
print(f"获取到嵌入向量:维度={embd_size}, 令牌数={num_tokens}")
|
52 |
+
|
53 |
+
# 将 C 数组转换为 numpy 数组
|
54 |
+
# 这里我们取最后一个 token 的隐藏状态作为句子嵌入
|
55 |
+
if num_tokens > 0:
|
56 |
+
# 获取最后一个 token 的嵌入(shape: [embd_size])
|
57 |
+
last_token_embedding = np.array([
|
58 |
+
result.last_hidden_layer.hidden_states[(num_tokens-1) * embd_size + i]
|
59 |
+
for i in range(embd_size)
|
60 |
+
])
|
61 |
+
|
62 |
+
self.current_result = {
|
63 |
+
'embedding': last_token_embedding,
|
64 |
+
'embd_size': embd_size,
|
65 |
+
'num_tokens': num_tokens
|
66 |
+
}
|
67 |
+
|
68 |
+
print(f"嵌入向量范数: {np.linalg.norm(last_token_embedding):.4f}")
|
69 |
+
print(f"嵌入向量前10维: {last_token_embedding[:10]}")
|
70 |
+
|
71 |
+
elif state == LLMCallState.RKLLM_RUN_ERROR:
|
72 |
+
print("推理过程发生错误")
|
73 |
+
|
74 |
+
def init_model(self):
|
75 |
+
"""初始化模型"""
|
76 |
+
try:
|
77 |
+
print(f"初始化 RKLLM 运行时,库路径: {self.library_path}")
|
78 |
+
self.runtime = RKLLMRuntime(self.library_path)
|
79 |
+
|
80 |
+
print("创建默认参数...")
|
81 |
+
params = self.runtime.create_default_param()
|
82 |
+
|
83 |
+
# 配置参数
|
84 |
+
params.model_path = self.model_path.encode('utf-8')
|
85 |
+
params.max_context_len = 1024 # 设置上下文长度
|
86 |
+
params.max_new_tokens = 1 # 嵌入任务不需要生成新token
|
87 |
+
params.temperature = 1.0 # 嵌入任务温度设置
|
88 |
+
params.top_k = 1 # 嵌入任务不需要采样
|
89 |
+
params.top_p = 1.0 # 嵌入任务不需要采样
|
90 |
+
|
91 |
+
# 扩展参数配置
|
92 |
+
params.extend_param.base_domain_id = 1 # 建议为 >1B 模型设置为1
|
93 |
+
params.extend_param.embed_flash = 0 # 是否使用flash存储Embedding
|
94 |
+
params.extend_param.enabled_cpus_num = 4 # 启用的CPU核心数
|
95 |
+
params.extend_param.enabled_cpus_mask = 0x0F # CPU核心掩码
|
96 |
+
|
97 |
+
print(f"初始化模型: {self.model_path}")
|
98 |
+
self.runtime.init(params, self.callback_function)
|
99 |
+
self.runtime.set_chat_template("","","")
|
100 |
+
print("模型初始化成功!")
|
101 |
+
|
102 |
+
except Exception as e:
|
103 |
+
print(f"模型初始化失败: {e}")
|
104 |
+
raise
|
105 |
+
|
106 |
+
def get_detailed_instruct(self, task_description: str, query: str) -> str:
|
107 |
+
"""
|
108 |
+
构建指令提示词(参考 README 中的用法)
|
109 |
+
|
110 |
+
Args:
|
111 |
+
task_description: 任务描述
|
112 |
+
query: 查询文本
|
113 |
+
|
114 |
+
Returns:
|
115 |
+
格式化的指令提示词
|
116 |
+
"""
|
117 |
+
return f'Instruct: {task_description}\nQuery: {query}'
|
118 |
+
|
119 |
+
def encode_text(self, text: str, task_description: str = None) -> np.ndarray:
|
120 |
+
"""
|
121 |
+
编码文本为嵌入向量
|
122 |
+
|
123 |
+
Args:
|
124 |
+
text: 要编码的文本
|
125 |
+
task_description: 任务描述,如果提供则使用指令提示
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
嵌入向量(numpy数组)
|
129 |
+
"""
|
130 |
+
try:
|
131 |
+
# 如果提供了任务描述,则使用指令提示
|
132 |
+
if task_description:
|
133 |
+
input_text = self.get_detailed_instruct(task_description, text)
|
134 |
+
else:
|
135 |
+
input_text = text
|
136 |
+
|
137 |
+
print(f"编码文本: {input_text[:100]}{'...' if len(input_text) > 100 else ''}")
|
138 |
+
|
139 |
+
# 准备输入
|
140 |
+
rk_input = RKLLMInput()
|
141 |
+
rk_input.input_type = RKLLMInputType.RKLLM_INPUT_PROMPT
|
142 |
+
c_prompt = input_text.encode('utf-8')
|
143 |
+
rk_input._union_data.prompt_input = c_prompt
|
144 |
+
|
145 |
+
# 准备推理参数
|
146 |
+
infer_params = RKLLMInferParam()
|
147 |
+
infer_params.mode = RKLLMInferMode.RKLLM_INFER_GET_LAST_HIDDEN_LAYER # 获取隐藏层输出
|
148 |
+
infer_params.keep_history = 0 # 不保留历史
|
149 |
+
|
150 |
+
# 清空之前的结果
|
151 |
+
self.current_result = None
|
152 |
+
self.runtime.clear_kv_cache(False)
|
153 |
+
|
154 |
+
# 执行推理
|
155 |
+
start_time = time.time()
|
156 |
+
self.runtime.run(rk_input, infer_params)
|
157 |
+
end_time = time.time()
|
158 |
+
|
159 |
+
print(f"推理耗时: {end_time - start_time:.3f}秒")
|
160 |
+
|
161 |
+
if self.current_result and 'embedding' in self.current_result:
|
162 |
+
# 对嵌入向量进行L2标准化
|
163 |
+
embedding = self.current_result['embedding']
|
164 |
+
normalized_embedding = embedding / np.linalg.norm(embedding)
|
165 |
+
return normalized_embedding
|
166 |
+
else:
|
167 |
+
raise RuntimeError("未能获取到有效的嵌入向量")
|
168 |
+
|
169 |
+
except Exception as e:
|
170 |
+
print(f"编码文本时发生错误: {e}")
|
171 |
+
raise
|
172 |
+
|
173 |
+
def compute_similarity(self, emb1: np.ndarray, emb2: np.ndarray) -> float:
|
174 |
+
"""
|
175 |
+
计算两个嵌入向量的余弦相似度
|
176 |
+
|
177 |
+
Args:
|
178 |
+
emb1: 第一个嵌入向量
|
179 |
+
emb2: 第二个嵌入向量
|
180 |
+
|
181 |
+
Returns:
|
182 |
+
余弦相似度值
|
183 |
+
"""
|
184 |
+
return np.dot(emb1, emb2)
|
185 |
+
|
186 |
+
def test_embedding_similarity(self):
|
187 |
+
"""测试嵌入相似度计算"""
|
188 |
+
print("\n" + "="*50)
|
189 |
+
print("测试嵌入相似度计算")
|
190 |
+
print("="*50)
|
191 |
+
|
192 |
+
# 测试文本
|
193 |
+
task_description = "Given a web search query, retrieve relevant passages that answer the query"
|
194 |
+
|
195 |
+
queries = [
|
196 |
+
"What is the capital of China?",
|
197 |
+
"Explain gravity"
|
198 |
+
]
|
199 |
+
|
200 |
+
documents = [
|
201 |
+
"The capital of China is Beijing.",
|
202 |
+
"Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun."
|
203 |
+
]
|
204 |
+
|
205 |
+
# 编码查询(使用指令)
|
206 |
+
print("\n编码查询文本:")
|
207 |
+
query_embeddings = []
|
208 |
+
for i, query in enumerate(queries):
|
209 |
+
print(f"\n查询 {i+1}: {query}")
|
210 |
+
emb = self.encode_text(query, task_description)
|
211 |
+
query_embeddings.append(emb)
|
212 |
+
|
213 |
+
# 编码文档(不使用指令)
|
214 |
+
print("\n编码文档文本:")
|
215 |
+
doc_embeddings = []
|
216 |
+
for i, doc in enumerate(documents):
|
217 |
+
print(f"\n文档 {i+1}: {doc}")
|
218 |
+
emb = self.encode_text(doc)
|
219 |
+
doc_embeddings.append(emb)
|
220 |
+
|
221 |
+
# 计算相似度矩阵
|
222 |
+
print("\n计算相似度矩阵:")
|
223 |
+
print("查询 vs 文档相似度:")
|
224 |
+
print("-" * 30)
|
225 |
+
|
226 |
+
similarities = []
|
227 |
+
for i, q_emb in enumerate(query_embeddings):
|
228 |
+
row_similarities = []
|
229 |
+
for j, d_emb in enumerate(doc_embeddings):
|
230 |
+
sim = self.compute_similarity(q_emb, d_emb)
|
231 |
+
row_similarities.append(sim)
|
232 |
+
print(f"查询{i+1} vs 文档{j+1}: {sim:.4f}")
|
233 |
+
similarities.append(row_similarities)
|
234 |
+
print()
|
235 |
+
|
236 |
+
return similarities
|
237 |
+
|
238 |
+
def test_multilingual_embedding(self):
|
239 |
+
"""测试多语言嵌入能力"""
|
240 |
+
print("\n" + "="*50)
|
241 |
+
print("测试多语言嵌入能力")
|
242 |
+
print("="*50)
|
243 |
+
|
244 |
+
# 多语言测试文本(相同含义的不同语言)
|
245 |
+
texts = {
|
246 |
+
"英语": "Hello, how are you?",
|
247 |
+
"中文": "你好,你好吗?",
|
248 |
+
"法语": "Bonjour, comment allez-vous?",
|
249 |
+
"西班牙语": "Hola, ¿cómo estás?",
|
250 |
+
"日语": "こんにちは、元気ですか?"
|
251 |
+
}
|
252 |
+
|
253 |
+
embeddings = {}
|
254 |
+
print("\n编码多语言文本:")
|
255 |
+
for lang, text in texts.items():
|
256 |
+
print(f"\n{lang}: {text}")
|
257 |
+
emb = self.encode_text(text)
|
258 |
+
embeddings[lang] = emb
|
259 |
+
|
260 |
+
# 计算跨语言相似度
|
261 |
+
print("\n跨语言相似度:")
|
262 |
+
print("-" * 30)
|
263 |
+
|
264 |
+
languages = list(texts.keys())
|
265 |
+
for i, lang1 in enumerate(languages):
|
266 |
+
for j, lang2 in enumerate(languages):
|
267 |
+
if i <= j:
|
268 |
+
sim = self.compute_similarity(embeddings[lang1], embeddings[lang2])
|
269 |
+
print(f"{lang1} vs {lang2}: {sim:.4f}")
|
270 |
+
|
271 |
+
def test_code_embedding(self):
|
272 |
+
"""测试代码嵌入能力"""
|
273 |
+
print("\n" + "="*50)
|
274 |
+
print("测试代码嵌入能力")
|
275 |
+
print("="*50)
|
276 |
+
|
277 |
+
# 代码示例
|
278 |
+
codes = {
|
279 |
+
"Python函数": """
|
280 |
+
def fibonacci(n):
|
281 |
+
if n <= 1:
|
282 |
+
return n
|
283 |
+
return fibonacci(n-1) + fibonacci(n-2)
|
284 |
+
""",
|
285 |
+
"JavaScript函数": """
|
286 |
+
function fibonacci(n) {
|
287 |
+
if (n <= 1) return n;
|
288 |
+
return fibonacci(n-1) + fibonacci(n-2);
|
289 |
+
}
|
290 |
+
""",
|
291 |
+
"C++函数": """
|
292 |
+
int fibonacci(int n) {
|
293 |
+
if (n <= 1) return n;
|
294 |
+
return fibonacci(n-1) + fibonacci(n-2);
|
295 |
+
}
|
296 |
+
""",
|
297 |
+
"数组排序": """
|
298 |
+
def bubble_sort(arr):
|
299 |
+
n = len(arr)
|
300 |
+
for i in range(n):
|
301 |
+
for j in range(0, n-i-1):
|
302 |
+
if arr[j] > arr[j+1]:
|
303 |
+
arr[j], arr[j+1] = arr[j+1], arr[j]
|
304 |
+
"""
|
305 |
+
}
|
306 |
+
|
307 |
+
embeddings = {}
|
308 |
+
print("\n编码代码文本:")
|
309 |
+
for name, code in codes.items():
|
310 |
+
print(f"\n{name}:")
|
311 |
+
print(code[:100] + "..." if len(code) > 100 else code)
|
312 |
+
emb = self.encode_text(code)
|
313 |
+
embeddings[name] = emb
|
314 |
+
|
315 |
+
# 计算代码相似度
|
316 |
+
print("\n代码相似度:")
|
317 |
+
print("-" * 30)
|
318 |
+
|
319 |
+
code_names = list(codes.keys())
|
320 |
+
for i, name1 in enumerate(code_names):
|
321 |
+
for j, name2 in enumerate(code_names):
|
322 |
+
if i <= j:
|
323 |
+
sim = self.compute_similarity(embeddings[name1], embeddings[name2])
|
324 |
+
print(f"{name1} vs {name2}: {sim:.4f}")
|
325 |
+
|
326 |
+
def cleanup(self):
|
327 |
+
"""清理资源"""
|
328 |
+
if self.runtime:
|
329 |
+
try:
|
330 |
+
self.runtime.destroy()
|
331 |
+
print("模型资源已清理")
|
332 |
+
except Exception as e:
|
333 |
+
print(f"清理资源时发生错误: {e}")
|
334 |
+
|
335 |
+
def main():
|
336 |
+
"""主函数"""
|
337 |
+
import argparse
|
338 |
+
|
339 |
+
# 解析命令行参数
|
340 |
+
parser = argparse.ArgumentParser(description='Qwen3-Embedding-0.6B 推理测试')
|
341 |
+
parser.add_argument('model_path', help='模型文件路径(.rkllm格式)')
|
342 |
+
parser.add_argument('--library_path', default="./librkllmrt.so", help='RKLLM库文件路径(默认为./librkllmrt.so)')
|
343 |
+
args = parser.parse_args()
|
344 |
+
|
345 |
+
# 检查文件是否存在
|
346 |
+
if not os.path.exists(args.model_path):
|
347 |
+
print(f"错误: 模型文件不存在: {args.model_path}")
|
348 |
+
print("请确保:")
|
349 |
+
print("1. 已下载 Qwen3-Embedding-0.6B 模型")
|
350 |
+
print("2. 已使用 rkllm-convert.py 将模型转换为 .rkllm 格式")
|
351 |
+
return
|
352 |
+
|
353 |
+
if not os.path.exists(args.library_path):
|
354 |
+
print(f"错误: RKLLM 库文件不存在: {args.library_path}")
|
355 |
+
print("请确保 librkllmrt.so 在当前目录或 LD_LIBRARY_PATH 中")
|
356 |
+
return
|
357 |
+
|
358 |
+
print("Qwen3-Embedding-0.6B 推理测试")
|
359 |
+
print("=" * 50)
|
360 |
+
|
361 |
+
# 创建测试器
|
362 |
+
tester = Qwen3EmbeddingTester(args.model_path, args.library_path)
|
363 |
+
|
364 |
+
try:
|
365 |
+
# 初始化模型
|
366 |
+
tester.init_model()
|
367 |
+
|
368 |
+
# 运行测试
|
369 |
+
print("\n开始运行嵌入测试...")
|
370 |
+
|
371 |
+
# 测试基础嵌入相似度
|
372 |
+
tester.test_embedding_similarity()
|
373 |
+
|
374 |
+
# 测试多语言嵌入
|
375 |
+
tester.test_multilingual_embedding()
|
376 |
+
|
377 |
+
# 测试代码嵌入
|
378 |
+
tester.test_code_embedding()
|
379 |
+
|
380 |
+
print("\n" + "="*50)
|
381 |
+
print("所有测试完成!")
|
382 |
+
print("="*50)
|
383 |
+
|
384 |
+
except KeyboardInterrupt:
|
385 |
+
print("\n测试被用户中断")
|
386 |
+
except Exception as e:
|
387 |
+
print(f"\n测试过程中发生错误: {e}")
|
388 |
+
import traceback
|
389 |
+
traceback.print_exc()
|
390 |
+
finally:
|
391 |
+
# 清理资源
|
392 |
+
tester.cleanup()
|
393 |
+
|
394 |
+
|
395 |
+
if __name__ == "__main__":
|
396 |
+
main()
|