Upload inference code
Browse files- .gitattributes +1 -0
- export_vision_onnx.py +89 -0
- librkllmrt.so +3 -0
- rkllm-convert.py +23 -0
- rkllm_binding.py +873 -0
- rknn_quickconvert.py +58 -0
- run_rkllm.py +226 -0
- ztu_somemodelruntime_rknnlite2.py +569 -0
.gitattributes
CHANGED
|
@@ -35,3 +35,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
language_model_w8a8.rkllm filter=lfs diff=lfs merge=lfs -text
|
| 37 |
vision_encoder.rknn filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
language_model_w8a8.rkllm filter=lfs diff=lfs merge=lfs -text
|
| 37 |
vision_encoder.rknn filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
librkllmrt.so filter=lfs diff=lfs merge=lfs -text
|
export_vision_onnx.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import argparse
|
| 7 |
+
|
| 8 |
+
argparse = argparse.ArgumentParser()
|
| 9 |
+
argparse.add_argument('--step', type=int, help='export step', required=True)
|
| 10 |
+
argparse.add_argument('--path', type=str, default='.', help='model path', required=False)
|
| 11 |
+
argparse.add_argument('--batch', type=int, default=1, help='batch size', required=False)
|
| 12 |
+
argparse.add_argument('--height', type=int, default=448, help='image height', required=False)
|
| 13 |
+
argparse.add_argument('--width', type=int, default=448, help='image width', required=False)
|
| 14 |
+
argparse.add_argument('--savepath', type=str, default='vision_encoder.onnx', help='save path', required=False)
|
| 15 |
+
args = argparse.parse_args()
|
| 16 |
+
|
| 17 |
+
step = args.step
|
| 18 |
+
# 加载本地模型
|
| 19 |
+
path = args.path
|
| 20 |
+
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 21 |
+
path,
|
| 22 |
+
torch_dtype=torch.float32, # 注意此处的数据类型,由于 rknn 目前仅支持 float32 ,因此需要指定;若是在加载权重时限制了数据类型,需要自行修改config.json中的 "use_flash_attn" 参数为 false
|
| 23 |
+
low_cpu_mem_usage=True,
|
| 24 |
+
trust_remote_code=True).eval()
|
| 25 |
+
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)
|
| 26 |
+
|
| 27 |
+
N = args.batch # batch size
|
| 28 |
+
channel = 3 # 3 for RGB
|
| 29 |
+
H = args.height # image height, must be divisible by (merge_size * patch_size)
|
| 30 |
+
W = args.width # image width, must be divisible by (merge_size * patch_size)
|
| 31 |
+
merge_size = 2
|
| 32 |
+
temporal_patch_size = 2
|
| 33 |
+
patch_size = 14
|
| 34 |
+
grid_t = N // temporal_patch_size if N%temporal_patch_size == 0 else N // temporal_patch_size + 1
|
| 35 |
+
grid_h = H // patch_size
|
| 36 |
+
grid_w = W // patch_size
|
| 37 |
+
|
| 38 |
+
def export_onnx(image):
|
| 39 |
+
if N == 1:
|
| 40 |
+
images = image.repeat(temporal_patch_size, 1, 1, 1)
|
| 41 |
+
elif N % temporal_patch_size != 0:
|
| 42 |
+
repeat_time = temporal_patch_size - N % temporal_patch_size
|
| 43 |
+
repeat_image = image[-1:, ...].repeat(repeat_time, 1, 1, 1)
|
| 44 |
+
images = torch.cat((image, repeat_image), dim=0)
|
| 45 |
+
patches = images.reshape(grid_t, temporal_patch_size, channel, grid_h//merge_size, merge_size, patch_size, grid_w//merge_size, merge_size, patch_size)
|
| 46 |
+
patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8)
|
| 47 |
+
flatten_patches = patches.reshape(grid_t * grid_h * grid_w, channel * temporal_patch_size * patch_size * patch_size)
|
| 48 |
+
model.visual.forward = forward_new(model.visual)
|
| 49 |
+
if step == 1:
|
| 50 |
+
feature = model.visual(flatten_patches, torch.tensor([grid_t, grid_h, grid_w]).unsqueeze(0))
|
| 51 |
+
else:
|
| 52 |
+
feature = model.visual(flatten_patches)
|
| 53 |
+
return feature
|
| 54 |
+
|
| 55 |
+
def forward_new(self):
|
| 56 |
+
def tmp (hidden_states, grid_thw=None):
|
| 57 |
+
hidden_states = self.patch_embed(hidden_states)
|
| 58 |
+
if grid_thw is not None:
|
| 59 |
+
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
| 60 |
+
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
|
| 61 |
+
dim=0, dtype=torch.int32
|
| 62 |
+
)
|
| 63 |
+
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
| 64 |
+
np.save("./rotary_pos_emb.npy", rotary_pos_emb.cpu().detach().numpy())
|
| 65 |
+
np.save("./cu_seqlens.npy", cu_seqlens.cpu().detach().numpy())
|
| 66 |
+
else:
|
| 67 |
+
rotary_pos_emb = torch.from_numpy(np.load("./rotary_pos_emb.npy")).to(dtype=hidden_states.dtype, device=hidden_states.device)
|
| 68 |
+
cu_seqlens = torch.from_numpy(np.load("./cu_seqlens.npy")).to(dtype=torch.int32, device=hidden_states.device)
|
| 69 |
+
|
| 70 |
+
for blk in self.blocks:
|
| 71 |
+
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
|
| 72 |
+
|
| 73 |
+
return self.merger(hidden_states)
|
| 74 |
+
return tmp
|
| 75 |
+
|
| 76 |
+
# 导出 Vison 部分所对应的 onnx 模型,假设输入是2x3x392x392->(28x28)x(3x2x14x14)
|
| 77 |
+
# pixel_values = torch.randn(784, 1176, device="cuda", dtype=torch.float32)
|
| 78 |
+
pixel_values = torch.randn(N, channel, H, W, device="cpu", dtype=torch.float32)
|
| 79 |
+
model.forward = export_onnx
|
| 80 |
+
model = model.to(torch.float32).eval()
|
| 81 |
+
if step == 1:
|
| 82 |
+
print("==========================================================")
|
| 83 |
+
print("Generating the rotary_pos_emb and cu_seqlens done.")
|
| 84 |
+
feature = model(pixel_values)
|
| 85 |
+
else:
|
| 86 |
+
print("==========================================================")
|
| 87 |
+
print(f"Exporting the vision part of {path} to onnx format.")
|
| 88 |
+
# os.makedirs(os.path.dirname(args.savepath), exist_ok=True)
|
| 89 |
+
torch.onnx.export(model, pixel_values, args.savepath, opset_version=17)
|
librkllmrt.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a7e6f87f07bbb08058cad4871cc74e8069a054fe4f6259b43c29a4738b0affdd
|
| 3 |
+
size 7461896
|
rkllm-convert.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from rkllm.api import RKLLM
|
| 2 |
+
|
| 3 |
+
modelpath = '.'
|
| 4 |
+
llm = RKLLM()
|
| 5 |
+
|
| 6 |
+
ret = llm.load_huggingface(model=modelpath, model_lora=None, device='cpu')
|
| 7 |
+
if ret != 0:
|
| 8 |
+
print('Load model failed!')
|
| 9 |
+
exit(ret)
|
| 10 |
+
|
| 11 |
+
qparams = None
|
| 12 |
+
ret = llm.build(do_quantization=True, optimization_level=1, quantized_dtype='w8a8',
|
| 13 |
+
quantized_algorithm='normal', target_platform='rk3588', num_npu_core=3, extra_qparams=qparams)
|
| 14 |
+
|
| 15 |
+
if ret != 0:
|
| 16 |
+
print('Build model failed!')
|
| 17 |
+
exit(ret)
|
| 18 |
+
|
| 19 |
+
# Export rkllm model
|
| 20 |
+
ret = llm.export_rkllm("./language_model.rkllm")
|
| 21 |
+
if ret != 0:
|
| 22 |
+
print('Export model failed!')
|
| 23 |
+
exit(ret)
|
rkllm_binding.py
ADDED
|
@@ -0,0 +1,873 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ctypes
|
| 2 |
+
import enum
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
# Define constants from the header
|
| 6 |
+
CPU0 = (1 << 0) # 0x01
|
| 7 |
+
CPU1 = (1 << 1) # 0x02
|
| 8 |
+
CPU2 = (1 << 2) # 0x04
|
| 9 |
+
CPU3 = (1 << 3) # 0x08
|
| 10 |
+
CPU4 = (1 << 4) # 0x10
|
| 11 |
+
CPU5 = (1 << 5) # 0x20
|
| 12 |
+
CPU6 = (1 << 6) # 0x40
|
| 13 |
+
CPU7 = (1 << 7) # 0x80
|
| 14 |
+
|
| 15 |
+
# --- Enums ---
|
| 16 |
+
class LLMCallState(enum.IntEnum):
|
| 17 |
+
RKLLM_RUN_NORMAL = 0
|
| 18 |
+
RKLLM_RUN_WAITING = 1
|
| 19 |
+
RKLLM_RUN_FINISH = 2
|
| 20 |
+
RKLLM_RUN_ERROR = 3
|
| 21 |
+
|
| 22 |
+
class RKLLMInputType(enum.IntEnum):
|
| 23 |
+
RKLLM_INPUT_PROMPT = 0
|
| 24 |
+
RKLLM_INPUT_TOKEN = 1
|
| 25 |
+
RKLLM_INPUT_EMBED = 2
|
| 26 |
+
RKLLM_INPUT_MULTIMODAL = 3
|
| 27 |
+
|
| 28 |
+
class RKLLMInferMode(enum.IntEnum):
|
| 29 |
+
RKLLM_INFER_GENERATE = 0
|
| 30 |
+
RKLLM_INFER_GET_LAST_HIDDEN_LAYER = 1
|
| 31 |
+
RKLLM_INFER_GET_LOGITS = 2
|
| 32 |
+
|
| 33 |
+
# --- Structures ---
|
| 34 |
+
class RKLLMExtendParam(ctypes.Structure):
|
| 35 |
+
base_domain_id: ctypes.c_int32
|
| 36 |
+
embed_flash: ctypes.c_int8
|
| 37 |
+
enabled_cpus_num: ctypes.c_int8
|
| 38 |
+
enabled_cpus_mask: ctypes.c_uint32
|
| 39 |
+
n_batch: ctypes.c_uint8
|
| 40 |
+
use_cross_attn: ctypes.c_int8
|
| 41 |
+
reserved: ctypes.c_uint8 * 104
|
| 42 |
+
|
| 43 |
+
_fields_ = [
|
| 44 |
+
("base_domain_id", ctypes.c_int32), # 基础域ID
|
| 45 |
+
("embed_flash", ctypes.c_int8), # 是否从闪存查询词嵌入向量(1启用,0禁用)
|
| 46 |
+
("enabled_cpus_num", ctypes.c_int8), # 推理启用的CPU数量
|
| 47 |
+
("enabled_cpus_mask", ctypes.c_uint32), # 指示启用哪些CPU的位掩码
|
| 48 |
+
("n_batch", ctypes.c_uint8), # 一次前向传播中并发处理的输入样本数,设置>1启用批量推理,默认为1
|
| 49 |
+
("use_cross_attn", ctypes.c_int8), # 是否启用交叉注意力(非零启用,0禁用)
|
| 50 |
+
("reserved", ctypes.c_uint8 * 104) # 保留字段
|
| 51 |
+
]
|
| 52 |
+
|
| 53 |
+
class RKLLMParam(ctypes.Structure):
|
| 54 |
+
model_path: ctypes.c_char_p
|
| 55 |
+
max_context_len: ctypes.c_int32
|
| 56 |
+
max_new_tokens: ctypes.c_int32
|
| 57 |
+
top_k: ctypes.c_int32
|
| 58 |
+
n_keep: ctypes.c_int32
|
| 59 |
+
top_p: ctypes.c_float
|
| 60 |
+
temperature: ctypes.c_float
|
| 61 |
+
repeat_penalty: ctypes.c_float
|
| 62 |
+
frequency_penalty: ctypes.c_float
|
| 63 |
+
presence_penalty: ctypes.c_float
|
| 64 |
+
mirostat: ctypes.c_int32
|
| 65 |
+
mirostat_tau: ctypes.c_float
|
| 66 |
+
mirostat_eta: ctypes.c_float
|
| 67 |
+
skip_special_token: ctypes.c_bool
|
| 68 |
+
is_async: ctypes.c_bool
|
| 69 |
+
img_start: ctypes.c_char_p
|
| 70 |
+
img_end: ctypes.c_char_p
|
| 71 |
+
img_content: ctypes.c_char_p
|
| 72 |
+
extend_param: RKLLMExtendParam
|
| 73 |
+
|
| 74 |
+
_fields_ = [
|
| 75 |
+
("model_path", ctypes.c_char_p), # 模型文件路径
|
| 76 |
+
("max_context_len", ctypes.c_int32), # 上下文窗口最大token数
|
| 77 |
+
("max_new_tokens", ctypes.c_int32), # 最大生成新token数
|
| 78 |
+
("top_k", ctypes.c_int32), # Top-K采样参数
|
| 79 |
+
("n_keep", ctypes.c_int32), # 上下文窗口移动时保留的kv缓存数量
|
| 80 |
+
("top_p", ctypes.c_float), # Top-P(nucleus)采样参数
|
| 81 |
+
("temperature", ctypes.c_float), # 采样温度,影响token选择的随机性
|
| 82 |
+
("repeat_penalty", ctypes.c_float), # 重复token惩罚
|
| 83 |
+
("frequency_penalty", ctypes.c_float), # 频繁token惩罚
|
| 84 |
+
("presence_penalty", ctypes.c_float), # 输入中已存在token的惩罚
|
| 85 |
+
("mirostat", ctypes.c_int32), # Mirostat采样策略标志(0表示禁用)
|
| 86 |
+
("mirostat_tau", ctypes.c_float), # Mirostat采样Tau参数
|
| 87 |
+
("mirostat_eta", ctypes.c_float), # Mirostat采样Eta参数
|
| 88 |
+
("skip_special_token", ctypes.c_bool), # 是否跳过特殊token
|
| 89 |
+
("is_async", ctypes.c_bool), # 是否异步推理
|
| 90 |
+
("img_start", ctypes.c_char_p), # 多模态输入中图像的起始位置
|
| 91 |
+
("img_end", ctypes.c_char_p), # 多模态输入中图像的结束位置
|
| 92 |
+
("img_content", ctypes.c_char_p), # 图像内容指针
|
| 93 |
+
("extend_param", RKLLMExtendParam) # 扩展参数
|
| 94 |
+
]
|
| 95 |
+
|
| 96 |
+
class RKLLMLoraAdapter(ctypes.Structure):
|
| 97 |
+
lora_adapter_path: ctypes.c_char_p
|
| 98 |
+
lora_adapter_name: ctypes.c_char_p
|
| 99 |
+
scale: ctypes.c_float
|
| 100 |
+
|
| 101 |
+
_fields_ = [
|
| 102 |
+
("lora_adapter_path", ctypes.c_char_p),
|
| 103 |
+
("lora_adapter_name", ctypes.c_char_p),
|
| 104 |
+
("scale", ctypes.c_float)
|
| 105 |
+
]
|
| 106 |
+
|
| 107 |
+
class RKLLMEmbedInput(ctypes.Structure):
|
| 108 |
+
embed: ctypes.POINTER(ctypes.c_float)
|
| 109 |
+
n_tokens: ctypes.c_size_t
|
| 110 |
+
|
| 111 |
+
_fields_ = [
|
| 112 |
+
("embed", ctypes.POINTER(ctypes.c_float)),
|
| 113 |
+
("n_tokens", ctypes.c_size_t)
|
| 114 |
+
]
|
| 115 |
+
|
| 116 |
+
class RKLLMTokenInput(ctypes.Structure):
|
| 117 |
+
input_ids: ctypes.POINTER(ctypes.c_int32)
|
| 118 |
+
n_tokens: ctypes.c_size_t
|
| 119 |
+
|
| 120 |
+
_fields_ = [
|
| 121 |
+
("input_ids", ctypes.POINTER(ctypes.c_int32)),
|
| 122 |
+
("n_tokens", ctypes.c_size_t)
|
| 123 |
+
]
|
| 124 |
+
|
| 125 |
+
class RKLLMMultiModelInput(ctypes.Structure):
|
| 126 |
+
prompt: ctypes.c_char_p
|
| 127 |
+
image_embed: ctypes.POINTER(ctypes.c_float)
|
| 128 |
+
n_image_tokens: ctypes.c_size_t
|
| 129 |
+
n_image: ctypes.c_size_t
|
| 130 |
+
image_width: ctypes.c_size_t
|
| 131 |
+
image_height: ctypes.c_size_t
|
| 132 |
+
|
| 133 |
+
_fields_ = [
|
| 134 |
+
("prompt", ctypes.c_char_p),
|
| 135 |
+
("image_embed", ctypes.POINTER(ctypes.c_float)),
|
| 136 |
+
("n_image_tokens", ctypes.c_size_t),
|
| 137 |
+
("n_image", ctypes.c_size_t),
|
| 138 |
+
("image_width", ctypes.c_size_t),
|
| 139 |
+
("image_height", ctypes.c_size_t)
|
| 140 |
+
]
|
| 141 |
+
|
| 142 |
+
class RKLLMCrossAttnParam(ctypes.Structure):
|
| 143 |
+
"""
|
| 144 |
+
交叉注意力参数结构体
|
| 145 |
+
|
| 146 |
+
该结构体用于在解码器中执行交叉注意力时使用。
|
| 147 |
+
它提供编码器输出(键/值缓存)、位置索引和注意力掩码。
|
| 148 |
+
|
| 149 |
+
- encoder_k_cache必须存储在连续内存中,布局为:
|
| 150 |
+
[num_layers][num_tokens][num_kv_heads][head_dim]
|
| 151 |
+
- encoder_v_cache必须存储在连续内存中,布局为:
|
| 152 |
+
[num_layers][num_kv_heads][head_dim][num_tokens]
|
| 153 |
+
"""
|
| 154 |
+
encoder_k_cache: ctypes.POINTER(ctypes.c_float)
|
| 155 |
+
encoder_v_cache: ctypes.POINTER(ctypes.c_float)
|
| 156 |
+
encoder_mask: ctypes.POINTER(ctypes.c_float)
|
| 157 |
+
encoder_pos: ctypes.POINTER(ctypes.c_int32)
|
| 158 |
+
num_tokens: ctypes.c_int
|
| 159 |
+
|
| 160 |
+
_fields_ = [
|
| 161 |
+
("encoder_k_cache", ctypes.POINTER(ctypes.c_float)), # 编码器键缓存指针(大小:num_layers * num_tokens * num_kv_heads * head_dim)
|
| 162 |
+
("encoder_v_cache", ctypes.POINTER(ctypes.c_float)), # 编码器值缓存指针(大小:num_layers * num_kv_heads * head_dim * num_tokens)
|
| 163 |
+
("encoder_mask", ctypes.POINTER(ctypes.c_float)), # 编码器注意力掩码指针(大小:num_tokens的数组)
|
| 164 |
+
("encoder_pos", ctypes.POINTER(ctypes.c_int32)), # 编码器token位置指针(大小:num_tokens的数组)
|
| 165 |
+
("num_tokens", ctypes.c_int) # 编码器序列中的token数量
|
| 166 |
+
]
|
| 167 |
+
|
| 168 |
+
class RKLLMPerfStat(ctypes.Structure):
|
| 169 |
+
"""
|
| 170 |
+
性能统计结构体
|
| 171 |
+
|
| 172 |
+
用于保存预填充和生成阶段的性能统计信息。
|
| 173 |
+
"""
|
| 174 |
+
prefill_time_ms: ctypes.c_float
|
| 175 |
+
prefill_tokens: ctypes.c_int
|
| 176 |
+
generate_time_ms: ctypes.c_float
|
| 177 |
+
generate_tokens: ctypes.c_int
|
| 178 |
+
memory_usage_mb: ctypes.c_float
|
| 179 |
+
|
| 180 |
+
_fields_ = [
|
| 181 |
+
("prefill_time_ms", ctypes.c_float), # 预填充阶段总耗时(毫秒)
|
| 182 |
+
("prefill_tokens", ctypes.c_int), # 预填充阶段处理的token数量
|
| 183 |
+
("generate_time_ms", ctypes.c_float), # 生成阶段总耗时(毫秒)
|
| 184 |
+
("generate_tokens", ctypes.c_int), # 生成阶段处理的token数量
|
| 185 |
+
("memory_usage_mb", ctypes.c_float) # 推理期间VmHWM常驻内存使用量(MB)
|
| 186 |
+
]
|
| 187 |
+
|
| 188 |
+
class _RKLLMInputUnion(ctypes.Union):
|
| 189 |
+
prompt_input: ctypes.c_char_p
|
| 190 |
+
embed_input: RKLLMEmbedInput
|
| 191 |
+
token_input: RKLLMTokenInput
|
| 192 |
+
multimodal_input: RKLLMMultiModelInput
|
| 193 |
+
|
| 194 |
+
_fields_ = [
|
| 195 |
+
("prompt_input", ctypes.c_char_p),
|
| 196 |
+
("embed_input", RKLLMEmbedInput),
|
| 197 |
+
("token_input", RKLLMTokenInput),
|
| 198 |
+
("multimodal_input", RKLLMMultiModelInput)
|
| 199 |
+
]
|
| 200 |
+
|
| 201 |
+
class RKLLMInput(ctypes.Structure):
|
| 202 |
+
"""
|
| 203 |
+
LLM输入结构体
|
| 204 |
+
|
| 205 |
+
通过联合体表示不同类型的LLM输入。
|
| 206 |
+
"""
|
| 207 |
+
role: ctypes.c_char_p
|
| 208 |
+
enable_thinking: ctypes.c_bool
|
| 209 |
+
input_type: ctypes.c_int
|
| 210 |
+
_union_data: _RKLLMInputUnion
|
| 211 |
+
|
| 212 |
+
_fields_ = [
|
| 213 |
+
("role", ctypes.c_char_p), # 消息角色:"user"(用户输入)、"tool"(函数结果)
|
| 214 |
+
("enable_thinking", ctypes.c_bool), # 控制Qwen3模型是否启用"思考模式"
|
| 215 |
+
("input_type", ctypes.c_int), # 枚举类型,指定输入类型(如prompt、token、embed、multimodal)
|
| 216 |
+
("_union_data", _RKLLMInputUnion) # 联合体数据
|
| 217 |
+
]
|
| 218 |
+
# Properties to make accessing union members easier
|
| 219 |
+
@property
|
| 220 |
+
def prompt_input(self) -> bytes: # Assuming c_char_p maps to bytes
|
| 221 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_PROMPT:
|
| 222 |
+
return self._union_data.prompt_input
|
| 223 |
+
raise AttributeError("Not a prompt input")
|
| 224 |
+
@prompt_input.setter
|
| 225 |
+
def prompt_input(self, value: bytes): # Assuming c_char_p maps to bytes
|
| 226 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_PROMPT:
|
| 227 |
+
self._union_data.prompt_input = value
|
| 228 |
+
else:
|
| 229 |
+
raise AttributeError("Not a prompt input")
|
| 230 |
+
@property
|
| 231 |
+
def embed_input(self) -> RKLLMEmbedInput:
|
| 232 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_EMBED:
|
| 233 |
+
return self._union_data.embed_input
|
| 234 |
+
raise AttributeError("Not an embed input")
|
| 235 |
+
@embed_input.setter
|
| 236 |
+
def embed_input(self, value: RKLLMEmbedInput):
|
| 237 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_EMBED:
|
| 238 |
+
self._union_data.embed_input = value
|
| 239 |
+
else:
|
| 240 |
+
raise AttributeError("Not an embed input")
|
| 241 |
+
|
| 242 |
+
@property
|
| 243 |
+
def token_input(self) -> RKLLMTokenInput:
|
| 244 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_TOKEN:
|
| 245 |
+
return self._union_data.token_input
|
| 246 |
+
raise AttributeError("Not a token input")
|
| 247 |
+
@token_input.setter
|
| 248 |
+
def token_input(self, value: RKLLMTokenInput):
|
| 249 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_TOKEN:
|
| 250 |
+
self._union_data.token_input = value
|
| 251 |
+
else:
|
| 252 |
+
raise AttributeError("Not a token input")
|
| 253 |
+
|
| 254 |
+
@property
|
| 255 |
+
def multimodal_input(self) -> RKLLMMultiModelInput:
|
| 256 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_MULTIMODAL:
|
| 257 |
+
return self._union_data.multimodal_input
|
| 258 |
+
raise AttributeError("Not a multimodal input")
|
| 259 |
+
@multimodal_input.setter
|
| 260 |
+
def multimodal_input(self, value: RKLLMMultiModelInput):
|
| 261 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_MULTIMODAL:
|
| 262 |
+
self._union_data.multimodal_input = value
|
| 263 |
+
else:
|
| 264 |
+
raise AttributeError("Not a multimodal input")
|
| 265 |
+
|
| 266 |
+
class RKLLMLoraParam(ctypes.Structure): # For inference
|
| 267 |
+
lora_adapter_name: ctypes.c_char_p
|
| 268 |
+
|
| 269 |
+
_fields_ = [
|
| 270 |
+
("lora_adapter_name", ctypes.c_char_p)
|
| 271 |
+
]
|
| 272 |
+
|
| 273 |
+
class RKLLMPromptCacheParam(ctypes.Structure): # For inference
|
| 274 |
+
save_prompt_cache: ctypes.c_int # bool-like
|
| 275 |
+
prompt_cache_path: ctypes.c_char_p
|
| 276 |
+
|
| 277 |
+
_fields_ = [
|
| 278 |
+
("save_prompt_cache", ctypes.c_int), # bool-like
|
| 279 |
+
("prompt_cache_path", ctypes.c_char_p)
|
| 280 |
+
]
|
| 281 |
+
|
| 282 |
+
class RKLLMInferParam(ctypes.Structure):
|
| 283 |
+
mode: ctypes.c_int
|
| 284 |
+
lora_params: ctypes.POINTER(RKLLMLoraParam)
|
| 285 |
+
prompt_cache_params: ctypes.POINTER(RKLLMPromptCacheParam)
|
| 286 |
+
keep_history: ctypes.c_int # bool-like
|
| 287 |
+
|
| 288 |
+
_fields_ = [
|
| 289 |
+
("mode", ctypes.c_int), # Enum will be passed as int, changed RKLLMInferMode to ctypes.c_int
|
| 290 |
+
("lora_params", ctypes.POINTER(RKLLMLoraParam)),
|
| 291 |
+
("prompt_cache_params", ctypes.POINTER(RKLLMPromptCacheParam)),
|
| 292 |
+
("keep_history", ctypes.c_int) # bool-like
|
| 293 |
+
]
|
| 294 |
+
|
| 295 |
+
class RKLLMResultLastHiddenLayer(ctypes.Structure):
|
| 296 |
+
hidden_states: ctypes.POINTER(ctypes.c_float)
|
| 297 |
+
embd_size: ctypes.c_int
|
| 298 |
+
num_tokens: ctypes.c_int
|
| 299 |
+
|
| 300 |
+
_fields_ = [
|
| 301 |
+
("hidden_states", ctypes.POINTER(ctypes.c_float)),
|
| 302 |
+
("embd_size", ctypes.c_int),
|
| 303 |
+
("num_tokens", ctypes.c_int)
|
| 304 |
+
]
|
| 305 |
+
|
| 306 |
+
class RKLLMResultLogits(ctypes.Structure):
|
| 307 |
+
logits: ctypes.POINTER(ctypes.c_float)
|
| 308 |
+
vocab_size: ctypes.c_int
|
| 309 |
+
num_tokens: ctypes.c_int
|
| 310 |
+
|
| 311 |
+
_fields_ = [
|
| 312 |
+
("logits", ctypes.POINTER(ctypes.c_float)),
|
| 313 |
+
("vocab_size", ctypes.c_int),
|
| 314 |
+
("num_tokens", ctypes.c_int)
|
| 315 |
+
]
|
| 316 |
+
|
| 317 |
+
class RKLLMResult(ctypes.Structure):
|
| 318 |
+
"""
|
| 319 |
+
LLM推理结果结构体
|
| 320 |
+
|
| 321 |
+
表示LLM推理的结果,包含生成的文本、token ID、隐藏层状态、logits和性能统计。
|
| 322 |
+
"""
|
| 323 |
+
text: ctypes.c_char_p
|
| 324 |
+
token_id: ctypes.c_int32
|
| 325 |
+
last_hidden_layer: RKLLMResultLastHiddenLayer
|
| 326 |
+
logits: RKLLMResultLogits
|
| 327 |
+
perf: RKLLMPerfStat
|
| 328 |
+
|
| 329 |
+
_fields_ = [
|
| 330 |
+
("text", ctypes.c_char_p), # 生成的文本结果
|
| 331 |
+
("token_id", ctypes.c_int32), # 生成的token ID
|
| 332 |
+
("last_hidden_layer", RKLLMResultLastHiddenLayer), # 最后一层的隐藏状态(如果请求的话)
|
| 333 |
+
("logits", RKLLMResultLogits), # 模型输出的logits
|
| 334 |
+
("perf", RKLLMPerfStat) # 性能统计(预填充和生成)
|
| 335 |
+
]
|
| 336 |
+
|
| 337 |
+
# --- Typedefs ---
|
| 338 |
+
LLMHandle = ctypes.c_void_p
|
| 339 |
+
|
| 340 |
+
# --- Callback Function Type ---
|
| 341 |
+
LLMResultCallback = ctypes.CFUNCTYPE(
|
| 342 |
+
ctypes.c_int, # 返回类型:int,表示处理状态
|
| 343 |
+
ctypes.POINTER(RKLLMResult), # LLM结果指针
|
| 344 |
+
ctypes.c_void_p, # 用户数据指针
|
| 345 |
+
ctypes.c_int # LLM调用状态(LLMCallState枚举值)
|
| 346 |
+
)
|
| 347 |
+
"""
|
| 348 |
+
回调函数类型定义
|
| 349 |
+
|
| 350 |
+
用于处理LLM结果的回调函数。
|
| 351 |
+
|
| 352 |
+
参数:
|
| 353 |
+
- result: 指向LLM结果的指针
|
| 354 |
+
- userdata: 回调的用户数据指针
|
| 355 |
+
- state: LLM调用状态(例如:完成、错误)
|
| 356 |
+
|
| 357 |
+
返回值:
|
| 358 |
+
- 0: 正常继续推理
|
| 359 |
+
- 1: 暂停推理。如果用户想要修改或干预结果(例如编辑输出、注入新提示),
|
| 360 |
+
返回1以暂停当前推理。稍后,使用更新的内容调用rkllm_run来恢复推理。
|
| 361 |
+
"""
|
| 362 |
+
|
| 363 |
+
class RKLLMRuntime:
|
| 364 |
+
def __init__(self, library_path="./librkllmrt.so"):
|
| 365 |
+
try:
|
| 366 |
+
self.lib = ctypes.CDLL(library_path)
|
| 367 |
+
except OSError as e:
|
| 368 |
+
raise OSError(f"Failed to load RKLLM library from {library_path}. "
|
| 369 |
+
f"Ensure it's in your LD_LIBRARY_PATH or provide the full path. Error: {e}")
|
| 370 |
+
self._setup_functions()
|
| 371 |
+
self.llm_handle = LLMHandle()
|
| 372 |
+
self._c_callback = None # To keep the callback object alive
|
| 373 |
+
|
| 374 |
+
def _setup_functions(self):
|
| 375 |
+
# RKLLMParam rkllm_createDefaultParam();
|
| 376 |
+
self.lib.rkllm_createDefaultParam.restype = RKLLMParam
|
| 377 |
+
self.lib.rkllm_createDefaultParam.argtypes = []
|
| 378 |
+
|
| 379 |
+
# int rkllm_init(LLMHandle* handle, RKLLMParam* param, LLMResultCallback callback);
|
| 380 |
+
self.lib.rkllm_init.restype = ctypes.c_int
|
| 381 |
+
self.lib.rkllm_init.argtypes = [
|
| 382 |
+
ctypes.POINTER(LLMHandle),
|
| 383 |
+
ctypes.POINTER(RKLLMParam),
|
| 384 |
+
LLMResultCallback
|
| 385 |
+
]
|
| 386 |
+
|
| 387 |
+
# int rkllm_load_lora(LLMHandle handle, RKLLMLoraAdapter* lora_adapter);
|
| 388 |
+
self.lib.rkllm_load_lora.restype = ctypes.c_int
|
| 389 |
+
self.lib.rkllm_load_lora.argtypes = [LLMHandle, ctypes.POINTER(RKLLMLoraAdapter)]
|
| 390 |
+
|
| 391 |
+
# int rkllm_load_prompt_cache(LLMHandle handle, const char* prompt_cache_path);
|
| 392 |
+
self.lib.rkllm_load_prompt_cache.restype = ctypes.c_int
|
| 393 |
+
self.lib.rkllm_load_prompt_cache.argtypes = [LLMHandle, ctypes.c_char_p]
|
| 394 |
+
|
| 395 |
+
# int rkllm_release_prompt_cache(LLMHandle handle);
|
| 396 |
+
self.lib.rkllm_release_prompt_cache.restype = ctypes.c_int
|
| 397 |
+
self.lib.rkllm_release_prompt_cache.argtypes = [LLMHandle]
|
| 398 |
+
|
| 399 |
+
# int rkllm_destroy(LLMHandle handle);
|
| 400 |
+
self.lib.rkllm_destroy.restype = ctypes.c_int
|
| 401 |
+
self.lib.rkllm_destroy.argtypes = [LLMHandle]
|
| 402 |
+
|
| 403 |
+
# int rkllm_run(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata);
|
| 404 |
+
self.lib.rkllm_run.restype = ctypes.c_int
|
| 405 |
+
self.lib.rkllm_run.argtypes = [
|
| 406 |
+
LLMHandle,
|
| 407 |
+
ctypes.POINTER(RKLLMInput),
|
| 408 |
+
ctypes.POINTER(RKLLMInferParam),
|
| 409 |
+
ctypes.c_void_p # userdata
|
| 410 |
+
]
|
| 411 |
+
|
| 412 |
+
# int rkllm_run_async(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata);
|
| 413 |
+
# Assuming async also takes userdata for the callback context
|
| 414 |
+
self.lib.rkllm_run_async.restype = ctypes.c_int
|
| 415 |
+
self.lib.rkllm_run_async.argtypes = [
|
| 416 |
+
LLMHandle,
|
| 417 |
+
ctypes.POINTER(RKLLMInput),
|
| 418 |
+
ctypes.POINTER(RKLLMInferParam),
|
| 419 |
+
ctypes.c_void_p # userdata
|
| 420 |
+
]
|
| 421 |
+
|
| 422 |
+
# int rkllm_abort(LLMHandle handle);
|
| 423 |
+
self.lib.rkllm_abort.restype = ctypes.c_int
|
| 424 |
+
self.lib.rkllm_abort.argtypes = [LLMHandle]
|
| 425 |
+
|
| 426 |
+
# int rkllm_is_running(LLMHandle handle);
|
| 427 |
+
self.lib.rkllm_is_running.restype = ctypes.c_int # 0 if running, non-zero otherwise
|
| 428 |
+
self.lib.rkllm_is_running.argtypes = [LLMHandle]
|
| 429 |
+
|
| 430 |
+
# int rkllm_clear_kv_cache(LLMHandle handle, int keep_system_prompt, int* start_pos, int* end_pos);
|
| 431 |
+
self.lib.rkllm_clear_kv_cache.restype = ctypes.c_int
|
| 432 |
+
self.lib.rkllm_clear_kv_cache.argtypes = [
|
| 433 |
+
LLMHandle,
|
| 434 |
+
ctypes.c_int,
|
| 435 |
+
ctypes.POINTER(ctypes.c_int), # start_pos
|
| 436 |
+
ctypes.POINTER(ctypes.c_int) # end_pos
|
| 437 |
+
]
|
| 438 |
+
|
| 439 |
+
# int rkllm_get_kv_cache_size(LLMHandle handle, int* cache_sizes);
|
| 440 |
+
self.lib.rkllm_get_kv_cache_size.restype = ctypes.c_int
|
| 441 |
+
self.lib.rkllm_get_kv_cache_size.argtypes = [LLMHandle, ctypes.POINTER(ctypes.c_int)]
|
| 442 |
+
|
| 443 |
+
# int rkllm_set_chat_template(LLMHandle handle, const char* system_prompt, const char* prompt_prefix, const char* prompt_postfix);
|
| 444 |
+
self.lib.rkllm_set_chat_template.restype = ctypes.c_int
|
| 445 |
+
self.lib.rkllm_set_chat_template.argtypes = [
|
| 446 |
+
LLMHandle,
|
| 447 |
+
ctypes.c_char_p,
|
| 448 |
+
ctypes.c_char_p,
|
| 449 |
+
ctypes.c_char_p
|
| 450 |
+
]
|
| 451 |
+
|
| 452 |
+
# int rkllm_set_function_tools(LLMHandle handle, const char* system_prompt, const char* tools, const char* tool_response_str);
|
| 453 |
+
self.lib.rkllm_set_function_tools.restype = ctypes.c_int
|
| 454 |
+
self.lib.rkllm_set_function_tools.argtypes = [
|
| 455 |
+
LLMHandle,
|
| 456 |
+
ctypes.c_char_p, # system_prompt
|
| 457 |
+
ctypes.c_char_p, # tools
|
| 458 |
+
ctypes.c_char_p # tool_response_str
|
| 459 |
+
]
|
| 460 |
+
|
| 461 |
+
# int rkllm_set_cross_attn_params(LLMHandle handle, RKLLMCrossAttnParam* cross_attn_params);
|
| 462 |
+
self.lib.rkllm_set_cross_attn_params.restype = ctypes.c_int
|
| 463 |
+
self.lib.rkllm_set_cross_attn_params.argtypes = [LLMHandle, ctypes.POINTER(RKLLMCrossAttnParam)]
|
| 464 |
+
|
| 465 |
+
def create_default_param(self) -> RKLLMParam:
|
| 466 |
+
"""Creates a default RKLLMParam structure."""
|
| 467 |
+
return self.lib.rkllm_createDefaultParam()
|
| 468 |
+
|
| 469 |
+
def init(self, param: RKLLMParam, callback_func) -> int:
|
| 470 |
+
"""
|
| 471 |
+
Initializes the LLM.
|
| 472 |
+
:param param: RKLLMParam structure.
|
| 473 |
+
:param callback_func: A Python function that matches the signature:
|
| 474 |
+
def my_callback(result_ptr, userdata_ptr, state_enum):
|
| 475 |
+
result = result_ptr.contents # RKLLMResult
|
| 476 |
+
# Process result
|
| 477 |
+
# userdata can be retrieved if passed during run, or ignored
|
| 478 |
+
# state = LLMCallState(state_enum)
|
| 479 |
+
:return: 0 for success, non-zero for failure.
|
| 480 |
+
"""
|
| 481 |
+
if not callable(callback_func):
|
| 482 |
+
raise ValueError("callback_func must be a callable Python function.")
|
| 483 |
+
|
| 484 |
+
# Keep a reference to the ctypes callback object to prevent it from being garbage collected
|
| 485 |
+
self._c_callback = LLMResultCallback(callback_func)
|
| 486 |
+
|
| 487 |
+
ret = self.lib.rkllm_init(ctypes.byref(self.llm_handle), ctypes.byref(param), self._c_callback)
|
| 488 |
+
if ret != 0:
|
| 489 |
+
raise RuntimeError(f"rkllm_init failed with error code {ret}")
|
| 490 |
+
return ret
|
| 491 |
+
|
| 492 |
+
def load_lora(self, lora_adapter: RKLLMLoraAdapter) -> int:
|
| 493 |
+
"""Loads a Lora adapter."""
|
| 494 |
+
ret = self.lib.rkllm_load_lora(self.llm_handle, ctypes.byref(lora_adapter))
|
| 495 |
+
if ret != 0:
|
| 496 |
+
raise RuntimeError(f"rkllm_load_lora failed with error code {ret}")
|
| 497 |
+
return ret
|
| 498 |
+
|
| 499 |
+
def load_prompt_cache(self, prompt_cache_path: str) -> int:
|
| 500 |
+
"""Loads a prompt cache from a file."""
|
| 501 |
+
c_path = prompt_cache_path.encode('utf-8')
|
| 502 |
+
ret = self.lib.rkllm_load_prompt_cache(self.llm_handle, c_path)
|
| 503 |
+
if ret != 0:
|
| 504 |
+
raise RuntimeError(f"rkllm_load_prompt_cache failed for {prompt_cache_path} with error code {ret}")
|
| 505 |
+
return ret
|
| 506 |
+
|
| 507 |
+
def release_prompt_cache(self) -> int:
|
| 508 |
+
"""Releases the prompt cache from memory."""
|
| 509 |
+
ret = self.lib.rkllm_release_prompt_cache(self.llm_handle)
|
| 510 |
+
if ret != 0:
|
| 511 |
+
raise RuntimeError(f"rkllm_release_prompt_cache failed with error code {ret}")
|
| 512 |
+
return ret
|
| 513 |
+
|
| 514 |
+
def destroy(self) -> int:
|
| 515 |
+
"""Destroys the LLM instance and releases resources."""
|
| 516 |
+
if self.llm_handle and self.llm_handle.value: # Check if handle is not NULL
|
| 517 |
+
ret = self.lib.rkllm_destroy(self.llm_handle)
|
| 518 |
+
self.llm_handle = LLMHandle() # Reset handle
|
| 519 |
+
if ret != 0:
|
| 520 |
+
# Don't raise here as it might be called in __del__
|
| 521 |
+
print(f"Warning: rkllm_destroy failed with error code {ret}")
|
| 522 |
+
return ret
|
| 523 |
+
return 0 # Already destroyed or not initialized
|
| 524 |
+
|
| 525 |
+
def run(self, rkllm_input: RKLLMInput, rkllm_infer_params: RKLLMInferParam, userdata=None) -> int:
|
| 526 |
+
"""Runs an LLM inference task synchronously."""
|
| 527 |
+
# userdata can be a ctypes.py_object if you want to pass Python objects,
|
| 528 |
+
# then cast to c_void_p. Or simply None.
|
| 529 |
+
if userdata is not None:
|
| 530 |
+
# Store the userdata object to keep it alive during the call
|
| 531 |
+
self._userdata_ref = userdata
|
| 532 |
+
c_userdata = ctypes.cast(ctypes.pointer(ctypes.py_object(userdata)), ctypes.c_void_p)
|
| 533 |
+
else:
|
| 534 |
+
c_userdata = None
|
| 535 |
+
ret = self.lib.rkllm_run(self.llm_handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), c_userdata)
|
| 536 |
+
if ret != 0:
|
| 537 |
+
raise RuntimeError(f"rkllm_run failed with error code {ret}")
|
| 538 |
+
return ret
|
| 539 |
+
|
| 540 |
+
def run_async(self, rkllm_input: RKLLMInput, rkllm_infer_params: RKLLMInferParam, userdata=None) -> int:
|
| 541 |
+
"""Runs an LLM inference task asynchronously."""
|
| 542 |
+
if userdata is not None:
|
| 543 |
+
# Store the userdata object to keep it alive during the call
|
| 544 |
+
self._userdata_ref = userdata
|
| 545 |
+
c_userdata = ctypes.cast(ctypes.pointer(ctypes.py_object(userdata)), ctypes.c_void_p)
|
| 546 |
+
else:
|
| 547 |
+
c_userdata = None
|
| 548 |
+
ret = self.lib.rkllm_run_async(self.llm_handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), c_userdata)
|
| 549 |
+
if ret != 0:
|
| 550 |
+
raise RuntimeError(f"rkllm_run_async failed with error code {ret}")
|
| 551 |
+
return ret
|
| 552 |
+
|
| 553 |
+
def abort(self) -> int:
|
| 554 |
+
"""Aborts an ongoing LLM task."""
|
| 555 |
+
ret = self.lib.rkllm_abort(self.llm_handle)
|
| 556 |
+
if ret != 0:
|
| 557 |
+
raise RuntimeError(f"rkllm_abort failed with error code {ret}")
|
| 558 |
+
return ret
|
| 559 |
+
|
| 560 |
+
def is_running(self) -> bool:
|
| 561 |
+
"""Checks if an LLM task is currently running. Returns True if running."""
|
| 562 |
+
# The C API returns 0 if running, non-zero otherwise.
|
| 563 |
+
# This is a bit counter-intuitive for a boolean "is_running".
|
| 564 |
+
return self.lib.rkllm_is_running(self.llm_handle) == 0
|
| 565 |
+
|
| 566 |
+
def clear_kv_cache(self, keep_system_prompt: bool, start_pos: list = None, end_pos: list = None) -> int:
|
| 567 |
+
"""
|
| 568 |
+
清除键值缓存
|
| 569 |
+
|
| 570 |
+
此函数用于清除部分或全部KV缓存。
|
| 571 |
+
|
| 572 |
+
参数:
|
| 573 |
+
- keep_system_prompt: 是否在缓存中保留系统提示(True保留,False清除)
|
| 574 |
+
如果提供了特定范围[start_pos, end_pos),此标志将被忽略
|
| 575 |
+
- start_pos: 要清除的KV缓存范围的起始位置数组(包含),每个批次一个
|
| 576 |
+
- end_pos: 要清除的KV缓存范围的结束位置数组(不包含),每个批次一个
|
| 577 |
+
如果start_pos和end_pos都设置为None,将清除整个缓存,keep_system_prompt将生效
|
| 578 |
+
如果start_pos[i] < end_pos[i],只有指定的范围会被清除,keep_system_prompt将被忽略
|
| 579 |
+
|
| 580 |
+
注意:start_pos或end_pos只有在keep_history == 0且生成已通过在回调中返回1暂停时才有效
|
| 581 |
+
|
| 582 |
+
返回:0表示缓存清除成功,非零表示失败
|
| 583 |
+
"""
|
| 584 |
+
# 准备C数组参数
|
| 585 |
+
c_start_pos = None
|
| 586 |
+
c_end_pos = None
|
| 587 |
+
|
| 588 |
+
if start_pos is not None and end_pos is not None:
|
| 589 |
+
if len(start_pos) != len(end_pos):
|
| 590 |
+
raise ValueError("start_pos和end_pos数组长度必须相同")
|
| 591 |
+
|
| 592 |
+
# 创建C数组
|
| 593 |
+
c_start_pos = (ctypes.c_int * len(start_pos))(*start_pos)
|
| 594 |
+
c_end_pos = (ctypes.c_int * len(end_pos))(*end_pos)
|
| 595 |
+
|
| 596 |
+
ret = self.lib.rkllm_clear_kv_cache(
|
| 597 |
+
self.llm_handle,
|
| 598 |
+
ctypes.c_int(1 if keep_system_prompt else 0),
|
| 599 |
+
c_start_pos,
|
| 600 |
+
c_end_pos
|
| 601 |
+
)
|
| 602 |
+
if ret != 0:
|
| 603 |
+
raise RuntimeError(f"rkllm_clear_kv_cache失败,错误代码:{ret}")
|
| 604 |
+
return ret
|
| 605 |
+
|
| 606 |
+
def set_chat_template(self, system_prompt: str, prompt_prefix: str, prompt_postfix: str) -> int:
|
| 607 |
+
"""Sets the chat template for the LLM."""
|
| 608 |
+
c_system = system_prompt.encode('utf-8') if system_prompt else b""
|
| 609 |
+
c_prefix = prompt_prefix.encode('utf-8') if prompt_prefix else b""
|
| 610 |
+
c_postfix = prompt_postfix.encode('utf-8') if prompt_postfix else b""
|
| 611 |
+
|
| 612 |
+
ret = self.lib.rkllm_set_chat_template(self.llm_handle, c_system, c_prefix, c_postfix)
|
| 613 |
+
if ret != 0:
|
| 614 |
+
raise RuntimeError(f"rkllm_set_chat_template failed with error code {ret}")
|
| 615 |
+
return ret
|
| 616 |
+
|
| 617 |
+
def get_kv_cache_size(self, n_batch: int) -> list:
|
| 618 |
+
"""
|
| 619 |
+
获取给定LLM句柄的键值缓存当前大小
|
| 620 |
+
|
| 621 |
+
此函数返回当前存储在模型KV缓存中的位置总数。
|
| 622 |
+
|
| 623 |
+
参数:
|
| 624 |
+
- n_batch: 批次数量,用于确定返回数组的大小
|
| 625 |
+
|
| 626 |
+
返回:
|
| 627 |
+
- list: 每个批次的缓存大小列表
|
| 628 |
+
"""
|
| 629 |
+
# 预分配数组以存储每个批次的缓存大小
|
| 630 |
+
cache_sizes = (ctypes.c_int * n_batch)()
|
| 631 |
+
|
| 632 |
+
ret = self.lib.rkllm_get_kv_cache_size(self.llm_handle, cache_sizes)
|
| 633 |
+
if ret != 0:
|
| 634 |
+
raise RuntimeError(f"rkllm_get_kv_cache_size失败,错误代码:{ret}")
|
| 635 |
+
|
| 636 |
+
# 转换为Python列表
|
| 637 |
+
return [cache_sizes[i] for i in range(n_batch)]
|
| 638 |
+
|
| 639 |
+
def set_function_tools(self, system_prompt: str, tools: str, tool_response_str: str) -> int:
|
| 640 |
+
"""
|
| 641 |
+
为LLM设置函数调用配置,包括系统提示、工具定义和工具响应token
|
| 642 |
+
|
| 643 |
+
参数:
|
| 644 |
+
- system_prompt: 定义语言模型上下文或行为的系统提示
|
| 645 |
+
- tools: JSON格式的字符串,定义可用的函数,包括它们的名称、描述和参数
|
| 646 |
+
- tool_response_str: 用于识别对话中函数调用结果的唯一标签。它作为标记标签,
|
| 647 |
+
允许分词器将工具输出与正常对话轮次分开识别
|
| 648 |
+
|
| 649 |
+
返回:0表示配置设置成功,非零表示错误
|
| 650 |
+
"""
|
| 651 |
+
c_system = system_prompt.encode('utf-8') if system_prompt else b""
|
| 652 |
+
c_tools = tools.encode('utf-8') if tools else b""
|
| 653 |
+
c_tool_response = tool_response_str.encode('utf-8') if tool_response_str else b""
|
| 654 |
+
|
| 655 |
+
ret = self.lib.rkllm_set_function_tools(self.llm_handle, c_system, c_tools, c_tool_response)
|
| 656 |
+
if ret != 0:
|
| 657 |
+
raise RuntimeError(f"rkllm_set_function_tools失败,错误代码:{ret}")
|
| 658 |
+
return ret
|
| 659 |
+
|
| 660 |
+
def set_cross_attn_params(self, cross_attn_params: RKLLMCrossAttnParam) -> int:
|
| 661 |
+
"""
|
| 662 |
+
为LLM解码器设置交叉注意力参数
|
| 663 |
+
|
| 664 |
+
参数:
|
| 665 |
+
- cross_attn_params: 包含用于交叉注意力的编码器相关输入数据的结构体
|
| 666 |
+
(详见RKLLMCrossAttnParam说明)
|
| 667 |
+
|
| 668 |
+
返回:0表示参数设置成功,非零表示错误
|
| 669 |
+
"""
|
| 670 |
+
ret = self.lib.rkllm_set_cross_attn_params(self.llm_handle, ctypes.byref(cross_attn_params))
|
| 671 |
+
if ret != 0:
|
| 672 |
+
raise RuntimeError(f"rkllm_set_cross_attn_params失败,错误代码:{ret}")
|
| 673 |
+
return ret
|
| 674 |
+
|
| 675 |
+
def __enter__(self):
|
| 676 |
+
return self
|
| 677 |
+
|
| 678 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 679 |
+
self.destroy()
|
| 680 |
+
|
| 681 |
+
def __del__(self):
|
| 682 |
+
self.destroy() # Ensure resources are freed if object is garbage collected
|
| 683 |
+
|
| 684 |
+
# --- Example Usage (Illustrative) ---
|
| 685 |
+
if __name__ == "__main__":
|
| 686 |
+
# This is a placeholder for how you might use it.
|
| 687 |
+
# You'll need a valid .rkllm model and librkllmrt.so in your path.
|
| 688 |
+
|
| 689 |
+
# Global list to store results from callback for demonstration
|
| 690 |
+
results_buffer = []
|
| 691 |
+
|
| 692 |
+
def my_python_callback(result_ptr, userdata_ptr, state_enum):
|
| 693 |
+
"""
|
| 694 |
+
回调函数,由C库调用来处理LLM结果
|
| 695 |
+
|
| 696 |
+
参数:
|
| 697 |
+
- result_ptr: 指向LLM结果的指针
|
| 698 |
+
- userdata_ptr: 用户数据指针
|
| 699 |
+
- state_enum: LLM调用状态枚举值
|
| 700 |
+
|
| 701 |
+
返回:
|
| 702 |
+
- 0: 继续推理
|
| 703 |
+
- 1: 暂停推理
|
| 704 |
+
"""
|
| 705 |
+
global results_buffer
|
| 706 |
+
state = LLMCallState(state_enum)
|
| 707 |
+
result = result_ptr.contents
|
| 708 |
+
|
| 709 |
+
current_text = ""
|
| 710 |
+
if result.text: # 检查char_p是否不为NULL
|
| 711 |
+
current_text = result.text.decode('utf-8', errors='ignore')
|
| 712 |
+
|
| 713 |
+
print(f"回调: State={state.name}, TokenID={result.token_id}, Text='{current_text}'")
|
| 714 |
+
|
| 715 |
+
# 显示性能统计信息
|
| 716 |
+
if result.perf.prefill_tokens > 0 or result.perf.generate_tokens > 0:
|
| 717 |
+
print(f" 性能统计: 预填充={result.perf.prefill_tokens}tokens/{result.perf.prefill_time_ms:.1f}ms, "
|
| 718 |
+
f"生成={result.perf.generate_tokens}tokens/{result.perf.generate_time_ms:.1f}ms, "
|
| 719 |
+
f"内存={result.perf.memory_usage_mb:.1f}MB")
|
| 720 |
+
|
| 721 |
+
results_buffer.append(current_text)
|
| 722 |
+
|
| 723 |
+
if state == LLMCallState.RKLLM_RUN_FINISH:
|
| 724 |
+
print("推理完成。")
|
| 725 |
+
elif state == LLMCallState.RKLLM_RUN_ERROR:
|
| 726 |
+
print("推理错误。")
|
| 727 |
+
|
| 728 |
+
# 返回0继续推理,返回1暂停推理
|
| 729 |
+
return 0
|
| 730 |
+
|
| 731 |
+
# --- Attempt to use the wrapper ---
|
| 732 |
+
try:
|
| 733 |
+
print("Initializing RKLLMRuntime...")
|
| 734 |
+
# Adjust library_path if librkllmrt.so is not in default search paths
|
| 735 |
+
# e.g., library_path="./path/to/librkllmrt.so"
|
| 736 |
+
rk_llm = RKLLMRuntime()
|
| 737 |
+
|
| 738 |
+
print("Creating default parameters...")
|
| 739 |
+
params = rk_llm.create_default_param()
|
| 740 |
+
|
| 741 |
+
# --- Configure parameters ---
|
| 742 |
+
# THIS IS CRITICAL: model_path must point to an actual .rkllm file
|
| 743 |
+
# For this example to run, you need a model file.
|
| 744 |
+
# Let's assume a dummy path for now, this will fail at init if not valid.
|
| 745 |
+
model_file = "dummy_model.rkllm"
|
| 746 |
+
if not os.path.exists(model_file):
|
| 747 |
+
print(f"Warning: Model file '{model_file}' does not exist. Init will likely fail.")
|
| 748 |
+
# Create a dummy file for the example to proceed further, though init will still fail
|
| 749 |
+
# with a real library unless it's a valid model.
|
| 750 |
+
with open(model_file, "w") as f:
|
| 751 |
+
f.write("dummy content")
|
| 752 |
+
|
| 753 |
+
params.model_path = model_file.encode('utf-8')
|
| 754 |
+
params.max_context_len = 512
|
| 755 |
+
params.max_new_tokens = 128
|
| 756 |
+
params.top_k = 1 # Greedy
|
| 757 |
+
params.temperature = 0.7
|
| 758 |
+
params.repeat_penalty = 1.1
|
| 759 |
+
# ... set other params as needed
|
| 760 |
+
|
| 761 |
+
print(f"Initializing LLM with model: {params.model_path.decode()}...")
|
| 762 |
+
# This will likely fail if dummy_model.rkllm is not a valid model recognized by the library
|
| 763 |
+
try:
|
| 764 |
+
rk_llm.init(params, my_python_callback)
|
| 765 |
+
print("LLM Initialized.")
|
| 766 |
+
except RuntimeError as e:
|
| 767 |
+
print(f"Error during LLM initialization: {e}")
|
| 768 |
+
print("This is expected if 'dummy_model.rkllm' is not a valid model.")
|
| 769 |
+
print("Replace 'dummy_model.rkllm' with a real model path to test further.")
|
| 770 |
+
exit()
|
| 771 |
+
|
| 772 |
+
|
| 773 |
+
# --- Prepare input ---
|
| 774 |
+
print("准备输入...")
|
| 775 |
+
rk_input = RKLLMInput()
|
| 776 |
+
rk_input.role = b"user" # 设置角色为用户输入
|
| 777 |
+
rk_input.enable_thinking = False # 禁用思考模式(适用于Qwen3模型)
|
| 778 |
+
rk_input.input_type = RKLLMInputType.RKLLM_INPUT_PROMPT
|
| 779 |
+
|
| 780 |
+
prompt_text = "将以下英文文本翻译成中文:'Hello, world!'"
|
| 781 |
+
c_prompt = prompt_text.encode('utf-8')
|
| 782 |
+
rk_input._union_data.prompt_input = c_prompt # 直接访问联合体成员
|
| 783 |
+
|
| 784 |
+
# --- Prepare inference parameters ---
|
| 785 |
+
print("Preparing inference parameters...")
|
| 786 |
+
infer_params = RKLLMInferParam()
|
| 787 |
+
infer_params.mode = RKLLMInferMode.RKLLM_INFER_GENERATE
|
| 788 |
+
infer_params.keep_history = 1 # True
|
| 789 |
+
# infer_params.lora_params = None # or set up RKLLMLoraParam if using LoRA
|
| 790 |
+
# infer_params.prompt_cache_params = None # or set up RKLLMPromptCacheParam
|
| 791 |
+
|
| 792 |
+
# --- Run inference ---
|
| 793 |
+
print(f"Running inference with prompt: '{prompt_text}'")
|
| 794 |
+
results_buffer.clear()
|
| 795 |
+
try:
|
| 796 |
+
rk_llm.run(rk_input, infer_params) # Userdata is None by default
|
| 797 |
+
print("\n--- Full Response ---")
|
| 798 |
+
print("".join(results_buffer))
|
| 799 |
+
print("---------------------\n")
|
| 800 |
+
except RuntimeError as e:
|
| 801 |
+
print(f"Error during LLM run: {e}")
|
| 802 |
+
|
| 803 |
+
|
| 804 |
+
# --- Example: Set chat template (if model supports it) ---
|
| 805 |
+
# print("Setting chat template...")
|
| 806 |
+
# try:
|
| 807 |
+
# rk_llm.set_chat_template("You are a helpful assistant.", "<user>: ", "<assistant>: ")
|
| 808 |
+
# print("Chat template set.")
|
| 809 |
+
# except RuntimeError as e:
|
| 810 |
+
# print(f"Error setting chat template: {e}")
|
| 811 |
+
|
| 812 |
+
# --- Example: Clear KV Cache ---
|
| 813 |
+
# print("Clearing KV cache (keeping system prompt if any)...")
|
| 814 |
+
# try:
|
| 815 |
+
# rk_llm.clear_kv_cache(keep_system_prompt=True)
|
| 816 |
+
# print("KV cache cleared.")
|
| 817 |
+
# except RuntimeError as e:
|
| 818 |
+
# print(f"Error clearing KV cache: {e}")
|
| 819 |
+
|
| 820 |
+
# --- 示例:获取KV缓存大小 ---
|
| 821 |
+
# print("获取KV缓存大小...")
|
| 822 |
+
# try:
|
| 823 |
+
# cache_sizes = rk_llm.get_kv_cache_size(n_batch=1) # 假设批次大小为1
|
| 824 |
+
# print(f"当前KV缓存大小: {cache_sizes}")
|
| 825 |
+
# except RuntimeError as e:
|
| 826 |
+
# print(f"获取KV缓存大小错误: {e}")
|
| 827 |
+
|
| 828 |
+
# --- 示例:设置函数工具 ---
|
| 829 |
+
# print("设置函数调用工具...")
|
| 830 |
+
# try:
|
| 831 |
+
# system_prompt = "你是一个有用的助手,可以调用提供的函��来帮助用户。"
|
| 832 |
+
# tools = '''[{
|
| 833 |
+
# "name": "get_weather",
|
| 834 |
+
# "description": "获取指定城市的天气信息",
|
| 835 |
+
# "parameters": {
|
| 836 |
+
# "type": "object",
|
| 837 |
+
# "properties": {
|
| 838 |
+
# "city": {"type": "string", "description": "城市名称"}
|
| 839 |
+
# },
|
| 840 |
+
# "required": ["city"]
|
| 841 |
+
# }
|
| 842 |
+
# }]'''
|
| 843 |
+
# tool_response_str = "<tool_response>"
|
| 844 |
+
# rk_llm.set_function_tools(system_prompt, tools, tool_response_str)
|
| 845 |
+
# print("函数工具设置成功。")
|
| 846 |
+
# except RuntimeError as e:
|
| 847 |
+
# print(f"设置函数工具错误: {e}")
|
| 848 |
+
|
| 849 |
+
# --- 示例:清除KV缓存(带范围参数) ---
|
| 850 |
+
# print("使用范围参数清除KV缓存...")
|
| 851 |
+
# try:
|
| 852 |
+
# # 清除位置10到20的缓存
|
| 853 |
+
# start_positions = [10] # 批次0的起始位置
|
| 854 |
+
# end_positions = [20] # 批次0的结束位置
|
| 855 |
+
# rk_llm.clear_kv_cache(keep_system_prompt=True, start_pos=start_positions, end_pos=end_positions)
|
| 856 |
+
# print("范围KV缓存清除完成。")
|
| 857 |
+
# except RuntimeError as e:
|
| 858 |
+
# print(f"清除范围KV缓存错误: {e}")
|
| 859 |
+
|
| 860 |
+
except OSError as e:
|
| 861 |
+
print(f"OSError: {e}. Could not load the RKLLM library.")
|
| 862 |
+
print("Please ensure 'librkllmrt.so' is in your LD_LIBRARY_PATH or provide the full path.")
|
| 863 |
+
except Exception as e:
|
| 864 |
+
print(f"An unexpected error occurred: {e}")
|
| 865 |
+
finally:
|
| 866 |
+
if 'rk_llm' in locals() and rk_llm.llm_handle and rk_llm.llm_handle.value:
|
| 867 |
+
print("Destroying LLM instance...")
|
| 868 |
+
rk_llm.destroy()
|
| 869 |
+
print("LLM instance destroyed.")
|
| 870 |
+
if os.path.exists(model_file) and model_file == "dummy_model.rkllm":
|
| 871 |
+
os.remove(model_file) # Clean up dummy file
|
| 872 |
+
|
| 873 |
+
print("Example finished.")
|
rknn_quickconvert.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding: utf-8
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
from rknn.api import RKNN
|
| 7 |
+
import argparse
|
| 8 |
+
|
| 9 |
+
def quick_convert(onnx_model_path):
|
| 10 |
+
rknn = RKNN(verbose=True)
|
| 11 |
+
rknn.config(quantized_algorithm='normal', quantized_method='channel', target_platform='rk3588', optimization_level=3, model_pruning=True)
|
| 12 |
+
ret = rknn.load_onnx(model=onnx_model_path)
|
| 13 |
+
if ret != 0:
|
| 14 |
+
print('Load model failed!')
|
| 15 |
+
exit(ret)
|
| 16 |
+
rknn.build(do_quantization=False, dataset=None, rknn_batch_size=None)
|
| 17 |
+
rknn_model_path = onnx_model_path.replace(".onnx", ".rknn")
|
| 18 |
+
ret = rknn.export_rknn(rknn_model_path)
|
| 19 |
+
if ret != 0:
|
| 20 |
+
print('Export RKNN model failed!')
|
| 21 |
+
exit(ret)
|
| 22 |
+
print(f"RKNN model exported to {rknn_model_path}")
|
| 23 |
+
|
| 24 |
+
def compile_only(onnx_model_path):
|
| 25 |
+
from rknn.api.rknn_compiler import RKNNCompiler, RKNNConfig, RKNNNormalize
|
| 26 |
+
from rknn.api.rknn_platform import support_soc_npu_target
|
| 27 |
+
from pprint import pprint
|
| 28 |
+
|
| 29 |
+
pprint(support_soc_npu_target)
|
| 30 |
+
|
| 31 |
+
RKNNCompiler.build(
|
| 32 |
+
onnx_model_path,
|
| 33 |
+
RKNNConfig(
|
| 34 |
+
target="v2",
|
| 35 |
+
request_type="float16",
|
| 36 |
+
optimize_options="compress=0, conv_eltwise_activation_fuse=1, global_fuse=1, multi-core-model-mode=1, output_optimize=1, layout_match=1, enable_argb_group=0, pipeline_fuse=1, enable_flash_attention=0",
|
| 37 |
+
verbose_level=4,
|
| 38 |
+
),
|
| 39 |
+
RKNNNormalize(
|
| 40 |
+
channel_means=[[0, 0, 0]],
|
| 41 |
+
channel_stds=[[1, 1, 1]],
|
| 42 |
+
channel_orders=[[0, 1, 2]], # not tested !!!
|
| 43 |
+
),
|
| 44 |
+
"out.rknn",
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
# "Usage: python rknn_quickconvert.py <onnx_model_path> [-c/--compile-only]"
|
| 48 |
+
if __name__ == "__main__":
|
| 49 |
+
parser = argparse.ArgumentParser()
|
| 50 |
+
parser.add_argument("onnx_model_path", type=str, help="Path to the ONNX model file")
|
| 51 |
+
parser.add_argument("-c", "--compile-only", action="store_true", help="Compile the model only")
|
| 52 |
+
args = parser.parse_args()
|
| 53 |
+
|
| 54 |
+
onnx_model_path = args.onnx_model_path
|
| 55 |
+
if args.compile_only:
|
| 56 |
+
compile_only(onnx_model_path)
|
| 57 |
+
else:
|
| 58 |
+
quick_convert(onnx_model_path)
|
run_rkllm.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import faulthandler
|
| 2 |
+
faulthandler.enable()
|
| 3 |
+
import sys
|
| 4 |
+
import os
|
| 5 |
+
os.environ["RKLLM_LOG_LEVEL"] = "1"
|
| 6 |
+
import ctypes
|
| 7 |
+
import argparse
|
| 8 |
+
import cv2
|
| 9 |
+
import numpy as np
|
| 10 |
+
import ztu_somemodelruntime_rknnlite2 as ort
|
| 11 |
+
from rkllm_binding import (
|
| 12 |
+
RKLLMRuntime,
|
| 13 |
+
RKLLMParam,
|
| 14 |
+
RKLLMInput,
|
| 15 |
+
RKLLMInferParam,
|
| 16 |
+
LLMCallState,
|
| 17 |
+
RKLLMInputType,
|
| 18 |
+
RKLLMInferMode,
|
| 19 |
+
RKLLMResult
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
# Constants
|
| 23 |
+
IMAGE_HEIGHT = 448
|
| 24 |
+
IMAGE_WIDTH = 448
|
| 25 |
+
|
| 26 |
+
def expand2square(img, background_color):
|
| 27 |
+
"""
|
| 28 |
+
Expand the image into a square and fill it with the specified background color.
|
| 29 |
+
"""
|
| 30 |
+
height, width, _ = img.shape
|
| 31 |
+
if width == height:
|
| 32 |
+
return img.copy()
|
| 33 |
+
|
| 34 |
+
size = max(width, height)
|
| 35 |
+
square_img = np.full((size, size, 3), background_color, dtype=np.uint8)
|
| 36 |
+
|
| 37 |
+
x_offset = (size - width) // 2
|
| 38 |
+
y_offset = (size - height) // 2
|
| 39 |
+
|
| 40 |
+
square_img[y_offset:y_offset+height, x_offset:x_offset+width] = img
|
| 41 |
+
return square_img
|
| 42 |
+
|
| 43 |
+
def llm_callback(result_ptr, userdata_ptr, state_enum):
|
| 44 |
+
"""
|
| 45 |
+
Callback function to handle LLM results.
|
| 46 |
+
"""
|
| 47 |
+
state = LLMCallState(state_enum)
|
| 48 |
+
result = result_ptr.contents
|
| 49 |
+
|
| 50 |
+
if state == LLMCallState.RKLLM_RUN_NORMAL:
|
| 51 |
+
if result.text:
|
| 52 |
+
print(result.text.decode('utf-8', errors='ignore'), end='', flush=True)
|
| 53 |
+
elif state == LLMCallState.RKLLM_RUN_FINISH:
|
| 54 |
+
print("\n", flush=True)
|
| 55 |
+
elif state == LLMCallState.RKLLM_RUN_ERROR:
|
| 56 |
+
print("\nrun error", flush=True)
|
| 57 |
+
|
| 58 |
+
return 0
|
| 59 |
+
|
| 60 |
+
def main():
|
| 61 |
+
parser = argparse.ArgumentParser(
|
| 62 |
+
description="Run RKLLM visual language model inference based on the C++ example."
|
| 63 |
+
)
|
| 64 |
+
parser.add_argument("image_path", type=str, help="Path to the input image.")
|
| 65 |
+
parser.add_argument("encoder_model_path", type=str, help="Path to the ONNX vision encoder model.")
|
| 66 |
+
parser.add_argument("llm_model_path", type=str, help="Path to the .rkllm language model.")
|
| 67 |
+
parser.add_argument("max_new_tokens", type=int, help="Maximum number of new tokens to generate.")
|
| 68 |
+
parser.add_argument("max_context_len", type=int, help="Maximum context length.")
|
| 69 |
+
# The rknn_core_num is not directly used by onnxruntime in the same way,
|
| 70 |
+
# but we keep it for API consistency with the C++ example.
|
| 71 |
+
# ONNX Runtime will manage its own threading and execution providers.
|
| 72 |
+
parser.add_argument("rknn_core_num", type=int, help="Core number for RKNN (informational for this script).")
|
| 73 |
+
|
| 74 |
+
args = parser.parse_args()
|
| 75 |
+
|
| 76 |
+
# --- 1. Initialize Image Encoder (ONNX Runtime) ---
|
| 77 |
+
print("Initializing ONNX Runtime for vision encoder...")
|
| 78 |
+
try:
|
| 79 |
+
ort_session = ort.InferenceSession(args.encoder_model_path)
|
| 80 |
+
except Exception as e:
|
| 81 |
+
print(f"Failed to load ONNX model: {e}")
|
| 82 |
+
sys.exit(1)
|
| 83 |
+
print("Vision encoder loaded successfully.")
|
| 84 |
+
|
| 85 |
+
input_name = ort_session.get_inputs()[0].name
|
| 86 |
+
output_name = ort_session.get_outputs()[0].name
|
| 87 |
+
print(f"ONNX Input: {input_name}, ONNX Output: {output_name}")
|
| 88 |
+
|
| 89 |
+
# --- 2. Initialize LLM ---
|
| 90 |
+
print("Initializing RKLLM Runtime...")
|
| 91 |
+
rk_llm = RKLLMRuntime()
|
| 92 |
+
param = rk_llm.create_default_param()
|
| 93 |
+
|
| 94 |
+
param.model_path = args.llm_model_path.encode('utf-8')
|
| 95 |
+
param.top_k = 1
|
| 96 |
+
param.max_new_tokens = args.max_new_tokens
|
| 97 |
+
param.max_context_len = args.max_context_len
|
| 98 |
+
param.skip_special_token = True
|
| 99 |
+
param.img_start = b"<|vision_start|>"
|
| 100 |
+
param.img_end = b"<|vision_end|>"
|
| 101 |
+
param.img_content = b"<|image_pad|>"
|
| 102 |
+
param.extend_param.base_domain_id = 1
|
| 103 |
+
|
| 104 |
+
try:
|
| 105 |
+
rk_llm.init(param, llm_callback)
|
| 106 |
+
print("RKLLM initialized successfully.")
|
| 107 |
+
except RuntimeError as e:
|
| 108 |
+
print(f"RKLLM init failed: {e}")
|
| 109 |
+
sys.exit(1)
|
| 110 |
+
|
| 111 |
+
# --- 3. Image Preprocessing ---
|
| 112 |
+
print("Preprocessing image...")
|
| 113 |
+
img = cv2.imread(args.image_path)
|
| 114 |
+
if img is None:
|
| 115 |
+
print(f"Failed to read image from {args.image_path}")
|
| 116 |
+
sys.exit(1)
|
| 117 |
+
|
| 118 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 119 |
+
|
| 120 |
+
background_color = (127.5, 127.5, 127.5) # As per C++ example
|
| 121 |
+
square_img = expand2square(img, background_color)
|
| 122 |
+
resized_img = cv2.resize(square_img, (IMAGE_WIDTH, IMAGE_HEIGHT), interpolation=cv2.INTER_LINEAR)
|
| 123 |
+
|
| 124 |
+
# Normalize and prepare for ONNX model
|
| 125 |
+
input_tensor = resized_img.astype(np.float32)
|
| 126 |
+
# Normalize using preprocessor config values
|
| 127 |
+
input_tensor = (input_tensor / 255.0 - np.array([0.48145466, 0.4578275, 0.40821073])) / np.array([0.26862954, 0.26130258, 0.27577711])
|
| 128 |
+
# Convert to NCHW format
|
| 129 |
+
input_tensor = np.transpose(input_tensor, (2, 0, 1)) # HWC -> CHW
|
| 130 |
+
input_tensor = np.expand_dims(input_tensor, axis=0) # Add batch dimension -> (1, 3, 392, 392)
|
| 131 |
+
|
| 132 |
+
# --- 4. Run Image Encoder ---
|
| 133 |
+
print("Running vision encoder...")
|
| 134 |
+
try:
|
| 135 |
+
img_vec_output = ort_session.run([output_name], {input_name: input_tensor.astype(np.float32)})[0]
|
| 136 |
+
# The output from C++ is a flat float array. Let's flatten the ONNX output.
|
| 137 |
+
img_vec = img_vec_output.flatten().astype(np.float32)
|
| 138 |
+
|
| 139 |
+
except Exception as e:
|
| 140 |
+
print(f"Failed to run vision encoder inference: {e}")
|
| 141 |
+
rk_llm.destroy()
|
| 142 |
+
sys.exit(1)
|
| 143 |
+
|
| 144 |
+
print("Image encoded successfully.")
|
| 145 |
+
|
| 146 |
+
# --- 5. Interactive Chat Loop ---
|
| 147 |
+
rkllm_infer_params = RKLLMInferParam()
|
| 148 |
+
rkllm_infer_params.mode = RKLLMInferMode.RKLLM_INFER_GENERATE
|
| 149 |
+
rkllm_infer_params.keep_history = 0
|
| 150 |
+
|
| 151 |
+
# Set chat template
|
| 152 |
+
rk_llm.set_chat_template(
|
| 153 |
+
system_prompt="<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n",
|
| 154 |
+
prompt_prefix="<|im_start|>user\n",
|
| 155 |
+
prompt_postfix="<|im_end|>\n<|im_start|>assistant\n"
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
pre_input = [
|
| 159 |
+
"Picture 1: <image> What is in the image?",
|
| 160 |
+
"Picture 1: <image> 这张图片中有什么?"
|
| 161 |
+
]
|
| 162 |
+
print("\n**********************可输入以下问题对应序号获取回答/或自定义输入********************\n")
|
| 163 |
+
for i, p in enumerate(pre_input):
|
| 164 |
+
print(f"[{i}] {p}")
|
| 165 |
+
print("\n*************************************************************************\n")
|
| 166 |
+
|
| 167 |
+
try:
|
| 168 |
+
while True:
|
| 169 |
+
print("\nuser: ", end="", flush=True)
|
| 170 |
+
input_str = sys.stdin.readline().strip()
|
| 171 |
+
|
| 172 |
+
if not input_str:
|
| 173 |
+
continue
|
| 174 |
+
if input_str == "exit":
|
| 175 |
+
break
|
| 176 |
+
if input_str == "clear":
|
| 177 |
+
try:
|
| 178 |
+
rk_llm.clear_kv_cache(keep_system_prompt=True)
|
| 179 |
+
print("KV cache cleared.")
|
| 180 |
+
except RuntimeError as e:
|
| 181 |
+
print(f"Failed to clear KV cache: {e}")
|
| 182 |
+
continue
|
| 183 |
+
|
| 184 |
+
try:
|
| 185 |
+
idx = int(input_str)
|
| 186 |
+
if 0 <= idx < len(pre_input):
|
| 187 |
+
input_str = pre_input[idx]
|
| 188 |
+
print(input_str)
|
| 189 |
+
except (ValueError, IndexError):
|
| 190 |
+
pass # Use the raw string if not a valid index
|
| 191 |
+
|
| 192 |
+
rkllm_input = RKLLMInput()
|
| 193 |
+
rkllm_input.role = b"user"
|
| 194 |
+
|
| 195 |
+
print("robot: ", end="", flush=True)
|
| 196 |
+
|
| 197 |
+
if "<image>" in input_str:
|
| 198 |
+
rkllm_input.input_type = RKLLMInputType.RKLLM_INPUT_MULTIMODAL
|
| 199 |
+
|
| 200 |
+
# Setup multimodal input
|
| 201 |
+
rkllm_input.multimodal_input.prompt = input_str.encode('utf-8')
|
| 202 |
+
rkllm_input.multimodal_input.image_embed = img_vec.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
|
| 203 |
+
rkllm_input.multimodal_input.n_image_tokens = img_vec_output.shape[0]
|
| 204 |
+
print("n_image_tokens: ", rkllm_input.multimodal_input.n_image_tokens)
|
| 205 |
+
rkllm_input.multimodal_input.n_image = 1
|
| 206 |
+
rkllm_input.multimodal_input.image_height = IMAGE_HEIGHT
|
| 207 |
+
rkllm_input.multimodal_input.image_width = IMAGE_WIDTH
|
| 208 |
+
else:
|
| 209 |
+
rkllm_input.input_type = RKLLMInputType.RKLLM_INPUT_PROMPT
|
| 210 |
+
rkllm_input.prompt_input = input_str.encode('utf-8')
|
| 211 |
+
|
| 212 |
+
try:
|
| 213 |
+
rk_llm.run(rkllm_input, rkllm_infer_params)
|
| 214 |
+
except RuntimeError as e:
|
| 215 |
+
print(f"\nError during rkllm_run: {e}")
|
| 216 |
+
|
| 217 |
+
except KeyboardInterrupt:
|
| 218 |
+
print("\nExiting...")
|
| 219 |
+
finally:
|
| 220 |
+
print("Releasing resources...")
|
| 221 |
+
rk_llm.destroy()
|
| 222 |
+
print("RKLLM instance destroyed.")
|
| 223 |
+
|
| 224 |
+
if __name__ == "__main__":
|
| 225 |
+
main()
|
| 226 |
+
|
ztu_somemodelruntime_rknnlite2.py
ADDED
|
@@ -0,0 +1,569 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 模块级常量和函数
|
| 2 |
+
from rknnlite.api import RKNNLite
|
| 3 |
+
import numpy as np
|
| 4 |
+
import os
|
| 5 |
+
import warnings
|
| 6 |
+
import logging
|
| 7 |
+
from typing import List, Dict, Union, Optional
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
import onnxruntime as ort
|
| 11 |
+
HAS_ORT = True
|
| 12 |
+
except ImportError:
|
| 13 |
+
HAS_ORT = False
|
| 14 |
+
warnings.warn("onnxruntime未安装,只能使用RKNN后端", ImportWarning)
|
| 15 |
+
|
| 16 |
+
# 配置日志
|
| 17 |
+
logger = logging.getLogger("somemodelruntime_rknnlite2")
|
| 18 |
+
logger.setLevel(logging.ERROR) # 默认只输出错误信息
|
| 19 |
+
if not logger.handlers:
|
| 20 |
+
handler = logging.StreamHandler()
|
| 21 |
+
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
|
| 22 |
+
logger.addHandler(handler)
|
| 23 |
+
|
| 24 |
+
# ONNX Runtime日志级别到Python logging级别的映射
|
| 25 |
+
_LOGGING_LEVEL_MAP = {
|
| 26 |
+
0: logging.DEBUG, # Verbose
|
| 27 |
+
1: logging.INFO, # Info
|
| 28 |
+
2: logging.WARNING, # Warning
|
| 29 |
+
3: logging.ERROR, # Error
|
| 30 |
+
4: logging.CRITICAL # Fatal
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
# 检查环境变量中的日志级别设置
|
| 34 |
+
try:
|
| 35 |
+
env_log_level = os.getenv('ZTU_MODELRT_RKNNL2_LOG_LEVEL')
|
| 36 |
+
if env_log_level is not None:
|
| 37 |
+
log_level = int(env_log_level)
|
| 38 |
+
if log_level in _LOGGING_LEVEL_MAP:
|
| 39 |
+
logger.setLevel(_LOGGING_LEVEL_MAP[log_level])
|
| 40 |
+
logger.info(f"从环境变量设置日志级别: {log_level}")
|
| 41 |
+
else:
|
| 42 |
+
logger.warning(f"环境变量ZTU_MODELRT_RKNNL2_LOG_LEVEL的值无效: {log_level}, 应该是0-4之间的整数")
|
| 43 |
+
except ValueError:
|
| 44 |
+
logger.warning(f"环境变量ZTU_MODELRT_RKNNL2_LOG_LEVEL的值无效: {env_log_level}, 应该是0-4之间的整数")
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def set_default_logger_severity(level: int) -> None:
|
| 48 |
+
"""
|
| 49 |
+
Sets the default logging severity. 0:Verbose, 1:Info, 2:Warning, 3:Error, 4:Fatal
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
level: 日志级别(0-4)
|
| 53 |
+
"""
|
| 54 |
+
if level not in _LOGGING_LEVEL_MAP:
|
| 55 |
+
raise ValueError(f"无效的日志级别: {level}, 应该是0-4之间的整数")
|
| 56 |
+
logger.setLevel(_LOGGING_LEVEL_MAP[level])
|
| 57 |
+
|
| 58 |
+
def set_default_logger_verbosity(level: int) -> None:
|
| 59 |
+
"""
|
| 60 |
+
Sets the default logging verbosity level. To activate the verbose log,
|
| 61 |
+
you need to set the default logging severity to 0:Verbose level.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
level: 日志级别(0-4)
|
| 65 |
+
"""
|
| 66 |
+
set_default_logger_severity(level)
|
| 67 |
+
|
| 68 |
+
# RKNN tensor type到numpy dtype的映射
|
| 69 |
+
RKNN_DTYPE_MAP = {
|
| 70 |
+
0: np.float32, # RKNN_TENSOR_FLOAT32
|
| 71 |
+
1: np.float16, # RKNN_TENSOR_FLOAT16
|
| 72 |
+
2: np.int8, # RKNN_TENSOR_INT8
|
| 73 |
+
3: np.uint8, # RKNN_TENSOR_UINT8
|
| 74 |
+
4: np.int16, # RKNN_TENSOR_INT16
|
| 75 |
+
5: np.uint16, # RKNN_TENSOR_UINT16
|
| 76 |
+
6: np.int32, # RKNN_TENSOR_INT32
|
| 77 |
+
7: np.uint32, # RKNN_TENSOR_UINT32
|
| 78 |
+
8: np.int64, # RKNN_TENSOR_INT64
|
| 79 |
+
9: bool, # RKNN_TENSOR_BOOL
|
| 80 |
+
10: np.int8, # RKNN_TENSOR_INT4 (用int8表示)
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
def get_available_providers() -> List[str]:
|
| 84 |
+
"""
|
| 85 |
+
获取可用的设备提供者列表(为保持接口兼容性的占位函数)
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
list: 可用的设备提供者列表,总是返回["CPUExecutionProvider", "somemodelruntime_rknnlite2_ExecutionProvider"]
|
| 89 |
+
"""
|
| 90 |
+
return ["CPUExecutionProvider", "somemodelruntime_rknnlite2_ExecutionProvider"]
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def get_device() -> str:
|
| 94 |
+
"""
|
| 95 |
+
获取当前设备
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
str: 当前设备
|
| 99 |
+
"""
|
| 100 |
+
return "RKNN2"
|
| 101 |
+
|
| 102 |
+
def get_version_info() -> Dict[str, str]:
|
| 103 |
+
"""
|
| 104 |
+
获取版本信息
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
dict: 包含API和驱动版本信息的字典
|
| 108 |
+
"""
|
| 109 |
+
runtime = RKNNLite()
|
| 110 |
+
version = runtime.get_sdk_version()
|
| 111 |
+
return {
|
| 112 |
+
"api_version": version.split('\n')[2].split(': ')[1].split(' ')[0],
|
| 113 |
+
"driver_version": version.split('\n')[3].split(': ')[1]
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
class IOTensor:
|
| 117 |
+
"""输入/输出张量的信息封装类"""
|
| 118 |
+
def __init__(self, name, shape, type=None):
|
| 119 |
+
self.name = name.decode() if isinstance(name, bytes) else name
|
| 120 |
+
self.shape = shape
|
| 121 |
+
self.type = type
|
| 122 |
+
|
| 123 |
+
def __str__(self):
|
| 124 |
+
return f"IOTensor(name='{self.name}', shape={self.shape}, type={self.type})"
|
| 125 |
+
|
| 126 |
+
class SessionOptions:
|
| 127 |
+
"""会话选项类"""
|
| 128 |
+
def __init__(self):
|
| 129 |
+
self.enable_profiling = False # 是否使用性能分析
|
| 130 |
+
self.intra_op_num_threads = 1 # 设置RKNN的线程数, 对应rknn的core_mask
|
| 131 |
+
self.log_severity_level = -1 # 另一个设置日志级别的参数
|
| 132 |
+
self.log_verbosity_level = -1 # 另一个设置日志级别的参数
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class InferenceSession:
|
| 136 |
+
"""
|
| 137 |
+
RKNNLite运行时封装类,API风格类似ONNX Runtime
|
| 138 |
+
"""
|
| 139 |
+
|
| 140 |
+
def __new__(cls, model_path: str, sess_options: Optional[SessionOptions] = None, **kwargs):
|
| 141 |
+
processed_path = InferenceSession._process_model_path(model_path, sess_options)
|
| 142 |
+
if isinstance(processed_path, str) and processed_path.lower().endswith('.onnx'):
|
| 143 |
+
logger.info("使用ONNX Runtime加载模型")
|
| 144 |
+
if not HAS_ORT:
|
| 145 |
+
raise RuntimeError("未安装onnxruntime,无法加载ONNX模型")
|
| 146 |
+
return ort.InferenceSession(processed_path, sess_options=sess_options, **kwargs)
|
| 147 |
+
else:
|
| 148 |
+
# 如果不是 ONNX 模型,则调用父类的 __new__ 创建 InferenceSession 实例
|
| 149 |
+
instance = super().__new__(cls)
|
| 150 |
+
# 保存处理后的路径
|
| 151 |
+
instance._processed_path = processed_path
|
| 152 |
+
return instance
|
| 153 |
+
|
| 154 |
+
def __init__(self, model_path: str, sess_options: Optional[SessionOptions] = None, **kwargs):
|
| 155 |
+
"""
|
| 156 |
+
初始化运行时并加载模型
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
model_path: 模型文件路径(.rknn或.onnx)
|
| 160 |
+
sess_options: 会话选项
|
| 161 |
+
**kwargs: 其他初始化参数
|
| 162 |
+
"""
|
| 163 |
+
options = sess_options or SessionOptions()
|
| 164 |
+
|
| 165 |
+
# 只在未设置环境变量时使用SessionOptions中的日志级别
|
| 166 |
+
if os.getenv('ZTU_MODELRT_RKNNL2_LOG_LEVEL') is None:
|
| 167 |
+
if options.log_severity_level != -1:
|
| 168 |
+
set_default_logger_severity(options.log_severity_level)
|
| 169 |
+
if options.log_verbosity_level != -1:
|
| 170 |
+
set_default_logger_verbosity(options.log_verbosity_level)
|
| 171 |
+
|
| 172 |
+
# 使用__new__中处理好的路径
|
| 173 |
+
model_path = getattr(self, '_processed_path', model_path)
|
| 174 |
+
if isinstance(model_path, str) and model_path.lower().endswith('.onnx'):
|
| 175 |
+
# 避免重复加载 ONNX 模型
|
| 176 |
+
return
|
| 177 |
+
|
| 178 |
+
# ... 现有的 RKNN 模型加载和初始化代码 ...
|
| 179 |
+
self.model_path = model_path
|
| 180 |
+
if not os.path.exists(self.model_path):
|
| 181 |
+
logger.error(f"模型文件不存在: {self.model_path}")
|
| 182 |
+
raise FileNotFoundError(f"模型文件不存在: {self.model_path}")
|
| 183 |
+
|
| 184 |
+
self.runtime = RKNNLite(verbose=options.enable_profiling)
|
| 185 |
+
|
| 186 |
+
logger.debug(f"正在加载模型: {self.model_path}")
|
| 187 |
+
ret = self.runtime.load_rknn(self.model_path)
|
| 188 |
+
if ret != 0:
|
| 189 |
+
logger.error(f"加载RKNN模型失败: {self.model_path}")
|
| 190 |
+
raise RuntimeError(f'加载RKNN模型失败: {self.model_path}')
|
| 191 |
+
logger.debug("模型加载成功")
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
if options.intra_op_num_threads == 1:
|
| 195 |
+
core_mask = RKNNLite.NPU_CORE_AUTO
|
| 196 |
+
elif options.intra_op_num_threads == 2:
|
| 197 |
+
core_mask = RKNNLite.NPU_CORE_0_1
|
| 198 |
+
elif options.intra_op_num_threads == 3:
|
| 199 |
+
core_mask = RKNNLite.NPU_CORE_0_1_2
|
| 200 |
+
else:
|
| 201 |
+
raise ValueError(f"intra_op_num_threads的值无效: {options.intra_op_num_threads}, 只能是1,2或3")
|
| 202 |
+
|
| 203 |
+
logger.debug("正在初始化运行时环境")
|
| 204 |
+
ret = self.runtime.init_runtime(core_mask=core_mask)
|
| 205 |
+
if ret != 0:
|
| 206 |
+
logger.error("初始化运行时环境失败")
|
| 207 |
+
raise RuntimeError('初始化运行时环境失败')
|
| 208 |
+
logger.debug("运行时环境初始化成功")
|
| 209 |
+
|
| 210 |
+
self._init_io_info()
|
| 211 |
+
self.options = options
|
| 212 |
+
|
| 213 |
+
def get_performance_info(self) -> Dict[str, float]:
|
| 214 |
+
"""
|
| 215 |
+
获取性能信息
|
| 216 |
+
|
| 217 |
+
Returns:
|
| 218 |
+
dict: 包含性能信息的字典
|
| 219 |
+
"""
|
| 220 |
+
if not self.options.perf_debug:
|
| 221 |
+
raise RuntimeError("性能分析未启用,请在SessionOptions中设置perf_debug=True")
|
| 222 |
+
|
| 223 |
+
perf = self.runtime.rknn_runtime.get_run_perf()
|
| 224 |
+
return {
|
| 225 |
+
"run_duration": perf.run_duration / 1000.0 # 转换为毫秒
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
def set_core_mask(self, core_mask: int) -> None:
|
| 229 |
+
"""
|
| 230 |
+
设置NPU核心使用模式
|
| 231 |
+
|
| 232 |
+
Args:
|
| 233 |
+
core_mask: NPU核心掩码,使用NPU_CORE_*常量
|
| 234 |
+
"""
|
| 235 |
+
ret = self.runtime.rknn_runtime.set_core_mask(core_mask)
|
| 236 |
+
if ret != 0:
|
| 237 |
+
raise RuntimeError("设置NPU核心模式失败")
|
| 238 |
+
|
| 239 |
+
@staticmethod
|
| 240 |
+
def _process_model_path(model_path, sess_options):
|
| 241 |
+
"""
|
| 242 |
+
处理模型路径,支持.onnx和.rknn文件
|
| 243 |
+
|
| 244 |
+
Args:
|
| 245 |
+
model_path: 模型文件路径
|
| 246 |
+
"""
|
| 247 |
+
# 如果是ONNX文件,检查是否需要自动加载RKNN
|
| 248 |
+
if model_path.lower().endswith('.onnx'):
|
| 249 |
+
logger.info("检测到ONNX模型文件")
|
| 250 |
+
|
| 251 |
+
# 获取需要跳过自动加载的模型列表
|
| 252 |
+
skip_models = os.getenv('ZTU_MODELRT_RKNNL2_SKIP', '').strip()
|
| 253 |
+
if skip_models:
|
| 254 |
+
skip_list = [m.strip() for m in skip_models.split(',')]
|
| 255 |
+
# 获取模型文件名(不含路径)用于匹配
|
| 256 |
+
model_name = os.path.basename(model_path)
|
| 257 |
+
if model_name.lower() in [m.lower() for m in skip_list]:
|
| 258 |
+
logger.info(f"模型{model_name}在跳过列表中,将使用ONNX Runtime")
|
| 259 |
+
return model_path
|
| 260 |
+
|
| 261 |
+
# 构造RKNN文件路径
|
| 262 |
+
rknn_path = os.path.splitext(model_path)[0] + '.rknn'
|
| 263 |
+
if os.path.exists(rknn_path):
|
| 264 |
+
logger.info(f"找到对应的RKNN模型,将使用RKNN: {rknn_path}")
|
| 265 |
+
return rknn_path
|
| 266 |
+
else:
|
| 267 |
+
logger.info("未找到对应的RKNN模型,将使用ONNX Runtime")
|
| 268 |
+
return model_path
|
| 269 |
+
|
| 270 |
+
return model_path
|
| 271 |
+
|
| 272 |
+
def _convert_nhwc_to_nchw(self, shape):
|
| 273 |
+
"""将NHWC格式的shape转换为NCHW格式"""
|
| 274 |
+
if len(shape) == 4:
|
| 275 |
+
# NHWC -> NCHW
|
| 276 |
+
n, h, w, c = shape
|
| 277 |
+
return [n, c, h, w]
|
| 278 |
+
return shape
|
| 279 |
+
|
| 280 |
+
def _init_io_info(self):
|
| 281 |
+
"""初始化模型的输入输出信息"""
|
| 282 |
+
runtime = self.runtime.rknn_runtime
|
| 283 |
+
|
| 284 |
+
# 获取输入输出数量
|
| 285 |
+
n_input, n_output = runtime.get_in_out_num()
|
| 286 |
+
|
| 287 |
+
# 获取输入信息
|
| 288 |
+
self.input_tensors = []
|
| 289 |
+
for i in range(n_input):
|
| 290 |
+
attr = runtime.get_tensor_attr(i)
|
| 291 |
+
shape = [attr.dims[j] for j in range(attr.n_dims)]
|
| 292 |
+
# 对四维输入进行NHWC到NCHW的转换
|
| 293 |
+
shape = self._convert_nhwc_to_nchw(shape)
|
| 294 |
+
# 获取dtype
|
| 295 |
+
dtype = RKNN_DTYPE_MAP.get(attr.type, None)
|
| 296 |
+
tensor = IOTensor(attr.name, shape, dtype)
|
| 297 |
+
self.input_tensors.append(tensor)
|
| 298 |
+
|
| 299 |
+
# 获取输出信息
|
| 300 |
+
self.output_tensors = []
|
| 301 |
+
for i in range(n_output):
|
| 302 |
+
attr = runtime.get_tensor_attr(i, is_output=True)
|
| 303 |
+
shape = runtime.get_output_shape(i)
|
| 304 |
+
# 获取dtype
|
| 305 |
+
dtype = RKNN_DTYPE_MAP.get(attr.type, None)
|
| 306 |
+
tensor = IOTensor(attr.name, shape, dtype)
|
| 307 |
+
self.output_tensors.append(tensor)
|
| 308 |
+
|
| 309 |
+
def get_inputs(self):
|
| 310 |
+
"""
|
| 311 |
+
获取模型输入信息
|
| 312 |
+
|
| 313 |
+
Returns:
|
| 314 |
+
list: 包含输入信息的列表
|
| 315 |
+
"""
|
| 316 |
+
return self.input_tensors
|
| 317 |
+
|
| 318 |
+
def get_outputs(self):
|
| 319 |
+
"""
|
| 320 |
+
获取模型输出信息
|
| 321 |
+
|
| 322 |
+
Returns:
|
| 323 |
+
list: 包含输出信息的列表
|
| 324 |
+
"""
|
| 325 |
+
return self.output_tensors
|
| 326 |
+
|
| 327 |
+
def run(self, output_names=None, input_feed=None, data_format="nchw", **kwargs):
|
| 328 |
+
"""
|
| 329 |
+
执行模型推理
|
| 330 |
+
|
| 331 |
+
Args:
|
| 332 |
+
output_names: 输出节点名称列表,指定需要返回哪些输出
|
| 333 |
+
input_feed: 输入数据字典或列表
|
| 334 |
+
data_format: 输入数据格式,"nchw"或"nhwc"
|
| 335 |
+
**kwargs: 其他运行时参数
|
| 336 |
+
|
| 337 |
+
Returns:
|
| 338 |
+
list: 模型输出结果列表,如果指定了output_names则只返回指定的输出
|
| 339 |
+
"""
|
| 340 |
+
if input_feed is None:
|
| 341 |
+
logger.error("input_feed不能为None")
|
| 342 |
+
raise ValueError("input_feed不能为None")
|
| 343 |
+
|
| 344 |
+
# 准备输入数据
|
| 345 |
+
if isinstance(input_feed, dict):
|
| 346 |
+
# 如果是字典,按照模型输入顺序排列
|
| 347 |
+
inputs = []
|
| 348 |
+
input_map = {tensor.name: i for i, tensor in enumerate(self.input_tensors)}
|
| 349 |
+
for tensor in self.input_tensors:
|
| 350 |
+
if tensor.name not in input_feed:
|
| 351 |
+
raise ValueError(f"缺少输入: {tensor.name}")
|
| 352 |
+
inputs.append(input_feed[tensor.name])
|
| 353 |
+
elif isinstance(input_feed, (list, tuple)):
|
| 354 |
+
# 如果是列表,确保长度匹配
|
| 355 |
+
if len(input_feed) != len(self.input_tensors):
|
| 356 |
+
raise ValueError(f"输入数量不匹配: 期望{len(self.input_tensors)}, 实际{len(input_feed)}")
|
| 357 |
+
inputs = list(input_feed)
|
| 358 |
+
else:
|
| 359 |
+
logger.error("input_feed必须是字典或列表类型")
|
| 360 |
+
raise ValueError("input_feed必须是字典或列表类型")
|
| 361 |
+
|
| 362 |
+
# 执行推理
|
| 363 |
+
try:
|
| 364 |
+
logger.debug("开始执行推理")
|
| 365 |
+
all_outputs = self.runtime.inference(inputs=inputs, data_format=data_format)
|
| 366 |
+
|
| 367 |
+
# 如果没有指定output_names,返回所有输出
|
| 368 |
+
if output_names is None:
|
| 369 |
+
return all_outputs
|
| 370 |
+
|
| 371 |
+
# 获取指定的输出
|
| 372 |
+
output_map = {tensor.name: i for i, tensor in enumerate(self.output_tensors)}
|
| 373 |
+
selected_outputs = []
|
| 374 |
+
for name in output_names:
|
| 375 |
+
if name not in output_map:
|
| 376 |
+
raise ValueError(f"未找到输出节点: {name}")
|
| 377 |
+
selected_outputs.append(all_outputs[output_map[name]])
|
| 378 |
+
|
| 379 |
+
return selected_outputs
|
| 380 |
+
|
| 381 |
+
except Exception as e:
|
| 382 |
+
logger.error(f"推理执行失败: {str(e)}")
|
| 383 |
+
raise RuntimeError(f"推理执行失败: {str(e)}")
|
| 384 |
+
|
| 385 |
+
def close(self):
|
| 386 |
+
"""
|
| 387 |
+
关闭会话,释放资源
|
| 388 |
+
"""
|
| 389 |
+
if self.runtime is not None:
|
| 390 |
+
logger.info("正在释放运行时资源")
|
| 391 |
+
self.runtime.release()
|
| 392 |
+
self.runtime = None
|
| 393 |
+
|
| 394 |
+
def __enter__(self):
|
| 395 |
+
return self
|
| 396 |
+
|
| 397 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 398 |
+
self.close()
|
| 399 |
+
|
| 400 |
+
def end_profiling(self) -> Optional[str]:
|
| 401 |
+
"""
|
| 402 |
+
结束性能分析的存根方法
|
| 403 |
+
|
| 404 |
+
Returns:
|
| 405 |
+
Optional[str]: None
|
| 406 |
+
"""
|
| 407 |
+
warnings.warn("end_profiling()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 408 |
+
return None
|
| 409 |
+
|
| 410 |
+
def get_profiling_start_time_ns(self) -> int:
|
| 411 |
+
"""
|
| 412 |
+
获取性能分析开始时间的存根方法
|
| 413 |
+
|
| 414 |
+
Returns:
|
| 415 |
+
int: 0
|
| 416 |
+
"""
|
| 417 |
+
warnings.warn("get_profiling_start_time_ns()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 418 |
+
return 0
|
| 419 |
+
|
| 420 |
+
def get_modelmeta(self) -> Dict[str, str]:
|
| 421 |
+
"""
|
| 422 |
+
获取模型元数据的存根方法
|
| 423 |
+
|
| 424 |
+
Returns:
|
| 425 |
+
Dict[str, str]: 空字典
|
| 426 |
+
"""
|
| 427 |
+
warnings.warn("get_modelmeta()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 428 |
+
return {}
|
| 429 |
+
|
| 430 |
+
def get_session_options(self) -> SessionOptions:
|
| 431 |
+
"""
|
| 432 |
+
获取会话选项
|
| 433 |
+
|
| 434 |
+
Returns:
|
| 435 |
+
SessionOptions: 当前会话选项
|
| 436 |
+
"""
|
| 437 |
+
return self.options
|
| 438 |
+
|
| 439 |
+
def get_providers(self) -> List[str]:
|
| 440 |
+
"""
|
| 441 |
+
获取当前使用的providers的存根方法
|
| 442 |
+
|
| 443 |
+
Returns:
|
| 444 |
+
List[str]: ["CPUExecutionProvider"]
|
| 445 |
+
"""
|
| 446 |
+
warnings.warn("get_providers()是存根方法,始终返回CPUExecutionProvider", RuntimeWarning, stacklevel=2)
|
| 447 |
+
return ["CPUExecutionProvider"]
|
| 448 |
+
|
| 449 |
+
def get_provider_options(self) -> Dict[str, Dict[str, str]]:
|
| 450 |
+
"""
|
| 451 |
+
获取provider选项的存根方法
|
| 452 |
+
|
| 453 |
+
Returns:
|
| 454 |
+
Dict[str, Dict[str, str]]: 空字典
|
| 455 |
+
"""
|
| 456 |
+
warnings.warn("get_provider_options()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 457 |
+
return {}
|
| 458 |
+
|
| 459 |
+
def get_session_config(self) -> Dict[str, str]:
|
| 460 |
+
"""
|
| 461 |
+
获取会话配置的存根方法
|
| 462 |
+
|
| 463 |
+
Returns:
|
| 464 |
+
Dict[str, str]: 空字典
|
| 465 |
+
"""
|
| 466 |
+
warnings.warn("get_session_config()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 467 |
+
return {}
|
| 468 |
+
|
| 469 |
+
def get_session_state(self) -> Dict[str, str]:
|
| 470 |
+
"""
|
| 471 |
+
获取会话状态的存根方法
|
| 472 |
+
|
| 473 |
+
Returns:
|
| 474 |
+
Dict[str, str]: 空字典
|
| 475 |
+
"""
|
| 476 |
+
warnings.warn("get_session_state()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 477 |
+
return {}
|
| 478 |
+
|
| 479 |
+
def set_session_config(self, config: Dict[str, str]) -> None:
|
| 480 |
+
"""
|
| 481 |
+
设置会话配置的存根方法
|
| 482 |
+
|
| 483 |
+
Args:
|
| 484 |
+
config: 会话配置字典
|
| 485 |
+
"""
|
| 486 |
+
warnings.warn("set_session_config()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 487 |
+
|
| 488 |
+
def get_memory_info(self) -> Dict[str, int]:
|
| 489 |
+
"""
|
| 490 |
+
获取内存使用信息的存根方法
|
| 491 |
+
|
| 492 |
+
Returns:
|
| 493 |
+
Dict[str, int]: 空字典
|
| 494 |
+
"""
|
| 495 |
+
warnings.warn("get_memory_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 496 |
+
return {}
|
| 497 |
+
|
| 498 |
+
def set_memory_pattern(self, enable: bool) -> None:
|
| 499 |
+
"""
|
| 500 |
+
设置内存模式的存根方法
|
| 501 |
+
|
| 502 |
+
Args:
|
| 503 |
+
enable: 是否启用内存模式
|
| 504 |
+
"""
|
| 505 |
+
warnings.warn("set_memory_pattern()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 506 |
+
|
| 507 |
+
def disable_memory_pattern(self) -> None:
|
| 508 |
+
"""
|
| 509 |
+
禁用内存模式的存根方法
|
| 510 |
+
"""
|
| 511 |
+
warnings.warn("disable_memory_pattern()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 512 |
+
|
| 513 |
+
def get_optimization_level(self) -> int:
|
| 514 |
+
"""
|
| 515 |
+
获取优化级别的存根方法
|
| 516 |
+
|
| 517 |
+
Returns:
|
| 518 |
+
int: 0
|
| 519 |
+
"""
|
| 520 |
+
warnings.warn("get_optimization_level()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 521 |
+
return 0
|
| 522 |
+
|
| 523 |
+
def set_optimization_level(self, level: int) -> None:
|
| 524 |
+
"""
|
| 525 |
+
设置优化级别的存根方法
|
| 526 |
+
|
| 527 |
+
Args:
|
| 528 |
+
level: 优化级别
|
| 529 |
+
"""
|
| 530 |
+
warnings.warn("set_optimization_level()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 531 |
+
|
| 532 |
+
def get_model_metadata(self) -> Dict[str, str]:
|
| 533 |
+
"""
|
| 534 |
+
获取模型元数据的存根方法(与get_modelmeta不同的接口)
|
| 535 |
+
|
| 536 |
+
Returns:
|
| 537 |
+
Dict[str, str]: 空字典
|
| 538 |
+
"""
|
| 539 |
+
warnings.warn("get_model_metadata()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 540 |
+
return {}
|
| 541 |
+
|
| 542 |
+
def get_model_path(self) -> str:
|
| 543 |
+
"""
|
| 544 |
+
获取模型路径
|
| 545 |
+
|
| 546 |
+
Returns:
|
| 547 |
+
str: 模型文件路径
|
| 548 |
+
"""
|
| 549 |
+
return self.model_path
|
| 550 |
+
|
| 551 |
+
def get_input_type_info(self) -> List[Dict[str, str]]:
|
| 552 |
+
"""
|
| 553 |
+
获取输入类型信息的存根方法
|
| 554 |
+
|
| 555 |
+
Returns:
|
| 556 |
+
List[Dict[str, str]]: 空列表
|
| 557 |
+
"""
|
| 558 |
+
warnings.warn("get_input_type_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 559 |
+
return []
|
| 560 |
+
|
| 561 |
+
def get_output_type_info(self) -> List[Dict[str, str]]:
|
| 562 |
+
"""
|
| 563 |
+
获取输出类型信息的存根方法
|
| 564 |
+
|
| 565 |
+
Returns:
|
| 566 |
+
List[Dict[str, str]]: 空列表
|
| 567 |
+
"""
|
| 568 |
+
warnings.warn("get_output_type_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 569 |
+
return []
|