Upload 12 files
Browse files- .gitattributes +5 -0
- convert_fastvithd.py +52 -0
- convert_mm_projector.py +52 -0
- export_onnx.py +136 -0
- fastvithd.rknn +3 -0
- librkllmrt.so +3 -0
- mm_projector.rknn +3 -0
- qwen_f16.rkllm +3 -0
- rkllm-convert.py +23 -0
- rkllm_binding.py +510 -0
- run_rknn.py +274 -0
- test.jpg +3 -0
- ztu_somemodelruntime_rknnlite2.py +569 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
fastvithd.rknn filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
librkllmrt.so filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
mm_projector.rknn filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
qwen_f16.rkllm filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
test.jpg filter=lfs diff=lfs merge=lfs -text
|
convert_fastvithd.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# ztu_somemodelruntime_rknn2: fastvithd
|
| 3 |
+
|
| 4 |
+
from rknn.api import RKNN
|
| 5 |
+
import os
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
def main():
|
| 9 |
+
# 创建RKNN实例
|
| 10 |
+
rknn = RKNN(verbose=False)
|
| 11 |
+
|
| 12 |
+
# ONNX模型路径
|
| 13 |
+
ONNX_MODEL = "fastvithd.onnx"
|
| 14 |
+
# 输出RKNN模型路径
|
| 15 |
+
RKNN_MODEL = "fastvithd.rknn"
|
| 16 |
+
|
| 17 |
+
# 配置参数
|
| 18 |
+
print("--> Config model")
|
| 19 |
+
ret = rknn.config(target_platform="rk3588",
|
| 20 |
+
dynamic_input=None)
|
| 21 |
+
if ret != 0:
|
| 22 |
+
print('Config model failed!')
|
| 23 |
+
exit(ret)
|
| 24 |
+
|
| 25 |
+
# 加载ONNX模型
|
| 26 |
+
print("--> Loading model")
|
| 27 |
+
ret = rknn.load_onnx(model=ONNX_MODEL,
|
| 28 |
+
inputs=['pixel_values'],
|
| 29 |
+
input_size_list=[[1, 3, 1024, 1024]])
|
| 30 |
+
if ret != 0:
|
| 31 |
+
print('Load model failed!')
|
| 32 |
+
exit(ret)
|
| 33 |
+
|
| 34 |
+
# 构建模型
|
| 35 |
+
print("--> Building model")
|
| 36 |
+
ret = rknn.build(do_quantization=False)
|
| 37 |
+
if ret != 0:
|
| 38 |
+
print('Build model failed!')
|
| 39 |
+
exit(ret)
|
| 40 |
+
|
| 41 |
+
# 导出RKNN模型
|
| 42 |
+
print("--> Export RKNN model")
|
| 43 |
+
ret = rknn.export_rknn(RKNN_MODEL)
|
| 44 |
+
if ret != 0:
|
| 45 |
+
print('Export RKNN model failed!')
|
| 46 |
+
exit(ret)
|
| 47 |
+
|
| 48 |
+
print(f'Done! The converted RKNN model has been saved to: ' + RKNN_MODEL)
|
| 49 |
+
rknn.release()
|
| 50 |
+
|
| 51 |
+
if __name__ == '__main__':
|
| 52 |
+
main()
|
convert_mm_projector.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# ztu_somemodelruntime_rknn2: mm_projector
|
| 3 |
+
|
| 4 |
+
from rknn.api import RKNN
|
| 5 |
+
import os
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
def main():
|
| 9 |
+
# 创建RKNN实例
|
| 10 |
+
rknn = RKNN(verbose=False)
|
| 11 |
+
|
| 12 |
+
# ONNX模型路径
|
| 13 |
+
ONNX_MODEL = "mm_projector.onnx"
|
| 14 |
+
# 输出RKNN模型路径
|
| 15 |
+
RKNN_MODEL = "mm_projector.rknn"
|
| 16 |
+
|
| 17 |
+
# 配置参数
|
| 18 |
+
print("--> Config model")
|
| 19 |
+
ret = rknn.config(target_platform="rk3588",
|
| 20 |
+
dynamic_input=None)
|
| 21 |
+
if ret != 0:
|
| 22 |
+
print('Config model failed!')
|
| 23 |
+
exit(ret)
|
| 24 |
+
|
| 25 |
+
# 加载ONNX模型
|
| 26 |
+
print("--> Loading model")
|
| 27 |
+
ret = rknn.load_onnx(model=ONNX_MODEL,
|
| 28 |
+
inputs=['last_hidden_state'],
|
| 29 |
+
input_size_list=[[1, 256, 3072]])
|
| 30 |
+
if ret != 0:
|
| 31 |
+
print('Load model failed!')
|
| 32 |
+
exit(ret)
|
| 33 |
+
|
| 34 |
+
# 构建模型
|
| 35 |
+
print("--> Building model")
|
| 36 |
+
ret = rknn.build(do_quantization=False)
|
| 37 |
+
if ret != 0:
|
| 38 |
+
print('Build model failed!')
|
| 39 |
+
exit(ret)
|
| 40 |
+
|
| 41 |
+
# 导出RKNN模型
|
| 42 |
+
print("--> Export RKNN model")
|
| 43 |
+
ret = rknn.export_rknn(RKNN_MODEL)
|
| 44 |
+
if ret != 0:
|
| 45 |
+
print('Export RKNN model failed!')
|
| 46 |
+
exit(ret)
|
| 47 |
+
|
| 48 |
+
print(f'Done! The converted RKNN model has been saved to: ' + RKNN_MODEL)
|
| 49 |
+
rknn.release()
|
| 50 |
+
|
| 51 |
+
if __name__ == '__main__':
|
| 52 |
+
main()
|
export_onnx.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# For licensing see accompanying LICENSE file.
|
| 3 |
+
# Copyright (C) 2025 Apple Inc. All Rights Reserved.
|
| 4 |
+
#
|
| 5 |
+
import os
|
| 6 |
+
import json
|
| 7 |
+
import copy
|
| 8 |
+
import argparse
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from llava.model.builder import load_pretrained_model
|
| 13 |
+
from llava.utils import disable_torch_init
|
| 14 |
+
from llava.mm_utils import get_model_name_from_path
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def export(args):
|
| 18 |
+
# Load model
|
| 19 |
+
disable_torch_init()
|
| 20 |
+
model_path = os.path.expanduser(args.model_path)
|
| 21 |
+
model_name = get_model_name_from_path(model_path)
|
| 22 |
+
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path,
|
| 23 |
+
args.model_base,
|
| 24 |
+
model_name,
|
| 25 |
+
device="cpu")
|
| 26 |
+
|
| 27 |
+
# Save extra metadata that is not saved during LLaVA training
|
| 28 |
+
# required by HF for auto-loading model and for mlx-vlm preprocessing
|
| 29 |
+
|
| 30 |
+
# Save image processing config
|
| 31 |
+
setattr(image_processor, "processor_class", "LlavaProcessor")
|
| 32 |
+
output_path = os.path.join(model_path, "preprocessor_config.json")
|
| 33 |
+
image_processor.to_json_file(output_path)
|
| 34 |
+
|
| 35 |
+
# Create processor config
|
| 36 |
+
processor_config = dict()
|
| 37 |
+
processor_config["image_token"] = "<image>"
|
| 38 |
+
processor_config["num_additional_image_tokens"] = 0
|
| 39 |
+
processor_config["processor_class"] = "LlavaProcessor"
|
| 40 |
+
processor_config["patch_size"] = 64
|
| 41 |
+
output_path = os.path.join(model_path, "processor_config.json")
|
| 42 |
+
json.dump(processor_config, open(output_path, "w"), indent=2)
|
| 43 |
+
|
| 44 |
+
# Modify tokenizer to include <image> special token.
|
| 45 |
+
tokenizer_config_path = os.path.join(model_path, "tokenizer_config.json")
|
| 46 |
+
tokenizer_config = json.load(open(tokenizer_config_path, 'r'))
|
| 47 |
+
token_ids = list()
|
| 48 |
+
image_token_is_present = False
|
| 49 |
+
for k, v in tokenizer_config['added_tokens_decoder'].items():
|
| 50 |
+
token_ids.append(int(k))
|
| 51 |
+
if v["content"] == "<image>":
|
| 52 |
+
image_token_is_present = True
|
| 53 |
+
token_ids.pop()
|
| 54 |
+
|
| 55 |
+
# Append only if <image> token is not present
|
| 56 |
+
if not image_token_is_present:
|
| 57 |
+
tokenizer_config['added_tokens_decoder'][f'{max(token_ids) + 1}'] = copy.deepcopy(
|
| 58 |
+
tokenizer_config['added_tokens_decoder'][f'{token_ids[0]}'])
|
| 59 |
+
tokenizer_config['added_tokens_decoder'][f'{max(token_ids) + 1}']["content"] = "<image>"
|
| 60 |
+
json.dump(tokenizer_config, open(tokenizer_config_path, 'w'), indent=2)
|
| 61 |
+
|
| 62 |
+
# Modify config to contain token id for <image>
|
| 63 |
+
config_path = os.path.join(model_path, "config.json")
|
| 64 |
+
model_config = json.load(open(config_path, 'r'))
|
| 65 |
+
model_config["image_token_index"] = max(token_ids) + 1
|
| 66 |
+
json.dump(model_config, open(config_path, 'w'), indent=2)
|
| 67 |
+
|
| 68 |
+
# Export the vision encoder to ONNX
|
| 69 |
+
image_res = image_processor.to_dict()['size']['shortest_edge']
|
| 70 |
+
dummy_vision_input = torch.rand(1, 3, image_res, image_res).float() # Dummy input tensor
|
| 71 |
+
|
| 72 |
+
vision_model = model.get_vision_tower()
|
| 73 |
+
# Ensure model is on CPU, in float precision, and in evaluation mode for ONNX export
|
| 74 |
+
vision_model = vision_model.cpu().float().eval()
|
| 75 |
+
|
| 76 |
+
onnx_vision_model_path = os.path.join(model_path, "fastvithd.onnx")
|
| 77 |
+
|
| 78 |
+
print(f"Exporting vision model to {onnx_vision_model_path}...")
|
| 79 |
+
torch.onnx.export(
|
| 80 |
+
vision_model,
|
| 81 |
+
dummy_vision_input, # Pass the dummy input tensor
|
| 82 |
+
onnx_vision_model_path,
|
| 83 |
+
input_names=['pixel_values'], # ONNX图中输入节点的名称
|
| 84 |
+
output_names=['last_hidden_state'], # ONNX图中输出节点的名称
|
| 85 |
+
# dynamic_axes={
|
| 86 |
+
# 'pixel_values': {0: 'batch_size'}, # 输入'pixel_values'的第0维是动态的batch_size
|
| 87 |
+
# 'last_hidden_state': {0: 'batch_size'} # 输出'last_hidden_state'的第0维是动态的batch_size
|
| 88 |
+
# },
|
| 89 |
+
opset_version=17, # ONNX opset 版本
|
| 90 |
+
export_params=True, # 在模型文件中存储训练好的参数权重
|
| 91 |
+
do_constant_folding=True # 执行常量折叠优化
|
| 92 |
+
)
|
| 93 |
+
print(f"Vision model ONNX export complete: {onnx_vision_model_path}")
|
| 94 |
+
|
| 95 |
+
# Generate dummy input for mm_projector by passing dummy_vision_input through vision_model
|
| 96 |
+
# This ensures the mm_projector receives input with the correct shape and characteristics
|
| 97 |
+
with torch.no_grad():
|
| 98 |
+
dummy_mm_projector_input = vision_model(dummy_vision_input)
|
| 99 |
+
|
| 100 |
+
# Ensure the input is on CPU and in float32 precision for the projector
|
| 101 |
+
dummy_mm_projector_input = dummy_mm_projector_input.cpu().float()
|
| 102 |
+
|
| 103 |
+
# Export the mm_projector to ONNX
|
| 104 |
+
# model.get_model() gives the underlying base model (e.g., LlavaLlamaModel)
|
| 105 |
+
# which contains the mm_projector attribute.
|
| 106 |
+
mm_projector = model.get_model().mm_projector
|
| 107 |
+
mm_projector = mm_projector.cpu().float().eval()
|
| 108 |
+
|
| 109 |
+
onnx_mm_projector_path = os.path.join(model_path, "mm_projector.onnx")
|
| 110 |
+
|
| 111 |
+
print(f"Exporting mm_projector to {onnx_mm_projector_path}...")
|
| 112 |
+
torch.onnx.export(
|
| 113 |
+
mm_projector,
|
| 114 |
+
dummy_mm_projector_input,
|
| 115 |
+
onnx_mm_projector_path,
|
| 116 |
+
input_names=['last_hidden_state'],
|
| 117 |
+
output_names=['projected_image_features'],
|
| 118 |
+
opset_version=17,
|
| 119 |
+
export_params=True,
|
| 120 |
+
do_constant_folding=True
|
| 121 |
+
)
|
| 122 |
+
print(f"mm_projector ONNX export complete: {onnx_mm_projector_path}")
|
| 123 |
+
|
| 124 |
+
# Removed CoreML specific code and intermediate .pt file handling
|
| 125 |
+
# No need for os.remove(pt_name) as pt_name is no longer created
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
if __name__ == "__main__":
|
| 129 |
+
parser = argparse.ArgumentParser()
|
| 130 |
+
parser.add_argument("--model-path", type=str, required=True)
|
| 131 |
+
parser.add_argument("--model-base", type=str, default=None)
|
| 132 |
+
parser.add_argument("--conv-mode", type=str, default="qwen_2")
|
| 133 |
+
|
| 134 |
+
args = parser.parse_args()
|
| 135 |
+
|
| 136 |
+
export(args)
|
fastvithd.rknn
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3d65f07758bd5fd610c76d28aae1d07a87bc65de0687f4fee5ef5c5e0f61d52a
|
| 3 |
+
size 372732105
|
librkllmrt.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f6a9c2de93cf94bb524eb071c27190ad4c83401e01b562534f265dff4cb40da2
|
| 3 |
+
size 6710712
|
mm_projector.rknn
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d3f6d9b06589c9aa7b70d28ef22703579d5d3c07c53e1b6ee72be85ac4ae7ee5
|
| 3 |
+
size 14272722
|
qwen_f16.rkllm
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7e13cc4849405b7338b5d21cef0f50a1776ea7c21594685e52665837cfec123c
|
| 3 |
+
size 3580141646
|
rkllm-convert.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from rkllm.api import RKLLM
|
| 2 |
+
|
| 3 |
+
modelpath = '.'
|
| 4 |
+
llm = RKLLM()
|
| 5 |
+
|
| 6 |
+
ret = llm.load_huggingface(model=modelpath, model_lora=None, device='cpu')
|
| 7 |
+
if ret != 0:
|
| 8 |
+
print('Load model failed!')
|
| 9 |
+
exit(ret)
|
| 10 |
+
|
| 11 |
+
qparams = None
|
| 12 |
+
ret = llm.build(do_quantization=False, optimization_level=1, quantized_dtype='w8a8_g128',
|
| 13 |
+
quantized_algorithm='normal', target_platform='rk3588', num_npu_core=3, extra_qparams=qparams, dataset='calibration_dataset.json')
|
| 14 |
+
|
| 15 |
+
if ret != 0:
|
| 16 |
+
print('Build model failed!')
|
| 17 |
+
exit(ret)
|
| 18 |
+
|
| 19 |
+
# Export rkllm model
|
| 20 |
+
ret = llm.export_rkllm("./qwen_f16.rkllm")
|
| 21 |
+
if ret != 0:
|
| 22 |
+
print('Export model failed!')
|
| 23 |
+
exit(ret)
|
rkllm_binding.py
ADDED
|
@@ -0,0 +1,510 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ctypes
|
| 2 |
+
import enum
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
# Define constants from the header
|
| 6 |
+
CPU0 = (1 << 0) # 0x01
|
| 7 |
+
CPU1 = (1 << 1) # 0x02
|
| 8 |
+
CPU2 = (1 << 2) # 0x04
|
| 9 |
+
CPU3 = (1 << 3) # 0x08
|
| 10 |
+
CPU4 = (1 << 4) # 0x10
|
| 11 |
+
CPU5 = (1 << 5) # 0x20
|
| 12 |
+
CPU6 = (1 << 6) # 0x40
|
| 13 |
+
CPU7 = (1 << 7) # 0x80
|
| 14 |
+
|
| 15 |
+
# --- Enums ---
|
| 16 |
+
class LLMCallState(enum.IntEnum):
|
| 17 |
+
RKLLM_RUN_NORMAL = 0
|
| 18 |
+
RKLLM_RUN_WAITING = 1
|
| 19 |
+
RKLLM_RUN_FINISH = 2
|
| 20 |
+
RKLLM_RUN_ERROR = 3
|
| 21 |
+
|
| 22 |
+
class RKLLMInputType(enum.IntEnum):
|
| 23 |
+
RKLLM_INPUT_PROMPT = 0
|
| 24 |
+
RKLLM_INPUT_TOKEN = 1
|
| 25 |
+
RKLLM_INPUT_EMBED = 2
|
| 26 |
+
RKLLM_INPUT_MULTIMODAL = 3
|
| 27 |
+
|
| 28 |
+
class RKLLMInferMode(enum.IntEnum):
|
| 29 |
+
RKLLM_INFER_GENERATE = 0
|
| 30 |
+
RKLLM_INFER_GET_LAST_HIDDEN_LAYER = 1
|
| 31 |
+
RKLLM_INFER_GET_LOGITS = 2
|
| 32 |
+
|
| 33 |
+
# --- Structures ---
|
| 34 |
+
class RKLLMExtendParam(ctypes.Structure):
|
| 35 |
+
_fields_ = [
|
| 36 |
+
("base_domain_id", ctypes.c_int32),
|
| 37 |
+
("embed_flash", ctypes.c_int8),
|
| 38 |
+
("enabled_cpus_num", ctypes.c_int8),
|
| 39 |
+
("enabled_cpus_mask", ctypes.c_uint32),
|
| 40 |
+
("reserved", ctypes.c_uint8 * 106)
|
| 41 |
+
]
|
| 42 |
+
|
| 43 |
+
class RKLLMParam(ctypes.Structure):
|
| 44 |
+
_fields_ = [
|
| 45 |
+
("model_path", ctypes.c_char_p),
|
| 46 |
+
("max_context_len", ctypes.c_int32),
|
| 47 |
+
("max_new_tokens", ctypes.c_int32),
|
| 48 |
+
("top_k", ctypes.c_int32),
|
| 49 |
+
("n_keep", ctypes.c_int32),
|
| 50 |
+
("top_p", ctypes.c_float),
|
| 51 |
+
("temperature", ctypes.c_float),
|
| 52 |
+
("repeat_penalty", ctypes.c_float),
|
| 53 |
+
("frequency_penalty", ctypes.c_float),
|
| 54 |
+
("presence_penalty", ctypes.c_float), # Note: This was missing in the provided text but is in typical LLM params
|
| 55 |
+
("mirostat", ctypes.c_int32),
|
| 56 |
+
("mirostat_tau", ctypes.c_float),
|
| 57 |
+
("mirostat_eta", ctypes.c_float),
|
| 58 |
+
("skip_special_token", ctypes.c_bool),
|
| 59 |
+
("is_async", ctypes.c_bool),
|
| 60 |
+
("img_start", ctypes.c_char_p),
|
| 61 |
+
("img_end", ctypes.c_char_p),
|
| 62 |
+
("img_content", ctypes.c_char_p), # This seems like it should be more structured for actual image data
|
| 63 |
+
("extend_param", RKLLMExtendParam)
|
| 64 |
+
]
|
| 65 |
+
|
| 66 |
+
class RKLLMLoraAdapter(ctypes.Structure):
|
| 67 |
+
_fields_ = [
|
| 68 |
+
("lora_adapter_path", ctypes.c_char_p),
|
| 69 |
+
("lora_adapter_name", ctypes.c_char_p),
|
| 70 |
+
("scale", ctypes.c_float)
|
| 71 |
+
]
|
| 72 |
+
|
| 73 |
+
class RKLLMEmbedInput(ctypes.Structure):
|
| 74 |
+
_fields_ = [
|
| 75 |
+
("embed", ctypes.POINTER(ctypes.c_float)),
|
| 76 |
+
("n_tokens", ctypes.c_size_t)
|
| 77 |
+
]
|
| 78 |
+
|
| 79 |
+
class RKLLMTokenInput(ctypes.Structure):
|
| 80 |
+
_fields_ = [
|
| 81 |
+
("input_ids", ctypes.POINTER(ctypes.c_int32)),
|
| 82 |
+
("n_tokens", ctypes.c_size_t)
|
| 83 |
+
]
|
| 84 |
+
|
| 85 |
+
class RKLLMMultiModelInput(ctypes.Structure):
|
| 86 |
+
_fields_ = [
|
| 87 |
+
("prompt", ctypes.c_char_p),
|
| 88 |
+
("image_embed", ctypes.POINTER(ctypes.c_float)),
|
| 89 |
+
("n_image_tokens", ctypes.c_size_t),
|
| 90 |
+
("n_image", ctypes.c_size_t),
|
| 91 |
+
("image_width", ctypes.c_size_t),
|
| 92 |
+
("image_height", ctypes.c_size_t)
|
| 93 |
+
]
|
| 94 |
+
|
| 95 |
+
class _RKLLMInputUnion(ctypes.Union):
|
| 96 |
+
_fields_ = [
|
| 97 |
+
("prompt_input", ctypes.c_char_p),
|
| 98 |
+
("embed_input", RKLLMEmbedInput),
|
| 99 |
+
("token_input", RKLLMTokenInput),
|
| 100 |
+
("multimodal_input", RKLLMMultiModelInput)
|
| 101 |
+
]
|
| 102 |
+
|
| 103 |
+
class RKLLMInput(ctypes.Structure):
|
| 104 |
+
_fields_ = [
|
| 105 |
+
("input_type", ctypes.c_int), # Enum will be passed as int, changed RKLLMInputType to ctypes.c_int
|
| 106 |
+
("_union_data", _RKLLMInputUnion)
|
| 107 |
+
]
|
| 108 |
+
# Properties to make accessing union members easier
|
| 109 |
+
@property
|
| 110 |
+
def prompt_input(self):
|
| 111 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_PROMPT:
|
| 112 |
+
return self._union_data.prompt_input
|
| 113 |
+
raise AttributeError("Not a prompt input")
|
| 114 |
+
@prompt_input.setter
|
| 115 |
+
def prompt_input(self, value):
|
| 116 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_PROMPT:
|
| 117 |
+
self._union_data.prompt_input = value
|
| 118 |
+
else:
|
| 119 |
+
raise AttributeError("Not a prompt input")
|
| 120 |
+
|
| 121 |
+
# Similar properties can be added for embed_input, token_input, multimodal_input
|
| 122 |
+
|
| 123 |
+
class RKLLMLoraParam(ctypes.Structure): # For inference
|
| 124 |
+
_fields_ = [
|
| 125 |
+
("lora_adapter_name", ctypes.c_char_p)
|
| 126 |
+
]
|
| 127 |
+
|
| 128 |
+
class RKLLMPromptCacheParam(ctypes.Structure): # For inference
|
| 129 |
+
_fields_ = [
|
| 130 |
+
("save_prompt_cache", ctypes.c_int), # bool-like
|
| 131 |
+
("prompt_cache_path", ctypes.c_char_p)
|
| 132 |
+
]
|
| 133 |
+
|
| 134 |
+
class RKLLMInferParam(ctypes.Structure):
|
| 135 |
+
_fields_ = [
|
| 136 |
+
("mode", ctypes.c_int), # Enum will be passed as int, changed RKLLMInferMode to ctypes.c_int
|
| 137 |
+
("lora_params", ctypes.POINTER(RKLLMLoraParam)),
|
| 138 |
+
("prompt_cache_params", ctypes.POINTER(RKLLMPromptCacheParam)),
|
| 139 |
+
("keep_history", ctypes.c_int) # bool-like
|
| 140 |
+
]
|
| 141 |
+
|
| 142 |
+
class RKLLMResultLastHiddenLayer(ctypes.Structure):
|
| 143 |
+
_fields_ = [
|
| 144 |
+
("hidden_states", ctypes.POINTER(ctypes.c_float)),
|
| 145 |
+
("embd_size", ctypes.c_int),
|
| 146 |
+
("num_tokens", ctypes.c_int)
|
| 147 |
+
]
|
| 148 |
+
|
| 149 |
+
class RKLLMResultLogits(ctypes.Structure):
|
| 150 |
+
_fields_ = [
|
| 151 |
+
("logits", ctypes.POINTER(ctypes.c_float)),
|
| 152 |
+
("vocab_size", ctypes.c_int),
|
| 153 |
+
("num_tokens", ctypes.c_int)
|
| 154 |
+
]
|
| 155 |
+
|
| 156 |
+
class RKLLMResult(ctypes.Structure):
|
| 157 |
+
_fields_ = [
|
| 158 |
+
("text", ctypes.c_char_p),
|
| 159 |
+
("token_id", ctypes.c_int32),
|
| 160 |
+
("last_hidden_layer", RKLLMResultLastHiddenLayer),
|
| 161 |
+
("logits", RKLLMResultLogits)
|
| 162 |
+
]
|
| 163 |
+
|
| 164 |
+
# --- Typedefs ---
|
| 165 |
+
LLMHandle = ctypes.c_void_p
|
| 166 |
+
|
| 167 |
+
# --- Callback Function Type ---
|
| 168 |
+
LLMResultCallback = ctypes.CFUNCTYPE(
|
| 169 |
+
None, # return type: void
|
| 170 |
+
ctypes.POINTER(RKLLMResult),
|
| 171 |
+
ctypes.c_void_p, # userdata
|
| 172 |
+
ctypes.c_int # enum, will be passed as int. Changed LLMCallState to ctypes.c_int
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class RKLLMRuntime:
|
| 177 |
+
def __init__(self, library_path="./librkllmrt.so"):
|
| 178 |
+
try:
|
| 179 |
+
self.lib = ctypes.CDLL(library_path)
|
| 180 |
+
except OSError as e:
|
| 181 |
+
raise OSError(f"Failed to load RKLLM library from {library_path}. "
|
| 182 |
+
f"Ensure it's in your LD_LIBRARY_PATH or provide the full path. Error: {e}")
|
| 183 |
+
self._setup_functions()
|
| 184 |
+
self.llm_handle = LLMHandle()
|
| 185 |
+
self._c_callback = None # To keep the callback object alive
|
| 186 |
+
|
| 187 |
+
def _setup_functions(self):
|
| 188 |
+
# RKLLMParam rkllm_createDefaultParam();
|
| 189 |
+
self.lib.rkllm_createDefaultParam.restype = RKLLMParam
|
| 190 |
+
self.lib.rkllm_createDefaultParam.argtypes = []
|
| 191 |
+
|
| 192 |
+
# int rkllm_init(LLMHandle* handle, RKLLMParam* param, LLMResultCallback callback);
|
| 193 |
+
self.lib.rkllm_init.restype = ctypes.c_int
|
| 194 |
+
self.lib.rkllm_init.argtypes = [
|
| 195 |
+
ctypes.POINTER(LLMHandle),
|
| 196 |
+
ctypes.POINTER(RKLLMParam),
|
| 197 |
+
LLMResultCallback
|
| 198 |
+
]
|
| 199 |
+
|
| 200 |
+
# int rkllm_load_lora(LLMHandle handle, RKLLMLoraAdapter* lora_adapter);
|
| 201 |
+
self.lib.rkllm_load_lora.restype = ctypes.c_int
|
| 202 |
+
self.lib.rkllm_load_lora.argtypes = [LLMHandle, ctypes.POINTER(RKLLMLoraAdapter)]
|
| 203 |
+
|
| 204 |
+
# int rkllm_load_prompt_cache(LLMHandle handle, const char* prompt_cache_path);
|
| 205 |
+
self.lib.rkllm_load_prompt_cache.restype = ctypes.c_int
|
| 206 |
+
self.lib.rkllm_load_prompt_cache.argtypes = [LLMHandle, ctypes.c_char_p]
|
| 207 |
+
|
| 208 |
+
# int rkllm_release_prompt_cache(LLMHandle handle);
|
| 209 |
+
self.lib.rkllm_release_prompt_cache.restype = ctypes.c_int
|
| 210 |
+
self.lib.rkllm_release_prompt_cache.argtypes = [LLMHandle]
|
| 211 |
+
|
| 212 |
+
# int rkllm_destroy(LLMHandle handle);
|
| 213 |
+
self.lib.rkllm_destroy.restype = ctypes.c_int
|
| 214 |
+
self.lib.rkllm_destroy.argtypes = [LLMHandle]
|
| 215 |
+
|
| 216 |
+
# int rkllm_run(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata);
|
| 217 |
+
self.lib.rkllm_run.restype = ctypes.c_int
|
| 218 |
+
self.lib.rkllm_run.argtypes = [
|
| 219 |
+
LLMHandle,
|
| 220 |
+
ctypes.POINTER(RKLLMInput),
|
| 221 |
+
ctypes.POINTER(RKLLMInferParam),
|
| 222 |
+
ctypes.c_void_p # userdata
|
| 223 |
+
]
|
| 224 |
+
|
| 225 |
+
# int rkllm_run_async(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata);
|
| 226 |
+
# Assuming async also takes userdata for the callback context
|
| 227 |
+
self.lib.rkllm_run_async.restype = ctypes.c_int
|
| 228 |
+
self.lib.rkllm_run_async.argtypes = [
|
| 229 |
+
LLMHandle,
|
| 230 |
+
ctypes.POINTER(RKLLMInput),
|
| 231 |
+
ctypes.POINTER(RKLLMInferParam),
|
| 232 |
+
ctypes.c_void_p # userdata
|
| 233 |
+
]
|
| 234 |
+
|
| 235 |
+
# int rkllm_abort(LLMHandle handle);
|
| 236 |
+
self.lib.rkllm_abort.restype = ctypes.c_int
|
| 237 |
+
self.lib.rkllm_abort.argtypes = [LLMHandle]
|
| 238 |
+
|
| 239 |
+
# int rkllm_is_running(LLMHandle handle);
|
| 240 |
+
self.lib.rkllm_is_running.restype = ctypes.c_int # 0 if running, non-zero otherwise
|
| 241 |
+
self.lib.rkllm_is_running.argtypes = [LLMHandle]
|
| 242 |
+
|
| 243 |
+
# int rkllm_clear_kv_cache(LLMHandle handle, int keep_system_prompt);
|
| 244 |
+
self.lib.rkllm_clear_kv_cache.restype = ctypes.c_int
|
| 245 |
+
self.lib.rkllm_clear_kv_cache.argtypes = [LLMHandle, ctypes.c_int]
|
| 246 |
+
|
| 247 |
+
# int rkllm_set_chat_template(LLMHandle handle, const char* system_prompt, const char* prompt_prefix, const char* prompt_postfix);
|
| 248 |
+
self.lib.rkllm_set_chat_template.restype = ctypes.c_int
|
| 249 |
+
self.lib.rkllm_set_chat_template.argtypes = [
|
| 250 |
+
LLMHandle,
|
| 251 |
+
ctypes.c_char_p,
|
| 252 |
+
ctypes.c_char_p,
|
| 253 |
+
ctypes.c_char_p
|
| 254 |
+
]
|
| 255 |
+
|
| 256 |
+
def create_default_param(self) -> RKLLMParam:
|
| 257 |
+
"""Creates a default RKLLMParam structure."""
|
| 258 |
+
return self.lib.rkllm_createDefaultParam()
|
| 259 |
+
|
| 260 |
+
def init(self, param: RKLLMParam, callback_func) -> int:
|
| 261 |
+
"""
|
| 262 |
+
Initializes the LLM.
|
| 263 |
+
:param param: RKLLMParam structure.
|
| 264 |
+
:param callback_func: A Python function that matches the signature:
|
| 265 |
+
def my_callback(result_ptr, userdata_ptr, state_enum):
|
| 266 |
+
result = result_ptr.contents # RKLLMResult
|
| 267 |
+
# Process result
|
| 268 |
+
# userdata can be retrieved if passed during run, or ignored
|
| 269 |
+
# state = LLMCallState(state_enum)
|
| 270 |
+
:return: 0 for success, non-zero for failure.
|
| 271 |
+
"""
|
| 272 |
+
if not callable(callback_func):
|
| 273 |
+
raise ValueError("callback_func must be a callable Python function.")
|
| 274 |
+
|
| 275 |
+
# Keep a reference to the ctypes callback object to prevent it from being garbage collected
|
| 276 |
+
self._c_callback = LLMResultCallback(callback_func)
|
| 277 |
+
|
| 278 |
+
ret = self.lib.rkllm_init(ctypes.byref(self.llm_handle), ctypes.byref(param), self._c_callback)
|
| 279 |
+
if ret != 0:
|
| 280 |
+
raise RuntimeError(f"rkllm_init failed with error code {ret}")
|
| 281 |
+
return ret
|
| 282 |
+
|
| 283 |
+
def load_lora(self, lora_adapter: RKLLMLoraAdapter) -> int:
|
| 284 |
+
"""Loads a Lora adapter."""
|
| 285 |
+
ret = self.lib.rkllm_load_lora(self.llm_handle, ctypes.byref(lora_adapter))
|
| 286 |
+
if ret != 0:
|
| 287 |
+
raise RuntimeError(f"rkllm_load_lora failed with error code {ret}")
|
| 288 |
+
return ret
|
| 289 |
+
|
| 290 |
+
def load_prompt_cache(self, prompt_cache_path: str) -> int:
|
| 291 |
+
"""Loads a prompt cache from a file."""
|
| 292 |
+
c_path = prompt_cache_path.encode('utf-8')
|
| 293 |
+
ret = self.lib.rkllm_load_prompt_cache(self.llm_handle, c_path)
|
| 294 |
+
if ret != 0:
|
| 295 |
+
raise RuntimeError(f"rkllm_load_prompt_cache failed for {prompt_cache_path} with error code {ret}")
|
| 296 |
+
return ret
|
| 297 |
+
|
| 298 |
+
def release_prompt_cache(self) -> int:
|
| 299 |
+
"""Releases the prompt cache from memory."""
|
| 300 |
+
ret = self.lib.rkllm_release_prompt_cache(self.llm_handle)
|
| 301 |
+
if ret != 0:
|
| 302 |
+
raise RuntimeError(f"rkllm_release_prompt_cache failed with error code {ret}")
|
| 303 |
+
return ret
|
| 304 |
+
|
| 305 |
+
def destroy(self) -> int:
|
| 306 |
+
"""Destroys the LLM instance and releases resources."""
|
| 307 |
+
if self.llm_handle and self.llm_handle.value: # Check if handle is not NULL
|
| 308 |
+
ret = self.lib.rkllm_destroy(self.llm_handle)
|
| 309 |
+
self.llm_handle = LLMHandle() # Reset handle
|
| 310 |
+
if ret != 0:
|
| 311 |
+
# Don't raise here as it might be called in __del__
|
| 312 |
+
print(f"Warning: rkllm_destroy failed with error code {ret}")
|
| 313 |
+
return ret
|
| 314 |
+
return 0 # Already destroyed or not initialized
|
| 315 |
+
|
| 316 |
+
def run(self, rkllm_input: RKLLMInput, rkllm_infer_params: RKLLMInferParam, userdata=None) -> int:
|
| 317 |
+
"""Runs an LLM inference task synchronously."""
|
| 318 |
+
# userdata can be a ctypes.py_object if you want to pass Python objects,
|
| 319 |
+
# then cast to c_void_p. Or simply None.
|
| 320 |
+
c_userdata = ctypes.cast(ctypes.py_object(userdata), ctypes.c_void_p) if userdata is not None else None
|
| 321 |
+
ret = self.lib.rkllm_run(self.llm_handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), c_userdata)
|
| 322 |
+
if ret != 0:
|
| 323 |
+
raise RuntimeError(f"rkllm_run failed with error code {ret}")
|
| 324 |
+
return ret
|
| 325 |
+
|
| 326 |
+
def run_async(self, rkllm_input: RKLLMInput, rkllm_infer_params: RKLLMInferParam, userdata=None) -> int:
|
| 327 |
+
"""Runs an LLM inference task asynchronously."""
|
| 328 |
+
c_userdata = ctypes.cast(ctypes.py_object(userdata), ctypes.c_void_p) if userdata is not None else None
|
| 329 |
+
ret = self.lib.rkllm_run_async(self.llm_handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), c_userdata)
|
| 330 |
+
if ret != 0:
|
| 331 |
+
raise RuntimeError(f"rkllm_run_async failed with error code {ret}")
|
| 332 |
+
return ret
|
| 333 |
+
|
| 334 |
+
def abort(self) -> int:
|
| 335 |
+
"""Aborts an ongoing LLM task."""
|
| 336 |
+
ret = self.lib.rkllm_abort(self.llm_handle)
|
| 337 |
+
if ret != 0:
|
| 338 |
+
raise RuntimeError(f"rkllm_abort failed with error code {ret}")
|
| 339 |
+
return ret
|
| 340 |
+
|
| 341 |
+
def is_running(self) -> bool:
|
| 342 |
+
"""Checks if an LLM task is currently running. Returns True if running."""
|
| 343 |
+
# The C API returns 0 if running, non-zero otherwise.
|
| 344 |
+
# This is a bit counter-intuitive for a boolean "is_running".
|
| 345 |
+
return self.lib.rkllm_is_running(self.llm_handle) == 0
|
| 346 |
+
|
| 347 |
+
def clear_kv_cache(self, keep_system_prompt: bool) -> int:
|
| 348 |
+
"""Clears the key-value cache."""
|
| 349 |
+
ret = self.lib.rkllm_clear_kv_cache(self.llm_handle, ctypes.c_int(1 if keep_system_prompt else 0))
|
| 350 |
+
if ret != 0:
|
| 351 |
+
raise RuntimeError(f"rkllm_clear_kv_cache failed with error code {ret}")
|
| 352 |
+
return ret
|
| 353 |
+
|
| 354 |
+
def set_chat_template(self, system_prompt: str, prompt_prefix: str, prompt_postfix: str) -> int:
|
| 355 |
+
"""Sets the chat template for the LLM."""
|
| 356 |
+
c_system = system_prompt.encode('utf-8') if system_prompt else None
|
| 357 |
+
c_prefix = prompt_prefix.encode('utf-8') if prompt_prefix else None
|
| 358 |
+
c_postfix = prompt_postfix.encode('utf-8') if prompt_postfix else None
|
| 359 |
+
|
| 360 |
+
ret = self.lib.rkllm_set_chat_template(self.llm_handle, c_system, c_prefix, c_postfix)
|
| 361 |
+
if ret != 0:
|
| 362 |
+
raise RuntimeError(f"rkllm_set_chat_template failed with error code {ret}")
|
| 363 |
+
return ret
|
| 364 |
+
|
| 365 |
+
def __enter__(self):
|
| 366 |
+
return self
|
| 367 |
+
|
| 368 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 369 |
+
self.destroy()
|
| 370 |
+
|
| 371 |
+
def __del__(self):
|
| 372 |
+
self.destroy() # Ensure resources are freed if object is garbage collected
|
| 373 |
+
|
| 374 |
+
# --- Example Usage (Illustrative) ---
|
| 375 |
+
if __name__ == "__main__":
|
| 376 |
+
# This is a placeholder for how you might use it.
|
| 377 |
+
# You'll need a valid .rkllm model and librkllmrt.so in your path.
|
| 378 |
+
|
| 379 |
+
# Global list to store results from callback for demonstration
|
| 380 |
+
results_buffer = []
|
| 381 |
+
|
| 382 |
+
def my_python_callback(result_ptr, userdata_ptr, state_enum):
|
| 383 |
+
"""
|
| 384 |
+
Callback function to be called by the C library.
|
| 385 |
+
"""
|
| 386 |
+
global results_buffer
|
| 387 |
+
state = LLMCallState(state_enum)
|
| 388 |
+
result = result_ptr.contents
|
| 389 |
+
|
| 390 |
+
current_text = ""
|
| 391 |
+
if result.text: # Check if the char_p is not NULL
|
| 392 |
+
current_text = result.text.decode('utf-8', errors='ignore')
|
| 393 |
+
|
| 394 |
+
print(f"Callback: State={state.name}, TokenID={result.token_id}, Text='{current_text}'")
|
| 395 |
+
results_buffer.append(current_text)
|
| 396 |
+
|
| 397 |
+
if state == LLMCallState.RKLLM_RUN_FINISH:
|
| 398 |
+
print("Inference finished.")
|
| 399 |
+
elif state == LLMCallState.RKLLM_RUN_ERROR:
|
| 400 |
+
print("Inference error.")
|
| 401 |
+
|
| 402 |
+
# Example: Accessing logits if available (and if mode was set to get logits)
|
| 403 |
+
# if result.logits.logits and result.logits.vocab_size > 0:
|
| 404 |
+
# print(f" Logits (first 5 of vocab_size {result.logits.vocab_size}):")
|
| 405 |
+
# for i in range(min(5, result.logits.vocab_size)):
|
| 406 |
+
# print(f" {result.logits.logits[i]:.4f}", end=" ")
|
| 407 |
+
# print()
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
# --- Attempt to use the wrapper ---
|
| 411 |
+
try:
|
| 412 |
+
print("Initializing RKLLMRuntime...")
|
| 413 |
+
# Adjust library_path if librkllmrt.so is not in default search paths
|
| 414 |
+
# e.g., library_path="./path/to/librkllmrt.so"
|
| 415 |
+
rk_llm = RKLLMRuntime()
|
| 416 |
+
|
| 417 |
+
print("Creating default parameters...")
|
| 418 |
+
params = rk_llm.create_default_param()
|
| 419 |
+
|
| 420 |
+
# --- Configure parameters ---
|
| 421 |
+
# THIS IS CRITICAL: model_path must point to an actual .rkllm file
|
| 422 |
+
# For this example to run, you need a model file.
|
| 423 |
+
# Let's assume a dummy path for now, this will fail at init if not valid.
|
| 424 |
+
model_file = "dummy_model.rkllm"
|
| 425 |
+
if not os.path.exists(model_file):
|
| 426 |
+
print(f"Warning: Model file '{model_file}' does not exist. Init will likely fail.")
|
| 427 |
+
# Create a dummy file for the example to proceed further, though init will still fail
|
| 428 |
+
# with a real library unless it's a valid model.
|
| 429 |
+
with open(model_file, "w") as f:
|
| 430 |
+
f.write("dummy content")
|
| 431 |
+
|
| 432 |
+
params.model_path = model_file.encode('utf-8')
|
| 433 |
+
params.max_context_len = 512
|
| 434 |
+
params.max_new_tokens = 128
|
| 435 |
+
params.top_k = 1 # Greedy
|
| 436 |
+
params.temperature = 0.7
|
| 437 |
+
params.repeat_penalty = 1.1
|
| 438 |
+
# ... set other params as needed
|
| 439 |
+
|
| 440 |
+
print(f"Initializing LLM with model: {params.model_path.decode()}...")
|
| 441 |
+
# This will likely fail if dummy_model.rkllm is not a valid model recognized by the library
|
| 442 |
+
try:
|
| 443 |
+
rk_llm.init(params, my_python_callback)
|
| 444 |
+
print("LLM Initialized.")
|
| 445 |
+
except RuntimeError as e:
|
| 446 |
+
print(f"Error during LLM initialization: {e}")
|
| 447 |
+
print("This is expected if 'dummy_model.rkllm' is not a valid model.")
|
| 448 |
+
print("Replace 'dummy_model.rkllm' with a real model path to test further.")
|
| 449 |
+
exit()
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
# --- Prepare input ---
|
| 453 |
+
print("Preparing input...")
|
| 454 |
+
rk_input = RKLLMInput()
|
| 455 |
+
rk_input.input_type = RKLLMInputType.RKLLM_INPUT_PROMPT
|
| 456 |
+
|
| 457 |
+
prompt_text = "Translate the following English text to French: 'Hello, world!'"
|
| 458 |
+
c_prompt = prompt_text.encode('utf-8')
|
| 459 |
+
rk_input._union_data.prompt_input = c_prompt # Accessing union member directly
|
| 460 |
+
|
| 461 |
+
# --- Prepare inference parameters ---
|
| 462 |
+
print("Preparing inference parameters...")
|
| 463 |
+
infer_params = RKLLMInferParam()
|
| 464 |
+
infer_params.mode = RKLLMInferMode.RKLLM_INFER_GENERATE
|
| 465 |
+
infer_params.keep_history = 1 # True
|
| 466 |
+
# infer_params.lora_params = None # or set up RKLLMLoraParam if using LoRA
|
| 467 |
+
# infer_params.prompt_cache_params = None # or set up RKLLMPromptCacheParam
|
| 468 |
+
|
| 469 |
+
# --- Run inference ---
|
| 470 |
+
print(f"Running inference with prompt: '{prompt_text}'")
|
| 471 |
+
results_buffer.clear()
|
| 472 |
+
try:
|
| 473 |
+
rk_llm.run(rk_input, infer_params) # Userdata is None by default
|
| 474 |
+
print("\n--- Full Response ---")
|
| 475 |
+
print("".join(results_buffer))
|
| 476 |
+
print("---------------------\n")
|
| 477 |
+
except RuntimeError as e:
|
| 478 |
+
print(f"Error during LLM run: {e}")
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
# --- Example: Set chat template (if model supports it) ---
|
| 482 |
+
# print("Setting chat template...")
|
| 483 |
+
# try:
|
| 484 |
+
# rk_llm.set_chat_template("You are a helpful assistant.", "<user>: ", "<assistant>: ")
|
| 485 |
+
# print("Chat template set.")
|
| 486 |
+
# except RuntimeError as e:
|
| 487 |
+
# print(f"Error setting chat template: {e}")
|
| 488 |
+
|
| 489 |
+
# --- Example: Clear KV Cache ---
|
| 490 |
+
# print("Clearing KV cache (keeping system prompt if any)...")
|
| 491 |
+
# try:
|
| 492 |
+
# rk_llm.clear_kv_cache(keep_system_prompt=True)
|
| 493 |
+
# print("KV cache cleared.")
|
| 494 |
+
# except RuntimeError as e:
|
| 495 |
+
# print(f"Error clearing KV cache: {e}")
|
| 496 |
+
|
| 497 |
+
except OSError as e:
|
| 498 |
+
print(f"OSError: {e}. Could not load the RKLLM library.")
|
| 499 |
+
print("Please ensure 'librkllmrt.so' is in your LD_LIBRARY_PATH or provide the full path.")
|
| 500 |
+
except Exception as e:
|
| 501 |
+
print(f"An unexpected error occurred: {e}")
|
| 502 |
+
finally:
|
| 503 |
+
if 'rk_llm' in locals() and rk_llm.llm_handle and rk_llm.llm_handle.value:
|
| 504 |
+
print("Destroying LLM instance...")
|
| 505 |
+
rk_llm.destroy()
|
| 506 |
+
print("LLM instance destroyed.")
|
| 507 |
+
if os.path.exists(model_file) and model_file == "dummy_model.rkllm":
|
| 508 |
+
os.remove(model_file) # Clean up dummy file
|
| 509 |
+
|
| 510 |
+
print("Example finished.")
|
run_rknn.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import faulthandler
|
| 2 |
+
faulthandler.enable()
|
| 3 |
+
import os
|
| 4 |
+
import time
|
| 5 |
+
import numpy as np
|
| 6 |
+
from rkllm_binding import *
|
| 7 |
+
import ztu_somemodelruntime_rknnlite2 as ort
|
| 8 |
+
import signal
|
| 9 |
+
import cv2
|
| 10 |
+
import ctypes
|
| 11 |
+
|
| 12 |
+
# --- Configuration ---
|
| 13 |
+
# These paths should point to the directory containing all model files
|
| 14 |
+
# or be absolute paths.
|
| 15 |
+
MODEL_DIR = "." # Assuming models are in the current directory or provide a specific path
|
| 16 |
+
LLM_MODEL_NAME = "qwen_f16.rkllm"
|
| 17 |
+
VISION_ENCODER_ONNX_NAME = "fastvithd.onnx"
|
| 18 |
+
MM_PROJECTOR_ONNX_NAME = "mm_projector.onnx"
|
| 19 |
+
PREPROCESSOR_CONFIG_NAME = "preprocessor_config.json" # Generated by export_onnx.py
|
| 20 |
+
|
| 21 |
+
LLM_MODEL_PATH = os.path.join(MODEL_DIR, LLM_MODEL_NAME)
|
| 22 |
+
VISION_ENCODER_PATH = os.path.join(MODEL_DIR, VISION_ENCODER_ONNX_NAME)
|
| 23 |
+
MM_PROJECTOR_PATH = os.path.join(MODEL_DIR, MM_PROJECTOR_ONNX_NAME)
|
| 24 |
+
PREPROCESSOR_CONFIG_PATH = os.path.join(MODEL_DIR, PREPROCESSOR_CONFIG_NAME)
|
| 25 |
+
|
| 26 |
+
IMAGE_PATH = "test.jpg" # Replace with your test image
|
| 27 |
+
# user_prompt = "Describe this image in detail."
|
| 28 |
+
user_prompt = "仔细描述一下这张图片。"
|
| 29 |
+
|
| 30 |
+
# Global RKLLMRuntime instance
|
| 31 |
+
rk_runtime = None
|
| 32 |
+
|
| 33 |
+
# Exit on Ctrl-C
|
| 34 |
+
def signal_handler(signal, frame):
|
| 35 |
+
print("Ctrl-C pressed, exiting...")
|
| 36 |
+
global rk_runtime
|
| 37 |
+
if rk_runtime:
|
| 38 |
+
try:
|
| 39 |
+
print("Attempting to abort RKLLM task...")
|
| 40 |
+
rk_runtime.abort()
|
| 41 |
+
print("RKLLM task aborted.")
|
| 42 |
+
except RuntimeError as e:
|
| 43 |
+
print(f"Note: RKLLM abort failed or task was not running: {e}")
|
| 44 |
+
except Exception as e:
|
| 45 |
+
print(f"Unexpected error during RKLLM abort in signal handler: {e}")
|
| 46 |
+
|
| 47 |
+
try:
|
| 48 |
+
print("Attempting to destroy RKLLM instance...")
|
| 49 |
+
rk_runtime.destroy()
|
| 50 |
+
print("RKLLM instance destroyed via signal handler.")
|
| 51 |
+
except RuntimeError as e:
|
| 52 |
+
print(f"Error during RKLLM destroy in signal handler: {e}")
|
| 53 |
+
except Exception as e: # Catch any other unexpected errors
|
| 54 |
+
print(f"Unexpected error during RKLLM destroy in signal handler: {e}")
|
| 55 |
+
exit(0)
|
| 56 |
+
|
| 57 |
+
signal.signal(signal.SIGINT, signal_handler)
|
| 58 |
+
|
| 59 |
+
# Set RKLLM log level if desired
|
| 60 |
+
os.environ["RKLLM_LOG_LEVEL"] = "1"
|
| 61 |
+
|
| 62 |
+
inference_count = 0
|
| 63 |
+
inference_start_time = 0
|
| 64 |
+
first_token_received = False
|
| 65 |
+
|
| 66 |
+
def result_callback(result_ptr, userdata, state_enum):
|
| 67 |
+
global inference_start_time, inference_count, first_token_received
|
| 68 |
+
state = LLMCallState(state_enum) # Convert int to enum
|
| 69 |
+
if result_ptr is None:
|
| 70 |
+
return
|
| 71 |
+
result = result_ptr.contents # Dereference the pointer
|
| 72 |
+
|
| 73 |
+
if state == LLMCallState.RKLLM_RUN_NORMAL:
|
| 74 |
+
if not first_token_received:
|
| 75 |
+
first_token_time = time.time()
|
| 76 |
+
print(f"\nTime to first token: {first_token_time - inference_start_time:.2f} seconds")
|
| 77 |
+
first_token_received = True
|
| 78 |
+
|
| 79 |
+
current_text = ""
|
| 80 |
+
if result.text: # Check if char_p is not NULL
|
| 81 |
+
current_text = result.text.decode('utf-8', errors='ignore')
|
| 82 |
+
print(current_text, end="", flush=True)
|
| 83 |
+
inference_count += 1
|
| 84 |
+
elif state == LLMCallState.RKLLM_RUN_FINISH:
|
| 85 |
+
print("\n\n(finished)")
|
| 86 |
+
elif state == LLMCallState.RKLLM_RUN_ERROR:
|
| 87 |
+
print("\nError occurred during LLM call")
|
| 88 |
+
# Add other states if needed, e.g., RKLLM_RUN_WAITING
|
| 89 |
+
|
| 90 |
+
def load_and_preprocess_image(image_path, config_path):
|
| 91 |
+
img_size = 1024
|
| 92 |
+
image_mean = [0.0, 0.0, 0.0]
|
| 93 |
+
image_std = [1.0, 1.0, 1.0]
|
| 94 |
+
|
| 95 |
+
print(f"Target image size from config: {img_size}x{img_size}")
|
| 96 |
+
print(f"Using image_mean: {image_mean}, image_std: {image_std}")
|
| 97 |
+
|
| 98 |
+
img = cv2.imread(image_path)
|
| 99 |
+
if img is None:
|
| 100 |
+
raise FileNotFoundError(f"Image not found: {image_path}")
|
| 101 |
+
|
| 102 |
+
# 计算缩放比例,保持宽高比
|
| 103 |
+
h, w = img.shape[:2]
|
| 104 |
+
scale = min(img_size / w, img_size / h)
|
| 105 |
+
new_w, new_h = int(w * scale), int(h * scale)
|
| 106 |
+
|
| 107 |
+
# 保持比例缩放
|
| 108 |
+
img_resized = cv2.resize(img, (new_w, new_h))
|
| 109 |
+
|
| 110 |
+
# 创建目标大小的黑色背景
|
| 111 |
+
img_padded = np.zeros((img_size, img_size, 3), dtype=np.uint8)
|
| 112 |
+
|
| 113 |
+
# 将缩放后的图像放在中心位置
|
| 114 |
+
y_offset = (img_size - new_h) // 2
|
| 115 |
+
x_offset = (img_size - new_w) // 2
|
| 116 |
+
img_padded[y_offset:y_offset+new_h, x_offset:x_offset+new_w] = img_resized
|
| 117 |
+
|
| 118 |
+
img_rgb = cv2.cvtColor(img_padded, cv2.COLOR_BGR2RGB)
|
| 119 |
+
img_fp32 = img_rgb.astype(np.float32)
|
| 120 |
+
|
| 121 |
+
# Normalize
|
| 122 |
+
img_normalized = (img_fp32 / 255.0 - image_mean) / image_std
|
| 123 |
+
|
| 124 |
+
# Transpose to NCHW format
|
| 125 |
+
img_nchw = img_normalized.transpose(2, 0, 1) # HWC to CHW
|
| 126 |
+
img_batch = img_nchw[np.newaxis, :, :, :] # Add batch dimension -> NCHW
|
| 127 |
+
|
| 128 |
+
return img_batch.astype(np.float32), img_size
|
| 129 |
+
def main():
|
| 130 |
+
global rk_runtime, inference_start_time, inference_count, first_token_received, user_prompt
|
| 131 |
+
|
| 132 |
+
# --- 1. Initialize ONNX Runtime for Vision Models ---
|
| 133 |
+
print("Loading ONNX vision encoder model...")
|
| 134 |
+
vision_session = ort.InferenceSession(VISION_ENCODER_PATH)
|
| 135 |
+
vision_input_name = vision_session.get_inputs()[0].name
|
| 136 |
+
vision_output_name = vision_session.get_outputs()[0].name
|
| 137 |
+
print(f"ONNX vision encoder loaded. Input: '{vision_input_name}', Output: '{vision_output_name}'")
|
| 138 |
+
|
| 139 |
+
print("Loading ONNX mm_projector model...")
|
| 140 |
+
mm_projector_session = ort.InferenceSession(MM_PROJECTOR_PATH)
|
| 141 |
+
mm_projector_input_name = mm_projector_session.get_inputs()[0].name
|
| 142 |
+
mm_projector_output_name = mm_projector_session.get_outputs()[0].name
|
| 143 |
+
print(f"ONNX mm_projector loaded. Input: '{mm_projector_input_name}', Output: '{mm_projector_output_name}'")
|
| 144 |
+
|
| 145 |
+
# --- 2. Initialize RKLLM ---
|
| 146 |
+
print("Initializing RKLLM...")
|
| 147 |
+
rk_runtime = RKLLMRuntime()
|
| 148 |
+
|
| 149 |
+
param = rk_runtime.create_default_param()
|
| 150 |
+
param.model_path = LLM_MODEL_PATH.encode('utf-8')
|
| 151 |
+
param.img_start = "<image>".encode('utf-8')
|
| 152 |
+
param.img_end = "".encode('utf-8')
|
| 153 |
+
param.img_content = "<unk>".encode('utf-8')
|
| 154 |
+
|
| 155 |
+
extend_param = RKLLMExtendParam()
|
| 156 |
+
extend_param.base_domain_id = 1
|
| 157 |
+
extend_param.embed_flash = 1
|
| 158 |
+
extend_param.enabled_cpus_num = 8
|
| 159 |
+
extend_param.enabled_cpus_mask = 0xffffffff
|
| 160 |
+
param.extend_param = extend_param
|
| 161 |
+
|
| 162 |
+
model_size_llm = os.path.getsize(LLM_MODEL_PATH)
|
| 163 |
+
print(f"Start loading language model (size: {model_size_llm / 1024 / 1024:.2f} MB)")
|
| 164 |
+
start_time_llm_load = time.time()
|
| 165 |
+
|
| 166 |
+
try:
|
| 167 |
+
rk_runtime.init(param, result_callback)
|
| 168 |
+
except RuntimeError as e:
|
| 169 |
+
print(f"RKLLM init failed: {e}")
|
| 170 |
+
if rk_runtime:
|
| 171 |
+
try:
|
| 172 |
+
rk_runtime.destroy()
|
| 173 |
+
except Exception as e_destroy:
|
| 174 |
+
print(f"Error destroying RKLLM after init failure: {e_destroy}")
|
| 175 |
+
return
|
| 176 |
+
|
| 177 |
+
end_time_llm_load = time.time()
|
| 178 |
+
print(f"Language model loaded in {end_time_llm_load - start_time_llm_load:.2f} seconds")
|
| 179 |
+
|
| 180 |
+
# --- 3. Load and Preprocess Image ---
|
| 181 |
+
print(f"Loading and preprocessing image: {IMAGE_PATH}")
|
| 182 |
+
preprocessed_image, original_img_dim = load_and_preprocess_image(IMAGE_PATH, PREPROCESSOR_CONFIG_PATH)
|
| 183 |
+
print(f"Input image shape for ONNX vision model: {preprocessed_image.shape}")
|
| 184 |
+
|
| 185 |
+
# --- 4. Vision Encoder Inference (ONNX) ---
|
| 186 |
+
start_time_vision = time.time()
|
| 187 |
+
vision_outputs = vision_session.run([vision_output_name], {vision_input_name: preprocessed_image})
|
| 188 |
+
image_features_from_vision = vision_outputs[0]
|
| 189 |
+
end_time_vision = time.time()
|
| 190 |
+
print(f"ONNX Vision encoder inference time: {end_time_vision - start_time_vision:.2f} seconds")
|
| 191 |
+
print(f"Vision encoder output shape: {image_features_from_vision.shape}")
|
| 192 |
+
|
| 193 |
+
# --- 5. MM Projector Inference (ONNX) ---
|
| 194 |
+
start_time_projector = time.time()
|
| 195 |
+
projector_outputs = mm_projector_session.run([mm_projector_output_name], {mm_projector_input_name: image_features_from_vision})
|
| 196 |
+
projected_image_embeddings_np = projector_outputs[0]
|
| 197 |
+
end_time_projector = time.time()
|
| 198 |
+
print(f"ONNX MM projector inference time: {end_time_projector - start_time_projector:.2f} seconds")
|
| 199 |
+
print(f"Projected image embeddings shape: {projected_image_embeddings_np.shape}")
|
| 200 |
+
|
| 201 |
+
# Ensure C-contiguous and float32 for ctypes
|
| 202 |
+
projected_image_embeddings_np = np.ascontiguousarray(projected_image_embeddings_np, dtype=np.float32)
|
| 203 |
+
|
| 204 |
+
# --- 6. Prepare Prompt and RKLLMInput ---
|
| 205 |
+
# The prompt should contain the <image> placeholder where the image features will be inserted.
|
| 206 |
+
# prompt = f"""<|im_start|>system
|
| 207 |
+
# You are a helpful assistant.<|im_end|>
|
| 208 |
+
# <|im_start|>user
|
| 209 |
+
# {param.img_start.decode()}
|
| 210 |
+
# {user_prompt}<|im_end|>
|
| 211 |
+
# <|im_start|>assistant
|
| 212 |
+
# """
|
| 213 |
+
|
| 214 |
+
# RKLLM now loads its own chat template, so we don't need to include that.
|
| 215 |
+
prompt = f"""{param.img_start.decode()}
|
| 216 |
+
{user_prompt}"""
|
| 217 |
+
|
| 218 |
+
print(f"\nUsing prompt:\n{prompt}")
|
| 219 |
+
|
| 220 |
+
rkllm_input = RKLLMInput()
|
| 221 |
+
rkllm_input.input_type = RKLLMInputType.RKLLM_INPUT_MULTIMODAL
|
| 222 |
+
|
| 223 |
+
multimodal_payload = RKLLMMultiModelInput()
|
| 224 |
+
multimodal_payload.prompt = prompt.encode('utf-8')
|
| 225 |
+
|
| 226 |
+
# projected_image_embeddings_np has shape (1, num_tokens, hidden_dim)
|
| 227 |
+
num_image_tokens = projected_image_embeddings_np.shape[1]
|
| 228 |
+
# The C API expects a flat pointer to the embedding data.
|
| 229 |
+
embedding_data_flat = projected_image_embeddings_np.flatten()
|
| 230 |
+
|
| 231 |
+
multimodal_payload.image_embed = embedding_data_flat.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
|
| 232 |
+
multimodal_payload.n_image_tokens = num_image_tokens
|
| 233 |
+
multimodal_payload.n_image = 1 # Number of images processed
|
| 234 |
+
multimodal_payload.image_width = original_img_dim # Width of the (resized before processing) image
|
| 235 |
+
multimodal_payload.image_height = original_img_dim # Height of the (resized before processing) image
|
| 236 |
+
|
| 237 |
+
rkllm_input._union_data.multimodal_input = multimodal_payload
|
| 238 |
+
|
| 239 |
+
# --- 7. Create Inference Parameters ---
|
| 240 |
+
infer_param = RKLLMInferParam()
|
| 241 |
+
infer_param.mode = RKLLMInferMode.RKLLM_INFER_GENERATE.value # Ensure this is an int for C API
|
| 242 |
+
# infer_param.keep_history = 1 # Or 0, default is usually 0 (false) in create_default_param or C struct.
|
| 243 |
+
# Check rkllm.h or binding for default if not setting explicitly.
|
| 244 |
+
# RKLLMInferParam from binding has keep_history as c_int.
|
| 245 |
+
|
| 246 |
+
# --- 8. Run RKLLM Inference ---
|
| 247 |
+
print("Starting RKLLM inference...")
|
| 248 |
+
inference_start_time = time.time()
|
| 249 |
+
inference_count = 0
|
| 250 |
+
first_token_received = False
|
| 251 |
+
|
| 252 |
+
try:
|
| 253 |
+
# The RKLLMRuntime.run method takes input and infer_param objects directly.
|
| 254 |
+
rk_runtime.run(rkllm_input, infer_param, None) # Userdata is None
|
| 255 |
+
except RuntimeError as e:
|
| 256 |
+
print(f"RKLLM run failed: {e}")
|
| 257 |
+
|
| 258 |
+
# --- 9. Clean up ---
|
| 259 |
+
# Normal cleanup if not interrupted by Ctrl-C.
|
| 260 |
+
# The signal handler also attempts to destroy the instance.
|
| 261 |
+
if rk_runtime and rk_runtime.llm_handle and rk_runtime.llm_handle.value:
|
| 262 |
+
try:
|
| 263 |
+
rk_runtime.destroy()
|
| 264 |
+
print("RKLLM instance destroyed at script end.")
|
| 265 |
+
except RuntimeError as e:
|
| 266 |
+
print(f"Error during RKLLM destroy at script end: {e}")
|
| 267 |
+
except Exception as e:
|
| 268 |
+
print(f"Unexpected error during RKLLM destroy at script end: {e}")
|
| 269 |
+
|
| 270 |
+
print("Script finished.")
|
| 271 |
+
|
| 272 |
+
if __name__ == "__main__":
|
| 273 |
+
# rk_runtime (global) will be initialized inside main()
|
| 274 |
+
main()
|
test.jpg
ADDED
|
Git LFS Details
|
ztu_somemodelruntime_rknnlite2.py
ADDED
|
@@ -0,0 +1,569 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 模块级常量和函数
|
| 2 |
+
from rknnlite.api import RKNNLite
|
| 3 |
+
import numpy as np
|
| 4 |
+
import os
|
| 5 |
+
import warnings
|
| 6 |
+
import logging
|
| 7 |
+
from typing import List, Dict, Union, Optional
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
import onnxruntime as ort
|
| 11 |
+
HAS_ORT = True
|
| 12 |
+
except ImportError:
|
| 13 |
+
HAS_ORT = False
|
| 14 |
+
warnings.warn("onnxruntime未安装,只能使用RKNN后端", ImportWarning)
|
| 15 |
+
|
| 16 |
+
# 配置日志
|
| 17 |
+
logger = logging.getLogger("somemodelruntime_rknnlite2")
|
| 18 |
+
logger.setLevel(logging.ERROR) # 默认只输出错误信息
|
| 19 |
+
if not logger.handlers:
|
| 20 |
+
handler = logging.StreamHandler()
|
| 21 |
+
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
|
| 22 |
+
logger.addHandler(handler)
|
| 23 |
+
|
| 24 |
+
# ONNX Runtime日志级别到Python logging级别的映射
|
| 25 |
+
_LOGGING_LEVEL_MAP = {
|
| 26 |
+
0: logging.DEBUG, # Verbose
|
| 27 |
+
1: logging.INFO, # Info
|
| 28 |
+
2: logging.WARNING, # Warning
|
| 29 |
+
3: logging.ERROR, # Error
|
| 30 |
+
4: logging.CRITICAL # Fatal
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
# 检查环境变量中的日志级别设置
|
| 34 |
+
try:
|
| 35 |
+
env_log_level = os.getenv('ZTU_MODELRT_RKNNL2_LOG_LEVEL')
|
| 36 |
+
if env_log_level is not None:
|
| 37 |
+
log_level = int(env_log_level)
|
| 38 |
+
if log_level in _LOGGING_LEVEL_MAP:
|
| 39 |
+
logger.setLevel(_LOGGING_LEVEL_MAP[log_level])
|
| 40 |
+
logger.info(f"从环境变量设置日志级别: {log_level}")
|
| 41 |
+
else:
|
| 42 |
+
logger.warning(f"环境变量ZTU_MODELRT_RKNNL2_LOG_LEVEL的值无效: {log_level}, 应该是0-4之间的整数")
|
| 43 |
+
except ValueError:
|
| 44 |
+
logger.warning(f"环境变量ZTU_MODELRT_RKNNL2_LOG_LEVEL的值无效: {env_log_level}, 应该是0-4之间的整数")
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def set_default_logger_severity(level: int) -> None:
|
| 48 |
+
"""
|
| 49 |
+
Sets the default logging severity. 0:Verbose, 1:Info, 2:Warning, 3:Error, 4:Fatal
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
level: 日志级别(0-4)
|
| 53 |
+
"""
|
| 54 |
+
if level not in _LOGGING_LEVEL_MAP:
|
| 55 |
+
raise ValueError(f"无效的日志级别: {level}, 应该是0-4之间的整数")
|
| 56 |
+
logger.setLevel(_LOGGING_LEVEL_MAP[level])
|
| 57 |
+
|
| 58 |
+
def set_default_logger_verbosity(level: int) -> None:
|
| 59 |
+
"""
|
| 60 |
+
Sets the default logging verbosity level. To activate the verbose log,
|
| 61 |
+
you need to set the default logging severity to 0:Verbose level.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
level: 日志级别(0-4)
|
| 65 |
+
"""
|
| 66 |
+
set_default_logger_severity(level)
|
| 67 |
+
|
| 68 |
+
# RKNN tensor type到numpy dtype的映射
|
| 69 |
+
RKNN_DTYPE_MAP = {
|
| 70 |
+
0: np.float32, # RKNN_TENSOR_FLOAT32
|
| 71 |
+
1: np.float16, # RKNN_TENSOR_FLOAT16
|
| 72 |
+
2: np.int8, # RKNN_TENSOR_INT8
|
| 73 |
+
3: np.uint8, # RKNN_TENSOR_UINT8
|
| 74 |
+
4: np.int16, # RKNN_TENSOR_INT16
|
| 75 |
+
5: np.uint16, # RKNN_TENSOR_UINT16
|
| 76 |
+
6: np.int32, # RKNN_TENSOR_INT32
|
| 77 |
+
7: np.uint32, # RKNN_TENSOR_UINT32
|
| 78 |
+
8: np.int64, # RKNN_TENSOR_INT64
|
| 79 |
+
9: bool, # RKNN_TENSOR_BOOL
|
| 80 |
+
10: np.int8, # RKNN_TENSOR_INT4 (用int8表示)
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
def get_available_providers() -> List[str]:
|
| 84 |
+
"""
|
| 85 |
+
获取可用的设备提供者列表(为保持接口兼容性的占位函数)
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
list: 可用的设备提供者列表,总是返回["CPUExecutionProvider", "somemodelruntime_rknnlite2_ExecutionProvider"]
|
| 89 |
+
"""
|
| 90 |
+
return ["CPUExecutionProvider", "somemodelruntime_rknnlite2_ExecutionProvider"]
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def get_device() -> str:
|
| 94 |
+
"""
|
| 95 |
+
获取当前设备
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
str: 当前设备
|
| 99 |
+
"""
|
| 100 |
+
return "RKNN2"
|
| 101 |
+
|
| 102 |
+
def get_version_info() -> Dict[str, str]:
|
| 103 |
+
"""
|
| 104 |
+
获取版本信息
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
dict: 包含API和驱动版本信息的字典
|
| 108 |
+
"""
|
| 109 |
+
runtime = RKNNLite()
|
| 110 |
+
version = runtime.get_sdk_version()
|
| 111 |
+
return {
|
| 112 |
+
"api_version": version.split('\n')[2].split(': ')[1].split(' ')[0],
|
| 113 |
+
"driver_version": version.split('\n')[3].split(': ')[1]
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
class IOTensor:
|
| 117 |
+
"""输入/输出张量的信息封装类"""
|
| 118 |
+
def __init__(self, name, shape, type=None):
|
| 119 |
+
self.name = name.decode() if isinstance(name, bytes) else name
|
| 120 |
+
self.shape = shape
|
| 121 |
+
self.type = type
|
| 122 |
+
|
| 123 |
+
def __str__(self):
|
| 124 |
+
return f"IOTensor(name='{self.name}', shape={self.shape}, type={self.type})"
|
| 125 |
+
|
| 126 |
+
class SessionOptions:
|
| 127 |
+
"""会话选项类"""
|
| 128 |
+
def __init__(self):
|
| 129 |
+
self.enable_profiling = False # 是否使用性能分析
|
| 130 |
+
self.intra_op_num_threads = 1 # 设置RKNN的线程数, 对应rknn的core_mask
|
| 131 |
+
self.log_severity_level = -1 # 另一个设置日志级别的参数
|
| 132 |
+
self.log_verbosity_level = -1 # 另一个设置日志级别的参数
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class InferenceSession:
|
| 136 |
+
"""
|
| 137 |
+
RKNNLite运行时封装类,API风格类似ONNX Runtime
|
| 138 |
+
"""
|
| 139 |
+
|
| 140 |
+
def __new__(cls, model_path: str, sess_options: Optional[SessionOptions] = None, **kwargs):
|
| 141 |
+
processed_path = InferenceSession._process_model_path(model_path, sess_options)
|
| 142 |
+
if isinstance(processed_path, str) and processed_path.lower().endswith('.onnx'):
|
| 143 |
+
logger.info("使用ONNX Runtime加载模型")
|
| 144 |
+
if not HAS_ORT:
|
| 145 |
+
raise RuntimeError("未安装onnxruntime,无法加载ONNX模型")
|
| 146 |
+
return ort.InferenceSession(processed_path, sess_options=sess_options, **kwargs)
|
| 147 |
+
else:
|
| 148 |
+
# 如果不是 ONNX 模型,则调用父类的 __new__ 创建 InferenceSession 实例
|
| 149 |
+
instance = super().__new__(cls)
|
| 150 |
+
# 保存处理后的路径
|
| 151 |
+
instance._processed_path = processed_path
|
| 152 |
+
return instance
|
| 153 |
+
|
| 154 |
+
def __init__(self, model_path: str, sess_options: Optional[SessionOptions] = None, **kwargs):
|
| 155 |
+
"""
|
| 156 |
+
初始化运行时并加载模型
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
model_path: 模型文件路径(.rknn或.onnx)
|
| 160 |
+
sess_options: 会话选项
|
| 161 |
+
**kwargs: 其他初始化参数
|
| 162 |
+
"""
|
| 163 |
+
options = sess_options or SessionOptions()
|
| 164 |
+
|
| 165 |
+
# 只在未设置环境变量时使用SessionOptions中的日志级别
|
| 166 |
+
if os.getenv('ZTU_MODELRT_RKNNL2_LOG_LEVEL') is None:
|
| 167 |
+
if options.log_severity_level != -1:
|
| 168 |
+
set_default_logger_severity(options.log_severity_level)
|
| 169 |
+
if options.log_verbosity_level != -1:
|
| 170 |
+
set_default_logger_verbosity(options.log_verbosity_level)
|
| 171 |
+
|
| 172 |
+
# 使用__new__中处理好的路径
|
| 173 |
+
model_path = getattr(self, '_processed_path', model_path)
|
| 174 |
+
if isinstance(model_path, str) and model_path.lower().endswith('.onnx'):
|
| 175 |
+
# 避免重复加载 ONNX 模型
|
| 176 |
+
return
|
| 177 |
+
|
| 178 |
+
# ... 现有的 RKNN 模型加载和初始化代码 ...
|
| 179 |
+
self.model_path = model_path
|
| 180 |
+
if not os.path.exists(self.model_path):
|
| 181 |
+
logger.error(f"模型文件不存在: {self.model_path}")
|
| 182 |
+
raise FileNotFoundError(f"模型文件不存在: {self.model_path}")
|
| 183 |
+
|
| 184 |
+
self.runtime = RKNNLite(verbose=options.enable_profiling)
|
| 185 |
+
|
| 186 |
+
logger.debug(f"正在加载模型: {self.model_path}")
|
| 187 |
+
ret = self.runtime.load_rknn(self.model_path)
|
| 188 |
+
if ret != 0:
|
| 189 |
+
logger.error(f"加载RKNN模型失败: {self.model_path}")
|
| 190 |
+
raise RuntimeError(f'加载RKNN模型失败: {self.model_path}')
|
| 191 |
+
logger.debug("模型加载成功")
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
if options.intra_op_num_threads == 1:
|
| 195 |
+
core_mask = RKNNLite.NPU_CORE_AUTO
|
| 196 |
+
elif options.intra_op_num_threads == 2:
|
| 197 |
+
core_mask = RKNNLite.NPU_CORE_0_1
|
| 198 |
+
elif options.intra_op_num_threads == 3:
|
| 199 |
+
core_mask = RKNNLite.NPU_CORE_0_1_2
|
| 200 |
+
else:
|
| 201 |
+
raise ValueError(f"intra_op_num_threads的值无效: {options.intra_op_num_threads}, 只能是1,2或3")
|
| 202 |
+
|
| 203 |
+
logger.debug("正在初始化运行时环境")
|
| 204 |
+
ret = self.runtime.init_runtime(core_mask=core_mask)
|
| 205 |
+
if ret != 0:
|
| 206 |
+
logger.error("初始化运行时环境失败")
|
| 207 |
+
raise RuntimeError('初始化运行时环境失败')
|
| 208 |
+
logger.debug("运行时环境初始化成功")
|
| 209 |
+
|
| 210 |
+
self._init_io_info()
|
| 211 |
+
self.options = options
|
| 212 |
+
|
| 213 |
+
def get_performance_info(self) -> Dict[str, float]:
|
| 214 |
+
"""
|
| 215 |
+
获取性能信息
|
| 216 |
+
|
| 217 |
+
Returns:
|
| 218 |
+
dict: 包含性能信息的字典
|
| 219 |
+
"""
|
| 220 |
+
if not self.options.perf_debug:
|
| 221 |
+
raise RuntimeError("性能分析未启用,请在SessionOptions中设置perf_debug=True")
|
| 222 |
+
|
| 223 |
+
perf = self.runtime.rknn_runtime.get_run_perf()
|
| 224 |
+
return {
|
| 225 |
+
"run_duration": perf.run_duration / 1000.0 # 转换为毫秒
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
def set_core_mask(self, core_mask: int) -> None:
|
| 229 |
+
"""
|
| 230 |
+
设置NPU核心使用模式
|
| 231 |
+
|
| 232 |
+
Args:
|
| 233 |
+
core_mask: NPU核心掩码,使用NPU_CORE_*常量
|
| 234 |
+
"""
|
| 235 |
+
ret = self.runtime.rknn_runtime.set_core_mask(core_mask)
|
| 236 |
+
if ret != 0:
|
| 237 |
+
raise RuntimeError("设置NPU核心模式失败")
|
| 238 |
+
|
| 239 |
+
@staticmethod
|
| 240 |
+
def _process_model_path(model_path, sess_options):
|
| 241 |
+
"""
|
| 242 |
+
处理模型路径,支持.onnx和.rknn文件
|
| 243 |
+
|
| 244 |
+
Args:
|
| 245 |
+
model_path: 模型文件路径
|
| 246 |
+
"""
|
| 247 |
+
# 如果是ONNX文件,检查是否需要自动加载RKNN
|
| 248 |
+
if model_path.lower().endswith('.onnx'):
|
| 249 |
+
logger.info("检测到ONNX模型文件")
|
| 250 |
+
|
| 251 |
+
# 获取需要跳过自动加载的模型列表
|
| 252 |
+
skip_models = os.getenv('ZTU_MODELRT_RKNNL2_SKIP', '').strip()
|
| 253 |
+
if skip_models:
|
| 254 |
+
skip_list = [m.strip() for m in skip_models.split(',')]
|
| 255 |
+
# 获取模型文件名(不含路径)用于匹配
|
| 256 |
+
model_name = os.path.basename(model_path)
|
| 257 |
+
if model_name.lower() in [m.lower() for m in skip_list]:
|
| 258 |
+
logger.info(f"模型{model_name}在跳过列表中,将使用ONNX Runtime")
|
| 259 |
+
return model_path
|
| 260 |
+
|
| 261 |
+
# 构造RKNN文件路径
|
| 262 |
+
rknn_path = os.path.splitext(model_path)[0] + '.rknn'
|
| 263 |
+
if os.path.exists(rknn_path):
|
| 264 |
+
logger.info(f"找到对应的RKNN模型,将使用RKNN: {rknn_path}")
|
| 265 |
+
return rknn_path
|
| 266 |
+
else:
|
| 267 |
+
logger.info("未找到对应的RKNN模型,将使用ONNX Runtime")
|
| 268 |
+
return model_path
|
| 269 |
+
|
| 270 |
+
return model_path
|
| 271 |
+
|
| 272 |
+
def _convert_nhwc_to_nchw(self, shape):
|
| 273 |
+
"""将NHWC格式的shape转换为NCHW格式"""
|
| 274 |
+
if len(shape) == 4:
|
| 275 |
+
# NHWC -> NCHW
|
| 276 |
+
n, h, w, c = shape
|
| 277 |
+
return [n, c, h, w]
|
| 278 |
+
return shape
|
| 279 |
+
|
| 280 |
+
def _init_io_info(self):
|
| 281 |
+
"""初始化模型的输入输出信息"""
|
| 282 |
+
runtime = self.runtime.rknn_runtime
|
| 283 |
+
|
| 284 |
+
# 获取输入输出数量
|
| 285 |
+
n_input, n_output = runtime.get_in_out_num()
|
| 286 |
+
|
| 287 |
+
# 获取输入信息
|
| 288 |
+
self.input_tensors = []
|
| 289 |
+
for i in range(n_input):
|
| 290 |
+
attr = runtime.get_tensor_attr(i)
|
| 291 |
+
shape = [attr.dims[j] for j in range(attr.n_dims)]
|
| 292 |
+
# 对四维输入进行NHWC到NCHW的转换
|
| 293 |
+
shape = self._convert_nhwc_to_nchw(shape)
|
| 294 |
+
# 获取dtype
|
| 295 |
+
dtype = RKNN_DTYPE_MAP.get(attr.type, None)
|
| 296 |
+
tensor = IOTensor(attr.name, shape, dtype)
|
| 297 |
+
self.input_tensors.append(tensor)
|
| 298 |
+
|
| 299 |
+
# 获取输出信息
|
| 300 |
+
self.output_tensors = []
|
| 301 |
+
for i in range(n_output):
|
| 302 |
+
attr = runtime.get_tensor_attr(i, is_output=True)
|
| 303 |
+
shape = runtime.get_output_shape(i)
|
| 304 |
+
# 获取dtype
|
| 305 |
+
dtype = RKNN_DTYPE_MAP.get(attr.type, None)
|
| 306 |
+
tensor = IOTensor(attr.name, shape, dtype)
|
| 307 |
+
self.output_tensors.append(tensor)
|
| 308 |
+
|
| 309 |
+
def get_inputs(self):
|
| 310 |
+
"""
|
| 311 |
+
获取模型输入信息
|
| 312 |
+
|
| 313 |
+
Returns:
|
| 314 |
+
list: 包含输入信息的列表
|
| 315 |
+
"""
|
| 316 |
+
return self.input_tensors
|
| 317 |
+
|
| 318 |
+
def get_outputs(self):
|
| 319 |
+
"""
|
| 320 |
+
获取模型输出信息
|
| 321 |
+
|
| 322 |
+
Returns:
|
| 323 |
+
list: 包含输出信息的列表
|
| 324 |
+
"""
|
| 325 |
+
return self.output_tensors
|
| 326 |
+
|
| 327 |
+
def run(self, output_names=None, input_feed=None, data_format="nchw", **kwargs):
|
| 328 |
+
"""
|
| 329 |
+
执行模型推理
|
| 330 |
+
|
| 331 |
+
Args:
|
| 332 |
+
output_names: 输出节点名称列表,指定需要返回哪些输出
|
| 333 |
+
input_feed: 输入数据字典或列表
|
| 334 |
+
data_format: 输入数据格式,"nchw"或"nhwc"
|
| 335 |
+
**kwargs: 其他运行时参数
|
| 336 |
+
|
| 337 |
+
Returns:
|
| 338 |
+
list: 模型输出结果列表,如果指定了output_names则只返回指定的输出
|
| 339 |
+
"""
|
| 340 |
+
if input_feed is None:
|
| 341 |
+
logger.error("input_feed不能为None")
|
| 342 |
+
raise ValueError("input_feed不能为None")
|
| 343 |
+
|
| 344 |
+
# 准备输入数据
|
| 345 |
+
if isinstance(input_feed, dict):
|
| 346 |
+
# 如果是字典,按照模型输入顺序排列
|
| 347 |
+
inputs = []
|
| 348 |
+
input_map = {tensor.name: i for i, tensor in enumerate(self.input_tensors)}
|
| 349 |
+
for tensor in self.input_tensors:
|
| 350 |
+
if tensor.name not in input_feed:
|
| 351 |
+
raise ValueError(f"缺少输入: {tensor.name}")
|
| 352 |
+
inputs.append(input_feed[tensor.name])
|
| 353 |
+
elif isinstance(input_feed, (list, tuple)):
|
| 354 |
+
# 如果是列表,确保长度匹配
|
| 355 |
+
if len(input_feed) != len(self.input_tensors):
|
| 356 |
+
raise ValueError(f"输入数量不匹配: 期望{len(self.input_tensors)}, 实际{len(input_feed)}")
|
| 357 |
+
inputs = list(input_feed)
|
| 358 |
+
else:
|
| 359 |
+
logger.error("input_feed必须是字典或列表类型")
|
| 360 |
+
raise ValueError("input_feed必须是字典或列表类型")
|
| 361 |
+
|
| 362 |
+
# 执行推理
|
| 363 |
+
try:
|
| 364 |
+
logger.debug("开始执行推理")
|
| 365 |
+
all_outputs = self.runtime.inference(inputs=inputs, data_format=data_format)
|
| 366 |
+
|
| 367 |
+
# 如果没有指定output_names,返回所有输出
|
| 368 |
+
if output_names is None:
|
| 369 |
+
return all_outputs
|
| 370 |
+
|
| 371 |
+
# 获取指定的输出
|
| 372 |
+
output_map = {tensor.name: i for i, tensor in enumerate(self.output_tensors)}
|
| 373 |
+
selected_outputs = []
|
| 374 |
+
for name in output_names:
|
| 375 |
+
if name not in output_map:
|
| 376 |
+
raise ValueError(f"未找到输出节点: {name}")
|
| 377 |
+
selected_outputs.append(all_outputs[output_map[name]])
|
| 378 |
+
|
| 379 |
+
return selected_outputs
|
| 380 |
+
|
| 381 |
+
except Exception as e:
|
| 382 |
+
logger.error(f"推理执行失败: {str(e)}")
|
| 383 |
+
raise RuntimeError(f"推理执行失败: {str(e)}")
|
| 384 |
+
|
| 385 |
+
def close(self):
|
| 386 |
+
"""
|
| 387 |
+
关闭会话,释放资源
|
| 388 |
+
"""
|
| 389 |
+
if self.runtime is not None:
|
| 390 |
+
logger.info("正在释放运行时资源")
|
| 391 |
+
self.runtime.release()
|
| 392 |
+
self.runtime = None
|
| 393 |
+
|
| 394 |
+
def __enter__(self):
|
| 395 |
+
return self
|
| 396 |
+
|
| 397 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 398 |
+
self.close()
|
| 399 |
+
|
| 400 |
+
def end_profiling(self) -> Optional[str]:
|
| 401 |
+
"""
|
| 402 |
+
结束性能分析的存根方法
|
| 403 |
+
|
| 404 |
+
Returns:
|
| 405 |
+
Optional[str]: None
|
| 406 |
+
"""
|
| 407 |
+
warnings.warn("end_profiling()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 408 |
+
return None
|
| 409 |
+
|
| 410 |
+
def get_profiling_start_time_ns(self) -> int:
|
| 411 |
+
"""
|
| 412 |
+
获取性能分析开始时间的存根方法
|
| 413 |
+
|
| 414 |
+
Returns:
|
| 415 |
+
int: 0
|
| 416 |
+
"""
|
| 417 |
+
warnings.warn("get_profiling_start_time_ns()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 418 |
+
return 0
|
| 419 |
+
|
| 420 |
+
def get_modelmeta(self) -> Dict[str, str]:
|
| 421 |
+
"""
|
| 422 |
+
获取模型元数据的存根方法
|
| 423 |
+
|
| 424 |
+
Returns:
|
| 425 |
+
Dict[str, str]: 空字典
|
| 426 |
+
"""
|
| 427 |
+
warnings.warn("get_modelmeta()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 428 |
+
return {}
|
| 429 |
+
|
| 430 |
+
def get_session_options(self) -> SessionOptions:
|
| 431 |
+
"""
|
| 432 |
+
获取会话选项
|
| 433 |
+
|
| 434 |
+
Returns:
|
| 435 |
+
SessionOptions: 当前会话选项
|
| 436 |
+
"""
|
| 437 |
+
return self.options
|
| 438 |
+
|
| 439 |
+
def get_providers(self) -> List[str]:
|
| 440 |
+
"""
|
| 441 |
+
获取当前使用的providers的存根方法
|
| 442 |
+
|
| 443 |
+
Returns:
|
| 444 |
+
List[str]: ["CPUExecutionProvider"]
|
| 445 |
+
"""
|
| 446 |
+
warnings.warn("get_providers()是存根方法,始终返回CPUExecutionProvider", RuntimeWarning, stacklevel=2)
|
| 447 |
+
return ["CPUExecutionProvider"]
|
| 448 |
+
|
| 449 |
+
def get_provider_options(self) -> Dict[str, Dict[str, str]]:
|
| 450 |
+
"""
|
| 451 |
+
获取provider选项的存根方法
|
| 452 |
+
|
| 453 |
+
Returns:
|
| 454 |
+
Dict[str, Dict[str, str]]: 空字典
|
| 455 |
+
"""
|
| 456 |
+
warnings.warn("get_provider_options()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 457 |
+
return {}
|
| 458 |
+
|
| 459 |
+
def get_session_config(self) -> Dict[str, str]:
|
| 460 |
+
"""
|
| 461 |
+
获取会话配置的存根方法
|
| 462 |
+
|
| 463 |
+
Returns:
|
| 464 |
+
Dict[str, str]: 空字典
|
| 465 |
+
"""
|
| 466 |
+
warnings.warn("get_session_config()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 467 |
+
return {}
|
| 468 |
+
|
| 469 |
+
def get_session_state(self) -> Dict[str, str]:
|
| 470 |
+
"""
|
| 471 |
+
获取会话状态的存根方法
|
| 472 |
+
|
| 473 |
+
Returns:
|
| 474 |
+
Dict[str, str]: 空字典
|
| 475 |
+
"""
|
| 476 |
+
warnings.warn("get_session_state()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 477 |
+
return {}
|
| 478 |
+
|
| 479 |
+
def set_session_config(self, config: Dict[str, str]) -> None:
|
| 480 |
+
"""
|
| 481 |
+
设置会话配置的存根方法
|
| 482 |
+
|
| 483 |
+
Args:
|
| 484 |
+
config: 会话配置字典
|
| 485 |
+
"""
|
| 486 |
+
warnings.warn("set_session_config()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 487 |
+
|
| 488 |
+
def get_memory_info(self) -> Dict[str, int]:
|
| 489 |
+
"""
|
| 490 |
+
获取内存使用信息的存根方法
|
| 491 |
+
|
| 492 |
+
Returns:
|
| 493 |
+
Dict[str, int]: 空字典
|
| 494 |
+
"""
|
| 495 |
+
warnings.warn("get_memory_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 496 |
+
return {}
|
| 497 |
+
|
| 498 |
+
def set_memory_pattern(self, enable: bool) -> None:
|
| 499 |
+
"""
|
| 500 |
+
设置内存模式的存根方法
|
| 501 |
+
|
| 502 |
+
Args:
|
| 503 |
+
enable: 是否启用内存模式
|
| 504 |
+
"""
|
| 505 |
+
warnings.warn("set_memory_pattern()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 506 |
+
|
| 507 |
+
def disable_memory_pattern(self) -> None:
|
| 508 |
+
"""
|
| 509 |
+
禁用内存模式的存根方法
|
| 510 |
+
"""
|
| 511 |
+
warnings.warn("disable_memory_pattern()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 512 |
+
|
| 513 |
+
def get_optimization_level(self) -> int:
|
| 514 |
+
"""
|
| 515 |
+
获取优化级别的存根方法
|
| 516 |
+
|
| 517 |
+
Returns:
|
| 518 |
+
int: 0
|
| 519 |
+
"""
|
| 520 |
+
warnings.warn("get_optimization_level()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 521 |
+
return 0
|
| 522 |
+
|
| 523 |
+
def set_optimization_level(self, level: int) -> None:
|
| 524 |
+
"""
|
| 525 |
+
设置优化级别的存根方法
|
| 526 |
+
|
| 527 |
+
Args:
|
| 528 |
+
level: 优化级别
|
| 529 |
+
"""
|
| 530 |
+
warnings.warn("set_optimization_level()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 531 |
+
|
| 532 |
+
def get_model_metadata(self) -> Dict[str, str]:
|
| 533 |
+
"""
|
| 534 |
+
获取模型元数据的存根方法(与get_modelmeta不同的接口)
|
| 535 |
+
|
| 536 |
+
Returns:
|
| 537 |
+
Dict[str, str]: 空字典
|
| 538 |
+
"""
|
| 539 |
+
warnings.warn("get_model_metadata()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 540 |
+
return {}
|
| 541 |
+
|
| 542 |
+
def get_model_path(self) -> str:
|
| 543 |
+
"""
|
| 544 |
+
获取模型路径
|
| 545 |
+
|
| 546 |
+
Returns:
|
| 547 |
+
str: 模型文件路径
|
| 548 |
+
"""
|
| 549 |
+
return self.model_path
|
| 550 |
+
|
| 551 |
+
def get_input_type_info(self) -> List[Dict[str, str]]:
|
| 552 |
+
"""
|
| 553 |
+
获取输入类型信息的存根方法
|
| 554 |
+
|
| 555 |
+
Returns:
|
| 556 |
+
List[Dict[str, str]]: 空列表
|
| 557 |
+
"""
|
| 558 |
+
warnings.warn("get_input_type_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 559 |
+
return []
|
| 560 |
+
|
| 561 |
+
def get_output_type_info(self) -> List[Dict[str, str]]:
|
| 562 |
+
"""
|
| 563 |
+
获取输出类型信息的存根方法
|
| 564 |
+
|
| 565 |
+
Returns:
|
| 566 |
+
List[Dict[str, str]]: 空列表
|
| 567 |
+
"""
|
| 568 |
+
warnings.warn("get_output_type_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 569 |
+
return []
|