happyme531 commited on
Commit
c8faba2
·
verified ·
1 Parent(s): 40dc221

Upload 6 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ librkllmrt.so filter=lfs diff=lfs merge=lfs -text
37
+ Qwen3-Embedding-0.6B_f16.rkllm filter=lfs diff=lfs merge=lfs -text
38
+ Qwen3-Embedding-0.6B_w8a8.rkllm filter=lfs diff=lfs merge=lfs -text
Qwen3-Embedding-0.6B_f16.rkllm ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4a383a715837f432a8af06798e236ab70a30b80d55cdbe4d7e4d489de639310b
3
+ size 1524801182
Qwen3-Embedding-0.6B_w8a8.rkllm ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:be4ffc5c46b6246c71459081c5fdb8b8984c52936fa23083e342f4a843945806
3
+ size 931372078
librkllmrt.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f6a9c2de93cf94bb524eb071c27190ad4c83401e01b562534f265dff4cb40da2
3
+ size 6710712
rkllm-convert.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from rkllm.api import RKLLM
3
+
4
+ def convert_model(model_path, output_name, do_quantization=False):
5
+ """转换单个模型"""
6
+ llm = RKLLM()
7
+
8
+ print(f"正在加载模型: {model_path}")
9
+ ret = llm.load_huggingface(model=model_path, model_lora=None, device='cpu')
10
+ if ret != 0:
11
+ print(f'加载模型失败: {model_path}')
12
+ return ret
13
+
14
+ print(f"正在构建模型: {output_name} (量化: {do_quantization})")
15
+ qparams = None
16
+ ret = llm.build(do_quantization=do_quantization, optimization_level=1, quantized_dtype='w8a8',
17
+ quantized_algorithm='normal', target_platform='rk3588', num_npu_core=3, extra_qparams=qparams)
18
+
19
+ if ret != 0:
20
+ print(f'构建模型失败: {output_name}')
21
+ return ret
22
+
23
+ # 导出rkllm模型
24
+ print(f"正在导出模型: {output_name}")
25
+ ret = llm.export_rkllm(output_name)
26
+ if ret != 0:
27
+ print(f'导出模型失败: {output_name}')
28
+ return ret
29
+
30
+ print(f"成功转换: {output_name}")
31
+ return 0
32
+
33
+ def main():
34
+ """主函数:遍历所有子文件夹并转换模型"""
35
+ current_dir = '.'
36
+
37
+ # 获取所有子文件夹
38
+ subdirs = [d for d in os.listdir(current_dir)
39
+ if os.path.isdir(os.path.join(current_dir, d)) and not d.startswith('.')]
40
+
41
+ print(f"找到 {len(subdirs)} 个模型文件夹: {subdirs}")
42
+
43
+ for subdir in subdirs:
44
+ model_path = os.path.join(current_dir, subdir)
45
+
46
+ # 生成输出文件名
47
+ base_name = subdir.replace('/', '_').replace('\\', '_')
48
+ quantized_output = f"{base_name}_w8a8.rkllm"
49
+ unquantized_output = f"{base_name}_f16.rkllm"
50
+
51
+ print(f"\n{'='*50}")
52
+ print(f"处理模型文件夹: {subdir}")
53
+ print(f"{'='*50}")
54
+
55
+ # 转换非量化版本
56
+ print(f"\n--- 转换非量化版本 ---")
57
+ ret = convert_model(model_path, unquantized_output, do_quantization=False)
58
+ if ret != 0:
59
+ print(f"非量化版本转换失败: {subdir}")
60
+ continue
61
+
62
+ # 转换量化版本
63
+ print(f"\n--- 转换量化版本 ---")
64
+ ret = convert_model(model_path, quantized_output, do_quantization=True)
65
+ if ret != 0:
66
+ print(f"量化版本转换失败: {subdir}")
67
+ continue
68
+
69
+ print(f"\n✓ {subdir} 模型转换完成!")
70
+ print(f" - 非量化版本: {unquantized_output}")
71
+ print(f" - 量化版本: {quantized_output}")
72
+
73
+ if __name__ == "__main__":
74
+ main()
rkllm_binding.py ADDED
@@ -0,0 +1,658 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import enum
3
+ import os
4
+
5
+ # Define constants from the header
6
+ CPU0 = (1 << 0) # 0x01
7
+ CPU1 = (1 << 1) # 0x02
8
+ CPU2 = (1 << 2) # 0x04
9
+ CPU3 = (1 << 3) # 0x08
10
+ CPU4 = (1 << 4) # 0x10
11
+ CPU5 = (1 << 5) # 0x20
12
+ CPU6 = (1 << 6) # 0x40
13
+ CPU7 = (1 << 7) # 0x80
14
+
15
+ # --- Enums ---
16
+ class LLMCallState(enum.IntEnum):
17
+ RKLLM_RUN_NORMAL = 0
18
+ RKLLM_RUN_WAITING = 1
19
+ RKLLM_RUN_FINISH = 2
20
+ RKLLM_RUN_ERROR = 3
21
+
22
+ class RKLLMInputType(enum.IntEnum):
23
+ RKLLM_INPUT_PROMPT = 0
24
+ RKLLM_INPUT_TOKEN = 1
25
+ RKLLM_INPUT_EMBED = 2
26
+ RKLLM_INPUT_MULTIMODAL = 3
27
+
28
+ class RKLLMInferMode(enum.IntEnum):
29
+ RKLLM_INFER_GENERATE = 0
30
+ RKLLM_INFER_GET_LAST_HIDDEN_LAYER = 1
31
+ RKLLM_INFER_GET_LOGITS = 2
32
+
33
+ # --- Structures ---
34
+ class RKLLMExtendParam(ctypes.Structure):
35
+ # 基础iommu domain ID, 对>1b的模型建议设置为1
36
+ base_domain_id: ctypes.c_int32
37
+ # 是否使用flash存储Embedding
38
+ embed_flash: ctypes.c_int8
39
+ # 启用的cpu核心数
40
+ enabled_cpus_num: ctypes.c_int8
41
+ # 启用的cpu核心掩码
42
+ enabled_cpus_mask: ctypes.c_uint32
43
+ reserved: ctypes.c_uint8 * 106
44
+
45
+ _fields_ = [
46
+ ("base_domain_id", ctypes.c_int32),
47
+ ("embed_flash", ctypes.c_int8),
48
+ ("enabled_cpus_num", ctypes.c_int8),
49
+ ("enabled_cpus_mask", ctypes.c_uint32),
50
+ ("reserved", ctypes.c_uint8 * 106)
51
+ ]
52
+
53
+ class RKLLMParam(ctypes.Structure):
54
+ # 模型文件路径
55
+ model_path: ctypes.c_char_p
56
+ # 上下文窗口最大token数
57
+ max_context_len: ctypes.c_int32
58
+ # 最大生成新token数
59
+ max_new_tokens: ctypes.c_int32
60
+ # Top-K采样参数
61
+ top_k: ctypes.c_int32
62
+ # 上下文窗口移动时保留的kv缓存数量
63
+ n_keep: ctypes.c_int32
64
+ # Top-P采样参数
65
+ top_p: ctypes.c_float
66
+ # 采样温度,影响token选择的随机性
67
+ temperature: ctypes.c_float
68
+ # 重复token惩罚
69
+ repeat_penalty: ctypes.c_float
70
+ # 频繁token惩罚
71
+ frequency_penalty: ctypes.c_float
72
+ # 输入中已存在token的惩罚
73
+ presence_penalty: ctypes.c_float
74
+ # Mirostat采样策略标志(0表示禁用)
75
+ mirostat: ctypes.c_int32
76
+ # Mirostat采样Tau参数
77
+ mirostat_tau: ctypes.c_float
78
+ # Mirostat采样Eta参数
79
+ mirostat_eta: ctypes.c_float
80
+ # 是否跳过特殊token
81
+ skip_special_token: ctypes.c_bool
82
+ # 是否异步推理
83
+ is_async: ctypes.c_bool
84
+ # 多模态输入中图像的起始Token
85
+ img_start: ctypes.c_char_p
86
+ # 多模态输入中图像的结束Token
87
+ img_end: ctypes.c_char_p
88
+ # 图像内容指针
89
+ img_content: ctypes.c_char_p
90
+ # 扩展参数
91
+ extend_param: RKLLMExtendParam
92
+
93
+ _fields_ = [
94
+ ("model_path", ctypes.c_char_p), # 模型文件路径
95
+ ("max_context_len", ctypes.c_int32), # 上下文窗口最大token数
96
+ ("max_new_tokens", ctypes.c_int32), # 最大生成新token数
97
+ ("top_k", ctypes.c_int32), # Top-K采样参数
98
+ ("n_keep", ctypes.c_int32), # 上下文窗口移动时保留的kv缓存数量
99
+ ("top_p", ctypes.c_float), # Top-P(nucleus)采样参数
100
+ ("temperature", ctypes.c_float), # 采样温度,影响token选择的随机性
101
+ ("repeat_penalty", ctypes.c_float), # 重复token惩罚
102
+ ("frequency_penalty", ctypes.c_float), # 频繁token惩罚
103
+ ("presence_penalty", ctypes.c_float), # 输入中已存在token的惩罚
104
+ ("mirostat", ctypes.c_int32), # Mirostat采样策略标志(0表示禁用)
105
+ ("mirostat_tau", ctypes.c_float), # Mirostat采样Tau参数
106
+ ("mirostat_eta", ctypes.c_float), # Mirostat采样Eta参数
107
+ ("skip_special_token", ctypes.c_bool), # 是否跳过特殊token
108
+ ("is_async", ctypes.c_bool), # 是否异步推理
109
+ ("img_start", ctypes.c_char_p), # 多模态输入中图像的起始Token
110
+ ("img_end", ctypes.c_char_p), # 多模态输入中图像的结束Token
111
+ ("img_content", ctypes.c_char_p), # 图像内容指针
112
+ ("extend_param", RKLLMExtendParam) # 扩展参数
113
+ ]
114
+
115
+ class RKLLMLoraAdapter(ctypes.Structure):
116
+ lora_adapter_path: ctypes.c_char_p
117
+ lora_adapter_name: ctypes.c_char_p
118
+ scale: ctypes.c_float
119
+
120
+ _fields_ = [
121
+ ("lora_adapter_path", ctypes.c_char_p),
122
+ ("lora_adapter_name", ctypes.c_char_p),
123
+ ("scale", ctypes.c_float)
124
+ ]
125
+
126
+ class RKLLMEmbedInput(ctypes.Structure):
127
+ # Shape: [n_tokens, embed_size]
128
+ embed: ctypes.POINTER(ctypes.c_float)
129
+ n_tokens: ctypes.c_size_t
130
+
131
+ _fields_ = [
132
+ ("embed", ctypes.POINTER(ctypes.c_float)),
133
+ ("n_tokens", ctypes.c_size_t)
134
+ ]
135
+
136
+ class RKLLMTokenInput(ctypes.Structure):
137
+ # Shape: [n_tokens]
138
+ input_ids: ctypes.POINTER(ctypes.c_int32)
139
+ n_tokens: ctypes.c_size_t
140
+
141
+ _fields_ = [
142
+ ("input_ids", ctypes.POINTER(ctypes.c_int32)),
143
+ ("n_tokens", ctypes.c_size_t)
144
+ ]
145
+
146
+ class RKLLMMultiModelInput(ctypes.Structure):
147
+ prompt: ctypes.c_char_p
148
+ image_embed: ctypes.POINTER(ctypes.c_float)
149
+ n_image_tokens: ctypes.c_size_t
150
+ n_image: ctypes.c_size_t
151
+ image_width: ctypes.c_size_t
152
+ image_height: ctypes.c_size_t
153
+
154
+ _fields_ = [
155
+ ("prompt", ctypes.c_char_p),
156
+ ("image_embed", ctypes.POINTER(ctypes.c_float)),
157
+ ("n_image_tokens", ctypes.c_size_t),
158
+ ("n_image", ctypes.c_size_t),
159
+ ("image_width", ctypes.c_size_t),
160
+ ("image_height", ctypes.c_size_t)
161
+ ]
162
+
163
+ class _RKLLMInputUnion(ctypes.Union):
164
+ prompt_input: ctypes.c_char_p
165
+ embed_input: RKLLMEmbedInput
166
+ token_input: RKLLMTokenInput
167
+ multimodal_input: RKLLMMultiModelInput
168
+
169
+ _fields_ = [
170
+ ("prompt_input", ctypes.c_char_p),
171
+ ("embed_input", RKLLMEmbedInput),
172
+ ("token_input", RKLLMTokenInput),
173
+ ("multimodal_input", RKLLMMultiModelInput)
174
+ ]
175
+
176
+ class RKLLMInput(ctypes.Structure):
177
+ input_type: ctypes.c_int
178
+ _union_data: _RKLLMInputUnion
179
+
180
+ _fields_ = [
181
+ ("input_type", ctypes.c_int), # Enum will be passed as int, changed RKLLMInputType to ctypes.c_int
182
+ ("_union_data", _RKLLMInputUnion)
183
+ ]
184
+ # Properties to make accessing union members easier
185
+ @property
186
+ def prompt_input(self) -> bytes: # Assuming c_char_p maps to bytes
187
+ if self.input_type == RKLLMInputType.RKLLM_INPUT_PROMPT:
188
+ return self._union_data.prompt_input
189
+ raise AttributeError("Not a prompt input")
190
+ @prompt_input.setter
191
+ def prompt_input(self, value: bytes): # Assuming c_char_p maps to bytes
192
+ if self.input_type == RKLLMInputType.RKLLM_INPUT_PROMPT:
193
+ self._union_data.prompt_input = value
194
+ else:
195
+ raise AttributeError("Not a prompt input")
196
+ @property
197
+ def embed_input(self) -> RKLLMEmbedInput:
198
+ if self.input_type == RKLLMInputType.RKLLM_INPUT_EMBED:
199
+ return self._union_data.embed_input
200
+ raise AttributeError("Not an embed input")
201
+ @embed_input.setter
202
+ def embed_input(self, value: RKLLMEmbedInput):
203
+ if self.input_type == RKLLMInputType.RKLLM_INPUT_EMBED:
204
+ self._union_data.embed_input = value
205
+ else:
206
+ raise AttributeError("Not an embed input")
207
+
208
+ @property
209
+ def token_input(self) -> RKLLMTokenInput:
210
+ if self.input_type == RKLLMInputType.RKLLM_INPUT_TOKEN:
211
+ return self._union_data.token_input
212
+ raise AttributeError("Not a token input")
213
+ @token_input.setter
214
+ def token_input(self, value: RKLLMTokenInput):
215
+ if self.input_type == RKLLMInputType.RKLLM_INPUT_TOKEN:
216
+ self._union_data.token_input = value
217
+ else:
218
+ raise AttributeError("Not a token input")
219
+
220
+ @property
221
+ def multimodal_input(self) -> RKLLMMultiModelInput:
222
+ if self.input_type == RKLLMInputType.RKLLM_INPUT_MULTIMODAL:
223
+ return self._union_data.multimodal_input
224
+ raise AttributeError("Not a multimodal input")
225
+ @multimodal_input.setter
226
+ def multimodal_input(self, value: RKLLMMultiModelInput):
227
+ if self.input_type == RKLLMInputType.RKLLM_INPUT_MULTIMODAL:
228
+ self._union_data.multimodal_input = value
229
+ else:
230
+ raise AttributeError("Not a multimodal input")
231
+
232
+ class RKLLMLoraParam(ctypes.Structure): # For inference
233
+ lora_adapter_name: ctypes.c_char_p
234
+
235
+ _fields_ = [
236
+ ("lora_adapter_name", ctypes.c_char_p)
237
+ ]
238
+
239
+ class RKLLMPromptCacheParam(ctypes.Structure): # For inference
240
+ save_prompt_cache: ctypes.c_int # bool-like
241
+ prompt_cache_path: ctypes.c_char_p
242
+
243
+ _fields_ = [
244
+ ("save_prompt_cache", ctypes.c_int), # bool-like
245
+ ("prompt_cache_path", ctypes.c_char_p)
246
+ ]
247
+
248
+ class RKLLMInferParam(ctypes.Structure):
249
+ mode: ctypes.c_int
250
+ lora_params: ctypes.POINTER(RKLLMLoraParam)
251
+ prompt_cache_params: ctypes.POINTER(RKLLMPromptCacheParam)
252
+ keep_history: ctypes.c_int # bool-like
253
+
254
+ _fields_ = [
255
+ ("mode", ctypes.c_int), # Enum will be passed as int, changed RKLLMInferMode to ctypes.c_int
256
+ ("lora_params", ctypes.POINTER(RKLLMLoraParam)),
257
+ ("prompt_cache_params", ctypes.POINTER(RKLLMPromptCacheParam)),
258
+ ("keep_history", ctypes.c_int) # bool-like
259
+ ]
260
+
261
+ class RKLLMResultLastHiddenLayer(ctypes.Structure):
262
+ # Shape: [num_tokens, embd_size]
263
+ hidden_states: ctypes.POINTER(ctypes.c_float)
264
+ # 隐藏层大小
265
+ embd_size: ctypes.c_int
266
+ # 输出token数
267
+ num_tokens: ctypes.c_int
268
+
269
+ _fields_ = [
270
+ ("hidden_states", ctypes.POINTER(ctypes.c_float)),
271
+ ("embd_size", ctypes.c_int),
272
+ ("num_tokens", ctypes.c_int)
273
+ ]
274
+
275
+ class RKLLMResultLogits(ctypes.Structure):
276
+ # Shape: [num_tokens, vocab_size]
277
+ logits: ctypes.POINTER(ctypes.c_float)
278
+ # 词汇表大小
279
+ vocab_size: ctypes.c_int
280
+ # 输出token数
281
+ num_tokens: ctypes.c_int
282
+
283
+ _fields_ = [
284
+ ("logits", ctypes.POINTER(ctypes.c_float)),
285
+ ("vocab_size", ctypes.c_int),
286
+ ("num_tokens", ctypes.c_int)
287
+ ]
288
+
289
+ class RKLLMResult(ctypes.Structure):
290
+ text: ctypes.c_char_p
291
+ token_id: ctypes.c_int32
292
+ last_hidden_layer: RKLLMResultLastHiddenLayer
293
+ logits: RKLLMResultLogits
294
+
295
+ _fields_ = [
296
+ ("text", ctypes.c_char_p),
297
+ ("token_id", ctypes.c_int32),
298
+ ("last_hidden_layer", RKLLMResultLastHiddenLayer),
299
+ ("logits", RKLLMResultLogits)
300
+ ]
301
+
302
+ # --- Typedefs ---
303
+ LLMHandle = ctypes.c_void_p
304
+
305
+ # --- Callback Function Type ---
306
+ LLMResultCallback = ctypes.CFUNCTYPE(
307
+ None, # return type: void
308
+ ctypes.POINTER(RKLLMResult),
309
+ ctypes.c_void_p, # userdata
310
+ ctypes.c_int # enum, will be passed as int. Changed LLMCallState to ctypes.c_int
311
+ )
312
+
313
+
314
+ class RKLLMRuntime:
315
+ def __init__(self, library_path="./librkllmrt.so"):
316
+ try:
317
+ self.lib = ctypes.CDLL(library_path)
318
+ except OSError as e:
319
+ raise OSError(f"Failed to load RKLLM library from {library_path}. "
320
+ f"Ensure it's in your LD_LIBRARY_PATH or provide the full path. Error: {e}")
321
+ self._setup_functions()
322
+ self.llm_handle = LLMHandle()
323
+ self._c_callback = None # To keep the callback object alive
324
+
325
+ def _setup_functions(self):
326
+ # RKLLMParam rkllm_createDefaultParam();
327
+ self.lib.rkllm_createDefaultParam.restype = RKLLMParam
328
+ self.lib.rkllm_createDefaultParam.argtypes = []
329
+
330
+ # int rkllm_init(LLMHandle* handle, RKLLMParam* param, LLMResultCallback callback);
331
+ self.lib.rkllm_init.restype = ctypes.c_int
332
+ self.lib.rkllm_init.argtypes = [
333
+ ctypes.POINTER(LLMHandle),
334
+ ctypes.POINTER(RKLLMParam),
335
+ LLMResultCallback
336
+ ]
337
+
338
+ # int rkllm_load_lora(LLMHandle handle, RKLLMLoraAdapter* lora_adapter);
339
+ self.lib.rkllm_load_lora.restype = ctypes.c_int
340
+ self.lib.rkllm_load_lora.argtypes = [LLMHandle, ctypes.POINTER(RKLLMLoraAdapter)]
341
+
342
+ # int rkllm_load_prompt_cache(LLMHandle handle, const char* prompt_cache_path);
343
+ self.lib.rkllm_load_prompt_cache.restype = ctypes.c_int
344
+ self.lib.rkllm_load_prompt_cache.argtypes = [LLMHandle, ctypes.c_char_p]
345
+
346
+ # int rkllm_release_prompt_cache(LLMHandle handle);
347
+ self.lib.rkllm_release_prompt_cache.restype = ctypes.c_int
348
+ self.lib.rkllm_release_prompt_cache.argtypes = [LLMHandle]
349
+
350
+ # int rkllm_destroy(LLMHandle handle);
351
+ self.lib.rkllm_destroy.restype = ctypes.c_int
352
+ self.lib.rkllm_destroy.argtypes = [LLMHandle]
353
+
354
+ # int rkllm_run(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata);
355
+ self.lib.rkllm_run.restype = ctypes.c_int
356
+ self.lib.rkllm_run.argtypes = [
357
+ LLMHandle,
358
+ ctypes.POINTER(RKLLMInput),
359
+ ctypes.POINTER(RKLLMInferParam),
360
+ ctypes.c_void_p # userdata
361
+ ]
362
+
363
+ # int rkllm_run_async(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata);
364
+ # Assuming async also takes userdata for the callback context
365
+ self.lib.rkllm_run_async.restype = ctypes.c_int
366
+ self.lib.rkllm_run_async.argtypes = [
367
+ LLMHandle,
368
+ ctypes.POINTER(RKLLMInput),
369
+ ctypes.POINTER(RKLLMInferParam),
370
+ ctypes.c_void_p # userdata
371
+ ]
372
+
373
+ # int rkllm_abort(LLMHandle handle);
374
+ self.lib.rkllm_abort.restype = ctypes.c_int
375
+ self.lib.rkllm_abort.argtypes = [LLMHandle]
376
+
377
+ # int rkllm_is_running(LLMHandle handle);
378
+ self.lib.rkllm_is_running.restype = ctypes.c_int # 0 if running, non-zero otherwise
379
+ self.lib.rkllm_is_running.argtypes = [LLMHandle]
380
+
381
+ # int rkllm_clear_kv_cache(LLMHandle handle, int keep_system_prompt);
382
+ self.lib.rkllm_clear_kv_cache.restype = ctypes.c_int
383
+ self.lib.rkllm_clear_kv_cache.argtypes = [LLMHandle, ctypes.c_int]
384
+
385
+ # int rkllm_set_chat_template(LLMHandle handle, const char* system_prompt, const char* prompt_prefix, const char* prompt_postfix);
386
+ self.lib.rkllm_set_chat_template.restype = ctypes.c_int
387
+ self.lib.rkllm_set_chat_template.argtypes = [
388
+ LLMHandle,
389
+ ctypes.c_char_p,
390
+ ctypes.c_char_p,
391
+ ctypes.c_char_p
392
+ ]
393
+
394
+ def create_default_param(self) -> RKLLMParam:
395
+ """Creates a default RKLLMParam structure."""
396
+ return self.lib.rkllm_createDefaultParam()
397
+
398
+ def init(self, param: RKLLMParam, callback_func) -> int:
399
+ """
400
+ Initializes the LLM.
401
+ :param param: RKLLMParam structure.
402
+ :param callback_func: A Python function that matches the signature:
403
+ def my_callback(result_ptr, userdata_ptr, state_enum):
404
+ result = result_ptr.contents # RKLLMResult
405
+ # Process result
406
+ # userdata can be retrieved if passed during run, or ignored
407
+ # state = LLMCallState(state_enum)
408
+ :return: 0 for success, non-zero for failure.
409
+ """
410
+ if not callable(callback_func):
411
+ raise ValueError("callback_func must be a callable Python function.")
412
+
413
+ # Keep a reference to the ctypes callback object to prevent it from being garbage collected
414
+ self._c_callback = LLMResultCallback(callback_func)
415
+
416
+ ret = self.lib.rkllm_init(ctypes.byref(self.llm_handle), ctypes.byref(param), self._c_callback)
417
+ if ret != 0:
418
+ raise RuntimeError(f"rkllm_init failed with error code {ret}")
419
+ return ret
420
+
421
+ def load_lora(self, lora_adapter: RKLLMLoraAdapter) -> int:
422
+ """Loads a Lora adapter."""
423
+ ret = self.lib.rkllm_load_lora(self.llm_handle, ctypes.byref(lora_adapter))
424
+ if ret != 0:
425
+ raise RuntimeError(f"rkllm_load_lora failed with error code {ret}")
426
+ return ret
427
+
428
+ def load_prompt_cache(self, prompt_cache_path: str) -> int:
429
+ """Loads a prompt cache from a file."""
430
+ c_path = prompt_cache_path.encode('utf-8')
431
+ ret = self.lib.rkllm_load_prompt_cache(self.llm_handle, c_path)
432
+ if ret != 0:
433
+ raise RuntimeError(f"rkllm_load_prompt_cache failed for {prompt_cache_path} with error code {ret}")
434
+ return ret
435
+
436
+ def release_prompt_cache(self) -> int:
437
+ """Releases the prompt cache from memory."""
438
+ ret = self.lib.rkllm_release_prompt_cache(self.llm_handle)
439
+ if ret != 0:
440
+ raise RuntimeError(f"rkllm_release_prompt_cache failed with error code {ret}")
441
+ return ret
442
+
443
+ def destroy(self) -> int:
444
+ """Destroys the LLM instance and releases resources."""
445
+ if self.llm_handle and self.llm_handle.value: # Check if handle is not NULL
446
+ ret = self.lib.rkllm_destroy(self.llm_handle)
447
+ self.llm_handle = LLMHandle() # Reset handle
448
+ if ret != 0:
449
+ # Don't raise here as it might be called in __del__
450
+ print(f"Warning: rkllm_destroy failed with error code {ret}")
451
+ return ret
452
+ return 0 # Already destroyed or not initialized
453
+
454
+ def run(self, rkllm_input: RKLLMInput, rkllm_infer_params: RKLLMInferParam, userdata=None) -> int:
455
+ """Runs an LLM inference task synchronously."""
456
+ # userdata can be a ctypes.py_object if you want to pass Python objects,
457
+ # then cast to c_void_p. Or simply None.
458
+ if userdata is not None:
459
+ # Store the userdata object to keep it alive during the call
460
+ self._userdata_ref = userdata
461
+ c_userdata = ctypes.cast(ctypes.pointer(ctypes.py_object(userdata)), ctypes.c_void_p)
462
+ else:
463
+ c_userdata = None
464
+ ret = self.lib.rkllm_run(self.llm_handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), c_userdata)
465
+ if ret != 0:
466
+ raise RuntimeError(f"rkllm_run failed with error code {ret}")
467
+ return ret
468
+
469
+ def run_async(self, rkllm_input: RKLLMInput, rkllm_infer_params: RKLLMInferParam, userdata=None) -> int:
470
+ """Runs an LLM inference task asynchronously."""
471
+ if userdata is not None:
472
+ # Store the userdata object to keep it alive during the call
473
+ self._userdata_ref = userdata
474
+ c_userdata = ctypes.cast(ctypes.pointer(ctypes.py_object(userdata)), ctypes.c_void_p)
475
+ else:
476
+ c_userdata = None
477
+ ret = self.lib.rkllm_run_async(self.llm_handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), c_userdata)
478
+ if ret != 0:
479
+ raise RuntimeError(f"rkllm_run_async failed with error code {ret}")
480
+ return ret
481
+
482
+ def abort(self) -> int:
483
+ """Aborts an ongoing LLM task."""
484
+ ret = self.lib.rkllm_abort(self.llm_handle)
485
+ if ret != 0:
486
+ raise RuntimeError(f"rkllm_abort failed with error code {ret}")
487
+ return ret
488
+
489
+ def is_running(self) -> bool:
490
+ """Checks if an LLM task is currently running. Returns True if running."""
491
+ # The C API returns 0 if running, non-zero otherwise.
492
+ # This is a bit counter-intuitive for a boolean "is_running".
493
+ return self.lib.rkllm_is_running(self.llm_handle) == 0
494
+
495
+ def clear_kv_cache(self, keep_system_prompt: bool) -> int:
496
+ """Clears the key-value cache."""
497
+ ret = self.lib.rkllm_clear_kv_cache(self.llm_handle, ctypes.c_int(1 if keep_system_prompt else 0))
498
+ if ret != 0:
499
+ raise RuntimeError(f"rkllm_clear_kv_cache failed with error code {ret}")
500
+ return ret
501
+
502
+ def set_chat_template(self, system_prompt: str, prompt_prefix: str, prompt_postfix: str) -> int:
503
+ """Sets the chat template for the LLM."""
504
+ c_system = system_prompt.encode('utf-8') if system_prompt else b""
505
+ c_prefix = prompt_prefix.encode('utf-8') if prompt_prefix else b""
506
+ c_postfix = prompt_postfix.encode('utf-8') if prompt_postfix else b""
507
+
508
+ ret = self.lib.rkllm_set_chat_template(self.llm_handle, c_system, c_prefix, c_postfix)
509
+ if ret != 0:
510
+ raise RuntimeError(f"rkllm_set_chat_template failed with error code {ret}")
511
+ return ret
512
+
513
+ def __enter__(self):
514
+ return self
515
+
516
+ def __exit__(self, exc_type, exc_val, exc_tb):
517
+ self.destroy()
518
+
519
+ def __del__(self):
520
+ self.destroy() # Ensure resources are freed if object is garbage collected
521
+
522
+ # --- Example Usage (Illustrative) ---
523
+ if __name__ == "__main__":
524
+ # This is a placeholder for how you might use it.
525
+ # You'll need a valid .rkllm model and librkllmrt.so in your path.
526
+
527
+ # Global list to store results from callback for demonstration
528
+ results_buffer = []
529
+
530
+ def my_python_callback(result_ptr, userdata_ptr, state_enum):
531
+ """
532
+ Callback function to be called by the C library.
533
+ """
534
+ global results_buffer
535
+ state = LLMCallState(state_enum)
536
+ result = result_ptr.contents
537
+
538
+ current_text = ""
539
+ if result.text: # Check if the char_p is not NULL
540
+ current_text = result.text.decode('utf-8', errors='ignore')
541
+
542
+ print(f"Callback: State={state.name}, TokenID={result.token_id}, Text='{current_text}'")
543
+ results_buffer.append(current_text)
544
+
545
+ if state == LLMCallState.RKLLM_RUN_FINISH:
546
+ print("Inference finished.")
547
+ elif state == LLMCallState.RKLLM_RUN_ERROR:
548
+ print("Inference error.")
549
+
550
+ # Example: Accessing logits if available (and if mode was set to get logits)
551
+ # if result.logits.logits and result.logits.vocab_size > 0:
552
+ # print(f" Logits (first 5 of vocab_size {result.logits.vocab_size}):")
553
+ # for i in range(min(5, result.logits.vocab_size)):
554
+ # print(f" {result.logits.logits[i]:.4f}", end=" ")
555
+ # print()
556
+
557
+
558
+ # --- Attempt to use the wrapper ---
559
+ try:
560
+ print("Initializing RKLLMRuntime...")
561
+ # Adjust library_path if librkllmrt.so is not in default search paths
562
+ # e.g., library_path="./path/to/librkllmrt.so"
563
+ rk_llm = RKLLMRuntime()
564
+
565
+ print("Creating default parameters...")
566
+ params = rk_llm.create_default_param()
567
+
568
+ # --- Configure parameters ---
569
+ # THIS IS CRITICAL: model_path must point to an actual .rkllm file
570
+ # For this example to run, you need a model file.
571
+ # Let's assume a dummy path for now, this will fail at init if not valid.
572
+ model_file = "dummy_model.rkllm"
573
+ if not os.path.exists(model_file):
574
+ print(f"Warning: Model file '{model_file}' does not exist. Init will likely fail.")
575
+ # Create a dummy file for the example to proceed further, though init will still fail
576
+ # with a real library unless it's a valid model.
577
+ with open(model_file, "w") as f:
578
+ f.write("dummy content")
579
+
580
+ params.model_path = model_file.encode('utf-8')
581
+ params.max_context_len = 512
582
+ params.max_new_tokens = 128
583
+ params.top_k = 1 # Greedy
584
+ params.temperature = 0.7
585
+ params.repeat_penalty = 1.1
586
+ # ... set other params as needed
587
+
588
+ print(f"Initializing LLM with model: {params.model_path.decode()}...")
589
+ # This will likely fail if dummy_model.rkllm is not a valid model recognized by the library
590
+ try:
591
+ rk_llm.init(params, my_python_callback)
592
+ print("LLM Initialized.")
593
+ except RuntimeError as e:
594
+ print(f"Error during LLM initialization: {e}")
595
+ print("This is expected if 'dummy_model.rkllm' is not a valid model.")
596
+ print("Replace 'dummy_model.rkllm' with a real model path to test further.")
597
+ exit()
598
+
599
+
600
+ # --- Prepare input ---
601
+ print("Preparing input...")
602
+ rk_input = RKLLMInput()
603
+ rk_input.input_type = RKLLMInputType.RKLLM_INPUT_PROMPT
604
+
605
+ prompt_text = "Translate the following English text to French: 'Hello, world!'"
606
+ c_prompt = prompt_text.encode('utf-8')
607
+ rk_input._union_data.prompt_input = c_prompt # Accessing union member directly
608
+
609
+ # --- Prepare inference parameters ---
610
+ print("Preparing inference parameters...")
611
+ infer_params = RKLLMInferParam()
612
+ infer_params.mode = RKLLMInferMode.RKLLM_INFER_GENERATE
613
+ infer_params.keep_history = 1 # True
614
+ # infer_params.lora_params = None # or set up RKLLMLoraParam if using LoRA
615
+ # infer_params.prompt_cache_params = None # or set up RKLLMPromptCacheParam
616
+
617
+ # --- Run inference ---
618
+ print(f"Running inference with prompt: '{prompt_text}'")
619
+ results_buffer.clear()
620
+ try:
621
+ rk_llm.run(rk_input, infer_params) # Userdata is None by default
622
+ print("\n--- Full Response ---")
623
+ print("".join(results_buffer))
624
+ print("---------------------\n")
625
+ except RuntimeError as e:
626
+ print(f"Error during LLM run: {e}")
627
+
628
+
629
+ # --- Example: Set chat template (if model supports it) ---
630
+ # print("Setting chat template...")
631
+ # try:
632
+ # rk_llm.set_chat_template("You are a helpful assistant.", "<user>: ", "<assistant>: ")
633
+ # print("Chat template set.")
634
+ # except RuntimeError as e:
635
+ # print(f"Error setting chat template: {e}")
636
+
637
+ # --- Example: Clear KV Cache ---
638
+ # print("Clearing KV cache (keeping system prompt if any)...")
639
+ # try:
640
+ # rk_llm.clear_kv_cache(keep_system_prompt=True)
641
+ # print("KV cache cleared.")
642
+ # except RuntimeError as e:
643
+ # print(f"Error clearing KV cache: {e}")
644
+
645
+ except OSError as e:
646
+ print(f"OSError: {e}. Could not load the RKLLM library.")
647
+ print("Please ensure 'librkllmrt.so' is in your LD_LIBRARY_PATH or provide the full path.")
648
+ except Exception as e:
649
+ print(f"An unexpected error occurred: {e}")
650
+ finally:
651
+ if 'rk_llm' in locals() and rk_llm.llm_handle and rk_llm.llm_handle.value:
652
+ print("Destroying LLM instance...")
653
+ rk_llm.destroy()
654
+ print("LLM instance destroyed.")
655
+ if os.path.exists(model_file) and model_file == "dummy_model.rkllm":
656
+ os.remove(model_file) # Clean up dummy file
657
+
658
+ print("Example finished.")
test_embedding.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Qwen3-Embedding-0.6B 推理测试代码
5
+ 使用 RKLLM API 进行文本嵌入推理
6
+ """
7
+ import faulthandler
8
+ faulthandler.enable()
9
+ import os
10
+ os.environ["RKLLM_LOG_LEVEL"] = "1"
11
+ import numpy as np
12
+ import time
13
+ from typing import List, Dict, Any
14
+ from rkllm_binding import *
15
+
16
+
17
+ class Qwen3EmbeddingTester:
18
+ def __init__(self, model_path: str, library_path: str = "./librkllmrt.so"):
19
+ """
20
+ 初始化 Qwen3 嵌入模型测试器
21
+
22
+ Args:
23
+ model_path: 模型文件路径(.rkllm 格式)
24
+ library_path: RKLLM 库文件路径
25
+ """
26
+ self.model_path = model_path
27
+ self.library_path = library_path
28
+ self.runtime = None
29
+ self.embeddings_buffer = []
30
+ self.current_result = None
31
+
32
+ def callback_function(self, result_ptr, userdata_ptr, state_enum):
33
+ """
34
+ 推理回调函数
35
+
36
+ Args:
37
+ result_ptr: 结果指针
38
+ userdata_ptr: 用户数据指针
39
+ state_enum: 状态枚举
40
+ """
41
+ state = LLMCallState(state_enum)
42
+
43
+ if state == LLMCallState.RKLLM_RUN_NORMAL:
44
+ result = result_ptr.contents
45
+ print(f"result: {result}")
46
+ # 获取最后隐藏层输出作为嵌入
47
+ if result.last_hidden_layer.hidden_states and result.last_hidden_layer.embd_size > 0:
48
+ embd_size = result.last_hidden_layer.embd_size
49
+ num_tokens = result.last_hidden_layer.num_tokens
50
+
51
+ print(f"获取到嵌入向量:维度={embd_size}, 令牌数={num_tokens}")
52
+
53
+ # 将 C 数组转换为 numpy 数组
54
+ # 这里我们取最后一个 token 的隐藏状态作为句子嵌入
55
+ if num_tokens > 0:
56
+ # 获取最后一个 token 的嵌入(shape: [embd_size])
57
+ last_token_embedding = np.array([
58
+ result.last_hidden_layer.hidden_states[(num_tokens-1) * embd_size + i]
59
+ for i in range(embd_size)
60
+ ])
61
+
62
+ self.current_result = {
63
+ 'embedding': last_token_embedding,
64
+ 'embd_size': embd_size,
65
+ 'num_tokens': num_tokens
66
+ }
67
+
68
+ print(f"嵌入向量范数: {np.linalg.norm(last_token_embedding):.4f}")
69
+ print(f"嵌入向量前10维: {last_token_embedding[:10]}")
70
+
71
+ elif state == LLMCallState.RKLLM_RUN_ERROR:
72
+ print("推理过程发生错误")
73
+
74
+ def init_model(self):
75
+ """初始化模型"""
76
+ try:
77
+ print(f"初始化 RKLLM 运行时,库路径: {self.library_path}")
78
+ self.runtime = RKLLMRuntime(self.library_path)
79
+
80
+ print("创建默认参数...")
81
+ params = self.runtime.create_default_param()
82
+
83
+ # 配置参数
84
+ params.model_path = self.model_path.encode('utf-8')
85
+ params.max_context_len = 1024 # 设置上下文长度
86
+ params.max_new_tokens = 1 # 嵌入任务不需要生成新token
87
+ params.temperature = 1.0 # 嵌入任务温度设置
88
+ params.top_k = 1 # 嵌入任务不需要采样
89
+ params.top_p = 1.0 # 嵌入任务不需要采样
90
+
91
+ # 扩展参数配置
92
+ params.extend_param.base_domain_id = 1 # 建议为 >1B 模型设置为1
93
+ params.extend_param.embed_flash = 0 # 是否使用flash存储Embedding
94
+ params.extend_param.enabled_cpus_num = 4 # 启用的CPU核心数
95
+ params.extend_param.enabled_cpus_mask = 0x0F # CPU核心掩码
96
+
97
+ print(f"初始化模型: {self.model_path}")
98
+ self.runtime.init(params, self.callback_function)
99
+ self.runtime.set_chat_template("","","")
100
+ print("模型初始化成功!")
101
+
102
+ except Exception as e:
103
+ print(f"模型初始化失败: {e}")
104
+ raise
105
+
106
+ def get_detailed_instruct(self, task_description: str, query: str) -> str:
107
+ """
108
+ 构建指令提示词(参考 README 中的用法)
109
+
110
+ Args:
111
+ task_description: 任务描述
112
+ query: 查询文本
113
+
114
+ Returns:
115
+ 格式化的指令提示词
116
+ """
117
+ return f'Instruct: {task_description}\nQuery: {query}'
118
+
119
+ def encode_text(self, text: str, task_description: str = None) -> np.ndarray:
120
+ """
121
+ 编码文本为嵌入向量
122
+
123
+ Args:
124
+ text: 要编码的文本
125
+ task_description: 任务描述,如果提供则使用指令提示
126
+
127
+ Returns:
128
+ 嵌入向量(numpy数组)
129
+ """
130
+ try:
131
+ # 如果提供了任务描述,则使用指令提示
132
+ if task_description:
133
+ input_text = self.get_detailed_instruct(task_description, text)
134
+ else:
135
+ input_text = text
136
+
137
+ print(f"编码文本: {input_text[:100]}{'...' if len(input_text) > 100 else ''}")
138
+
139
+ # 准备输入
140
+ rk_input = RKLLMInput()
141
+ rk_input.input_type = RKLLMInputType.RKLLM_INPUT_PROMPT
142
+ c_prompt = input_text.encode('utf-8')
143
+ rk_input._union_data.prompt_input = c_prompt
144
+
145
+ # 准备推理参数
146
+ infer_params = RKLLMInferParam()
147
+ infer_params.mode = RKLLMInferMode.RKLLM_INFER_GET_LAST_HIDDEN_LAYER # 获取隐藏层输出
148
+ infer_params.keep_history = 0 # 不保留历史
149
+
150
+ # 清空之前的结果
151
+ self.current_result = None
152
+ self.runtime.clear_kv_cache(False)
153
+
154
+ # 执行推理
155
+ start_time = time.time()
156
+ self.runtime.run(rk_input, infer_params)
157
+ end_time = time.time()
158
+
159
+ print(f"推理耗时: {end_time - start_time:.3f}秒")
160
+
161
+ if self.current_result and 'embedding' in self.current_result:
162
+ # 对嵌入向量进行L2标准化
163
+ embedding = self.current_result['embedding']
164
+ normalized_embedding = embedding / np.linalg.norm(embedding)
165
+ return normalized_embedding
166
+ else:
167
+ raise RuntimeError("未能获取到有效的嵌入向量")
168
+
169
+ except Exception as e:
170
+ print(f"编码文本时发生错误: {e}")
171
+ raise
172
+
173
+ def compute_similarity(self, emb1: np.ndarray, emb2: np.ndarray) -> float:
174
+ """
175
+ 计算两个嵌入向量的余弦相似度
176
+
177
+ Args:
178
+ emb1: 第一个嵌入向量
179
+ emb2: 第二个嵌入向量
180
+
181
+ Returns:
182
+ 余弦相似度值
183
+ """
184
+ return np.dot(emb1, emb2)
185
+
186
+ def test_embedding_similarity(self):
187
+ """测试嵌入相似度计算"""
188
+ print("\n" + "="*50)
189
+ print("测试嵌入相似度计算")
190
+ print("="*50)
191
+
192
+ # 测试文本
193
+ task_description = "Given a web search query, retrieve relevant passages that answer the query"
194
+
195
+ queries = [
196
+ "What is the capital of China?",
197
+ "Explain gravity"
198
+ ]
199
+
200
+ documents = [
201
+ "The capital of China is Beijing.",
202
+ "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun."
203
+ ]
204
+
205
+ # 编码查询(使用指令)
206
+ print("\n编码查询文本:")
207
+ query_embeddings = []
208
+ for i, query in enumerate(queries):
209
+ print(f"\n查询 {i+1}: {query}")
210
+ emb = self.encode_text(query, task_description)
211
+ query_embeddings.append(emb)
212
+
213
+ # 编码文档(不使用指令)
214
+ print("\n编码文档文本:")
215
+ doc_embeddings = []
216
+ for i, doc in enumerate(documents):
217
+ print(f"\n文档 {i+1}: {doc}")
218
+ emb = self.encode_text(doc)
219
+ doc_embeddings.append(emb)
220
+
221
+ # 计算相似度矩阵
222
+ print("\n计算相似度矩阵:")
223
+ print("查询 vs 文档相似度:")
224
+ print("-" * 30)
225
+
226
+ similarities = []
227
+ for i, q_emb in enumerate(query_embeddings):
228
+ row_similarities = []
229
+ for j, d_emb in enumerate(doc_embeddings):
230
+ sim = self.compute_similarity(q_emb, d_emb)
231
+ row_similarities.append(sim)
232
+ print(f"查询{i+1} vs 文档{j+1}: {sim:.4f}")
233
+ similarities.append(row_similarities)
234
+ print()
235
+
236
+ return similarities
237
+
238
+ def test_multilingual_embedding(self):
239
+ """测试多语言嵌入能力"""
240
+ print("\n" + "="*50)
241
+ print("测试多语言嵌入能力")
242
+ print("="*50)
243
+
244
+ # 多语言测试文本(相同含义的不同语言)
245
+ texts = {
246
+ "英语": "Hello, how are you?",
247
+ "中文": "你好,你好吗?",
248
+ "法语": "Bonjour, comment allez-vous?",
249
+ "西班牙语": "Hola, ¿cómo estás?",
250
+ "日语": "こんにちは、元気ですか?"
251
+ }
252
+
253
+ embeddings = {}
254
+ print("\n编码多语言文本:")
255
+ for lang, text in texts.items():
256
+ print(f"\n{lang}: {text}")
257
+ emb = self.encode_text(text)
258
+ embeddings[lang] = emb
259
+
260
+ # 计算跨语言相似度
261
+ print("\n跨语言相似度:")
262
+ print("-" * 30)
263
+
264
+ languages = list(texts.keys())
265
+ for i, lang1 in enumerate(languages):
266
+ for j, lang2 in enumerate(languages):
267
+ if i <= j:
268
+ sim = self.compute_similarity(embeddings[lang1], embeddings[lang2])
269
+ print(f"{lang1} vs {lang2}: {sim:.4f}")
270
+
271
+ def test_code_embedding(self):
272
+ """测试代码嵌入能力"""
273
+ print("\n" + "="*50)
274
+ print("测试代码嵌入能力")
275
+ print("="*50)
276
+
277
+ # 代码示例
278
+ codes = {
279
+ "Python函数": """
280
+ def fibonacci(n):
281
+ if n <= 1:
282
+ return n
283
+ return fibonacci(n-1) + fibonacci(n-2)
284
+ """,
285
+ "JavaScript函数": """
286
+ function fibonacci(n) {
287
+ if (n <= 1) return n;
288
+ return fibonacci(n-1) + fibonacci(n-2);
289
+ }
290
+ """,
291
+ "C++函数": """
292
+ int fibonacci(int n) {
293
+ if (n <= 1) return n;
294
+ return fibonacci(n-1) + fibonacci(n-2);
295
+ }
296
+ """,
297
+ "数组排序": """
298
+ def bubble_sort(arr):
299
+ n = len(arr)
300
+ for i in range(n):
301
+ for j in range(0, n-i-1):
302
+ if arr[j] > arr[j+1]:
303
+ arr[j], arr[j+1] = arr[j+1], arr[j]
304
+ """
305
+ }
306
+
307
+ embeddings = {}
308
+ print("\n编码代码文本:")
309
+ for name, code in codes.items():
310
+ print(f"\n{name}:")
311
+ print(code[:100] + "..." if len(code) > 100 else code)
312
+ emb = self.encode_text(code)
313
+ embeddings[name] = emb
314
+
315
+ # 计算代码相似度
316
+ print("\n代码相似度:")
317
+ print("-" * 30)
318
+
319
+ code_names = list(codes.keys())
320
+ for i, name1 in enumerate(code_names):
321
+ for j, name2 in enumerate(code_names):
322
+ if i <= j:
323
+ sim = self.compute_similarity(embeddings[name1], embeddings[name2])
324
+ print(f"{name1} vs {name2}: {sim:.4f}")
325
+
326
+ def cleanup(self):
327
+ """清理资源"""
328
+ if self.runtime:
329
+ try:
330
+ self.runtime.destroy()
331
+ print("模型资源已清理")
332
+ except Exception as e:
333
+ print(f"清理资源时发生错误: {e}")
334
+
335
+ def main():
336
+ """主函数"""
337
+ import argparse
338
+
339
+ # 解析命令行参数
340
+ parser = argparse.ArgumentParser(description='Qwen3-Embedding-0.6B 推理测试')
341
+ parser.add_argument('model_path', help='模型文件路径(.rkllm格式)')
342
+ parser.add_argument('--library_path', default="./librkllmrt.so", help='RKLLM库文件路径(默认为./librkllmrt.so)')
343
+ args = parser.parse_args()
344
+
345
+ # 检查文件是否存在
346
+ if not os.path.exists(args.model_path):
347
+ print(f"错误: 模型文件不存在: {args.model_path}")
348
+ print("请确保:")
349
+ print("1. 已下载 Qwen3-Embedding-0.6B 模型")
350
+ print("2. 已使用 rkllm-convert.py 将模型转换为 .rkllm 格式")
351
+ return
352
+
353
+ if not os.path.exists(args.library_path):
354
+ print(f"错误: RKLLM 库文件不存在: {args.library_path}")
355
+ print("请确保 librkllmrt.so 在当前目录或 LD_LIBRARY_PATH 中")
356
+ return
357
+
358
+ print("Qwen3-Embedding-0.6B 推理测试")
359
+ print("=" * 50)
360
+
361
+ # 创建测试器
362
+ tester = Qwen3EmbeddingTester(args.model_path, args.library_path)
363
+
364
+ try:
365
+ # 初始化模型
366
+ tester.init_model()
367
+
368
+ # 运行测试
369
+ print("\n开始运行嵌入测试...")
370
+
371
+ # 测试基础嵌入相似度
372
+ tester.test_embedding_similarity()
373
+
374
+ # 测试多语言嵌入
375
+ tester.test_multilingual_embedding()
376
+
377
+ # 测试代码嵌入
378
+ tester.test_code_embedding()
379
+
380
+ print("\n" + "="*50)
381
+ print("所有测试完成!")
382
+ print("="*50)
383
+
384
+ except KeyboardInterrupt:
385
+ print("\n测试被用户中断")
386
+ except Exception as e:
387
+ print(f"\n测试过程中发生错误: {e}")
388
+ import traceback
389
+ traceback.print_exc()
390
+ finally:
391
+ # 清理资源
392
+ tester.cleanup()
393
+
394
+
395
+ if __name__ == "__main__":
396
+ main()