happyme531 commited on
Commit
549ab86
·
verified ·
1 Parent(s): a49d7c1

Upload inference code

Browse files
.gitattributes CHANGED
@@ -35,3 +35,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  language_model_w8a8.rkllm filter=lfs diff=lfs merge=lfs -text
37
  vision_encoder.rknn filter=lfs diff=lfs merge=lfs -text
 
 
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  language_model_w8a8.rkllm filter=lfs diff=lfs merge=lfs -text
37
  vision_encoder.rknn filter=lfs diff=lfs merge=lfs -text
38
+ librkllmrt.so filter=lfs diff=lfs merge=lfs -text
export_vision_onnx.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ import torch
4
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer
5
+ import torch.nn.functional as F
6
+ import argparse
7
+
8
+ argparse = argparse.ArgumentParser()
9
+ argparse.add_argument('--step', type=int, help='export step', required=True)
10
+ argparse.add_argument('--path', type=str, default='.', help='model path', required=False)
11
+ argparse.add_argument('--batch', type=int, default=1, help='batch size', required=False)
12
+ argparse.add_argument('--height', type=int, default=448, help='image height', required=False)
13
+ argparse.add_argument('--width', type=int, default=448, help='image width', required=False)
14
+ argparse.add_argument('--savepath', type=str, default='vision_encoder.onnx', help='save path', required=False)
15
+ args = argparse.parse_args()
16
+
17
+ step = args.step
18
+ # 加载本地模型
19
+ path = args.path
20
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
21
+ path,
22
+ torch_dtype=torch.float32, # 注意此处的数据类型,由于 rknn 目前仅支持 float32 ,因此需要指定;若是在加载权重时限制了数据类型,需要自行修改config.json中的 "use_flash_attn" 参数为 false
23
+ low_cpu_mem_usage=True,
24
+ trust_remote_code=True).eval()
25
+ tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)
26
+
27
+ N = args.batch # batch size
28
+ channel = 3 # 3 for RGB
29
+ H = args.height # image height, must be divisible by (merge_size * patch_size)
30
+ W = args.width # image width, must be divisible by (merge_size * patch_size)
31
+ merge_size = 2
32
+ temporal_patch_size = 2
33
+ patch_size = 14
34
+ grid_t = N // temporal_patch_size if N%temporal_patch_size == 0 else N // temporal_patch_size + 1
35
+ grid_h = H // patch_size
36
+ grid_w = W // patch_size
37
+
38
+ def export_onnx(image):
39
+ if N == 1:
40
+ images = image.repeat(temporal_patch_size, 1, 1, 1)
41
+ elif N % temporal_patch_size != 0:
42
+ repeat_time = temporal_patch_size - N % temporal_patch_size
43
+ repeat_image = image[-1:, ...].repeat(repeat_time, 1, 1, 1)
44
+ images = torch.cat((image, repeat_image), dim=0)
45
+ patches = images.reshape(grid_t, temporal_patch_size, channel, grid_h//merge_size, merge_size, patch_size, grid_w//merge_size, merge_size, patch_size)
46
+ patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8)
47
+ flatten_patches = patches.reshape(grid_t * grid_h * grid_w, channel * temporal_patch_size * patch_size * patch_size)
48
+ model.visual.forward = forward_new(model.visual)
49
+ if step == 1:
50
+ feature = model.visual(flatten_patches, torch.tensor([grid_t, grid_h, grid_w]).unsqueeze(0))
51
+ else:
52
+ feature = model.visual(flatten_patches)
53
+ return feature
54
+
55
+ def forward_new(self):
56
+ def tmp (hidden_states, grid_thw=None):
57
+ hidden_states = self.patch_embed(hidden_states)
58
+ if grid_thw is not None:
59
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
60
+ cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
61
+ dim=0, dtype=torch.int32
62
+ )
63
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
64
+ np.save("./rotary_pos_emb.npy", rotary_pos_emb.cpu().detach().numpy())
65
+ np.save("./cu_seqlens.npy", cu_seqlens.cpu().detach().numpy())
66
+ else:
67
+ rotary_pos_emb = torch.from_numpy(np.load("./rotary_pos_emb.npy")).to(dtype=hidden_states.dtype, device=hidden_states.device)
68
+ cu_seqlens = torch.from_numpy(np.load("./cu_seqlens.npy")).to(dtype=torch.int32, device=hidden_states.device)
69
+
70
+ for blk in self.blocks:
71
+ hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
72
+
73
+ return self.merger(hidden_states)
74
+ return tmp
75
+
76
+ # 导出 Vison 部分所对应的 onnx 模型,假设输入是2x3x392x392->(28x28)x(3x2x14x14)
77
+ # pixel_values = torch.randn(784, 1176, device="cuda", dtype=torch.float32)
78
+ pixel_values = torch.randn(N, channel, H, W, device="cpu", dtype=torch.float32)
79
+ model.forward = export_onnx
80
+ model = model.to(torch.float32).eval()
81
+ if step == 1:
82
+ print("==========================================================")
83
+ print("Generating the rotary_pos_emb and cu_seqlens done.")
84
+ feature = model(pixel_values)
85
+ else:
86
+ print("==========================================================")
87
+ print(f"Exporting the vision part of {path} to onnx format.")
88
+ # os.makedirs(os.path.dirname(args.savepath), exist_ok=True)
89
+ torch.onnx.export(model, pixel_values, args.savepath, opset_version=17)
librkllmrt.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7e6f87f07bbb08058cad4871cc74e8069a054fe4f6259b43c29a4738b0affdd
3
+ size 7461896
rkllm-convert.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rkllm.api import RKLLM
2
+
3
+ modelpath = '.'
4
+ llm = RKLLM()
5
+
6
+ ret = llm.load_huggingface(model=modelpath, model_lora=None, device='cpu')
7
+ if ret != 0:
8
+ print('Load model failed!')
9
+ exit(ret)
10
+
11
+ qparams = None
12
+ ret = llm.build(do_quantization=True, optimization_level=1, quantized_dtype='w8a8',
13
+ quantized_algorithm='normal', target_platform='rk3588', num_npu_core=3, extra_qparams=qparams)
14
+
15
+ if ret != 0:
16
+ print('Build model failed!')
17
+ exit(ret)
18
+
19
+ # Export rkllm model
20
+ ret = llm.export_rkllm("./language_model.rkllm")
21
+ if ret != 0:
22
+ print('Export model failed!')
23
+ exit(ret)
rkllm_binding.py ADDED
@@ -0,0 +1,873 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import enum
3
+ import os
4
+
5
+ # Define constants from the header
6
+ CPU0 = (1 << 0) # 0x01
7
+ CPU1 = (1 << 1) # 0x02
8
+ CPU2 = (1 << 2) # 0x04
9
+ CPU3 = (1 << 3) # 0x08
10
+ CPU4 = (1 << 4) # 0x10
11
+ CPU5 = (1 << 5) # 0x20
12
+ CPU6 = (1 << 6) # 0x40
13
+ CPU7 = (1 << 7) # 0x80
14
+
15
+ # --- Enums ---
16
+ class LLMCallState(enum.IntEnum):
17
+ RKLLM_RUN_NORMAL = 0
18
+ RKLLM_RUN_WAITING = 1
19
+ RKLLM_RUN_FINISH = 2
20
+ RKLLM_RUN_ERROR = 3
21
+
22
+ class RKLLMInputType(enum.IntEnum):
23
+ RKLLM_INPUT_PROMPT = 0
24
+ RKLLM_INPUT_TOKEN = 1
25
+ RKLLM_INPUT_EMBED = 2
26
+ RKLLM_INPUT_MULTIMODAL = 3
27
+
28
+ class RKLLMInferMode(enum.IntEnum):
29
+ RKLLM_INFER_GENERATE = 0
30
+ RKLLM_INFER_GET_LAST_HIDDEN_LAYER = 1
31
+ RKLLM_INFER_GET_LOGITS = 2
32
+
33
+ # --- Structures ---
34
+ class RKLLMExtendParam(ctypes.Structure):
35
+ base_domain_id: ctypes.c_int32
36
+ embed_flash: ctypes.c_int8
37
+ enabled_cpus_num: ctypes.c_int8
38
+ enabled_cpus_mask: ctypes.c_uint32
39
+ n_batch: ctypes.c_uint8
40
+ use_cross_attn: ctypes.c_int8
41
+ reserved: ctypes.c_uint8 * 104
42
+
43
+ _fields_ = [
44
+ ("base_domain_id", ctypes.c_int32), # 基础域ID
45
+ ("embed_flash", ctypes.c_int8), # 是否从闪存查询词嵌入向量(1启用,0禁用)
46
+ ("enabled_cpus_num", ctypes.c_int8), # 推理启用的CPU数量
47
+ ("enabled_cpus_mask", ctypes.c_uint32), # 指示启用哪些CPU的位掩码
48
+ ("n_batch", ctypes.c_uint8), # 一次前向传播中并发处理的输入样本数,设置>1启用批量推理,默认为1
49
+ ("use_cross_attn", ctypes.c_int8), # 是否启用交叉注意力(非零启用,0禁用)
50
+ ("reserved", ctypes.c_uint8 * 104) # 保留字段
51
+ ]
52
+
53
+ class RKLLMParam(ctypes.Structure):
54
+ model_path: ctypes.c_char_p
55
+ max_context_len: ctypes.c_int32
56
+ max_new_tokens: ctypes.c_int32
57
+ top_k: ctypes.c_int32
58
+ n_keep: ctypes.c_int32
59
+ top_p: ctypes.c_float
60
+ temperature: ctypes.c_float
61
+ repeat_penalty: ctypes.c_float
62
+ frequency_penalty: ctypes.c_float
63
+ presence_penalty: ctypes.c_float
64
+ mirostat: ctypes.c_int32
65
+ mirostat_tau: ctypes.c_float
66
+ mirostat_eta: ctypes.c_float
67
+ skip_special_token: ctypes.c_bool
68
+ is_async: ctypes.c_bool
69
+ img_start: ctypes.c_char_p
70
+ img_end: ctypes.c_char_p
71
+ img_content: ctypes.c_char_p
72
+ extend_param: RKLLMExtendParam
73
+
74
+ _fields_ = [
75
+ ("model_path", ctypes.c_char_p), # 模型文件路径
76
+ ("max_context_len", ctypes.c_int32), # 上下文窗口最大token数
77
+ ("max_new_tokens", ctypes.c_int32), # 最大生成新token数
78
+ ("top_k", ctypes.c_int32), # Top-K采样参数
79
+ ("n_keep", ctypes.c_int32), # 上下文窗口移动时保留的kv缓存数量
80
+ ("top_p", ctypes.c_float), # Top-P(nucleus)采样参数
81
+ ("temperature", ctypes.c_float), # 采样温度,影响token选择的随机性
82
+ ("repeat_penalty", ctypes.c_float), # 重复token惩罚
83
+ ("frequency_penalty", ctypes.c_float), # 频繁token惩罚
84
+ ("presence_penalty", ctypes.c_float), # 输入中已存在token的惩罚
85
+ ("mirostat", ctypes.c_int32), # Mirostat采样策略标志(0表示禁用)
86
+ ("mirostat_tau", ctypes.c_float), # Mirostat采样Tau参数
87
+ ("mirostat_eta", ctypes.c_float), # Mirostat采样Eta参数
88
+ ("skip_special_token", ctypes.c_bool), # 是否跳过特殊token
89
+ ("is_async", ctypes.c_bool), # 是否异步推理
90
+ ("img_start", ctypes.c_char_p), # 多模态输入中图像的起始位置
91
+ ("img_end", ctypes.c_char_p), # 多模态输入中图像的结束位置
92
+ ("img_content", ctypes.c_char_p), # 图像内容指针
93
+ ("extend_param", RKLLMExtendParam) # 扩展参数
94
+ ]
95
+
96
+ class RKLLMLoraAdapter(ctypes.Structure):
97
+ lora_adapter_path: ctypes.c_char_p
98
+ lora_adapter_name: ctypes.c_char_p
99
+ scale: ctypes.c_float
100
+
101
+ _fields_ = [
102
+ ("lora_adapter_path", ctypes.c_char_p),
103
+ ("lora_adapter_name", ctypes.c_char_p),
104
+ ("scale", ctypes.c_float)
105
+ ]
106
+
107
+ class RKLLMEmbedInput(ctypes.Structure):
108
+ embed: ctypes.POINTER(ctypes.c_float)
109
+ n_tokens: ctypes.c_size_t
110
+
111
+ _fields_ = [
112
+ ("embed", ctypes.POINTER(ctypes.c_float)),
113
+ ("n_tokens", ctypes.c_size_t)
114
+ ]
115
+
116
+ class RKLLMTokenInput(ctypes.Structure):
117
+ input_ids: ctypes.POINTER(ctypes.c_int32)
118
+ n_tokens: ctypes.c_size_t
119
+
120
+ _fields_ = [
121
+ ("input_ids", ctypes.POINTER(ctypes.c_int32)),
122
+ ("n_tokens", ctypes.c_size_t)
123
+ ]
124
+
125
+ class RKLLMMultiModelInput(ctypes.Structure):
126
+ prompt: ctypes.c_char_p
127
+ image_embed: ctypes.POINTER(ctypes.c_float)
128
+ n_image_tokens: ctypes.c_size_t
129
+ n_image: ctypes.c_size_t
130
+ image_width: ctypes.c_size_t
131
+ image_height: ctypes.c_size_t
132
+
133
+ _fields_ = [
134
+ ("prompt", ctypes.c_char_p),
135
+ ("image_embed", ctypes.POINTER(ctypes.c_float)),
136
+ ("n_image_tokens", ctypes.c_size_t),
137
+ ("n_image", ctypes.c_size_t),
138
+ ("image_width", ctypes.c_size_t),
139
+ ("image_height", ctypes.c_size_t)
140
+ ]
141
+
142
+ class RKLLMCrossAttnParam(ctypes.Structure):
143
+ """
144
+ 交叉注意力参数结构体
145
+
146
+ 该结构体用于在解码器中执行交叉注意力时使用。
147
+ 它提供编码器输出(键/值缓存)、位置索引和注意力掩码。
148
+
149
+ - encoder_k_cache必须存储在连续内存中,布局为:
150
+ [num_layers][num_tokens][num_kv_heads][head_dim]
151
+ - encoder_v_cache必须存储在连续内存中,布局为:
152
+ [num_layers][num_kv_heads][head_dim][num_tokens]
153
+ """
154
+ encoder_k_cache: ctypes.POINTER(ctypes.c_float)
155
+ encoder_v_cache: ctypes.POINTER(ctypes.c_float)
156
+ encoder_mask: ctypes.POINTER(ctypes.c_float)
157
+ encoder_pos: ctypes.POINTER(ctypes.c_int32)
158
+ num_tokens: ctypes.c_int
159
+
160
+ _fields_ = [
161
+ ("encoder_k_cache", ctypes.POINTER(ctypes.c_float)), # 编码器键缓存指针(大小:num_layers * num_tokens * num_kv_heads * head_dim)
162
+ ("encoder_v_cache", ctypes.POINTER(ctypes.c_float)), # 编码器值缓存指针(大小:num_layers * num_kv_heads * head_dim * num_tokens)
163
+ ("encoder_mask", ctypes.POINTER(ctypes.c_float)), # 编码器注意力掩码指针(大小:num_tokens的数组)
164
+ ("encoder_pos", ctypes.POINTER(ctypes.c_int32)), # 编码器token位置指针(大小:num_tokens的数组)
165
+ ("num_tokens", ctypes.c_int) # 编码器序列中的token数量
166
+ ]
167
+
168
+ class RKLLMPerfStat(ctypes.Structure):
169
+ """
170
+ 性能统计结构体
171
+
172
+ 用于保存预填充和生成阶段的性能统计信息。
173
+ """
174
+ prefill_time_ms: ctypes.c_float
175
+ prefill_tokens: ctypes.c_int
176
+ generate_time_ms: ctypes.c_float
177
+ generate_tokens: ctypes.c_int
178
+ memory_usage_mb: ctypes.c_float
179
+
180
+ _fields_ = [
181
+ ("prefill_time_ms", ctypes.c_float), # 预填充阶段总耗时(毫秒)
182
+ ("prefill_tokens", ctypes.c_int), # 预填充阶段处理的token数量
183
+ ("generate_time_ms", ctypes.c_float), # 生成阶段总耗时(毫秒)
184
+ ("generate_tokens", ctypes.c_int), # 生成阶段处理的token数量
185
+ ("memory_usage_mb", ctypes.c_float) # 推理期间VmHWM常驻内存使用量(MB)
186
+ ]
187
+
188
+ class _RKLLMInputUnion(ctypes.Union):
189
+ prompt_input: ctypes.c_char_p
190
+ embed_input: RKLLMEmbedInput
191
+ token_input: RKLLMTokenInput
192
+ multimodal_input: RKLLMMultiModelInput
193
+
194
+ _fields_ = [
195
+ ("prompt_input", ctypes.c_char_p),
196
+ ("embed_input", RKLLMEmbedInput),
197
+ ("token_input", RKLLMTokenInput),
198
+ ("multimodal_input", RKLLMMultiModelInput)
199
+ ]
200
+
201
+ class RKLLMInput(ctypes.Structure):
202
+ """
203
+ LLM输入结构体
204
+
205
+ 通过联合体表示不同类型的LLM输入。
206
+ """
207
+ role: ctypes.c_char_p
208
+ enable_thinking: ctypes.c_bool
209
+ input_type: ctypes.c_int
210
+ _union_data: _RKLLMInputUnion
211
+
212
+ _fields_ = [
213
+ ("role", ctypes.c_char_p), # 消息角色:"user"(用户输入)、"tool"(函数结果)
214
+ ("enable_thinking", ctypes.c_bool), # 控制Qwen3模型是否启用"思考模式"
215
+ ("input_type", ctypes.c_int), # 枚举类型,指定输入类型(如prompt、token、embed、multimodal)
216
+ ("_union_data", _RKLLMInputUnion) # 联合体数据
217
+ ]
218
+ # Properties to make accessing union members easier
219
+ @property
220
+ def prompt_input(self) -> bytes: # Assuming c_char_p maps to bytes
221
+ if self.input_type == RKLLMInputType.RKLLM_INPUT_PROMPT:
222
+ return self._union_data.prompt_input
223
+ raise AttributeError("Not a prompt input")
224
+ @prompt_input.setter
225
+ def prompt_input(self, value: bytes): # Assuming c_char_p maps to bytes
226
+ if self.input_type == RKLLMInputType.RKLLM_INPUT_PROMPT:
227
+ self._union_data.prompt_input = value
228
+ else:
229
+ raise AttributeError("Not a prompt input")
230
+ @property
231
+ def embed_input(self) -> RKLLMEmbedInput:
232
+ if self.input_type == RKLLMInputType.RKLLM_INPUT_EMBED:
233
+ return self._union_data.embed_input
234
+ raise AttributeError("Not an embed input")
235
+ @embed_input.setter
236
+ def embed_input(self, value: RKLLMEmbedInput):
237
+ if self.input_type == RKLLMInputType.RKLLM_INPUT_EMBED:
238
+ self._union_data.embed_input = value
239
+ else:
240
+ raise AttributeError("Not an embed input")
241
+
242
+ @property
243
+ def token_input(self) -> RKLLMTokenInput:
244
+ if self.input_type == RKLLMInputType.RKLLM_INPUT_TOKEN:
245
+ return self._union_data.token_input
246
+ raise AttributeError("Not a token input")
247
+ @token_input.setter
248
+ def token_input(self, value: RKLLMTokenInput):
249
+ if self.input_type == RKLLMInputType.RKLLM_INPUT_TOKEN:
250
+ self._union_data.token_input = value
251
+ else:
252
+ raise AttributeError("Not a token input")
253
+
254
+ @property
255
+ def multimodal_input(self) -> RKLLMMultiModelInput:
256
+ if self.input_type == RKLLMInputType.RKLLM_INPUT_MULTIMODAL:
257
+ return self._union_data.multimodal_input
258
+ raise AttributeError("Not a multimodal input")
259
+ @multimodal_input.setter
260
+ def multimodal_input(self, value: RKLLMMultiModelInput):
261
+ if self.input_type == RKLLMInputType.RKLLM_INPUT_MULTIMODAL:
262
+ self._union_data.multimodal_input = value
263
+ else:
264
+ raise AttributeError("Not a multimodal input")
265
+
266
+ class RKLLMLoraParam(ctypes.Structure): # For inference
267
+ lora_adapter_name: ctypes.c_char_p
268
+
269
+ _fields_ = [
270
+ ("lora_adapter_name", ctypes.c_char_p)
271
+ ]
272
+
273
+ class RKLLMPromptCacheParam(ctypes.Structure): # For inference
274
+ save_prompt_cache: ctypes.c_int # bool-like
275
+ prompt_cache_path: ctypes.c_char_p
276
+
277
+ _fields_ = [
278
+ ("save_prompt_cache", ctypes.c_int), # bool-like
279
+ ("prompt_cache_path", ctypes.c_char_p)
280
+ ]
281
+
282
+ class RKLLMInferParam(ctypes.Structure):
283
+ mode: ctypes.c_int
284
+ lora_params: ctypes.POINTER(RKLLMLoraParam)
285
+ prompt_cache_params: ctypes.POINTER(RKLLMPromptCacheParam)
286
+ keep_history: ctypes.c_int # bool-like
287
+
288
+ _fields_ = [
289
+ ("mode", ctypes.c_int), # Enum will be passed as int, changed RKLLMInferMode to ctypes.c_int
290
+ ("lora_params", ctypes.POINTER(RKLLMLoraParam)),
291
+ ("prompt_cache_params", ctypes.POINTER(RKLLMPromptCacheParam)),
292
+ ("keep_history", ctypes.c_int) # bool-like
293
+ ]
294
+
295
+ class RKLLMResultLastHiddenLayer(ctypes.Structure):
296
+ hidden_states: ctypes.POINTER(ctypes.c_float)
297
+ embd_size: ctypes.c_int
298
+ num_tokens: ctypes.c_int
299
+
300
+ _fields_ = [
301
+ ("hidden_states", ctypes.POINTER(ctypes.c_float)),
302
+ ("embd_size", ctypes.c_int),
303
+ ("num_tokens", ctypes.c_int)
304
+ ]
305
+
306
+ class RKLLMResultLogits(ctypes.Structure):
307
+ logits: ctypes.POINTER(ctypes.c_float)
308
+ vocab_size: ctypes.c_int
309
+ num_tokens: ctypes.c_int
310
+
311
+ _fields_ = [
312
+ ("logits", ctypes.POINTER(ctypes.c_float)),
313
+ ("vocab_size", ctypes.c_int),
314
+ ("num_tokens", ctypes.c_int)
315
+ ]
316
+
317
+ class RKLLMResult(ctypes.Structure):
318
+ """
319
+ LLM推理结果结构体
320
+
321
+ 表示LLM推理的结果,包含生成的文本、token ID、隐藏层状态、logits和性能统计。
322
+ """
323
+ text: ctypes.c_char_p
324
+ token_id: ctypes.c_int32
325
+ last_hidden_layer: RKLLMResultLastHiddenLayer
326
+ logits: RKLLMResultLogits
327
+ perf: RKLLMPerfStat
328
+
329
+ _fields_ = [
330
+ ("text", ctypes.c_char_p), # 生成的文本结果
331
+ ("token_id", ctypes.c_int32), # 生成的token ID
332
+ ("last_hidden_layer", RKLLMResultLastHiddenLayer), # 最后一层的隐藏状态(如果请求的话)
333
+ ("logits", RKLLMResultLogits), # 模型输出的logits
334
+ ("perf", RKLLMPerfStat) # 性能统计(预填充和生成)
335
+ ]
336
+
337
+ # --- Typedefs ---
338
+ LLMHandle = ctypes.c_void_p
339
+
340
+ # --- Callback Function Type ---
341
+ LLMResultCallback = ctypes.CFUNCTYPE(
342
+ ctypes.c_int, # 返回类型:int,表示处理状态
343
+ ctypes.POINTER(RKLLMResult), # LLM结果指针
344
+ ctypes.c_void_p, # 用户数据指针
345
+ ctypes.c_int # LLM调用状态(LLMCallState枚举值)
346
+ )
347
+ """
348
+ 回调函数类型定义
349
+
350
+ 用于处理LLM结果的回调函数。
351
+
352
+ 参数:
353
+ - result: 指向LLM结果的指针
354
+ - userdata: 回调的用户数据指针
355
+ - state: LLM调用状态(例如:完成、错误)
356
+
357
+ 返回值:
358
+ - 0: 正常继续推理
359
+ - 1: 暂停推理。如果用户想要修改或干预结果(例如编辑输出、注入新提示),
360
+ 返回1以暂停当前推理。稍后,使用更新的内容调用rkllm_run来恢复推理。
361
+ """
362
+
363
+ class RKLLMRuntime:
364
+ def __init__(self, library_path="./librkllmrt.so"):
365
+ try:
366
+ self.lib = ctypes.CDLL(library_path)
367
+ except OSError as e:
368
+ raise OSError(f"Failed to load RKLLM library from {library_path}. "
369
+ f"Ensure it's in your LD_LIBRARY_PATH or provide the full path. Error: {e}")
370
+ self._setup_functions()
371
+ self.llm_handle = LLMHandle()
372
+ self._c_callback = None # To keep the callback object alive
373
+
374
+ def _setup_functions(self):
375
+ # RKLLMParam rkllm_createDefaultParam();
376
+ self.lib.rkllm_createDefaultParam.restype = RKLLMParam
377
+ self.lib.rkllm_createDefaultParam.argtypes = []
378
+
379
+ # int rkllm_init(LLMHandle* handle, RKLLMParam* param, LLMResultCallback callback);
380
+ self.lib.rkllm_init.restype = ctypes.c_int
381
+ self.lib.rkllm_init.argtypes = [
382
+ ctypes.POINTER(LLMHandle),
383
+ ctypes.POINTER(RKLLMParam),
384
+ LLMResultCallback
385
+ ]
386
+
387
+ # int rkllm_load_lora(LLMHandle handle, RKLLMLoraAdapter* lora_adapter);
388
+ self.lib.rkllm_load_lora.restype = ctypes.c_int
389
+ self.lib.rkllm_load_lora.argtypes = [LLMHandle, ctypes.POINTER(RKLLMLoraAdapter)]
390
+
391
+ # int rkllm_load_prompt_cache(LLMHandle handle, const char* prompt_cache_path);
392
+ self.lib.rkllm_load_prompt_cache.restype = ctypes.c_int
393
+ self.lib.rkllm_load_prompt_cache.argtypes = [LLMHandle, ctypes.c_char_p]
394
+
395
+ # int rkllm_release_prompt_cache(LLMHandle handle);
396
+ self.lib.rkllm_release_prompt_cache.restype = ctypes.c_int
397
+ self.lib.rkllm_release_prompt_cache.argtypes = [LLMHandle]
398
+
399
+ # int rkllm_destroy(LLMHandle handle);
400
+ self.lib.rkllm_destroy.restype = ctypes.c_int
401
+ self.lib.rkllm_destroy.argtypes = [LLMHandle]
402
+
403
+ # int rkllm_run(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata);
404
+ self.lib.rkllm_run.restype = ctypes.c_int
405
+ self.lib.rkllm_run.argtypes = [
406
+ LLMHandle,
407
+ ctypes.POINTER(RKLLMInput),
408
+ ctypes.POINTER(RKLLMInferParam),
409
+ ctypes.c_void_p # userdata
410
+ ]
411
+
412
+ # int rkllm_run_async(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata);
413
+ # Assuming async also takes userdata for the callback context
414
+ self.lib.rkllm_run_async.restype = ctypes.c_int
415
+ self.lib.rkllm_run_async.argtypes = [
416
+ LLMHandle,
417
+ ctypes.POINTER(RKLLMInput),
418
+ ctypes.POINTER(RKLLMInferParam),
419
+ ctypes.c_void_p # userdata
420
+ ]
421
+
422
+ # int rkllm_abort(LLMHandle handle);
423
+ self.lib.rkllm_abort.restype = ctypes.c_int
424
+ self.lib.rkllm_abort.argtypes = [LLMHandle]
425
+
426
+ # int rkllm_is_running(LLMHandle handle);
427
+ self.lib.rkllm_is_running.restype = ctypes.c_int # 0 if running, non-zero otherwise
428
+ self.lib.rkllm_is_running.argtypes = [LLMHandle]
429
+
430
+ # int rkllm_clear_kv_cache(LLMHandle handle, int keep_system_prompt, int* start_pos, int* end_pos);
431
+ self.lib.rkllm_clear_kv_cache.restype = ctypes.c_int
432
+ self.lib.rkllm_clear_kv_cache.argtypes = [
433
+ LLMHandle,
434
+ ctypes.c_int,
435
+ ctypes.POINTER(ctypes.c_int), # start_pos
436
+ ctypes.POINTER(ctypes.c_int) # end_pos
437
+ ]
438
+
439
+ # int rkllm_get_kv_cache_size(LLMHandle handle, int* cache_sizes);
440
+ self.lib.rkllm_get_kv_cache_size.restype = ctypes.c_int
441
+ self.lib.rkllm_get_kv_cache_size.argtypes = [LLMHandle, ctypes.POINTER(ctypes.c_int)]
442
+
443
+ # int rkllm_set_chat_template(LLMHandle handle, const char* system_prompt, const char* prompt_prefix, const char* prompt_postfix);
444
+ self.lib.rkllm_set_chat_template.restype = ctypes.c_int
445
+ self.lib.rkllm_set_chat_template.argtypes = [
446
+ LLMHandle,
447
+ ctypes.c_char_p,
448
+ ctypes.c_char_p,
449
+ ctypes.c_char_p
450
+ ]
451
+
452
+ # int rkllm_set_function_tools(LLMHandle handle, const char* system_prompt, const char* tools, const char* tool_response_str);
453
+ self.lib.rkllm_set_function_tools.restype = ctypes.c_int
454
+ self.lib.rkllm_set_function_tools.argtypes = [
455
+ LLMHandle,
456
+ ctypes.c_char_p, # system_prompt
457
+ ctypes.c_char_p, # tools
458
+ ctypes.c_char_p # tool_response_str
459
+ ]
460
+
461
+ # int rkllm_set_cross_attn_params(LLMHandle handle, RKLLMCrossAttnParam* cross_attn_params);
462
+ self.lib.rkllm_set_cross_attn_params.restype = ctypes.c_int
463
+ self.lib.rkllm_set_cross_attn_params.argtypes = [LLMHandle, ctypes.POINTER(RKLLMCrossAttnParam)]
464
+
465
+ def create_default_param(self) -> RKLLMParam:
466
+ """Creates a default RKLLMParam structure."""
467
+ return self.lib.rkllm_createDefaultParam()
468
+
469
+ def init(self, param: RKLLMParam, callback_func) -> int:
470
+ """
471
+ Initializes the LLM.
472
+ :param param: RKLLMParam structure.
473
+ :param callback_func: A Python function that matches the signature:
474
+ def my_callback(result_ptr, userdata_ptr, state_enum):
475
+ result = result_ptr.contents # RKLLMResult
476
+ # Process result
477
+ # userdata can be retrieved if passed during run, or ignored
478
+ # state = LLMCallState(state_enum)
479
+ :return: 0 for success, non-zero for failure.
480
+ """
481
+ if not callable(callback_func):
482
+ raise ValueError("callback_func must be a callable Python function.")
483
+
484
+ # Keep a reference to the ctypes callback object to prevent it from being garbage collected
485
+ self._c_callback = LLMResultCallback(callback_func)
486
+
487
+ ret = self.lib.rkllm_init(ctypes.byref(self.llm_handle), ctypes.byref(param), self._c_callback)
488
+ if ret != 0:
489
+ raise RuntimeError(f"rkllm_init failed with error code {ret}")
490
+ return ret
491
+
492
+ def load_lora(self, lora_adapter: RKLLMLoraAdapter) -> int:
493
+ """Loads a Lora adapter."""
494
+ ret = self.lib.rkllm_load_lora(self.llm_handle, ctypes.byref(lora_adapter))
495
+ if ret != 0:
496
+ raise RuntimeError(f"rkllm_load_lora failed with error code {ret}")
497
+ return ret
498
+
499
+ def load_prompt_cache(self, prompt_cache_path: str) -> int:
500
+ """Loads a prompt cache from a file."""
501
+ c_path = prompt_cache_path.encode('utf-8')
502
+ ret = self.lib.rkllm_load_prompt_cache(self.llm_handle, c_path)
503
+ if ret != 0:
504
+ raise RuntimeError(f"rkllm_load_prompt_cache failed for {prompt_cache_path} with error code {ret}")
505
+ return ret
506
+
507
+ def release_prompt_cache(self) -> int:
508
+ """Releases the prompt cache from memory."""
509
+ ret = self.lib.rkllm_release_prompt_cache(self.llm_handle)
510
+ if ret != 0:
511
+ raise RuntimeError(f"rkllm_release_prompt_cache failed with error code {ret}")
512
+ return ret
513
+
514
+ def destroy(self) -> int:
515
+ """Destroys the LLM instance and releases resources."""
516
+ if self.llm_handle and self.llm_handle.value: # Check if handle is not NULL
517
+ ret = self.lib.rkllm_destroy(self.llm_handle)
518
+ self.llm_handle = LLMHandle() # Reset handle
519
+ if ret != 0:
520
+ # Don't raise here as it might be called in __del__
521
+ print(f"Warning: rkllm_destroy failed with error code {ret}")
522
+ return ret
523
+ return 0 # Already destroyed or not initialized
524
+
525
+ def run(self, rkllm_input: RKLLMInput, rkllm_infer_params: RKLLMInferParam, userdata=None) -> int:
526
+ """Runs an LLM inference task synchronously."""
527
+ # userdata can be a ctypes.py_object if you want to pass Python objects,
528
+ # then cast to c_void_p. Or simply None.
529
+ if userdata is not None:
530
+ # Store the userdata object to keep it alive during the call
531
+ self._userdata_ref = userdata
532
+ c_userdata = ctypes.cast(ctypes.pointer(ctypes.py_object(userdata)), ctypes.c_void_p)
533
+ else:
534
+ c_userdata = None
535
+ ret = self.lib.rkllm_run(self.llm_handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), c_userdata)
536
+ if ret != 0:
537
+ raise RuntimeError(f"rkllm_run failed with error code {ret}")
538
+ return ret
539
+
540
+ def run_async(self, rkllm_input: RKLLMInput, rkllm_infer_params: RKLLMInferParam, userdata=None) -> int:
541
+ """Runs an LLM inference task asynchronously."""
542
+ if userdata is not None:
543
+ # Store the userdata object to keep it alive during the call
544
+ self._userdata_ref = userdata
545
+ c_userdata = ctypes.cast(ctypes.pointer(ctypes.py_object(userdata)), ctypes.c_void_p)
546
+ else:
547
+ c_userdata = None
548
+ ret = self.lib.rkllm_run_async(self.llm_handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), c_userdata)
549
+ if ret != 0:
550
+ raise RuntimeError(f"rkllm_run_async failed with error code {ret}")
551
+ return ret
552
+
553
+ def abort(self) -> int:
554
+ """Aborts an ongoing LLM task."""
555
+ ret = self.lib.rkllm_abort(self.llm_handle)
556
+ if ret != 0:
557
+ raise RuntimeError(f"rkllm_abort failed with error code {ret}")
558
+ return ret
559
+
560
+ def is_running(self) -> bool:
561
+ """Checks if an LLM task is currently running. Returns True if running."""
562
+ # The C API returns 0 if running, non-zero otherwise.
563
+ # This is a bit counter-intuitive for a boolean "is_running".
564
+ return self.lib.rkllm_is_running(self.llm_handle) == 0
565
+
566
+ def clear_kv_cache(self, keep_system_prompt: bool, start_pos: list = None, end_pos: list = None) -> int:
567
+ """
568
+ 清除键值缓存
569
+
570
+ 此函数用于清除部分或全部KV缓存。
571
+
572
+ 参数:
573
+ - keep_system_prompt: 是否在缓存中保留系统提示(True保留,False清除)
574
+ 如果提供了特定范围[start_pos, end_pos),此标志将被忽略
575
+ - start_pos: 要清除的KV缓存范围的起始位置数组(包含),每个批次一个
576
+ - end_pos: 要清除的KV缓存范围的结束位置数组(不包含),每个批次一个
577
+ 如果start_pos和end_pos都设置为None,将清除整个缓存,keep_system_prompt将生效
578
+ 如果start_pos[i] < end_pos[i],只有指定的范围会被清除,keep_system_prompt将被忽略
579
+
580
+ 注意:start_pos或end_pos只有在keep_history == 0且生成已通过在回调中返回1暂停时才有效
581
+
582
+ 返回:0表示缓存清除成功,非零表示失败
583
+ """
584
+ # 准备C数组参数
585
+ c_start_pos = None
586
+ c_end_pos = None
587
+
588
+ if start_pos is not None and end_pos is not None:
589
+ if len(start_pos) != len(end_pos):
590
+ raise ValueError("start_pos和end_pos数组长度必须相同")
591
+
592
+ # 创建C数组
593
+ c_start_pos = (ctypes.c_int * len(start_pos))(*start_pos)
594
+ c_end_pos = (ctypes.c_int * len(end_pos))(*end_pos)
595
+
596
+ ret = self.lib.rkllm_clear_kv_cache(
597
+ self.llm_handle,
598
+ ctypes.c_int(1 if keep_system_prompt else 0),
599
+ c_start_pos,
600
+ c_end_pos
601
+ )
602
+ if ret != 0:
603
+ raise RuntimeError(f"rkllm_clear_kv_cache失败,错误代码:{ret}")
604
+ return ret
605
+
606
+ def set_chat_template(self, system_prompt: str, prompt_prefix: str, prompt_postfix: str) -> int:
607
+ """Sets the chat template for the LLM."""
608
+ c_system = system_prompt.encode('utf-8') if system_prompt else b""
609
+ c_prefix = prompt_prefix.encode('utf-8') if prompt_prefix else b""
610
+ c_postfix = prompt_postfix.encode('utf-8') if prompt_postfix else b""
611
+
612
+ ret = self.lib.rkllm_set_chat_template(self.llm_handle, c_system, c_prefix, c_postfix)
613
+ if ret != 0:
614
+ raise RuntimeError(f"rkllm_set_chat_template failed with error code {ret}")
615
+ return ret
616
+
617
+ def get_kv_cache_size(self, n_batch: int) -> list:
618
+ """
619
+ 获取给定LLM句柄的键值缓存当前大小
620
+
621
+ 此函数返回当前存储在模型KV缓存中的位置总数。
622
+
623
+ 参数:
624
+ - n_batch: 批次数量,用于确定返回数组的大小
625
+
626
+ 返回:
627
+ - list: 每个批次的缓存大小列表
628
+ """
629
+ # 预分配数组以存储每个批次的缓存大小
630
+ cache_sizes = (ctypes.c_int * n_batch)()
631
+
632
+ ret = self.lib.rkllm_get_kv_cache_size(self.llm_handle, cache_sizes)
633
+ if ret != 0:
634
+ raise RuntimeError(f"rkllm_get_kv_cache_size失败,错误代码:{ret}")
635
+
636
+ # 转换为Python列表
637
+ return [cache_sizes[i] for i in range(n_batch)]
638
+
639
+ def set_function_tools(self, system_prompt: str, tools: str, tool_response_str: str) -> int:
640
+ """
641
+ 为LLM设置函数调用配置,包括系统提示、工具定义和工具响应token
642
+
643
+ 参数:
644
+ - system_prompt: 定义语言模型上下文或行为的系统提示
645
+ - tools: JSON格式的字符串,定义可用的函数,包括它们的名称、描述和参数
646
+ - tool_response_str: 用于识别对话中函数调用结果的唯一标签。它作为标记标签,
647
+ 允许分词器将工具输出与正常对话轮次分开识别
648
+
649
+ 返回:0表示配置设置成功,非零表示错误
650
+ """
651
+ c_system = system_prompt.encode('utf-8') if system_prompt else b""
652
+ c_tools = tools.encode('utf-8') if tools else b""
653
+ c_tool_response = tool_response_str.encode('utf-8') if tool_response_str else b""
654
+
655
+ ret = self.lib.rkllm_set_function_tools(self.llm_handle, c_system, c_tools, c_tool_response)
656
+ if ret != 0:
657
+ raise RuntimeError(f"rkllm_set_function_tools失败,错误代码:{ret}")
658
+ return ret
659
+
660
+ def set_cross_attn_params(self, cross_attn_params: RKLLMCrossAttnParam) -> int:
661
+ """
662
+ 为LLM解码器设置交叉注意力参数
663
+
664
+ 参数:
665
+ - cross_attn_params: 包含用于交叉注意力的编码器相关输入数据的结构体
666
+ (详见RKLLMCrossAttnParam说明)
667
+
668
+ 返回:0表示参数设置成功,非零表示错误
669
+ """
670
+ ret = self.lib.rkllm_set_cross_attn_params(self.llm_handle, ctypes.byref(cross_attn_params))
671
+ if ret != 0:
672
+ raise RuntimeError(f"rkllm_set_cross_attn_params失败,错误代码:{ret}")
673
+ return ret
674
+
675
+ def __enter__(self):
676
+ return self
677
+
678
+ def __exit__(self, exc_type, exc_val, exc_tb):
679
+ self.destroy()
680
+
681
+ def __del__(self):
682
+ self.destroy() # Ensure resources are freed if object is garbage collected
683
+
684
+ # --- Example Usage (Illustrative) ---
685
+ if __name__ == "__main__":
686
+ # This is a placeholder for how you might use it.
687
+ # You'll need a valid .rkllm model and librkllmrt.so in your path.
688
+
689
+ # Global list to store results from callback for demonstration
690
+ results_buffer = []
691
+
692
+ def my_python_callback(result_ptr, userdata_ptr, state_enum):
693
+ """
694
+ 回调函数,由C库调用来处理LLM结果
695
+
696
+ 参数:
697
+ - result_ptr: 指向LLM结果的指针
698
+ - userdata_ptr: 用户数据指针
699
+ - state_enum: LLM调用状态枚举值
700
+
701
+ 返回:
702
+ - 0: 继续推理
703
+ - 1: 暂停推理
704
+ """
705
+ global results_buffer
706
+ state = LLMCallState(state_enum)
707
+ result = result_ptr.contents
708
+
709
+ current_text = ""
710
+ if result.text: # 检查char_p是否不为NULL
711
+ current_text = result.text.decode('utf-8', errors='ignore')
712
+
713
+ print(f"回调: State={state.name}, TokenID={result.token_id}, Text='{current_text}'")
714
+
715
+ # 显示性能统计信息
716
+ if result.perf.prefill_tokens > 0 or result.perf.generate_tokens > 0:
717
+ print(f" 性能统计: 预填充={result.perf.prefill_tokens}tokens/{result.perf.prefill_time_ms:.1f}ms, "
718
+ f"生成={result.perf.generate_tokens}tokens/{result.perf.generate_time_ms:.1f}ms, "
719
+ f"内存={result.perf.memory_usage_mb:.1f}MB")
720
+
721
+ results_buffer.append(current_text)
722
+
723
+ if state == LLMCallState.RKLLM_RUN_FINISH:
724
+ print("推理完成。")
725
+ elif state == LLMCallState.RKLLM_RUN_ERROR:
726
+ print("推理错误。")
727
+
728
+ # 返回0继续推理,返回1暂停推理
729
+ return 0
730
+
731
+ # --- Attempt to use the wrapper ---
732
+ try:
733
+ print("Initializing RKLLMRuntime...")
734
+ # Adjust library_path if librkllmrt.so is not in default search paths
735
+ # e.g., library_path="./path/to/librkllmrt.so"
736
+ rk_llm = RKLLMRuntime()
737
+
738
+ print("Creating default parameters...")
739
+ params = rk_llm.create_default_param()
740
+
741
+ # --- Configure parameters ---
742
+ # THIS IS CRITICAL: model_path must point to an actual .rkllm file
743
+ # For this example to run, you need a model file.
744
+ # Let's assume a dummy path for now, this will fail at init if not valid.
745
+ model_file = "dummy_model.rkllm"
746
+ if not os.path.exists(model_file):
747
+ print(f"Warning: Model file '{model_file}' does not exist. Init will likely fail.")
748
+ # Create a dummy file for the example to proceed further, though init will still fail
749
+ # with a real library unless it's a valid model.
750
+ with open(model_file, "w") as f:
751
+ f.write("dummy content")
752
+
753
+ params.model_path = model_file.encode('utf-8')
754
+ params.max_context_len = 512
755
+ params.max_new_tokens = 128
756
+ params.top_k = 1 # Greedy
757
+ params.temperature = 0.7
758
+ params.repeat_penalty = 1.1
759
+ # ... set other params as needed
760
+
761
+ print(f"Initializing LLM with model: {params.model_path.decode()}...")
762
+ # This will likely fail if dummy_model.rkllm is not a valid model recognized by the library
763
+ try:
764
+ rk_llm.init(params, my_python_callback)
765
+ print("LLM Initialized.")
766
+ except RuntimeError as e:
767
+ print(f"Error during LLM initialization: {e}")
768
+ print("This is expected if 'dummy_model.rkllm' is not a valid model.")
769
+ print("Replace 'dummy_model.rkllm' with a real model path to test further.")
770
+ exit()
771
+
772
+
773
+ # --- Prepare input ---
774
+ print("准备输入...")
775
+ rk_input = RKLLMInput()
776
+ rk_input.role = b"user" # 设置角色为用户输入
777
+ rk_input.enable_thinking = False # 禁用思考模式(适用于Qwen3模型)
778
+ rk_input.input_type = RKLLMInputType.RKLLM_INPUT_PROMPT
779
+
780
+ prompt_text = "将以下英文文本翻译成中文:'Hello, world!'"
781
+ c_prompt = prompt_text.encode('utf-8')
782
+ rk_input._union_data.prompt_input = c_prompt # 直接访问联合体成员
783
+
784
+ # --- Prepare inference parameters ---
785
+ print("Preparing inference parameters...")
786
+ infer_params = RKLLMInferParam()
787
+ infer_params.mode = RKLLMInferMode.RKLLM_INFER_GENERATE
788
+ infer_params.keep_history = 1 # True
789
+ # infer_params.lora_params = None # or set up RKLLMLoraParam if using LoRA
790
+ # infer_params.prompt_cache_params = None # or set up RKLLMPromptCacheParam
791
+
792
+ # --- Run inference ---
793
+ print(f"Running inference with prompt: '{prompt_text}'")
794
+ results_buffer.clear()
795
+ try:
796
+ rk_llm.run(rk_input, infer_params) # Userdata is None by default
797
+ print("\n--- Full Response ---")
798
+ print("".join(results_buffer))
799
+ print("---------------------\n")
800
+ except RuntimeError as e:
801
+ print(f"Error during LLM run: {e}")
802
+
803
+
804
+ # --- Example: Set chat template (if model supports it) ---
805
+ # print("Setting chat template...")
806
+ # try:
807
+ # rk_llm.set_chat_template("You are a helpful assistant.", "<user>: ", "<assistant>: ")
808
+ # print("Chat template set.")
809
+ # except RuntimeError as e:
810
+ # print(f"Error setting chat template: {e}")
811
+
812
+ # --- Example: Clear KV Cache ---
813
+ # print("Clearing KV cache (keeping system prompt if any)...")
814
+ # try:
815
+ # rk_llm.clear_kv_cache(keep_system_prompt=True)
816
+ # print("KV cache cleared.")
817
+ # except RuntimeError as e:
818
+ # print(f"Error clearing KV cache: {e}")
819
+
820
+ # --- 示例:获取KV缓存大小 ---
821
+ # print("获取KV缓存大小...")
822
+ # try:
823
+ # cache_sizes = rk_llm.get_kv_cache_size(n_batch=1) # 假设批次大小为1
824
+ # print(f"当前KV缓存大小: {cache_sizes}")
825
+ # except RuntimeError as e:
826
+ # print(f"获取KV缓存大小错误: {e}")
827
+
828
+ # --- 示例:设置函数工具 ---
829
+ # print("设置函数调用工具...")
830
+ # try:
831
+ # system_prompt = "你是一个有用的助手,可以调用提供的函��来帮助用户。"
832
+ # tools = '''[{
833
+ # "name": "get_weather",
834
+ # "description": "获取指定城市的天气信息",
835
+ # "parameters": {
836
+ # "type": "object",
837
+ # "properties": {
838
+ # "city": {"type": "string", "description": "城市名称"}
839
+ # },
840
+ # "required": ["city"]
841
+ # }
842
+ # }]'''
843
+ # tool_response_str = "<tool_response>"
844
+ # rk_llm.set_function_tools(system_prompt, tools, tool_response_str)
845
+ # print("函数工具设置成功。")
846
+ # except RuntimeError as e:
847
+ # print(f"设置函数工具错误: {e}")
848
+
849
+ # --- 示例:清除KV缓存(带范围参数) ---
850
+ # print("使用范围参数清除KV缓存...")
851
+ # try:
852
+ # # 清除位置10到20的缓存
853
+ # start_positions = [10] # 批次0的起始位置
854
+ # end_positions = [20] # 批次0的结束位置
855
+ # rk_llm.clear_kv_cache(keep_system_prompt=True, start_pos=start_positions, end_pos=end_positions)
856
+ # print("范围KV缓存清除完成。")
857
+ # except RuntimeError as e:
858
+ # print(f"清除范围KV缓存错误: {e}")
859
+
860
+ except OSError as e:
861
+ print(f"OSError: {e}. Could not load the RKLLM library.")
862
+ print("Please ensure 'librkllmrt.so' is in your LD_LIBRARY_PATH or provide the full path.")
863
+ except Exception as e:
864
+ print(f"An unexpected error occurred: {e}")
865
+ finally:
866
+ if 'rk_llm' in locals() and rk_llm.llm_handle and rk_llm.llm_handle.value:
867
+ print("Destroying LLM instance...")
868
+ rk_llm.destroy()
869
+ print("LLM instance destroyed.")
870
+ if os.path.exists(model_file) and model_file == "dummy_model.rkllm":
871
+ os.remove(model_file) # Clean up dummy file
872
+
873
+ print("Example finished.")
rknn_quickconvert.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ import os
5
+ import sys
6
+ from rknn.api import RKNN
7
+ import argparse
8
+
9
+ def quick_convert(onnx_model_path):
10
+ rknn = RKNN(verbose=True)
11
+ rknn.config(quantized_algorithm='normal', quantized_method='channel', target_platform='rk3588', optimization_level=3, model_pruning=True)
12
+ ret = rknn.load_onnx(model=onnx_model_path)
13
+ if ret != 0:
14
+ print('Load model failed!')
15
+ exit(ret)
16
+ rknn.build(do_quantization=False, dataset=None, rknn_batch_size=None)
17
+ rknn_model_path = onnx_model_path.replace(".onnx", ".rknn")
18
+ ret = rknn.export_rknn(rknn_model_path)
19
+ if ret != 0:
20
+ print('Export RKNN model failed!')
21
+ exit(ret)
22
+ print(f"RKNN model exported to {rknn_model_path}")
23
+
24
+ def compile_only(onnx_model_path):
25
+ from rknn.api.rknn_compiler import RKNNCompiler, RKNNConfig, RKNNNormalize
26
+ from rknn.api.rknn_platform import support_soc_npu_target
27
+ from pprint import pprint
28
+
29
+ pprint(support_soc_npu_target)
30
+
31
+ RKNNCompiler.build(
32
+ onnx_model_path,
33
+ RKNNConfig(
34
+ target="v2",
35
+ request_type="float16",
36
+ optimize_options="compress=0, conv_eltwise_activation_fuse=1, global_fuse=1, multi-core-model-mode=1, output_optimize=1, layout_match=1, enable_argb_group=0, pipeline_fuse=1, enable_flash_attention=0",
37
+ verbose_level=4,
38
+ ),
39
+ RKNNNormalize(
40
+ channel_means=[[0, 0, 0]],
41
+ channel_stds=[[1, 1, 1]],
42
+ channel_orders=[[0, 1, 2]], # not tested !!!
43
+ ),
44
+ "out.rknn",
45
+ )
46
+
47
+ # "Usage: python rknn_quickconvert.py <onnx_model_path> [-c/--compile-only]"
48
+ if __name__ == "__main__":
49
+ parser = argparse.ArgumentParser()
50
+ parser.add_argument("onnx_model_path", type=str, help="Path to the ONNX model file")
51
+ parser.add_argument("-c", "--compile-only", action="store_true", help="Compile the model only")
52
+ args = parser.parse_args()
53
+
54
+ onnx_model_path = args.onnx_model_path
55
+ if args.compile_only:
56
+ compile_only(onnx_model_path)
57
+ else:
58
+ quick_convert(onnx_model_path)
run_rkllm.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import faulthandler
2
+ faulthandler.enable()
3
+ import sys
4
+ import os
5
+ os.environ["RKLLM_LOG_LEVEL"] = "1"
6
+ import ctypes
7
+ import argparse
8
+ import cv2
9
+ import numpy as np
10
+ import ztu_somemodelruntime_rknnlite2 as ort
11
+ from rkllm_binding import (
12
+ RKLLMRuntime,
13
+ RKLLMParam,
14
+ RKLLMInput,
15
+ RKLLMInferParam,
16
+ LLMCallState,
17
+ RKLLMInputType,
18
+ RKLLMInferMode,
19
+ RKLLMResult
20
+ )
21
+
22
+ # Constants
23
+ IMAGE_HEIGHT = 448
24
+ IMAGE_WIDTH = 448
25
+
26
+ def expand2square(img, background_color):
27
+ """
28
+ Expand the image into a square and fill it with the specified background color.
29
+ """
30
+ height, width, _ = img.shape
31
+ if width == height:
32
+ return img.copy()
33
+
34
+ size = max(width, height)
35
+ square_img = np.full((size, size, 3), background_color, dtype=np.uint8)
36
+
37
+ x_offset = (size - width) // 2
38
+ y_offset = (size - height) // 2
39
+
40
+ square_img[y_offset:y_offset+height, x_offset:x_offset+width] = img
41
+ return square_img
42
+
43
+ def llm_callback(result_ptr, userdata_ptr, state_enum):
44
+ """
45
+ Callback function to handle LLM results.
46
+ """
47
+ state = LLMCallState(state_enum)
48
+ result = result_ptr.contents
49
+
50
+ if state == LLMCallState.RKLLM_RUN_NORMAL:
51
+ if result.text:
52
+ print(result.text.decode('utf-8', errors='ignore'), end='', flush=True)
53
+ elif state == LLMCallState.RKLLM_RUN_FINISH:
54
+ print("\n", flush=True)
55
+ elif state == LLMCallState.RKLLM_RUN_ERROR:
56
+ print("\nrun error", flush=True)
57
+
58
+ return 0
59
+
60
+ def main():
61
+ parser = argparse.ArgumentParser(
62
+ description="Run RKLLM visual language model inference based on the C++ example."
63
+ )
64
+ parser.add_argument("image_path", type=str, help="Path to the input image.")
65
+ parser.add_argument("encoder_model_path", type=str, help="Path to the ONNX vision encoder model.")
66
+ parser.add_argument("llm_model_path", type=str, help="Path to the .rkllm language model.")
67
+ parser.add_argument("max_new_tokens", type=int, help="Maximum number of new tokens to generate.")
68
+ parser.add_argument("max_context_len", type=int, help="Maximum context length.")
69
+ # The rknn_core_num is not directly used by onnxruntime in the same way,
70
+ # but we keep it for API consistency with the C++ example.
71
+ # ONNX Runtime will manage its own threading and execution providers.
72
+ parser.add_argument("rknn_core_num", type=int, help="Core number for RKNN (informational for this script).")
73
+
74
+ args = parser.parse_args()
75
+
76
+ # --- 1. Initialize Image Encoder (ONNX Runtime) ---
77
+ print("Initializing ONNX Runtime for vision encoder...")
78
+ try:
79
+ ort_session = ort.InferenceSession(args.encoder_model_path)
80
+ except Exception as e:
81
+ print(f"Failed to load ONNX model: {e}")
82
+ sys.exit(1)
83
+ print("Vision encoder loaded successfully.")
84
+
85
+ input_name = ort_session.get_inputs()[0].name
86
+ output_name = ort_session.get_outputs()[0].name
87
+ print(f"ONNX Input: {input_name}, ONNX Output: {output_name}")
88
+
89
+ # --- 2. Initialize LLM ---
90
+ print("Initializing RKLLM Runtime...")
91
+ rk_llm = RKLLMRuntime()
92
+ param = rk_llm.create_default_param()
93
+
94
+ param.model_path = args.llm_model_path.encode('utf-8')
95
+ param.top_k = 1
96
+ param.max_new_tokens = args.max_new_tokens
97
+ param.max_context_len = args.max_context_len
98
+ param.skip_special_token = True
99
+ param.img_start = b"<|vision_start|>"
100
+ param.img_end = b"<|vision_end|>"
101
+ param.img_content = b"<|image_pad|>"
102
+ param.extend_param.base_domain_id = 1
103
+
104
+ try:
105
+ rk_llm.init(param, llm_callback)
106
+ print("RKLLM initialized successfully.")
107
+ except RuntimeError as e:
108
+ print(f"RKLLM init failed: {e}")
109
+ sys.exit(1)
110
+
111
+ # --- 3. Image Preprocessing ---
112
+ print("Preprocessing image...")
113
+ img = cv2.imread(args.image_path)
114
+ if img is None:
115
+ print(f"Failed to read image from {args.image_path}")
116
+ sys.exit(1)
117
+
118
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
119
+
120
+ background_color = (127.5, 127.5, 127.5) # As per C++ example
121
+ square_img = expand2square(img, background_color)
122
+ resized_img = cv2.resize(square_img, (IMAGE_WIDTH, IMAGE_HEIGHT), interpolation=cv2.INTER_LINEAR)
123
+
124
+ # Normalize and prepare for ONNX model
125
+ input_tensor = resized_img.astype(np.float32)
126
+ # Normalize using preprocessor config values
127
+ input_tensor = (input_tensor / 255.0 - np.array([0.48145466, 0.4578275, 0.40821073])) / np.array([0.26862954, 0.26130258, 0.27577711])
128
+ # Convert to NCHW format
129
+ input_tensor = np.transpose(input_tensor, (2, 0, 1)) # HWC -> CHW
130
+ input_tensor = np.expand_dims(input_tensor, axis=0) # Add batch dimension -> (1, 3, 392, 392)
131
+
132
+ # --- 4. Run Image Encoder ---
133
+ print("Running vision encoder...")
134
+ try:
135
+ img_vec_output = ort_session.run([output_name], {input_name: input_tensor.astype(np.float32)})[0]
136
+ # The output from C++ is a flat float array. Let's flatten the ONNX output.
137
+ img_vec = img_vec_output.flatten().astype(np.float32)
138
+
139
+ except Exception as e:
140
+ print(f"Failed to run vision encoder inference: {e}")
141
+ rk_llm.destroy()
142
+ sys.exit(1)
143
+
144
+ print("Image encoded successfully.")
145
+
146
+ # --- 5. Interactive Chat Loop ---
147
+ rkllm_infer_params = RKLLMInferParam()
148
+ rkllm_infer_params.mode = RKLLMInferMode.RKLLM_INFER_GENERATE
149
+ rkllm_infer_params.keep_history = 0
150
+
151
+ # Set chat template
152
+ rk_llm.set_chat_template(
153
+ system_prompt="<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n",
154
+ prompt_prefix="<|im_start|>user\n",
155
+ prompt_postfix="<|im_end|>\n<|im_start|>assistant\n"
156
+ )
157
+
158
+ pre_input = [
159
+ "Picture 1: <image> What is in the image?",
160
+ "Picture 1: <image> 这张图片中有什么?"
161
+ ]
162
+ print("\n**********************可输入以下问题对应序号获取回答/或自定义输入********************\n")
163
+ for i, p in enumerate(pre_input):
164
+ print(f"[{i}] {p}")
165
+ print("\n*************************************************************************\n")
166
+
167
+ try:
168
+ while True:
169
+ print("\nuser: ", end="", flush=True)
170
+ input_str = sys.stdin.readline().strip()
171
+
172
+ if not input_str:
173
+ continue
174
+ if input_str == "exit":
175
+ break
176
+ if input_str == "clear":
177
+ try:
178
+ rk_llm.clear_kv_cache(keep_system_prompt=True)
179
+ print("KV cache cleared.")
180
+ except RuntimeError as e:
181
+ print(f"Failed to clear KV cache: {e}")
182
+ continue
183
+
184
+ try:
185
+ idx = int(input_str)
186
+ if 0 <= idx < len(pre_input):
187
+ input_str = pre_input[idx]
188
+ print(input_str)
189
+ except (ValueError, IndexError):
190
+ pass # Use the raw string if not a valid index
191
+
192
+ rkllm_input = RKLLMInput()
193
+ rkllm_input.role = b"user"
194
+
195
+ print("robot: ", end="", flush=True)
196
+
197
+ if "<image>" in input_str:
198
+ rkllm_input.input_type = RKLLMInputType.RKLLM_INPUT_MULTIMODAL
199
+
200
+ # Setup multimodal input
201
+ rkllm_input.multimodal_input.prompt = input_str.encode('utf-8')
202
+ rkllm_input.multimodal_input.image_embed = img_vec.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
203
+ rkllm_input.multimodal_input.n_image_tokens = img_vec_output.shape[0]
204
+ print("n_image_tokens: ", rkllm_input.multimodal_input.n_image_tokens)
205
+ rkllm_input.multimodal_input.n_image = 1
206
+ rkllm_input.multimodal_input.image_height = IMAGE_HEIGHT
207
+ rkllm_input.multimodal_input.image_width = IMAGE_WIDTH
208
+ else:
209
+ rkllm_input.input_type = RKLLMInputType.RKLLM_INPUT_PROMPT
210
+ rkllm_input.prompt_input = input_str.encode('utf-8')
211
+
212
+ try:
213
+ rk_llm.run(rkllm_input, rkllm_infer_params)
214
+ except RuntimeError as e:
215
+ print(f"\nError during rkllm_run: {e}")
216
+
217
+ except KeyboardInterrupt:
218
+ print("\nExiting...")
219
+ finally:
220
+ print("Releasing resources...")
221
+ rk_llm.destroy()
222
+ print("RKLLM instance destroyed.")
223
+
224
+ if __name__ == "__main__":
225
+ main()
226
+
ztu_somemodelruntime_rknnlite2.py ADDED
@@ -0,0 +1,569 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 模块级常量和函数
2
+ from rknnlite.api import RKNNLite
3
+ import numpy as np
4
+ import os
5
+ import warnings
6
+ import logging
7
+ from typing import List, Dict, Union, Optional
8
+
9
+ try:
10
+ import onnxruntime as ort
11
+ HAS_ORT = True
12
+ except ImportError:
13
+ HAS_ORT = False
14
+ warnings.warn("onnxruntime未安装,只能使用RKNN后端", ImportWarning)
15
+
16
+ # 配置日志
17
+ logger = logging.getLogger("somemodelruntime_rknnlite2")
18
+ logger.setLevel(logging.ERROR) # 默认只输出错误信息
19
+ if not logger.handlers:
20
+ handler = logging.StreamHandler()
21
+ handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
22
+ logger.addHandler(handler)
23
+
24
+ # ONNX Runtime日志级别到Python logging级别的映射
25
+ _LOGGING_LEVEL_MAP = {
26
+ 0: logging.DEBUG, # Verbose
27
+ 1: logging.INFO, # Info
28
+ 2: logging.WARNING, # Warning
29
+ 3: logging.ERROR, # Error
30
+ 4: logging.CRITICAL # Fatal
31
+ }
32
+
33
+ # 检查环境变量中的日志级别设置
34
+ try:
35
+ env_log_level = os.getenv('ZTU_MODELRT_RKNNL2_LOG_LEVEL')
36
+ if env_log_level is not None:
37
+ log_level = int(env_log_level)
38
+ if log_level in _LOGGING_LEVEL_MAP:
39
+ logger.setLevel(_LOGGING_LEVEL_MAP[log_level])
40
+ logger.info(f"从环境变量设置日志级别: {log_level}")
41
+ else:
42
+ logger.warning(f"环境变量ZTU_MODELRT_RKNNL2_LOG_LEVEL的值无效: {log_level}, 应该是0-4之间的整数")
43
+ except ValueError:
44
+ logger.warning(f"环境变量ZTU_MODELRT_RKNNL2_LOG_LEVEL的值无效: {env_log_level}, 应该是0-4之间的整数")
45
+
46
+
47
+ def set_default_logger_severity(level: int) -> None:
48
+ """
49
+ Sets the default logging severity. 0:Verbose, 1:Info, 2:Warning, 3:Error, 4:Fatal
50
+
51
+ Args:
52
+ level: 日志级别(0-4)
53
+ """
54
+ if level not in _LOGGING_LEVEL_MAP:
55
+ raise ValueError(f"无效的日志级别: {level}, 应该是0-4之间的整数")
56
+ logger.setLevel(_LOGGING_LEVEL_MAP[level])
57
+
58
+ def set_default_logger_verbosity(level: int) -> None:
59
+ """
60
+ Sets the default logging verbosity level. To activate the verbose log,
61
+ you need to set the default logging severity to 0:Verbose level.
62
+
63
+ Args:
64
+ level: 日志级别(0-4)
65
+ """
66
+ set_default_logger_severity(level)
67
+
68
+ # RKNN tensor type到numpy dtype的映射
69
+ RKNN_DTYPE_MAP = {
70
+ 0: np.float32, # RKNN_TENSOR_FLOAT32
71
+ 1: np.float16, # RKNN_TENSOR_FLOAT16
72
+ 2: np.int8, # RKNN_TENSOR_INT8
73
+ 3: np.uint8, # RKNN_TENSOR_UINT8
74
+ 4: np.int16, # RKNN_TENSOR_INT16
75
+ 5: np.uint16, # RKNN_TENSOR_UINT16
76
+ 6: np.int32, # RKNN_TENSOR_INT32
77
+ 7: np.uint32, # RKNN_TENSOR_UINT32
78
+ 8: np.int64, # RKNN_TENSOR_INT64
79
+ 9: bool, # RKNN_TENSOR_BOOL
80
+ 10: np.int8, # RKNN_TENSOR_INT4 (用int8表示)
81
+ }
82
+
83
+ def get_available_providers() -> List[str]:
84
+ """
85
+ 获取可用的设备提供者列表(为保持接口兼容性的占位函数)
86
+
87
+ Returns:
88
+ list: 可用的设备提供者列表,总是返回["CPUExecutionProvider", "somemodelruntime_rknnlite2_ExecutionProvider"]
89
+ """
90
+ return ["CPUExecutionProvider", "somemodelruntime_rknnlite2_ExecutionProvider"]
91
+
92
+
93
+ def get_device() -> str:
94
+ """
95
+ 获取当前设备
96
+
97
+ Returns:
98
+ str: 当前设备
99
+ """
100
+ return "RKNN2"
101
+
102
+ def get_version_info() -> Dict[str, str]:
103
+ """
104
+ 获取版本信息
105
+
106
+ Returns:
107
+ dict: 包含API和驱动版本信息的字典
108
+ """
109
+ runtime = RKNNLite()
110
+ version = runtime.get_sdk_version()
111
+ return {
112
+ "api_version": version.split('\n')[2].split(': ')[1].split(' ')[0],
113
+ "driver_version": version.split('\n')[3].split(': ')[1]
114
+ }
115
+
116
+ class IOTensor:
117
+ """输入/输出张量的信息封装类"""
118
+ def __init__(self, name, shape, type=None):
119
+ self.name = name.decode() if isinstance(name, bytes) else name
120
+ self.shape = shape
121
+ self.type = type
122
+
123
+ def __str__(self):
124
+ return f"IOTensor(name='{self.name}', shape={self.shape}, type={self.type})"
125
+
126
+ class SessionOptions:
127
+ """会话选项类"""
128
+ def __init__(self):
129
+ self.enable_profiling = False # 是否使用性能分析
130
+ self.intra_op_num_threads = 1 # 设置RKNN的线程数, 对应rknn的core_mask
131
+ self.log_severity_level = -1 # 另一个设置日志级别的参数
132
+ self.log_verbosity_level = -1 # 另一个设置日志级别的参数
133
+
134
+
135
+ class InferenceSession:
136
+ """
137
+ RKNNLite运行时封装类,API风格类似ONNX Runtime
138
+ """
139
+
140
+ def __new__(cls, model_path: str, sess_options: Optional[SessionOptions] = None, **kwargs):
141
+ processed_path = InferenceSession._process_model_path(model_path, sess_options)
142
+ if isinstance(processed_path, str) and processed_path.lower().endswith('.onnx'):
143
+ logger.info("使用ONNX Runtime加载模型")
144
+ if not HAS_ORT:
145
+ raise RuntimeError("未安装onnxruntime,无法加载ONNX模型")
146
+ return ort.InferenceSession(processed_path, sess_options=sess_options, **kwargs)
147
+ else:
148
+ # 如果不是 ONNX 模型,则调用父类的 __new__ 创建 InferenceSession 实例
149
+ instance = super().__new__(cls)
150
+ # 保存处理后的路径
151
+ instance._processed_path = processed_path
152
+ return instance
153
+
154
+ def __init__(self, model_path: str, sess_options: Optional[SessionOptions] = None, **kwargs):
155
+ """
156
+ 初始化运行时并加载模型
157
+
158
+ Args:
159
+ model_path: 模型文件路径(.rknn或.onnx)
160
+ sess_options: 会话选项
161
+ **kwargs: 其他初始化参数
162
+ """
163
+ options = sess_options or SessionOptions()
164
+
165
+ # 只在未设置环境变量时使用SessionOptions中的日志级别
166
+ if os.getenv('ZTU_MODELRT_RKNNL2_LOG_LEVEL') is None:
167
+ if options.log_severity_level != -1:
168
+ set_default_logger_severity(options.log_severity_level)
169
+ if options.log_verbosity_level != -1:
170
+ set_default_logger_verbosity(options.log_verbosity_level)
171
+
172
+ # 使用__new__中处理好的路径
173
+ model_path = getattr(self, '_processed_path', model_path)
174
+ if isinstance(model_path, str) and model_path.lower().endswith('.onnx'):
175
+ # 避免重复加载 ONNX 模型
176
+ return
177
+
178
+ # ... 现有的 RKNN 模型加载和初始化代码 ...
179
+ self.model_path = model_path
180
+ if not os.path.exists(self.model_path):
181
+ logger.error(f"模型文件不存在: {self.model_path}")
182
+ raise FileNotFoundError(f"模型文件不存在: {self.model_path}")
183
+
184
+ self.runtime = RKNNLite(verbose=options.enable_profiling)
185
+
186
+ logger.debug(f"正在加载模型: {self.model_path}")
187
+ ret = self.runtime.load_rknn(self.model_path)
188
+ if ret != 0:
189
+ logger.error(f"加载RKNN模型失败: {self.model_path}")
190
+ raise RuntimeError(f'加载RKNN模型失败: {self.model_path}')
191
+ logger.debug("模型加载成功")
192
+
193
+
194
+ if options.intra_op_num_threads == 1:
195
+ core_mask = RKNNLite.NPU_CORE_AUTO
196
+ elif options.intra_op_num_threads == 2:
197
+ core_mask = RKNNLite.NPU_CORE_0_1
198
+ elif options.intra_op_num_threads == 3:
199
+ core_mask = RKNNLite.NPU_CORE_0_1_2
200
+ else:
201
+ raise ValueError(f"intra_op_num_threads的值无效: {options.intra_op_num_threads}, 只能是1,2或3")
202
+
203
+ logger.debug("正在初始化运行时环境")
204
+ ret = self.runtime.init_runtime(core_mask=core_mask)
205
+ if ret != 0:
206
+ logger.error("初始化运行时环境失败")
207
+ raise RuntimeError('初始化运行时环境失败')
208
+ logger.debug("运行时环境初始化成功")
209
+
210
+ self._init_io_info()
211
+ self.options = options
212
+
213
+ def get_performance_info(self) -> Dict[str, float]:
214
+ """
215
+ 获取性能信息
216
+
217
+ Returns:
218
+ dict: 包含性能信息的字典
219
+ """
220
+ if not self.options.perf_debug:
221
+ raise RuntimeError("性能分析未启用,请在SessionOptions中设置perf_debug=True")
222
+
223
+ perf = self.runtime.rknn_runtime.get_run_perf()
224
+ return {
225
+ "run_duration": perf.run_duration / 1000.0 # 转换为毫秒
226
+ }
227
+
228
+ def set_core_mask(self, core_mask: int) -> None:
229
+ """
230
+ 设置NPU核心使用模式
231
+
232
+ Args:
233
+ core_mask: NPU核心掩码,使用NPU_CORE_*常量
234
+ """
235
+ ret = self.runtime.rknn_runtime.set_core_mask(core_mask)
236
+ if ret != 0:
237
+ raise RuntimeError("设置NPU核心模式失败")
238
+
239
+ @staticmethod
240
+ def _process_model_path(model_path, sess_options):
241
+ """
242
+ 处理模型路径,支持.onnx和.rknn文件
243
+
244
+ Args:
245
+ model_path: 模型文件路径
246
+ """
247
+ # 如果是ONNX文件,检查是否需要自动加载RKNN
248
+ if model_path.lower().endswith('.onnx'):
249
+ logger.info("检测到ONNX模型文件")
250
+
251
+ # 获取需要跳过自动加载的模型列表
252
+ skip_models = os.getenv('ZTU_MODELRT_RKNNL2_SKIP', '').strip()
253
+ if skip_models:
254
+ skip_list = [m.strip() for m in skip_models.split(',')]
255
+ # 获取模型文件名(不含路径)用于匹配
256
+ model_name = os.path.basename(model_path)
257
+ if model_name.lower() in [m.lower() for m in skip_list]:
258
+ logger.info(f"模型{model_name}在跳过列表中,将使用ONNX Runtime")
259
+ return model_path
260
+
261
+ # 构造RKNN文件路径
262
+ rknn_path = os.path.splitext(model_path)[0] + '.rknn'
263
+ if os.path.exists(rknn_path):
264
+ logger.info(f"找到对应的RKNN模型,将使用RKNN: {rknn_path}")
265
+ return rknn_path
266
+ else:
267
+ logger.info("未找到对应的RKNN模型,将使用ONNX Runtime")
268
+ return model_path
269
+
270
+ return model_path
271
+
272
+ def _convert_nhwc_to_nchw(self, shape):
273
+ """将NHWC格式的shape转换为NCHW格式"""
274
+ if len(shape) == 4:
275
+ # NHWC -> NCHW
276
+ n, h, w, c = shape
277
+ return [n, c, h, w]
278
+ return shape
279
+
280
+ def _init_io_info(self):
281
+ """初始化模型的输入输出信息"""
282
+ runtime = self.runtime.rknn_runtime
283
+
284
+ # 获取输入输出数量
285
+ n_input, n_output = runtime.get_in_out_num()
286
+
287
+ # 获取输入信息
288
+ self.input_tensors = []
289
+ for i in range(n_input):
290
+ attr = runtime.get_tensor_attr(i)
291
+ shape = [attr.dims[j] for j in range(attr.n_dims)]
292
+ # 对四维输入进行NHWC到NCHW的转换
293
+ shape = self._convert_nhwc_to_nchw(shape)
294
+ # 获取dtype
295
+ dtype = RKNN_DTYPE_MAP.get(attr.type, None)
296
+ tensor = IOTensor(attr.name, shape, dtype)
297
+ self.input_tensors.append(tensor)
298
+
299
+ # 获取输出信息
300
+ self.output_tensors = []
301
+ for i in range(n_output):
302
+ attr = runtime.get_tensor_attr(i, is_output=True)
303
+ shape = runtime.get_output_shape(i)
304
+ # 获取dtype
305
+ dtype = RKNN_DTYPE_MAP.get(attr.type, None)
306
+ tensor = IOTensor(attr.name, shape, dtype)
307
+ self.output_tensors.append(tensor)
308
+
309
+ def get_inputs(self):
310
+ """
311
+ 获取模型输入信息
312
+
313
+ Returns:
314
+ list: 包含输入信息的列表
315
+ """
316
+ return self.input_tensors
317
+
318
+ def get_outputs(self):
319
+ """
320
+ 获取模型输出信息
321
+
322
+ Returns:
323
+ list: 包含输出信息的列表
324
+ """
325
+ return self.output_tensors
326
+
327
+ def run(self, output_names=None, input_feed=None, data_format="nchw", **kwargs):
328
+ """
329
+ 执行模型推理
330
+
331
+ Args:
332
+ output_names: 输出节点名称列表,指定需要返回哪些输出
333
+ input_feed: 输入数据字典或列表
334
+ data_format: 输入数据格式,"nchw"或"nhwc"
335
+ **kwargs: 其他运行时参数
336
+
337
+ Returns:
338
+ list: 模型输出结果列表,如果指定了output_names则只返回指定的输出
339
+ """
340
+ if input_feed is None:
341
+ logger.error("input_feed不能为None")
342
+ raise ValueError("input_feed不能为None")
343
+
344
+ # 准备输入数据
345
+ if isinstance(input_feed, dict):
346
+ # 如果是字典,按照模型输入顺序排列
347
+ inputs = []
348
+ input_map = {tensor.name: i for i, tensor in enumerate(self.input_tensors)}
349
+ for tensor in self.input_tensors:
350
+ if tensor.name not in input_feed:
351
+ raise ValueError(f"缺少输入: {tensor.name}")
352
+ inputs.append(input_feed[tensor.name])
353
+ elif isinstance(input_feed, (list, tuple)):
354
+ # 如果是列表,确保长度匹配
355
+ if len(input_feed) != len(self.input_tensors):
356
+ raise ValueError(f"输入数量不匹配: 期望{len(self.input_tensors)}, 实际{len(input_feed)}")
357
+ inputs = list(input_feed)
358
+ else:
359
+ logger.error("input_feed必须是字典或列表类型")
360
+ raise ValueError("input_feed必须是字典或列表类型")
361
+
362
+ # 执行推理
363
+ try:
364
+ logger.debug("开始执行推理")
365
+ all_outputs = self.runtime.inference(inputs=inputs, data_format=data_format)
366
+
367
+ # 如果没有指定output_names,返回所有输出
368
+ if output_names is None:
369
+ return all_outputs
370
+
371
+ # 获取指定的输出
372
+ output_map = {tensor.name: i for i, tensor in enumerate(self.output_tensors)}
373
+ selected_outputs = []
374
+ for name in output_names:
375
+ if name not in output_map:
376
+ raise ValueError(f"未找到输出节点: {name}")
377
+ selected_outputs.append(all_outputs[output_map[name]])
378
+
379
+ return selected_outputs
380
+
381
+ except Exception as e:
382
+ logger.error(f"推理执行失败: {str(e)}")
383
+ raise RuntimeError(f"推理执行失败: {str(e)}")
384
+
385
+ def close(self):
386
+ """
387
+ 关闭会话,释放资源
388
+ """
389
+ if self.runtime is not None:
390
+ logger.info("正在释放运行时资源")
391
+ self.runtime.release()
392
+ self.runtime = None
393
+
394
+ def __enter__(self):
395
+ return self
396
+
397
+ def __exit__(self, exc_type, exc_val, exc_tb):
398
+ self.close()
399
+
400
+ def end_profiling(self) -> Optional[str]:
401
+ """
402
+ 结束性能分析的存根方法
403
+
404
+ Returns:
405
+ Optional[str]: None
406
+ """
407
+ warnings.warn("end_profiling()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
408
+ return None
409
+
410
+ def get_profiling_start_time_ns(self) -> int:
411
+ """
412
+ 获取性能分析开始时间的存根方法
413
+
414
+ Returns:
415
+ int: 0
416
+ """
417
+ warnings.warn("get_profiling_start_time_ns()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
418
+ return 0
419
+
420
+ def get_modelmeta(self) -> Dict[str, str]:
421
+ """
422
+ 获取模型元数据的存根方法
423
+
424
+ Returns:
425
+ Dict[str, str]: 空字典
426
+ """
427
+ warnings.warn("get_modelmeta()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
428
+ return {}
429
+
430
+ def get_session_options(self) -> SessionOptions:
431
+ """
432
+ 获取会话选项
433
+
434
+ Returns:
435
+ SessionOptions: 当前会话选项
436
+ """
437
+ return self.options
438
+
439
+ def get_providers(self) -> List[str]:
440
+ """
441
+ 获取当前使用的providers的存根方法
442
+
443
+ Returns:
444
+ List[str]: ["CPUExecutionProvider"]
445
+ """
446
+ warnings.warn("get_providers()是存根方法,始终返回CPUExecutionProvider", RuntimeWarning, stacklevel=2)
447
+ return ["CPUExecutionProvider"]
448
+
449
+ def get_provider_options(self) -> Dict[str, Dict[str, str]]:
450
+ """
451
+ 获取provider选项的存根方法
452
+
453
+ Returns:
454
+ Dict[str, Dict[str, str]]: 空字典
455
+ """
456
+ warnings.warn("get_provider_options()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
457
+ return {}
458
+
459
+ def get_session_config(self) -> Dict[str, str]:
460
+ """
461
+ 获取会话配置的存根方法
462
+
463
+ Returns:
464
+ Dict[str, str]: 空字典
465
+ """
466
+ warnings.warn("get_session_config()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
467
+ return {}
468
+
469
+ def get_session_state(self) -> Dict[str, str]:
470
+ """
471
+ 获取会话状态的存根方法
472
+
473
+ Returns:
474
+ Dict[str, str]: 空字典
475
+ """
476
+ warnings.warn("get_session_state()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
477
+ return {}
478
+
479
+ def set_session_config(self, config: Dict[str, str]) -> None:
480
+ """
481
+ 设置会话配置的存根方法
482
+
483
+ Args:
484
+ config: 会话配置字典
485
+ """
486
+ warnings.warn("set_session_config()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
487
+
488
+ def get_memory_info(self) -> Dict[str, int]:
489
+ """
490
+ 获取内存使用信息的存根方法
491
+
492
+ Returns:
493
+ Dict[str, int]: 空字典
494
+ """
495
+ warnings.warn("get_memory_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
496
+ return {}
497
+
498
+ def set_memory_pattern(self, enable: bool) -> None:
499
+ """
500
+ 设置内存模式的存根方法
501
+
502
+ Args:
503
+ enable: 是否启用内存模式
504
+ """
505
+ warnings.warn("set_memory_pattern()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
506
+
507
+ def disable_memory_pattern(self) -> None:
508
+ """
509
+ 禁用内存模式的存根方法
510
+ """
511
+ warnings.warn("disable_memory_pattern()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
512
+
513
+ def get_optimization_level(self) -> int:
514
+ """
515
+ 获取优化级别的存根方法
516
+
517
+ Returns:
518
+ int: 0
519
+ """
520
+ warnings.warn("get_optimization_level()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
521
+ return 0
522
+
523
+ def set_optimization_level(self, level: int) -> None:
524
+ """
525
+ 设置优化级别的存根方法
526
+
527
+ Args:
528
+ level: 优化级别
529
+ """
530
+ warnings.warn("set_optimization_level()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
531
+
532
+ def get_model_metadata(self) -> Dict[str, str]:
533
+ """
534
+ 获取模型元数据的存根方法(与get_modelmeta不同的接口)
535
+
536
+ Returns:
537
+ Dict[str, str]: 空字典
538
+ """
539
+ warnings.warn("get_model_metadata()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
540
+ return {}
541
+
542
+ def get_model_path(self) -> str:
543
+ """
544
+ 获取模型路径
545
+
546
+ Returns:
547
+ str: 模型文件路径
548
+ """
549
+ return self.model_path
550
+
551
+ def get_input_type_info(self) -> List[Dict[str, str]]:
552
+ """
553
+ 获取输入类型信息的存根方法
554
+
555
+ Returns:
556
+ List[Dict[str, str]]: 空列表
557
+ """
558
+ warnings.warn("get_input_type_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
559
+ return []
560
+
561
+ def get_output_type_info(self) -> List[Dict[str, str]]:
562
+ """
563
+ 获取输出类型信息的存根方法
564
+
565
+ Returns:
566
+ List[Dict[str, str]]: 空列表
567
+ """
568
+ warnings.warn("get_output_type_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
569
+ return []