Upload 12 files
Browse files- .gitattributes +5 -0
- convert_fastvithd.py +52 -0
- convert_mm_projector.py +52 -0
- export_onnx.py +136 -0
- fastvithd.rknn +3 -0
- librkllmrt.so +3 -0
- mm_projector.rknn +3 -0
- qwen_f16.rkllm +3 -0
- rkllm-convert.py +23 -0
- rkllm_binding.py +510 -0
- run_rknn.py +274 -0
- test.jpg +3 -0
- ztu_somemodelruntime_rknnlite2.py +569 -0
.gitattributes
CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
fastvithd.rknn filter=lfs diff=lfs merge=lfs -text
|
37 |
+
librkllmrt.so filter=lfs diff=lfs merge=lfs -text
|
38 |
+
mm_projector.rknn filter=lfs diff=lfs merge=lfs -text
|
39 |
+
qwen_f16.rkllm filter=lfs diff=lfs merge=lfs -text
|
40 |
+
test.jpg filter=lfs diff=lfs merge=lfs -text
|
convert_fastvithd.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# ztu_somemodelruntime_rknn2: fastvithd
|
3 |
+
|
4 |
+
from rknn.api import RKNN
|
5 |
+
import os
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
def main():
|
9 |
+
# 创建RKNN实例
|
10 |
+
rknn = RKNN(verbose=False)
|
11 |
+
|
12 |
+
# ONNX模型路径
|
13 |
+
ONNX_MODEL = "fastvithd.onnx"
|
14 |
+
# 输出RKNN模型路径
|
15 |
+
RKNN_MODEL = "fastvithd.rknn"
|
16 |
+
|
17 |
+
# 配置参数
|
18 |
+
print("--> Config model")
|
19 |
+
ret = rknn.config(target_platform="rk3588",
|
20 |
+
dynamic_input=None)
|
21 |
+
if ret != 0:
|
22 |
+
print('Config model failed!')
|
23 |
+
exit(ret)
|
24 |
+
|
25 |
+
# 加载ONNX模型
|
26 |
+
print("--> Loading model")
|
27 |
+
ret = rknn.load_onnx(model=ONNX_MODEL,
|
28 |
+
inputs=['pixel_values'],
|
29 |
+
input_size_list=[[1, 3, 1024, 1024]])
|
30 |
+
if ret != 0:
|
31 |
+
print('Load model failed!')
|
32 |
+
exit(ret)
|
33 |
+
|
34 |
+
# 构建模型
|
35 |
+
print("--> Building model")
|
36 |
+
ret = rknn.build(do_quantization=False)
|
37 |
+
if ret != 0:
|
38 |
+
print('Build model failed!')
|
39 |
+
exit(ret)
|
40 |
+
|
41 |
+
# 导出RKNN模型
|
42 |
+
print("--> Export RKNN model")
|
43 |
+
ret = rknn.export_rknn(RKNN_MODEL)
|
44 |
+
if ret != 0:
|
45 |
+
print('Export RKNN model failed!')
|
46 |
+
exit(ret)
|
47 |
+
|
48 |
+
print(f'Done! The converted RKNN model has been saved to: ' + RKNN_MODEL)
|
49 |
+
rknn.release()
|
50 |
+
|
51 |
+
if __name__ == '__main__':
|
52 |
+
main()
|
convert_mm_projector.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# ztu_somemodelruntime_rknn2: mm_projector
|
3 |
+
|
4 |
+
from rknn.api import RKNN
|
5 |
+
import os
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
def main():
|
9 |
+
# 创建RKNN实例
|
10 |
+
rknn = RKNN(verbose=False)
|
11 |
+
|
12 |
+
# ONNX模型路径
|
13 |
+
ONNX_MODEL = "mm_projector.onnx"
|
14 |
+
# 输出RKNN模型路径
|
15 |
+
RKNN_MODEL = "mm_projector.rknn"
|
16 |
+
|
17 |
+
# 配置参数
|
18 |
+
print("--> Config model")
|
19 |
+
ret = rknn.config(target_platform="rk3588",
|
20 |
+
dynamic_input=None)
|
21 |
+
if ret != 0:
|
22 |
+
print('Config model failed!')
|
23 |
+
exit(ret)
|
24 |
+
|
25 |
+
# 加载ONNX模型
|
26 |
+
print("--> Loading model")
|
27 |
+
ret = rknn.load_onnx(model=ONNX_MODEL,
|
28 |
+
inputs=['last_hidden_state'],
|
29 |
+
input_size_list=[[1, 256, 3072]])
|
30 |
+
if ret != 0:
|
31 |
+
print('Load model failed!')
|
32 |
+
exit(ret)
|
33 |
+
|
34 |
+
# 构建模型
|
35 |
+
print("--> Building model")
|
36 |
+
ret = rknn.build(do_quantization=False)
|
37 |
+
if ret != 0:
|
38 |
+
print('Build model failed!')
|
39 |
+
exit(ret)
|
40 |
+
|
41 |
+
# 导出RKNN模型
|
42 |
+
print("--> Export RKNN model")
|
43 |
+
ret = rknn.export_rknn(RKNN_MODEL)
|
44 |
+
if ret != 0:
|
45 |
+
print('Export RKNN model failed!')
|
46 |
+
exit(ret)
|
47 |
+
|
48 |
+
print(f'Done! The converted RKNN model has been saved to: ' + RKNN_MODEL)
|
49 |
+
rknn.release()
|
50 |
+
|
51 |
+
if __name__ == '__main__':
|
52 |
+
main()
|
export_onnx.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# For licensing see accompanying LICENSE file.
|
3 |
+
# Copyright (C) 2025 Apple Inc. All Rights Reserved.
|
4 |
+
#
|
5 |
+
import os
|
6 |
+
import json
|
7 |
+
import copy
|
8 |
+
import argparse
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from llava.model.builder import load_pretrained_model
|
13 |
+
from llava.utils import disable_torch_init
|
14 |
+
from llava.mm_utils import get_model_name_from_path
|
15 |
+
|
16 |
+
|
17 |
+
def export(args):
|
18 |
+
# Load model
|
19 |
+
disable_torch_init()
|
20 |
+
model_path = os.path.expanduser(args.model_path)
|
21 |
+
model_name = get_model_name_from_path(model_path)
|
22 |
+
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path,
|
23 |
+
args.model_base,
|
24 |
+
model_name,
|
25 |
+
device="cpu")
|
26 |
+
|
27 |
+
# Save extra metadata that is not saved during LLaVA training
|
28 |
+
# required by HF for auto-loading model and for mlx-vlm preprocessing
|
29 |
+
|
30 |
+
# Save image processing config
|
31 |
+
setattr(image_processor, "processor_class", "LlavaProcessor")
|
32 |
+
output_path = os.path.join(model_path, "preprocessor_config.json")
|
33 |
+
image_processor.to_json_file(output_path)
|
34 |
+
|
35 |
+
# Create processor config
|
36 |
+
processor_config = dict()
|
37 |
+
processor_config["image_token"] = "<image>"
|
38 |
+
processor_config["num_additional_image_tokens"] = 0
|
39 |
+
processor_config["processor_class"] = "LlavaProcessor"
|
40 |
+
processor_config["patch_size"] = 64
|
41 |
+
output_path = os.path.join(model_path, "processor_config.json")
|
42 |
+
json.dump(processor_config, open(output_path, "w"), indent=2)
|
43 |
+
|
44 |
+
# Modify tokenizer to include <image> special token.
|
45 |
+
tokenizer_config_path = os.path.join(model_path, "tokenizer_config.json")
|
46 |
+
tokenizer_config = json.load(open(tokenizer_config_path, 'r'))
|
47 |
+
token_ids = list()
|
48 |
+
image_token_is_present = False
|
49 |
+
for k, v in tokenizer_config['added_tokens_decoder'].items():
|
50 |
+
token_ids.append(int(k))
|
51 |
+
if v["content"] == "<image>":
|
52 |
+
image_token_is_present = True
|
53 |
+
token_ids.pop()
|
54 |
+
|
55 |
+
# Append only if <image> token is not present
|
56 |
+
if not image_token_is_present:
|
57 |
+
tokenizer_config['added_tokens_decoder'][f'{max(token_ids) + 1}'] = copy.deepcopy(
|
58 |
+
tokenizer_config['added_tokens_decoder'][f'{token_ids[0]}'])
|
59 |
+
tokenizer_config['added_tokens_decoder'][f'{max(token_ids) + 1}']["content"] = "<image>"
|
60 |
+
json.dump(tokenizer_config, open(tokenizer_config_path, 'w'), indent=2)
|
61 |
+
|
62 |
+
# Modify config to contain token id for <image>
|
63 |
+
config_path = os.path.join(model_path, "config.json")
|
64 |
+
model_config = json.load(open(config_path, 'r'))
|
65 |
+
model_config["image_token_index"] = max(token_ids) + 1
|
66 |
+
json.dump(model_config, open(config_path, 'w'), indent=2)
|
67 |
+
|
68 |
+
# Export the vision encoder to ONNX
|
69 |
+
image_res = image_processor.to_dict()['size']['shortest_edge']
|
70 |
+
dummy_vision_input = torch.rand(1, 3, image_res, image_res).float() # Dummy input tensor
|
71 |
+
|
72 |
+
vision_model = model.get_vision_tower()
|
73 |
+
# Ensure model is on CPU, in float precision, and in evaluation mode for ONNX export
|
74 |
+
vision_model = vision_model.cpu().float().eval()
|
75 |
+
|
76 |
+
onnx_vision_model_path = os.path.join(model_path, "fastvithd.onnx")
|
77 |
+
|
78 |
+
print(f"Exporting vision model to {onnx_vision_model_path}...")
|
79 |
+
torch.onnx.export(
|
80 |
+
vision_model,
|
81 |
+
dummy_vision_input, # Pass the dummy input tensor
|
82 |
+
onnx_vision_model_path,
|
83 |
+
input_names=['pixel_values'], # ONNX图中输入节点的名称
|
84 |
+
output_names=['last_hidden_state'], # ONNX图中输出节点的名称
|
85 |
+
# dynamic_axes={
|
86 |
+
# 'pixel_values': {0: 'batch_size'}, # 输入'pixel_values'的第0维是动态的batch_size
|
87 |
+
# 'last_hidden_state': {0: 'batch_size'} # 输出'last_hidden_state'的第0维是动态的batch_size
|
88 |
+
# },
|
89 |
+
opset_version=17, # ONNX opset 版本
|
90 |
+
export_params=True, # 在模型文件中存储训练好的参数权重
|
91 |
+
do_constant_folding=True # 执行常量折叠优化
|
92 |
+
)
|
93 |
+
print(f"Vision model ONNX export complete: {onnx_vision_model_path}")
|
94 |
+
|
95 |
+
# Generate dummy input for mm_projector by passing dummy_vision_input through vision_model
|
96 |
+
# This ensures the mm_projector receives input with the correct shape and characteristics
|
97 |
+
with torch.no_grad():
|
98 |
+
dummy_mm_projector_input = vision_model(dummy_vision_input)
|
99 |
+
|
100 |
+
# Ensure the input is on CPU and in float32 precision for the projector
|
101 |
+
dummy_mm_projector_input = dummy_mm_projector_input.cpu().float()
|
102 |
+
|
103 |
+
# Export the mm_projector to ONNX
|
104 |
+
# model.get_model() gives the underlying base model (e.g., LlavaLlamaModel)
|
105 |
+
# which contains the mm_projector attribute.
|
106 |
+
mm_projector = model.get_model().mm_projector
|
107 |
+
mm_projector = mm_projector.cpu().float().eval()
|
108 |
+
|
109 |
+
onnx_mm_projector_path = os.path.join(model_path, "mm_projector.onnx")
|
110 |
+
|
111 |
+
print(f"Exporting mm_projector to {onnx_mm_projector_path}...")
|
112 |
+
torch.onnx.export(
|
113 |
+
mm_projector,
|
114 |
+
dummy_mm_projector_input,
|
115 |
+
onnx_mm_projector_path,
|
116 |
+
input_names=['last_hidden_state'],
|
117 |
+
output_names=['projected_image_features'],
|
118 |
+
opset_version=17,
|
119 |
+
export_params=True,
|
120 |
+
do_constant_folding=True
|
121 |
+
)
|
122 |
+
print(f"mm_projector ONNX export complete: {onnx_mm_projector_path}")
|
123 |
+
|
124 |
+
# Removed CoreML specific code and intermediate .pt file handling
|
125 |
+
# No need for os.remove(pt_name) as pt_name is no longer created
|
126 |
+
|
127 |
+
|
128 |
+
if __name__ == "__main__":
|
129 |
+
parser = argparse.ArgumentParser()
|
130 |
+
parser.add_argument("--model-path", type=str, required=True)
|
131 |
+
parser.add_argument("--model-base", type=str, default=None)
|
132 |
+
parser.add_argument("--conv-mode", type=str, default="qwen_2")
|
133 |
+
|
134 |
+
args = parser.parse_args()
|
135 |
+
|
136 |
+
export(args)
|
fastvithd.rknn
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3d65f07758bd5fd610c76d28aae1d07a87bc65de0687f4fee5ef5c5e0f61d52a
|
3 |
+
size 372732105
|
librkllmrt.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f6a9c2de93cf94bb524eb071c27190ad4c83401e01b562534f265dff4cb40da2
|
3 |
+
size 6710712
|
mm_projector.rknn
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d3f6d9b06589c9aa7b70d28ef22703579d5d3c07c53e1b6ee72be85ac4ae7ee5
|
3 |
+
size 14272722
|
qwen_f16.rkllm
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7e13cc4849405b7338b5d21cef0f50a1776ea7c21594685e52665837cfec123c
|
3 |
+
size 3580141646
|
rkllm-convert.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from rkllm.api import RKLLM
|
2 |
+
|
3 |
+
modelpath = '.'
|
4 |
+
llm = RKLLM()
|
5 |
+
|
6 |
+
ret = llm.load_huggingface(model=modelpath, model_lora=None, device='cpu')
|
7 |
+
if ret != 0:
|
8 |
+
print('Load model failed!')
|
9 |
+
exit(ret)
|
10 |
+
|
11 |
+
qparams = None
|
12 |
+
ret = llm.build(do_quantization=False, optimization_level=1, quantized_dtype='w8a8_g128',
|
13 |
+
quantized_algorithm='normal', target_platform='rk3588', num_npu_core=3, extra_qparams=qparams, dataset='calibration_dataset.json')
|
14 |
+
|
15 |
+
if ret != 0:
|
16 |
+
print('Build model failed!')
|
17 |
+
exit(ret)
|
18 |
+
|
19 |
+
# Export rkllm model
|
20 |
+
ret = llm.export_rkllm("./qwen_f16.rkllm")
|
21 |
+
if ret != 0:
|
22 |
+
print('Export model failed!')
|
23 |
+
exit(ret)
|
rkllm_binding.py
ADDED
@@ -0,0 +1,510 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ctypes
|
2 |
+
import enum
|
3 |
+
import os
|
4 |
+
|
5 |
+
# Define constants from the header
|
6 |
+
CPU0 = (1 << 0) # 0x01
|
7 |
+
CPU1 = (1 << 1) # 0x02
|
8 |
+
CPU2 = (1 << 2) # 0x04
|
9 |
+
CPU3 = (1 << 3) # 0x08
|
10 |
+
CPU4 = (1 << 4) # 0x10
|
11 |
+
CPU5 = (1 << 5) # 0x20
|
12 |
+
CPU6 = (1 << 6) # 0x40
|
13 |
+
CPU7 = (1 << 7) # 0x80
|
14 |
+
|
15 |
+
# --- Enums ---
|
16 |
+
class LLMCallState(enum.IntEnum):
|
17 |
+
RKLLM_RUN_NORMAL = 0
|
18 |
+
RKLLM_RUN_WAITING = 1
|
19 |
+
RKLLM_RUN_FINISH = 2
|
20 |
+
RKLLM_RUN_ERROR = 3
|
21 |
+
|
22 |
+
class RKLLMInputType(enum.IntEnum):
|
23 |
+
RKLLM_INPUT_PROMPT = 0
|
24 |
+
RKLLM_INPUT_TOKEN = 1
|
25 |
+
RKLLM_INPUT_EMBED = 2
|
26 |
+
RKLLM_INPUT_MULTIMODAL = 3
|
27 |
+
|
28 |
+
class RKLLMInferMode(enum.IntEnum):
|
29 |
+
RKLLM_INFER_GENERATE = 0
|
30 |
+
RKLLM_INFER_GET_LAST_HIDDEN_LAYER = 1
|
31 |
+
RKLLM_INFER_GET_LOGITS = 2
|
32 |
+
|
33 |
+
# --- Structures ---
|
34 |
+
class RKLLMExtendParam(ctypes.Structure):
|
35 |
+
_fields_ = [
|
36 |
+
("base_domain_id", ctypes.c_int32),
|
37 |
+
("embed_flash", ctypes.c_int8),
|
38 |
+
("enabled_cpus_num", ctypes.c_int8),
|
39 |
+
("enabled_cpus_mask", ctypes.c_uint32),
|
40 |
+
("reserved", ctypes.c_uint8 * 106)
|
41 |
+
]
|
42 |
+
|
43 |
+
class RKLLMParam(ctypes.Structure):
|
44 |
+
_fields_ = [
|
45 |
+
("model_path", ctypes.c_char_p),
|
46 |
+
("max_context_len", ctypes.c_int32),
|
47 |
+
("max_new_tokens", ctypes.c_int32),
|
48 |
+
("top_k", ctypes.c_int32),
|
49 |
+
("n_keep", ctypes.c_int32),
|
50 |
+
("top_p", ctypes.c_float),
|
51 |
+
("temperature", ctypes.c_float),
|
52 |
+
("repeat_penalty", ctypes.c_float),
|
53 |
+
("frequency_penalty", ctypes.c_float),
|
54 |
+
("presence_penalty", ctypes.c_float), # Note: This was missing in the provided text but is in typical LLM params
|
55 |
+
("mirostat", ctypes.c_int32),
|
56 |
+
("mirostat_tau", ctypes.c_float),
|
57 |
+
("mirostat_eta", ctypes.c_float),
|
58 |
+
("skip_special_token", ctypes.c_bool),
|
59 |
+
("is_async", ctypes.c_bool),
|
60 |
+
("img_start", ctypes.c_char_p),
|
61 |
+
("img_end", ctypes.c_char_p),
|
62 |
+
("img_content", ctypes.c_char_p), # This seems like it should be more structured for actual image data
|
63 |
+
("extend_param", RKLLMExtendParam)
|
64 |
+
]
|
65 |
+
|
66 |
+
class RKLLMLoraAdapter(ctypes.Structure):
|
67 |
+
_fields_ = [
|
68 |
+
("lora_adapter_path", ctypes.c_char_p),
|
69 |
+
("lora_adapter_name", ctypes.c_char_p),
|
70 |
+
("scale", ctypes.c_float)
|
71 |
+
]
|
72 |
+
|
73 |
+
class RKLLMEmbedInput(ctypes.Structure):
|
74 |
+
_fields_ = [
|
75 |
+
("embed", ctypes.POINTER(ctypes.c_float)),
|
76 |
+
("n_tokens", ctypes.c_size_t)
|
77 |
+
]
|
78 |
+
|
79 |
+
class RKLLMTokenInput(ctypes.Structure):
|
80 |
+
_fields_ = [
|
81 |
+
("input_ids", ctypes.POINTER(ctypes.c_int32)),
|
82 |
+
("n_tokens", ctypes.c_size_t)
|
83 |
+
]
|
84 |
+
|
85 |
+
class RKLLMMultiModelInput(ctypes.Structure):
|
86 |
+
_fields_ = [
|
87 |
+
("prompt", ctypes.c_char_p),
|
88 |
+
("image_embed", ctypes.POINTER(ctypes.c_float)),
|
89 |
+
("n_image_tokens", ctypes.c_size_t),
|
90 |
+
("n_image", ctypes.c_size_t),
|
91 |
+
("image_width", ctypes.c_size_t),
|
92 |
+
("image_height", ctypes.c_size_t)
|
93 |
+
]
|
94 |
+
|
95 |
+
class _RKLLMInputUnion(ctypes.Union):
|
96 |
+
_fields_ = [
|
97 |
+
("prompt_input", ctypes.c_char_p),
|
98 |
+
("embed_input", RKLLMEmbedInput),
|
99 |
+
("token_input", RKLLMTokenInput),
|
100 |
+
("multimodal_input", RKLLMMultiModelInput)
|
101 |
+
]
|
102 |
+
|
103 |
+
class RKLLMInput(ctypes.Structure):
|
104 |
+
_fields_ = [
|
105 |
+
("input_type", ctypes.c_int), # Enum will be passed as int, changed RKLLMInputType to ctypes.c_int
|
106 |
+
("_union_data", _RKLLMInputUnion)
|
107 |
+
]
|
108 |
+
# Properties to make accessing union members easier
|
109 |
+
@property
|
110 |
+
def prompt_input(self):
|
111 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_PROMPT:
|
112 |
+
return self._union_data.prompt_input
|
113 |
+
raise AttributeError("Not a prompt input")
|
114 |
+
@prompt_input.setter
|
115 |
+
def prompt_input(self, value):
|
116 |
+
if self.input_type == RKLLMInputType.RKLLM_INPUT_PROMPT:
|
117 |
+
self._union_data.prompt_input = value
|
118 |
+
else:
|
119 |
+
raise AttributeError("Not a prompt input")
|
120 |
+
|
121 |
+
# Similar properties can be added for embed_input, token_input, multimodal_input
|
122 |
+
|
123 |
+
class RKLLMLoraParam(ctypes.Structure): # For inference
|
124 |
+
_fields_ = [
|
125 |
+
("lora_adapter_name", ctypes.c_char_p)
|
126 |
+
]
|
127 |
+
|
128 |
+
class RKLLMPromptCacheParam(ctypes.Structure): # For inference
|
129 |
+
_fields_ = [
|
130 |
+
("save_prompt_cache", ctypes.c_int), # bool-like
|
131 |
+
("prompt_cache_path", ctypes.c_char_p)
|
132 |
+
]
|
133 |
+
|
134 |
+
class RKLLMInferParam(ctypes.Structure):
|
135 |
+
_fields_ = [
|
136 |
+
("mode", ctypes.c_int), # Enum will be passed as int, changed RKLLMInferMode to ctypes.c_int
|
137 |
+
("lora_params", ctypes.POINTER(RKLLMLoraParam)),
|
138 |
+
("prompt_cache_params", ctypes.POINTER(RKLLMPromptCacheParam)),
|
139 |
+
("keep_history", ctypes.c_int) # bool-like
|
140 |
+
]
|
141 |
+
|
142 |
+
class RKLLMResultLastHiddenLayer(ctypes.Structure):
|
143 |
+
_fields_ = [
|
144 |
+
("hidden_states", ctypes.POINTER(ctypes.c_float)),
|
145 |
+
("embd_size", ctypes.c_int),
|
146 |
+
("num_tokens", ctypes.c_int)
|
147 |
+
]
|
148 |
+
|
149 |
+
class RKLLMResultLogits(ctypes.Structure):
|
150 |
+
_fields_ = [
|
151 |
+
("logits", ctypes.POINTER(ctypes.c_float)),
|
152 |
+
("vocab_size", ctypes.c_int),
|
153 |
+
("num_tokens", ctypes.c_int)
|
154 |
+
]
|
155 |
+
|
156 |
+
class RKLLMResult(ctypes.Structure):
|
157 |
+
_fields_ = [
|
158 |
+
("text", ctypes.c_char_p),
|
159 |
+
("token_id", ctypes.c_int32),
|
160 |
+
("last_hidden_layer", RKLLMResultLastHiddenLayer),
|
161 |
+
("logits", RKLLMResultLogits)
|
162 |
+
]
|
163 |
+
|
164 |
+
# --- Typedefs ---
|
165 |
+
LLMHandle = ctypes.c_void_p
|
166 |
+
|
167 |
+
# --- Callback Function Type ---
|
168 |
+
LLMResultCallback = ctypes.CFUNCTYPE(
|
169 |
+
None, # return type: void
|
170 |
+
ctypes.POINTER(RKLLMResult),
|
171 |
+
ctypes.c_void_p, # userdata
|
172 |
+
ctypes.c_int # enum, will be passed as int. Changed LLMCallState to ctypes.c_int
|
173 |
+
)
|
174 |
+
|
175 |
+
|
176 |
+
class RKLLMRuntime:
|
177 |
+
def __init__(self, library_path="./librkllmrt.so"):
|
178 |
+
try:
|
179 |
+
self.lib = ctypes.CDLL(library_path)
|
180 |
+
except OSError as e:
|
181 |
+
raise OSError(f"Failed to load RKLLM library from {library_path}. "
|
182 |
+
f"Ensure it's in your LD_LIBRARY_PATH or provide the full path. Error: {e}")
|
183 |
+
self._setup_functions()
|
184 |
+
self.llm_handle = LLMHandle()
|
185 |
+
self._c_callback = None # To keep the callback object alive
|
186 |
+
|
187 |
+
def _setup_functions(self):
|
188 |
+
# RKLLMParam rkllm_createDefaultParam();
|
189 |
+
self.lib.rkllm_createDefaultParam.restype = RKLLMParam
|
190 |
+
self.lib.rkllm_createDefaultParam.argtypes = []
|
191 |
+
|
192 |
+
# int rkllm_init(LLMHandle* handle, RKLLMParam* param, LLMResultCallback callback);
|
193 |
+
self.lib.rkllm_init.restype = ctypes.c_int
|
194 |
+
self.lib.rkllm_init.argtypes = [
|
195 |
+
ctypes.POINTER(LLMHandle),
|
196 |
+
ctypes.POINTER(RKLLMParam),
|
197 |
+
LLMResultCallback
|
198 |
+
]
|
199 |
+
|
200 |
+
# int rkllm_load_lora(LLMHandle handle, RKLLMLoraAdapter* lora_adapter);
|
201 |
+
self.lib.rkllm_load_lora.restype = ctypes.c_int
|
202 |
+
self.lib.rkllm_load_lora.argtypes = [LLMHandle, ctypes.POINTER(RKLLMLoraAdapter)]
|
203 |
+
|
204 |
+
# int rkllm_load_prompt_cache(LLMHandle handle, const char* prompt_cache_path);
|
205 |
+
self.lib.rkllm_load_prompt_cache.restype = ctypes.c_int
|
206 |
+
self.lib.rkllm_load_prompt_cache.argtypes = [LLMHandle, ctypes.c_char_p]
|
207 |
+
|
208 |
+
# int rkllm_release_prompt_cache(LLMHandle handle);
|
209 |
+
self.lib.rkllm_release_prompt_cache.restype = ctypes.c_int
|
210 |
+
self.lib.rkllm_release_prompt_cache.argtypes = [LLMHandle]
|
211 |
+
|
212 |
+
# int rkllm_destroy(LLMHandle handle);
|
213 |
+
self.lib.rkllm_destroy.restype = ctypes.c_int
|
214 |
+
self.lib.rkllm_destroy.argtypes = [LLMHandle]
|
215 |
+
|
216 |
+
# int rkllm_run(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata);
|
217 |
+
self.lib.rkllm_run.restype = ctypes.c_int
|
218 |
+
self.lib.rkllm_run.argtypes = [
|
219 |
+
LLMHandle,
|
220 |
+
ctypes.POINTER(RKLLMInput),
|
221 |
+
ctypes.POINTER(RKLLMInferParam),
|
222 |
+
ctypes.c_void_p # userdata
|
223 |
+
]
|
224 |
+
|
225 |
+
# int rkllm_run_async(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata);
|
226 |
+
# Assuming async also takes userdata for the callback context
|
227 |
+
self.lib.rkllm_run_async.restype = ctypes.c_int
|
228 |
+
self.lib.rkllm_run_async.argtypes = [
|
229 |
+
LLMHandle,
|
230 |
+
ctypes.POINTER(RKLLMInput),
|
231 |
+
ctypes.POINTER(RKLLMInferParam),
|
232 |
+
ctypes.c_void_p # userdata
|
233 |
+
]
|
234 |
+
|
235 |
+
# int rkllm_abort(LLMHandle handle);
|
236 |
+
self.lib.rkllm_abort.restype = ctypes.c_int
|
237 |
+
self.lib.rkllm_abort.argtypes = [LLMHandle]
|
238 |
+
|
239 |
+
# int rkllm_is_running(LLMHandle handle);
|
240 |
+
self.lib.rkllm_is_running.restype = ctypes.c_int # 0 if running, non-zero otherwise
|
241 |
+
self.lib.rkllm_is_running.argtypes = [LLMHandle]
|
242 |
+
|
243 |
+
# int rkllm_clear_kv_cache(LLMHandle handle, int keep_system_prompt);
|
244 |
+
self.lib.rkllm_clear_kv_cache.restype = ctypes.c_int
|
245 |
+
self.lib.rkllm_clear_kv_cache.argtypes = [LLMHandle, ctypes.c_int]
|
246 |
+
|
247 |
+
# int rkllm_set_chat_template(LLMHandle handle, const char* system_prompt, const char* prompt_prefix, const char* prompt_postfix);
|
248 |
+
self.lib.rkllm_set_chat_template.restype = ctypes.c_int
|
249 |
+
self.lib.rkllm_set_chat_template.argtypes = [
|
250 |
+
LLMHandle,
|
251 |
+
ctypes.c_char_p,
|
252 |
+
ctypes.c_char_p,
|
253 |
+
ctypes.c_char_p
|
254 |
+
]
|
255 |
+
|
256 |
+
def create_default_param(self) -> RKLLMParam:
|
257 |
+
"""Creates a default RKLLMParam structure."""
|
258 |
+
return self.lib.rkllm_createDefaultParam()
|
259 |
+
|
260 |
+
def init(self, param: RKLLMParam, callback_func) -> int:
|
261 |
+
"""
|
262 |
+
Initializes the LLM.
|
263 |
+
:param param: RKLLMParam structure.
|
264 |
+
:param callback_func: A Python function that matches the signature:
|
265 |
+
def my_callback(result_ptr, userdata_ptr, state_enum):
|
266 |
+
result = result_ptr.contents # RKLLMResult
|
267 |
+
# Process result
|
268 |
+
# userdata can be retrieved if passed during run, or ignored
|
269 |
+
# state = LLMCallState(state_enum)
|
270 |
+
:return: 0 for success, non-zero for failure.
|
271 |
+
"""
|
272 |
+
if not callable(callback_func):
|
273 |
+
raise ValueError("callback_func must be a callable Python function.")
|
274 |
+
|
275 |
+
# Keep a reference to the ctypes callback object to prevent it from being garbage collected
|
276 |
+
self._c_callback = LLMResultCallback(callback_func)
|
277 |
+
|
278 |
+
ret = self.lib.rkllm_init(ctypes.byref(self.llm_handle), ctypes.byref(param), self._c_callback)
|
279 |
+
if ret != 0:
|
280 |
+
raise RuntimeError(f"rkllm_init failed with error code {ret}")
|
281 |
+
return ret
|
282 |
+
|
283 |
+
def load_lora(self, lora_adapter: RKLLMLoraAdapter) -> int:
|
284 |
+
"""Loads a Lora adapter."""
|
285 |
+
ret = self.lib.rkllm_load_lora(self.llm_handle, ctypes.byref(lora_adapter))
|
286 |
+
if ret != 0:
|
287 |
+
raise RuntimeError(f"rkllm_load_lora failed with error code {ret}")
|
288 |
+
return ret
|
289 |
+
|
290 |
+
def load_prompt_cache(self, prompt_cache_path: str) -> int:
|
291 |
+
"""Loads a prompt cache from a file."""
|
292 |
+
c_path = prompt_cache_path.encode('utf-8')
|
293 |
+
ret = self.lib.rkllm_load_prompt_cache(self.llm_handle, c_path)
|
294 |
+
if ret != 0:
|
295 |
+
raise RuntimeError(f"rkllm_load_prompt_cache failed for {prompt_cache_path} with error code {ret}")
|
296 |
+
return ret
|
297 |
+
|
298 |
+
def release_prompt_cache(self) -> int:
|
299 |
+
"""Releases the prompt cache from memory."""
|
300 |
+
ret = self.lib.rkllm_release_prompt_cache(self.llm_handle)
|
301 |
+
if ret != 0:
|
302 |
+
raise RuntimeError(f"rkllm_release_prompt_cache failed with error code {ret}")
|
303 |
+
return ret
|
304 |
+
|
305 |
+
def destroy(self) -> int:
|
306 |
+
"""Destroys the LLM instance and releases resources."""
|
307 |
+
if self.llm_handle and self.llm_handle.value: # Check if handle is not NULL
|
308 |
+
ret = self.lib.rkllm_destroy(self.llm_handle)
|
309 |
+
self.llm_handle = LLMHandle() # Reset handle
|
310 |
+
if ret != 0:
|
311 |
+
# Don't raise here as it might be called in __del__
|
312 |
+
print(f"Warning: rkllm_destroy failed with error code {ret}")
|
313 |
+
return ret
|
314 |
+
return 0 # Already destroyed or not initialized
|
315 |
+
|
316 |
+
def run(self, rkllm_input: RKLLMInput, rkllm_infer_params: RKLLMInferParam, userdata=None) -> int:
|
317 |
+
"""Runs an LLM inference task synchronously."""
|
318 |
+
# userdata can be a ctypes.py_object if you want to pass Python objects,
|
319 |
+
# then cast to c_void_p. Or simply None.
|
320 |
+
c_userdata = ctypes.cast(ctypes.py_object(userdata), ctypes.c_void_p) if userdata is not None else None
|
321 |
+
ret = self.lib.rkllm_run(self.llm_handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), c_userdata)
|
322 |
+
if ret != 0:
|
323 |
+
raise RuntimeError(f"rkllm_run failed with error code {ret}")
|
324 |
+
return ret
|
325 |
+
|
326 |
+
def run_async(self, rkllm_input: RKLLMInput, rkllm_infer_params: RKLLMInferParam, userdata=None) -> int:
|
327 |
+
"""Runs an LLM inference task asynchronously."""
|
328 |
+
c_userdata = ctypes.cast(ctypes.py_object(userdata), ctypes.c_void_p) if userdata is not None else None
|
329 |
+
ret = self.lib.rkllm_run_async(self.llm_handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), c_userdata)
|
330 |
+
if ret != 0:
|
331 |
+
raise RuntimeError(f"rkllm_run_async failed with error code {ret}")
|
332 |
+
return ret
|
333 |
+
|
334 |
+
def abort(self) -> int:
|
335 |
+
"""Aborts an ongoing LLM task."""
|
336 |
+
ret = self.lib.rkllm_abort(self.llm_handle)
|
337 |
+
if ret != 0:
|
338 |
+
raise RuntimeError(f"rkllm_abort failed with error code {ret}")
|
339 |
+
return ret
|
340 |
+
|
341 |
+
def is_running(self) -> bool:
|
342 |
+
"""Checks if an LLM task is currently running. Returns True if running."""
|
343 |
+
# The C API returns 0 if running, non-zero otherwise.
|
344 |
+
# This is a bit counter-intuitive for a boolean "is_running".
|
345 |
+
return self.lib.rkllm_is_running(self.llm_handle) == 0
|
346 |
+
|
347 |
+
def clear_kv_cache(self, keep_system_prompt: bool) -> int:
|
348 |
+
"""Clears the key-value cache."""
|
349 |
+
ret = self.lib.rkllm_clear_kv_cache(self.llm_handle, ctypes.c_int(1 if keep_system_prompt else 0))
|
350 |
+
if ret != 0:
|
351 |
+
raise RuntimeError(f"rkllm_clear_kv_cache failed with error code {ret}")
|
352 |
+
return ret
|
353 |
+
|
354 |
+
def set_chat_template(self, system_prompt: str, prompt_prefix: str, prompt_postfix: str) -> int:
|
355 |
+
"""Sets the chat template for the LLM."""
|
356 |
+
c_system = system_prompt.encode('utf-8') if system_prompt else None
|
357 |
+
c_prefix = prompt_prefix.encode('utf-8') if prompt_prefix else None
|
358 |
+
c_postfix = prompt_postfix.encode('utf-8') if prompt_postfix else None
|
359 |
+
|
360 |
+
ret = self.lib.rkllm_set_chat_template(self.llm_handle, c_system, c_prefix, c_postfix)
|
361 |
+
if ret != 0:
|
362 |
+
raise RuntimeError(f"rkllm_set_chat_template failed with error code {ret}")
|
363 |
+
return ret
|
364 |
+
|
365 |
+
def __enter__(self):
|
366 |
+
return self
|
367 |
+
|
368 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
369 |
+
self.destroy()
|
370 |
+
|
371 |
+
def __del__(self):
|
372 |
+
self.destroy() # Ensure resources are freed if object is garbage collected
|
373 |
+
|
374 |
+
# --- Example Usage (Illustrative) ---
|
375 |
+
if __name__ == "__main__":
|
376 |
+
# This is a placeholder for how you might use it.
|
377 |
+
# You'll need a valid .rkllm model and librkllmrt.so in your path.
|
378 |
+
|
379 |
+
# Global list to store results from callback for demonstration
|
380 |
+
results_buffer = []
|
381 |
+
|
382 |
+
def my_python_callback(result_ptr, userdata_ptr, state_enum):
|
383 |
+
"""
|
384 |
+
Callback function to be called by the C library.
|
385 |
+
"""
|
386 |
+
global results_buffer
|
387 |
+
state = LLMCallState(state_enum)
|
388 |
+
result = result_ptr.contents
|
389 |
+
|
390 |
+
current_text = ""
|
391 |
+
if result.text: # Check if the char_p is not NULL
|
392 |
+
current_text = result.text.decode('utf-8', errors='ignore')
|
393 |
+
|
394 |
+
print(f"Callback: State={state.name}, TokenID={result.token_id}, Text='{current_text}'")
|
395 |
+
results_buffer.append(current_text)
|
396 |
+
|
397 |
+
if state == LLMCallState.RKLLM_RUN_FINISH:
|
398 |
+
print("Inference finished.")
|
399 |
+
elif state == LLMCallState.RKLLM_RUN_ERROR:
|
400 |
+
print("Inference error.")
|
401 |
+
|
402 |
+
# Example: Accessing logits if available (and if mode was set to get logits)
|
403 |
+
# if result.logits.logits and result.logits.vocab_size > 0:
|
404 |
+
# print(f" Logits (first 5 of vocab_size {result.logits.vocab_size}):")
|
405 |
+
# for i in range(min(5, result.logits.vocab_size)):
|
406 |
+
# print(f" {result.logits.logits[i]:.4f}", end=" ")
|
407 |
+
# print()
|
408 |
+
|
409 |
+
|
410 |
+
# --- Attempt to use the wrapper ---
|
411 |
+
try:
|
412 |
+
print("Initializing RKLLMRuntime...")
|
413 |
+
# Adjust library_path if librkllmrt.so is not in default search paths
|
414 |
+
# e.g., library_path="./path/to/librkllmrt.so"
|
415 |
+
rk_llm = RKLLMRuntime()
|
416 |
+
|
417 |
+
print("Creating default parameters...")
|
418 |
+
params = rk_llm.create_default_param()
|
419 |
+
|
420 |
+
# --- Configure parameters ---
|
421 |
+
# THIS IS CRITICAL: model_path must point to an actual .rkllm file
|
422 |
+
# For this example to run, you need a model file.
|
423 |
+
# Let's assume a dummy path for now, this will fail at init if not valid.
|
424 |
+
model_file = "dummy_model.rkllm"
|
425 |
+
if not os.path.exists(model_file):
|
426 |
+
print(f"Warning: Model file '{model_file}' does not exist. Init will likely fail.")
|
427 |
+
# Create a dummy file for the example to proceed further, though init will still fail
|
428 |
+
# with a real library unless it's a valid model.
|
429 |
+
with open(model_file, "w") as f:
|
430 |
+
f.write("dummy content")
|
431 |
+
|
432 |
+
params.model_path = model_file.encode('utf-8')
|
433 |
+
params.max_context_len = 512
|
434 |
+
params.max_new_tokens = 128
|
435 |
+
params.top_k = 1 # Greedy
|
436 |
+
params.temperature = 0.7
|
437 |
+
params.repeat_penalty = 1.1
|
438 |
+
# ... set other params as needed
|
439 |
+
|
440 |
+
print(f"Initializing LLM with model: {params.model_path.decode()}...")
|
441 |
+
# This will likely fail if dummy_model.rkllm is not a valid model recognized by the library
|
442 |
+
try:
|
443 |
+
rk_llm.init(params, my_python_callback)
|
444 |
+
print("LLM Initialized.")
|
445 |
+
except RuntimeError as e:
|
446 |
+
print(f"Error during LLM initialization: {e}")
|
447 |
+
print("This is expected if 'dummy_model.rkllm' is not a valid model.")
|
448 |
+
print("Replace 'dummy_model.rkllm' with a real model path to test further.")
|
449 |
+
exit()
|
450 |
+
|
451 |
+
|
452 |
+
# --- Prepare input ---
|
453 |
+
print("Preparing input...")
|
454 |
+
rk_input = RKLLMInput()
|
455 |
+
rk_input.input_type = RKLLMInputType.RKLLM_INPUT_PROMPT
|
456 |
+
|
457 |
+
prompt_text = "Translate the following English text to French: 'Hello, world!'"
|
458 |
+
c_prompt = prompt_text.encode('utf-8')
|
459 |
+
rk_input._union_data.prompt_input = c_prompt # Accessing union member directly
|
460 |
+
|
461 |
+
# --- Prepare inference parameters ---
|
462 |
+
print("Preparing inference parameters...")
|
463 |
+
infer_params = RKLLMInferParam()
|
464 |
+
infer_params.mode = RKLLMInferMode.RKLLM_INFER_GENERATE
|
465 |
+
infer_params.keep_history = 1 # True
|
466 |
+
# infer_params.lora_params = None # or set up RKLLMLoraParam if using LoRA
|
467 |
+
# infer_params.prompt_cache_params = None # or set up RKLLMPromptCacheParam
|
468 |
+
|
469 |
+
# --- Run inference ---
|
470 |
+
print(f"Running inference with prompt: '{prompt_text}'")
|
471 |
+
results_buffer.clear()
|
472 |
+
try:
|
473 |
+
rk_llm.run(rk_input, infer_params) # Userdata is None by default
|
474 |
+
print("\n--- Full Response ---")
|
475 |
+
print("".join(results_buffer))
|
476 |
+
print("---------------------\n")
|
477 |
+
except RuntimeError as e:
|
478 |
+
print(f"Error during LLM run: {e}")
|
479 |
+
|
480 |
+
|
481 |
+
# --- Example: Set chat template (if model supports it) ---
|
482 |
+
# print("Setting chat template...")
|
483 |
+
# try:
|
484 |
+
# rk_llm.set_chat_template("You are a helpful assistant.", "<user>: ", "<assistant>: ")
|
485 |
+
# print("Chat template set.")
|
486 |
+
# except RuntimeError as e:
|
487 |
+
# print(f"Error setting chat template: {e}")
|
488 |
+
|
489 |
+
# --- Example: Clear KV Cache ---
|
490 |
+
# print("Clearing KV cache (keeping system prompt if any)...")
|
491 |
+
# try:
|
492 |
+
# rk_llm.clear_kv_cache(keep_system_prompt=True)
|
493 |
+
# print("KV cache cleared.")
|
494 |
+
# except RuntimeError as e:
|
495 |
+
# print(f"Error clearing KV cache: {e}")
|
496 |
+
|
497 |
+
except OSError as e:
|
498 |
+
print(f"OSError: {e}. Could not load the RKLLM library.")
|
499 |
+
print("Please ensure 'librkllmrt.so' is in your LD_LIBRARY_PATH or provide the full path.")
|
500 |
+
except Exception as e:
|
501 |
+
print(f"An unexpected error occurred: {e}")
|
502 |
+
finally:
|
503 |
+
if 'rk_llm' in locals() and rk_llm.llm_handle and rk_llm.llm_handle.value:
|
504 |
+
print("Destroying LLM instance...")
|
505 |
+
rk_llm.destroy()
|
506 |
+
print("LLM instance destroyed.")
|
507 |
+
if os.path.exists(model_file) and model_file == "dummy_model.rkllm":
|
508 |
+
os.remove(model_file) # Clean up dummy file
|
509 |
+
|
510 |
+
print("Example finished.")
|
run_rknn.py
ADDED
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import faulthandler
|
2 |
+
faulthandler.enable()
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
import numpy as np
|
6 |
+
from rkllm_binding import *
|
7 |
+
import ztu_somemodelruntime_rknnlite2 as ort
|
8 |
+
import signal
|
9 |
+
import cv2
|
10 |
+
import ctypes
|
11 |
+
|
12 |
+
# --- Configuration ---
|
13 |
+
# These paths should point to the directory containing all model files
|
14 |
+
# or be absolute paths.
|
15 |
+
MODEL_DIR = "." # Assuming models are in the current directory or provide a specific path
|
16 |
+
LLM_MODEL_NAME = "qwen_f16.rkllm"
|
17 |
+
VISION_ENCODER_ONNX_NAME = "fastvithd.onnx"
|
18 |
+
MM_PROJECTOR_ONNX_NAME = "mm_projector.onnx"
|
19 |
+
PREPROCESSOR_CONFIG_NAME = "preprocessor_config.json" # Generated by export_onnx.py
|
20 |
+
|
21 |
+
LLM_MODEL_PATH = os.path.join(MODEL_DIR, LLM_MODEL_NAME)
|
22 |
+
VISION_ENCODER_PATH = os.path.join(MODEL_DIR, VISION_ENCODER_ONNX_NAME)
|
23 |
+
MM_PROJECTOR_PATH = os.path.join(MODEL_DIR, MM_PROJECTOR_ONNX_NAME)
|
24 |
+
PREPROCESSOR_CONFIG_PATH = os.path.join(MODEL_DIR, PREPROCESSOR_CONFIG_NAME)
|
25 |
+
|
26 |
+
IMAGE_PATH = "test.jpg" # Replace with your test image
|
27 |
+
# user_prompt = "Describe this image in detail."
|
28 |
+
user_prompt = "仔细描述一下这张图片。"
|
29 |
+
|
30 |
+
# Global RKLLMRuntime instance
|
31 |
+
rk_runtime = None
|
32 |
+
|
33 |
+
# Exit on Ctrl-C
|
34 |
+
def signal_handler(signal, frame):
|
35 |
+
print("Ctrl-C pressed, exiting...")
|
36 |
+
global rk_runtime
|
37 |
+
if rk_runtime:
|
38 |
+
try:
|
39 |
+
print("Attempting to abort RKLLM task...")
|
40 |
+
rk_runtime.abort()
|
41 |
+
print("RKLLM task aborted.")
|
42 |
+
except RuntimeError as e:
|
43 |
+
print(f"Note: RKLLM abort failed or task was not running: {e}")
|
44 |
+
except Exception as e:
|
45 |
+
print(f"Unexpected error during RKLLM abort in signal handler: {e}")
|
46 |
+
|
47 |
+
try:
|
48 |
+
print("Attempting to destroy RKLLM instance...")
|
49 |
+
rk_runtime.destroy()
|
50 |
+
print("RKLLM instance destroyed via signal handler.")
|
51 |
+
except RuntimeError as e:
|
52 |
+
print(f"Error during RKLLM destroy in signal handler: {e}")
|
53 |
+
except Exception as e: # Catch any other unexpected errors
|
54 |
+
print(f"Unexpected error during RKLLM destroy in signal handler: {e}")
|
55 |
+
exit(0)
|
56 |
+
|
57 |
+
signal.signal(signal.SIGINT, signal_handler)
|
58 |
+
|
59 |
+
# Set RKLLM log level if desired
|
60 |
+
os.environ["RKLLM_LOG_LEVEL"] = "1"
|
61 |
+
|
62 |
+
inference_count = 0
|
63 |
+
inference_start_time = 0
|
64 |
+
first_token_received = False
|
65 |
+
|
66 |
+
def result_callback(result_ptr, userdata, state_enum):
|
67 |
+
global inference_start_time, inference_count, first_token_received
|
68 |
+
state = LLMCallState(state_enum) # Convert int to enum
|
69 |
+
if result_ptr is None:
|
70 |
+
return
|
71 |
+
result = result_ptr.contents # Dereference the pointer
|
72 |
+
|
73 |
+
if state == LLMCallState.RKLLM_RUN_NORMAL:
|
74 |
+
if not first_token_received:
|
75 |
+
first_token_time = time.time()
|
76 |
+
print(f"\nTime to first token: {first_token_time - inference_start_time:.2f} seconds")
|
77 |
+
first_token_received = True
|
78 |
+
|
79 |
+
current_text = ""
|
80 |
+
if result.text: # Check if char_p is not NULL
|
81 |
+
current_text = result.text.decode('utf-8', errors='ignore')
|
82 |
+
print(current_text, end="", flush=True)
|
83 |
+
inference_count += 1
|
84 |
+
elif state == LLMCallState.RKLLM_RUN_FINISH:
|
85 |
+
print("\n\n(finished)")
|
86 |
+
elif state == LLMCallState.RKLLM_RUN_ERROR:
|
87 |
+
print("\nError occurred during LLM call")
|
88 |
+
# Add other states if needed, e.g., RKLLM_RUN_WAITING
|
89 |
+
|
90 |
+
def load_and_preprocess_image(image_path, config_path):
|
91 |
+
img_size = 1024
|
92 |
+
image_mean = [0.0, 0.0, 0.0]
|
93 |
+
image_std = [1.0, 1.0, 1.0]
|
94 |
+
|
95 |
+
print(f"Target image size from config: {img_size}x{img_size}")
|
96 |
+
print(f"Using image_mean: {image_mean}, image_std: {image_std}")
|
97 |
+
|
98 |
+
img = cv2.imread(image_path)
|
99 |
+
if img is None:
|
100 |
+
raise FileNotFoundError(f"Image not found: {image_path}")
|
101 |
+
|
102 |
+
# 计算缩放比例,保持宽高比
|
103 |
+
h, w = img.shape[:2]
|
104 |
+
scale = min(img_size / w, img_size / h)
|
105 |
+
new_w, new_h = int(w * scale), int(h * scale)
|
106 |
+
|
107 |
+
# 保持比例缩放
|
108 |
+
img_resized = cv2.resize(img, (new_w, new_h))
|
109 |
+
|
110 |
+
# 创建目标大小的黑色背景
|
111 |
+
img_padded = np.zeros((img_size, img_size, 3), dtype=np.uint8)
|
112 |
+
|
113 |
+
# 将缩放后的图像放在中心位置
|
114 |
+
y_offset = (img_size - new_h) // 2
|
115 |
+
x_offset = (img_size - new_w) // 2
|
116 |
+
img_padded[y_offset:y_offset+new_h, x_offset:x_offset+new_w] = img_resized
|
117 |
+
|
118 |
+
img_rgb = cv2.cvtColor(img_padded, cv2.COLOR_BGR2RGB)
|
119 |
+
img_fp32 = img_rgb.astype(np.float32)
|
120 |
+
|
121 |
+
# Normalize
|
122 |
+
img_normalized = (img_fp32 / 255.0 - image_mean) / image_std
|
123 |
+
|
124 |
+
# Transpose to NCHW format
|
125 |
+
img_nchw = img_normalized.transpose(2, 0, 1) # HWC to CHW
|
126 |
+
img_batch = img_nchw[np.newaxis, :, :, :] # Add batch dimension -> NCHW
|
127 |
+
|
128 |
+
return img_batch.astype(np.float32), img_size
|
129 |
+
def main():
|
130 |
+
global rk_runtime, inference_start_time, inference_count, first_token_received, user_prompt
|
131 |
+
|
132 |
+
# --- 1. Initialize ONNX Runtime for Vision Models ---
|
133 |
+
print("Loading ONNX vision encoder model...")
|
134 |
+
vision_session = ort.InferenceSession(VISION_ENCODER_PATH)
|
135 |
+
vision_input_name = vision_session.get_inputs()[0].name
|
136 |
+
vision_output_name = vision_session.get_outputs()[0].name
|
137 |
+
print(f"ONNX vision encoder loaded. Input: '{vision_input_name}', Output: '{vision_output_name}'")
|
138 |
+
|
139 |
+
print("Loading ONNX mm_projector model...")
|
140 |
+
mm_projector_session = ort.InferenceSession(MM_PROJECTOR_PATH)
|
141 |
+
mm_projector_input_name = mm_projector_session.get_inputs()[0].name
|
142 |
+
mm_projector_output_name = mm_projector_session.get_outputs()[0].name
|
143 |
+
print(f"ONNX mm_projector loaded. Input: '{mm_projector_input_name}', Output: '{mm_projector_output_name}'")
|
144 |
+
|
145 |
+
# --- 2. Initialize RKLLM ---
|
146 |
+
print("Initializing RKLLM...")
|
147 |
+
rk_runtime = RKLLMRuntime()
|
148 |
+
|
149 |
+
param = rk_runtime.create_default_param()
|
150 |
+
param.model_path = LLM_MODEL_PATH.encode('utf-8')
|
151 |
+
param.img_start = "<image>".encode('utf-8')
|
152 |
+
param.img_end = "".encode('utf-8')
|
153 |
+
param.img_content = "<unk>".encode('utf-8')
|
154 |
+
|
155 |
+
extend_param = RKLLMExtendParam()
|
156 |
+
extend_param.base_domain_id = 1
|
157 |
+
extend_param.embed_flash = 1
|
158 |
+
extend_param.enabled_cpus_num = 8
|
159 |
+
extend_param.enabled_cpus_mask = 0xffffffff
|
160 |
+
param.extend_param = extend_param
|
161 |
+
|
162 |
+
model_size_llm = os.path.getsize(LLM_MODEL_PATH)
|
163 |
+
print(f"Start loading language model (size: {model_size_llm / 1024 / 1024:.2f} MB)")
|
164 |
+
start_time_llm_load = time.time()
|
165 |
+
|
166 |
+
try:
|
167 |
+
rk_runtime.init(param, result_callback)
|
168 |
+
except RuntimeError as e:
|
169 |
+
print(f"RKLLM init failed: {e}")
|
170 |
+
if rk_runtime:
|
171 |
+
try:
|
172 |
+
rk_runtime.destroy()
|
173 |
+
except Exception as e_destroy:
|
174 |
+
print(f"Error destroying RKLLM after init failure: {e_destroy}")
|
175 |
+
return
|
176 |
+
|
177 |
+
end_time_llm_load = time.time()
|
178 |
+
print(f"Language model loaded in {end_time_llm_load - start_time_llm_load:.2f} seconds")
|
179 |
+
|
180 |
+
# --- 3. Load and Preprocess Image ---
|
181 |
+
print(f"Loading and preprocessing image: {IMAGE_PATH}")
|
182 |
+
preprocessed_image, original_img_dim = load_and_preprocess_image(IMAGE_PATH, PREPROCESSOR_CONFIG_PATH)
|
183 |
+
print(f"Input image shape for ONNX vision model: {preprocessed_image.shape}")
|
184 |
+
|
185 |
+
# --- 4. Vision Encoder Inference (ONNX) ---
|
186 |
+
start_time_vision = time.time()
|
187 |
+
vision_outputs = vision_session.run([vision_output_name], {vision_input_name: preprocessed_image})
|
188 |
+
image_features_from_vision = vision_outputs[0]
|
189 |
+
end_time_vision = time.time()
|
190 |
+
print(f"ONNX Vision encoder inference time: {end_time_vision - start_time_vision:.2f} seconds")
|
191 |
+
print(f"Vision encoder output shape: {image_features_from_vision.shape}")
|
192 |
+
|
193 |
+
# --- 5. MM Projector Inference (ONNX) ---
|
194 |
+
start_time_projector = time.time()
|
195 |
+
projector_outputs = mm_projector_session.run([mm_projector_output_name], {mm_projector_input_name: image_features_from_vision})
|
196 |
+
projected_image_embeddings_np = projector_outputs[0]
|
197 |
+
end_time_projector = time.time()
|
198 |
+
print(f"ONNX MM projector inference time: {end_time_projector - start_time_projector:.2f} seconds")
|
199 |
+
print(f"Projected image embeddings shape: {projected_image_embeddings_np.shape}")
|
200 |
+
|
201 |
+
# Ensure C-contiguous and float32 for ctypes
|
202 |
+
projected_image_embeddings_np = np.ascontiguousarray(projected_image_embeddings_np, dtype=np.float32)
|
203 |
+
|
204 |
+
# --- 6. Prepare Prompt and RKLLMInput ---
|
205 |
+
# The prompt should contain the <image> placeholder where the image features will be inserted.
|
206 |
+
# prompt = f"""<|im_start|>system
|
207 |
+
# You are a helpful assistant.<|im_end|>
|
208 |
+
# <|im_start|>user
|
209 |
+
# {param.img_start.decode()}
|
210 |
+
# {user_prompt}<|im_end|>
|
211 |
+
# <|im_start|>assistant
|
212 |
+
# """
|
213 |
+
|
214 |
+
# RKLLM now loads its own chat template, so we don't need to include that.
|
215 |
+
prompt = f"""{param.img_start.decode()}
|
216 |
+
{user_prompt}"""
|
217 |
+
|
218 |
+
print(f"\nUsing prompt:\n{prompt}")
|
219 |
+
|
220 |
+
rkllm_input = RKLLMInput()
|
221 |
+
rkllm_input.input_type = RKLLMInputType.RKLLM_INPUT_MULTIMODAL
|
222 |
+
|
223 |
+
multimodal_payload = RKLLMMultiModelInput()
|
224 |
+
multimodal_payload.prompt = prompt.encode('utf-8')
|
225 |
+
|
226 |
+
# projected_image_embeddings_np has shape (1, num_tokens, hidden_dim)
|
227 |
+
num_image_tokens = projected_image_embeddings_np.shape[1]
|
228 |
+
# The C API expects a flat pointer to the embedding data.
|
229 |
+
embedding_data_flat = projected_image_embeddings_np.flatten()
|
230 |
+
|
231 |
+
multimodal_payload.image_embed = embedding_data_flat.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
|
232 |
+
multimodal_payload.n_image_tokens = num_image_tokens
|
233 |
+
multimodal_payload.n_image = 1 # Number of images processed
|
234 |
+
multimodal_payload.image_width = original_img_dim # Width of the (resized before processing) image
|
235 |
+
multimodal_payload.image_height = original_img_dim # Height of the (resized before processing) image
|
236 |
+
|
237 |
+
rkllm_input._union_data.multimodal_input = multimodal_payload
|
238 |
+
|
239 |
+
# --- 7. Create Inference Parameters ---
|
240 |
+
infer_param = RKLLMInferParam()
|
241 |
+
infer_param.mode = RKLLMInferMode.RKLLM_INFER_GENERATE.value # Ensure this is an int for C API
|
242 |
+
# infer_param.keep_history = 1 # Or 0, default is usually 0 (false) in create_default_param or C struct.
|
243 |
+
# Check rkllm.h or binding for default if not setting explicitly.
|
244 |
+
# RKLLMInferParam from binding has keep_history as c_int.
|
245 |
+
|
246 |
+
# --- 8. Run RKLLM Inference ---
|
247 |
+
print("Starting RKLLM inference...")
|
248 |
+
inference_start_time = time.time()
|
249 |
+
inference_count = 0
|
250 |
+
first_token_received = False
|
251 |
+
|
252 |
+
try:
|
253 |
+
# The RKLLMRuntime.run method takes input and infer_param objects directly.
|
254 |
+
rk_runtime.run(rkllm_input, infer_param, None) # Userdata is None
|
255 |
+
except RuntimeError as e:
|
256 |
+
print(f"RKLLM run failed: {e}")
|
257 |
+
|
258 |
+
# --- 9. Clean up ---
|
259 |
+
# Normal cleanup if not interrupted by Ctrl-C.
|
260 |
+
# The signal handler also attempts to destroy the instance.
|
261 |
+
if rk_runtime and rk_runtime.llm_handle and rk_runtime.llm_handle.value:
|
262 |
+
try:
|
263 |
+
rk_runtime.destroy()
|
264 |
+
print("RKLLM instance destroyed at script end.")
|
265 |
+
except RuntimeError as e:
|
266 |
+
print(f"Error during RKLLM destroy at script end: {e}")
|
267 |
+
except Exception as e:
|
268 |
+
print(f"Unexpected error during RKLLM destroy at script end: {e}")
|
269 |
+
|
270 |
+
print("Script finished.")
|
271 |
+
|
272 |
+
if __name__ == "__main__":
|
273 |
+
# rk_runtime (global) will be initialized inside main()
|
274 |
+
main()
|
test.jpg
ADDED
![]() |
Git LFS Details
|
ztu_somemodelruntime_rknnlite2.py
ADDED
@@ -0,0 +1,569 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 模块级常量和函数
|
2 |
+
from rknnlite.api import RKNNLite
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
import warnings
|
6 |
+
import logging
|
7 |
+
from typing import List, Dict, Union, Optional
|
8 |
+
|
9 |
+
try:
|
10 |
+
import onnxruntime as ort
|
11 |
+
HAS_ORT = True
|
12 |
+
except ImportError:
|
13 |
+
HAS_ORT = False
|
14 |
+
warnings.warn("onnxruntime未安装,只能使用RKNN后端", ImportWarning)
|
15 |
+
|
16 |
+
# 配置日志
|
17 |
+
logger = logging.getLogger("somemodelruntime_rknnlite2")
|
18 |
+
logger.setLevel(logging.ERROR) # 默认只输出错误信息
|
19 |
+
if not logger.handlers:
|
20 |
+
handler = logging.StreamHandler()
|
21 |
+
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
|
22 |
+
logger.addHandler(handler)
|
23 |
+
|
24 |
+
# ONNX Runtime日志级别到Python logging级别的映射
|
25 |
+
_LOGGING_LEVEL_MAP = {
|
26 |
+
0: logging.DEBUG, # Verbose
|
27 |
+
1: logging.INFO, # Info
|
28 |
+
2: logging.WARNING, # Warning
|
29 |
+
3: logging.ERROR, # Error
|
30 |
+
4: logging.CRITICAL # Fatal
|
31 |
+
}
|
32 |
+
|
33 |
+
# 检查环境变量中的日志级别设置
|
34 |
+
try:
|
35 |
+
env_log_level = os.getenv('ZTU_MODELRT_RKNNL2_LOG_LEVEL')
|
36 |
+
if env_log_level is not None:
|
37 |
+
log_level = int(env_log_level)
|
38 |
+
if log_level in _LOGGING_LEVEL_MAP:
|
39 |
+
logger.setLevel(_LOGGING_LEVEL_MAP[log_level])
|
40 |
+
logger.info(f"从环境变量设置日志级别: {log_level}")
|
41 |
+
else:
|
42 |
+
logger.warning(f"环境变量ZTU_MODELRT_RKNNL2_LOG_LEVEL的值无效: {log_level}, 应该是0-4之间的整数")
|
43 |
+
except ValueError:
|
44 |
+
logger.warning(f"环境变量ZTU_MODELRT_RKNNL2_LOG_LEVEL的值无效: {env_log_level}, 应该是0-4之间的整数")
|
45 |
+
|
46 |
+
|
47 |
+
def set_default_logger_severity(level: int) -> None:
|
48 |
+
"""
|
49 |
+
Sets the default logging severity. 0:Verbose, 1:Info, 2:Warning, 3:Error, 4:Fatal
|
50 |
+
|
51 |
+
Args:
|
52 |
+
level: 日志级别(0-4)
|
53 |
+
"""
|
54 |
+
if level not in _LOGGING_LEVEL_MAP:
|
55 |
+
raise ValueError(f"无效的日志级别: {level}, 应该是0-4之间的整数")
|
56 |
+
logger.setLevel(_LOGGING_LEVEL_MAP[level])
|
57 |
+
|
58 |
+
def set_default_logger_verbosity(level: int) -> None:
|
59 |
+
"""
|
60 |
+
Sets the default logging verbosity level. To activate the verbose log,
|
61 |
+
you need to set the default logging severity to 0:Verbose level.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
level: 日志级别(0-4)
|
65 |
+
"""
|
66 |
+
set_default_logger_severity(level)
|
67 |
+
|
68 |
+
# RKNN tensor type到numpy dtype的映射
|
69 |
+
RKNN_DTYPE_MAP = {
|
70 |
+
0: np.float32, # RKNN_TENSOR_FLOAT32
|
71 |
+
1: np.float16, # RKNN_TENSOR_FLOAT16
|
72 |
+
2: np.int8, # RKNN_TENSOR_INT8
|
73 |
+
3: np.uint8, # RKNN_TENSOR_UINT8
|
74 |
+
4: np.int16, # RKNN_TENSOR_INT16
|
75 |
+
5: np.uint16, # RKNN_TENSOR_UINT16
|
76 |
+
6: np.int32, # RKNN_TENSOR_INT32
|
77 |
+
7: np.uint32, # RKNN_TENSOR_UINT32
|
78 |
+
8: np.int64, # RKNN_TENSOR_INT64
|
79 |
+
9: bool, # RKNN_TENSOR_BOOL
|
80 |
+
10: np.int8, # RKNN_TENSOR_INT4 (用int8表示)
|
81 |
+
}
|
82 |
+
|
83 |
+
def get_available_providers() -> List[str]:
|
84 |
+
"""
|
85 |
+
获取可用的设备提供者列表(为保持接口兼容性的占位函数)
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
list: 可用的设备提供者列表,总是返回["CPUExecutionProvider", "somemodelruntime_rknnlite2_ExecutionProvider"]
|
89 |
+
"""
|
90 |
+
return ["CPUExecutionProvider", "somemodelruntime_rknnlite2_ExecutionProvider"]
|
91 |
+
|
92 |
+
|
93 |
+
def get_device() -> str:
|
94 |
+
"""
|
95 |
+
获取当前设备
|
96 |
+
|
97 |
+
Returns:
|
98 |
+
str: 当前设备
|
99 |
+
"""
|
100 |
+
return "RKNN2"
|
101 |
+
|
102 |
+
def get_version_info() -> Dict[str, str]:
|
103 |
+
"""
|
104 |
+
获取版本信息
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
dict: 包含API和驱动版本信息的字典
|
108 |
+
"""
|
109 |
+
runtime = RKNNLite()
|
110 |
+
version = runtime.get_sdk_version()
|
111 |
+
return {
|
112 |
+
"api_version": version.split('\n')[2].split(': ')[1].split(' ')[0],
|
113 |
+
"driver_version": version.split('\n')[3].split(': ')[1]
|
114 |
+
}
|
115 |
+
|
116 |
+
class IOTensor:
|
117 |
+
"""输入/输出张量的信息封装类"""
|
118 |
+
def __init__(self, name, shape, type=None):
|
119 |
+
self.name = name.decode() if isinstance(name, bytes) else name
|
120 |
+
self.shape = shape
|
121 |
+
self.type = type
|
122 |
+
|
123 |
+
def __str__(self):
|
124 |
+
return f"IOTensor(name='{self.name}', shape={self.shape}, type={self.type})"
|
125 |
+
|
126 |
+
class SessionOptions:
|
127 |
+
"""会话选项类"""
|
128 |
+
def __init__(self):
|
129 |
+
self.enable_profiling = False # 是否使用性能分析
|
130 |
+
self.intra_op_num_threads = 1 # 设置RKNN的线程数, 对应rknn的core_mask
|
131 |
+
self.log_severity_level = -1 # 另一个设置日志级别的参数
|
132 |
+
self.log_verbosity_level = -1 # 另一个设置日志级别的参数
|
133 |
+
|
134 |
+
|
135 |
+
class InferenceSession:
|
136 |
+
"""
|
137 |
+
RKNNLite运行时封装类,API风格类似ONNX Runtime
|
138 |
+
"""
|
139 |
+
|
140 |
+
def __new__(cls, model_path: str, sess_options: Optional[SessionOptions] = None, **kwargs):
|
141 |
+
processed_path = InferenceSession._process_model_path(model_path, sess_options)
|
142 |
+
if isinstance(processed_path, str) and processed_path.lower().endswith('.onnx'):
|
143 |
+
logger.info("使用ONNX Runtime加载模型")
|
144 |
+
if not HAS_ORT:
|
145 |
+
raise RuntimeError("未安装onnxruntime,无法加载ONNX模型")
|
146 |
+
return ort.InferenceSession(processed_path, sess_options=sess_options, **kwargs)
|
147 |
+
else:
|
148 |
+
# 如果不是 ONNX 模型,则调用父类的 __new__ 创建 InferenceSession 实例
|
149 |
+
instance = super().__new__(cls)
|
150 |
+
# 保存处理后的路径
|
151 |
+
instance._processed_path = processed_path
|
152 |
+
return instance
|
153 |
+
|
154 |
+
def __init__(self, model_path: str, sess_options: Optional[SessionOptions] = None, **kwargs):
|
155 |
+
"""
|
156 |
+
初始化运行时并加载模型
|
157 |
+
|
158 |
+
Args:
|
159 |
+
model_path: 模型文件路径(.rknn或.onnx)
|
160 |
+
sess_options: 会话选项
|
161 |
+
**kwargs: 其他初始化参数
|
162 |
+
"""
|
163 |
+
options = sess_options or SessionOptions()
|
164 |
+
|
165 |
+
# 只在未设置环境变量时使用SessionOptions中的日志级别
|
166 |
+
if os.getenv('ZTU_MODELRT_RKNNL2_LOG_LEVEL') is None:
|
167 |
+
if options.log_severity_level != -1:
|
168 |
+
set_default_logger_severity(options.log_severity_level)
|
169 |
+
if options.log_verbosity_level != -1:
|
170 |
+
set_default_logger_verbosity(options.log_verbosity_level)
|
171 |
+
|
172 |
+
# 使用__new__中处理好的路径
|
173 |
+
model_path = getattr(self, '_processed_path', model_path)
|
174 |
+
if isinstance(model_path, str) and model_path.lower().endswith('.onnx'):
|
175 |
+
# 避免重复加载 ONNX 模型
|
176 |
+
return
|
177 |
+
|
178 |
+
# ... 现有的 RKNN 模型加载和初始化代码 ...
|
179 |
+
self.model_path = model_path
|
180 |
+
if not os.path.exists(self.model_path):
|
181 |
+
logger.error(f"模型文件不存在: {self.model_path}")
|
182 |
+
raise FileNotFoundError(f"模型文件不存在: {self.model_path}")
|
183 |
+
|
184 |
+
self.runtime = RKNNLite(verbose=options.enable_profiling)
|
185 |
+
|
186 |
+
logger.debug(f"正在加载模型: {self.model_path}")
|
187 |
+
ret = self.runtime.load_rknn(self.model_path)
|
188 |
+
if ret != 0:
|
189 |
+
logger.error(f"加载RKNN模型失败: {self.model_path}")
|
190 |
+
raise RuntimeError(f'加载RKNN模型失败: {self.model_path}')
|
191 |
+
logger.debug("模型加载成功")
|
192 |
+
|
193 |
+
|
194 |
+
if options.intra_op_num_threads == 1:
|
195 |
+
core_mask = RKNNLite.NPU_CORE_AUTO
|
196 |
+
elif options.intra_op_num_threads == 2:
|
197 |
+
core_mask = RKNNLite.NPU_CORE_0_1
|
198 |
+
elif options.intra_op_num_threads == 3:
|
199 |
+
core_mask = RKNNLite.NPU_CORE_0_1_2
|
200 |
+
else:
|
201 |
+
raise ValueError(f"intra_op_num_threads的值无效: {options.intra_op_num_threads}, 只能是1,2或3")
|
202 |
+
|
203 |
+
logger.debug("正在初始化运行时环境")
|
204 |
+
ret = self.runtime.init_runtime(core_mask=core_mask)
|
205 |
+
if ret != 0:
|
206 |
+
logger.error("初始化运行时环境失败")
|
207 |
+
raise RuntimeError('初始化运行时环境失败')
|
208 |
+
logger.debug("运行时环境初始化成功")
|
209 |
+
|
210 |
+
self._init_io_info()
|
211 |
+
self.options = options
|
212 |
+
|
213 |
+
def get_performance_info(self) -> Dict[str, float]:
|
214 |
+
"""
|
215 |
+
获取性能信息
|
216 |
+
|
217 |
+
Returns:
|
218 |
+
dict: 包含性能信息的字典
|
219 |
+
"""
|
220 |
+
if not self.options.perf_debug:
|
221 |
+
raise RuntimeError("性能分析未启用,请在SessionOptions中设置perf_debug=True")
|
222 |
+
|
223 |
+
perf = self.runtime.rknn_runtime.get_run_perf()
|
224 |
+
return {
|
225 |
+
"run_duration": perf.run_duration / 1000.0 # 转换为毫秒
|
226 |
+
}
|
227 |
+
|
228 |
+
def set_core_mask(self, core_mask: int) -> None:
|
229 |
+
"""
|
230 |
+
设置NPU核心使用模式
|
231 |
+
|
232 |
+
Args:
|
233 |
+
core_mask: NPU核心掩码,使用NPU_CORE_*常量
|
234 |
+
"""
|
235 |
+
ret = self.runtime.rknn_runtime.set_core_mask(core_mask)
|
236 |
+
if ret != 0:
|
237 |
+
raise RuntimeError("设置NPU核心模式失败")
|
238 |
+
|
239 |
+
@staticmethod
|
240 |
+
def _process_model_path(model_path, sess_options):
|
241 |
+
"""
|
242 |
+
处理模型路径,支持.onnx和.rknn文件
|
243 |
+
|
244 |
+
Args:
|
245 |
+
model_path: 模型文件路径
|
246 |
+
"""
|
247 |
+
# 如果是ONNX文件,检查是否需要自动加载RKNN
|
248 |
+
if model_path.lower().endswith('.onnx'):
|
249 |
+
logger.info("检测到ONNX模型文件")
|
250 |
+
|
251 |
+
# 获取需要跳过自动加载的模型列表
|
252 |
+
skip_models = os.getenv('ZTU_MODELRT_RKNNL2_SKIP', '').strip()
|
253 |
+
if skip_models:
|
254 |
+
skip_list = [m.strip() for m in skip_models.split(',')]
|
255 |
+
# 获取模型文件名(不含路径)用于匹配
|
256 |
+
model_name = os.path.basename(model_path)
|
257 |
+
if model_name.lower() in [m.lower() for m in skip_list]:
|
258 |
+
logger.info(f"模型{model_name}在跳过列表中,将使用ONNX Runtime")
|
259 |
+
return model_path
|
260 |
+
|
261 |
+
# 构造RKNN文件路径
|
262 |
+
rknn_path = os.path.splitext(model_path)[0] + '.rknn'
|
263 |
+
if os.path.exists(rknn_path):
|
264 |
+
logger.info(f"找到对应的RKNN模型,将使用RKNN: {rknn_path}")
|
265 |
+
return rknn_path
|
266 |
+
else:
|
267 |
+
logger.info("未找到对应的RKNN模型,将使用ONNX Runtime")
|
268 |
+
return model_path
|
269 |
+
|
270 |
+
return model_path
|
271 |
+
|
272 |
+
def _convert_nhwc_to_nchw(self, shape):
|
273 |
+
"""将NHWC格式的shape转换为NCHW格式"""
|
274 |
+
if len(shape) == 4:
|
275 |
+
# NHWC -> NCHW
|
276 |
+
n, h, w, c = shape
|
277 |
+
return [n, c, h, w]
|
278 |
+
return shape
|
279 |
+
|
280 |
+
def _init_io_info(self):
|
281 |
+
"""初始化模型的输入输出信息"""
|
282 |
+
runtime = self.runtime.rknn_runtime
|
283 |
+
|
284 |
+
# 获取输入输出数量
|
285 |
+
n_input, n_output = runtime.get_in_out_num()
|
286 |
+
|
287 |
+
# 获取输入信息
|
288 |
+
self.input_tensors = []
|
289 |
+
for i in range(n_input):
|
290 |
+
attr = runtime.get_tensor_attr(i)
|
291 |
+
shape = [attr.dims[j] for j in range(attr.n_dims)]
|
292 |
+
# 对四维输入进行NHWC到NCHW的转换
|
293 |
+
shape = self._convert_nhwc_to_nchw(shape)
|
294 |
+
# 获取dtype
|
295 |
+
dtype = RKNN_DTYPE_MAP.get(attr.type, None)
|
296 |
+
tensor = IOTensor(attr.name, shape, dtype)
|
297 |
+
self.input_tensors.append(tensor)
|
298 |
+
|
299 |
+
# 获取输出信息
|
300 |
+
self.output_tensors = []
|
301 |
+
for i in range(n_output):
|
302 |
+
attr = runtime.get_tensor_attr(i, is_output=True)
|
303 |
+
shape = runtime.get_output_shape(i)
|
304 |
+
# 获取dtype
|
305 |
+
dtype = RKNN_DTYPE_MAP.get(attr.type, None)
|
306 |
+
tensor = IOTensor(attr.name, shape, dtype)
|
307 |
+
self.output_tensors.append(tensor)
|
308 |
+
|
309 |
+
def get_inputs(self):
|
310 |
+
"""
|
311 |
+
获取模型输入信息
|
312 |
+
|
313 |
+
Returns:
|
314 |
+
list: 包含输入信息的列表
|
315 |
+
"""
|
316 |
+
return self.input_tensors
|
317 |
+
|
318 |
+
def get_outputs(self):
|
319 |
+
"""
|
320 |
+
获取模型输出信息
|
321 |
+
|
322 |
+
Returns:
|
323 |
+
list: 包含输出信息的列表
|
324 |
+
"""
|
325 |
+
return self.output_tensors
|
326 |
+
|
327 |
+
def run(self, output_names=None, input_feed=None, data_format="nchw", **kwargs):
|
328 |
+
"""
|
329 |
+
执行模型推理
|
330 |
+
|
331 |
+
Args:
|
332 |
+
output_names: 输出节点名称列表,指定需要返回哪些输出
|
333 |
+
input_feed: 输入数据字典或列表
|
334 |
+
data_format: 输入数据格式,"nchw"或"nhwc"
|
335 |
+
**kwargs: 其他运行时参数
|
336 |
+
|
337 |
+
Returns:
|
338 |
+
list: 模型输出结果列表,如果指定了output_names则只返回指定的输出
|
339 |
+
"""
|
340 |
+
if input_feed is None:
|
341 |
+
logger.error("input_feed不能为None")
|
342 |
+
raise ValueError("input_feed不能为None")
|
343 |
+
|
344 |
+
# 准备输入数据
|
345 |
+
if isinstance(input_feed, dict):
|
346 |
+
# 如果是字典,按照模型输入顺序排列
|
347 |
+
inputs = []
|
348 |
+
input_map = {tensor.name: i for i, tensor in enumerate(self.input_tensors)}
|
349 |
+
for tensor in self.input_tensors:
|
350 |
+
if tensor.name not in input_feed:
|
351 |
+
raise ValueError(f"缺少输入: {tensor.name}")
|
352 |
+
inputs.append(input_feed[tensor.name])
|
353 |
+
elif isinstance(input_feed, (list, tuple)):
|
354 |
+
# 如果是列表,确保长度匹配
|
355 |
+
if len(input_feed) != len(self.input_tensors):
|
356 |
+
raise ValueError(f"输入数量不匹配: 期望{len(self.input_tensors)}, 实际{len(input_feed)}")
|
357 |
+
inputs = list(input_feed)
|
358 |
+
else:
|
359 |
+
logger.error("input_feed必须是字典或列表类型")
|
360 |
+
raise ValueError("input_feed必须是字典或列表类型")
|
361 |
+
|
362 |
+
# 执行推理
|
363 |
+
try:
|
364 |
+
logger.debug("开始执行推理")
|
365 |
+
all_outputs = self.runtime.inference(inputs=inputs, data_format=data_format)
|
366 |
+
|
367 |
+
# 如果没有指定output_names,返回所有输出
|
368 |
+
if output_names is None:
|
369 |
+
return all_outputs
|
370 |
+
|
371 |
+
# 获取指定的输出
|
372 |
+
output_map = {tensor.name: i for i, tensor in enumerate(self.output_tensors)}
|
373 |
+
selected_outputs = []
|
374 |
+
for name in output_names:
|
375 |
+
if name not in output_map:
|
376 |
+
raise ValueError(f"未找到输出节点: {name}")
|
377 |
+
selected_outputs.append(all_outputs[output_map[name]])
|
378 |
+
|
379 |
+
return selected_outputs
|
380 |
+
|
381 |
+
except Exception as e:
|
382 |
+
logger.error(f"推理执行失败: {str(e)}")
|
383 |
+
raise RuntimeError(f"推理执行失败: {str(e)}")
|
384 |
+
|
385 |
+
def close(self):
|
386 |
+
"""
|
387 |
+
关闭会话,释放资源
|
388 |
+
"""
|
389 |
+
if self.runtime is not None:
|
390 |
+
logger.info("正在释放运行时资源")
|
391 |
+
self.runtime.release()
|
392 |
+
self.runtime = None
|
393 |
+
|
394 |
+
def __enter__(self):
|
395 |
+
return self
|
396 |
+
|
397 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
398 |
+
self.close()
|
399 |
+
|
400 |
+
def end_profiling(self) -> Optional[str]:
|
401 |
+
"""
|
402 |
+
结束性能分析的存根方法
|
403 |
+
|
404 |
+
Returns:
|
405 |
+
Optional[str]: None
|
406 |
+
"""
|
407 |
+
warnings.warn("end_profiling()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
408 |
+
return None
|
409 |
+
|
410 |
+
def get_profiling_start_time_ns(self) -> int:
|
411 |
+
"""
|
412 |
+
获取性能分析开始时间的存根方法
|
413 |
+
|
414 |
+
Returns:
|
415 |
+
int: 0
|
416 |
+
"""
|
417 |
+
warnings.warn("get_profiling_start_time_ns()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
418 |
+
return 0
|
419 |
+
|
420 |
+
def get_modelmeta(self) -> Dict[str, str]:
|
421 |
+
"""
|
422 |
+
获取模型元数据的存根方法
|
423 |
+
|
424 |
+
Returns:
|
425 |
+
Dict[str, str]: 空字典
|
426 |
+
"""
|
427 |
+
warnings.warn("get_modelmeta()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
428 |
+
return {}
|
429 |
+
|
430 |
+
def get_session_options(self) -> SessionOptions:
|
431 |
+
"""
|
432 |
+
获取会话选项
|
433 |
+
|
434 |
+
Returns:
|
435 |
+
SessionOptions: 当前会话选项
|
436 |
+
"""
|
437 |
+
return self.options
|
438 |
+
|
439 |
+
def get_providers(self) -> List[str]:
|
440 |
+
"""
|
441 |
+
获取当前使用的providers的存根方法
|
442 |
+
|
443 |
+
Returns:
|
444 |
+
List[str]: ["CPUExecutionProvider"]
|
445 |
+
"""
|
446 |
+
warnings.warn("get_providers()是存根方法,始终返回CPUExecutionProvider", RuntimeWarning, stacklevel=2)
|
447 |
+
return ["CPUExecutionProvider"]
|
448 |
+
|
449 |
+
def get_provider_options(self) -> Dict[str, Dict[str, str]]:
|
450 |
+
"""
|
451 |
+
获取provider选项的存根方法
|
452 |
+
|
453 |
+
Returns:
|
454 |
+
Dict[str, Dict[str, str]]: 空字典
|
455 |
+
"""
|
456 |
+
warnings.warn("get_provider_options()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
457 |
+
return {}
|
458 |
+
|
459 |
+
def get_session_config(self) -> Dict[str, str]:
|
460 |
+
"""
|
461 |
+
获取会话配置的存根方法
|
462 |
+
|
463 |
+
Returns:
|
464 |
+
Dict[str, str]: 空字典
|
465 |
+
"""
|
466 |
+
warnings.warn("get_session_config()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
467 |
+
return {}
|
468 |
+
|
469 |
+
def get_session_state(self) -> Dict[str, str]:
|
470 |
+
"""
|
471 |
+
获取会话状态的存根方法
|
472 |
+
|
473 |
+
Returns:
|
474 |
+
Dict[str, str]: 空字典
|
475 |
+
"""
|
476 |
+
warnings.warn("get_session_state()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
477 |
+
return {}
|
478 |
+
|
479 |
+
def set_session_config(self, config: Dict[str, str]) -> None:
|
480 |
+
"""
|
481 |
+
设置会话配置的存根方法
|
482 |
+
|
483 |
+
Args:
|
484 |
+
config: 会话配置字典
|
485 |
+
"""
|
486 |
+
warnings.warn("set_session_config()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
487 |
+
|
488 |
+
def get_memory_info(self) -> Dict[str, int]:
|
489 |
+
"""
|
490 |
+
获取内存使用信息的存根方法
|
491 |
+
|
492 |
+
Returns:
|
493 |
+
Dict[str, int]: 空字典
|
494 |
+
"""
|
495 |
+
warnings.warn("get_memory_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
496 |
+
return {}
|
497 |
+
|
498 |
+
def set_memory_pattern(self, enable: bool) -> None:
|
499 |
+
"""
|
500 |
+
设置内存模式的存根方法
|
501 |
+
|
502 |
+
Args:
|
503 |
+
enable: 是否启用内存模式
|
504 |
+
"""
|
505 |
+
warnings.warn("set_memory_pattern()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
506 |
+
|
507 |
+
def disable_memory_pattern(self) -> None:
|
508 |
+
"""
|
509 |
+
禁用内存模式的存根方法
|
510 |
+
"""
|
511 |
+
warnings.warn("disable_memory_pattern()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
512 |
+
|
513 |
+
def get_optimization_level(self) -> int:
|
514 |
+
"""
|
515 |
+
获取优化级别的存根方法
|
516 |
+
|
517 |
+
Returns:
|
518 |
+
int: 0
|
519 |
+
"""
|
520 |
+
warnings.warn("get_optimization_level()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
521 |
+
return 0
|
522 |
+
|
523 |
+
def set_optimization_level(self, level: int) -> None:
|
524 |
+
"""
|
525 |
+
设置优化级别的存根方法
|
526 |
+
|
527 |
+
Args:
|
528 |
+
level: 优化级别
|
529 |
+
"""
|
530 |
+
warnings.warn("set_optimization_level()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
531 |
+
|
532 |
+
def get_model_metadata(self) -> Dict[str, str]:
|
533 |
+
"""
|
534 |
+
获取模型元数据的存根方法(与get_modelmeta不同的接口)
|
535 |
+
|
536 |
+
Returns:
|
537 |
+
Dict[str, str]: 空字典
|
538 |
+
"""
|
539 |
+
warnings.warn("get_model_metadata()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
540 |
+
return {}
|
541 |
+
|
542 |
+
def get_model_path(self) -> str:
|
543 |
+
"""
|
544 |
+
获取模型路径
|
545 |
+
|
546 |
+
Returns:
|
547 |
+
str: 模型文件路径
|
548 |
+
"""
|
549 |
+
return self.model_path
|
550 |
+
|
551 |
+
def get_input_type_info(self) -> List[Dict[str, str]]:
|
552 |
+
"""
|
553 |
+
获取输入类型信息的存根方法
|
554 |
+
|
555 |
+
Returns:
|
556 |
+
List[Dict[str, str]]: 空列表
|
557 |
+
"""
|
558 |
+
warnings.warn("get_input_type_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
559 |
+
return []
|
560 |
+
|
561 |
+
def get_output_type_info(self) -> List[Dict[str, str]]:
|
562 |
+
"""
|
563 |
+
获取输出类型信息的存根方法
|
564 |
+
|
565 |
+
Returns:
|
566 |
+
List[Dict[str, str]]: 空列表
|
567 |
+
"""
|
568 |
+
warnings.warn("get_output_type_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
569 |
+
return []
|