happyme531 commited on
Commit
6128fc3
·
verified ·
1 Parent(s): 27253ba

Upload 12 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ fastvithd.rknn filter=lfs diff=lfs merge=lfs -text
37
+ librkllmrt.so filter=lfs diff=lfs merge=lfs -text
38
+ mm_projector.rknn filter=lfs diff=lfs merge=lfs -text
39
+ qwen_f16.rkllm filter=lfs diff=lfs merge=lfs -text
40
+ test.jpg filter=lfs diff=lfs merge=lfs -text
convert_fastvithd.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # ztu_somemodelruntime_rknn2: fastvithd
3
+
4
+ from rknn.api import RKNN
5
+ import os
6
+ import numpy as np
7
+
8
+ def main():
9
+ # 创建RKNN实例
10
+ rknn = RKNN(verbose=False)
11
+
12
+ # ONNX模型路径
13
+ ONNX_MODEL = "fastvithd.onnx"
14
+ # 输出RKNN模型路径
15
+ RKNN_MODEL = "fastvithd.rknn"
16
+
17
+ # 配置参数
18
+ print("--> Config model")
19
+ ret = rknn.config(target_platform="rk3588",
20
+ dynamic_input=None)
21
+ if ret != 0:
22
+ print('Config model failed!')
23
+ exit(ret)
24
+
25
+ # 加载ONNX模型
26
+ print("--> Loading model")
27
+ ret = rknn.load_onnx(model=ONNX_MODEL,
28
+ inputs=['pixel_values'],
29
+ input_size_list=[[1, 3, 1024, 1024]])
30
+ if ret != 0:
31
+ print('Load model failed!')
32
+ exit(ret)
33
+
34
+ # 构建模型
35
+ print("--> Building model")
36
+ ret = rknn.build(do_quantization=False)
37
+ if ret != 0:
38
+ print('Build model failed!')
39
+ exit(ret)
40
+
41
+ # 导出RKNN模型
42
+ print("--> Export RKNN model")
43
+ ret = rknn.export_rknn(RKNN_MODEL)
44
+ if ret != 0:
45
+ print('Export RKNN model failed!')
46
+ exit(ret)
47
+
48
+ print(f'Done! The converted RKNN model has been saved to: ' + RKNN_MODEL)
49
+ rknn.release()
50
+
51
+ if __name__ == '__main__':
52
+ main()
convert_mm_projector.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # ztu_somemodelruntime_rknn2: mm_projector
3
+
4
+ from rknn.api import RKNN
5
+ import os
6
+ import numpy as np
7
+
8
+ def main():
9
+ # 创建RKNN实例
10
+ rknn = RKNN(verbose=False)
11
+
12
+ # ONNX模型路径
13
+ ONNX_MODEL = "mm_projector.onnx"
14
+ # 输出RKNN模型路径
15
+ RKNN_MODEL = "mm_projector.rknn"
16
+
17
+ # 配置参数
18
+ print("--> Config model")
19
+ ret = rknn.config(target_platform="rk3588",
20
+ dynamic_input=None)
21
+ if ret != 0:
22
+ print('Config model failed!')
23
+ exit(ret)
24
+
25
+ # 加载ONNX模型
26
+ print("--> Loading model")
27
+ ret = rknn.load_onnx(model=ONNX_MODEL,
28
+ inputs=['last_hidden_state'],
29
+ input_size_list=[[1, 256, 3072]])
30
+ if ret != 0:
31
+ print('Load model failed!')
32
+ exit(ret)
33
+
34
+ # 构建模型
35
+ print("--> Building model")
36
+ ret = rknn.build(do_quantization=False)
37
+ if ret != 0:
38
+ print('Build model failed!')
39
+ exit(ret)
40
+
41
+ # 导出RKNN模型
42
+ print("--> Export RKNN model")
43
+ ret = rknn.export_rknn(RKNN_MODEL)
44
+ if ret != 0:
45
+ print('Export RKNN model failed!')
46
+ exit(ret)
47
+
48
+ print(f'Done! The converted RKNN model has been saved to: ' + RKNN_MODEL)
49
+ rknn.release()
50
+
51
+ if __name__ == '__main__':
52
+ main()
export_onnx.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # For licensing see accompanying LICENSE file.
3
+ # Copyright (C) 2025 Apple Inc. All Rights Reserved.
4
+ #
5
+ import os
6
+ import json
7
+ import copy
8
+ import argparse
9
+
10
+ import torch
11
+
12
+ from llava.model.builder import load_pretrained_model
13
+ from llava.utils import disable_torch_init
14
+ from llava.mm_utils import get_model_name_from_path
15
+
16
+
17
+ def export(args):
18
+ # Load model
19
+ disable_torch_init()
20
+ model_path = os.path.expanduser(args.model_path)
21
+ model_name = get_model_name_from_path(model_path)
22
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path,
23
+ args.model_base,
24
+ model_name,
25
+ device="cpu")
26
+
27
+ # Save extra metadata that is not saved during LLaVA training
28
+ # required by HF for auto-loading model and for mlx-vlm preprocessing
29
+
30
+ # Save image processing config
31
+ setattr(image_processor, "processor_class", "LlavaProcessor")
32
+ output_path = os.path.join(model_path, "preprocessor_config.json")
33
+ image_processor.to_json_file(output_path)
34
+
35
+ # Create processor config
36
+ processor_config = dict()
37
+ processor_config["image_token"] = "<image>"
38
+ processor_config["num_additional_image_tokens"] = 0
39
+ processor_config["processor_class"] = "LlavaProcessor"
40
+ processor_config["patch_size"] = 64
41
+ output_path = os.path.join(model_path, "processor_config.json")
42
+ json.dump(processor_config, open(output_path, "w"), indent=2)
43
+
44
+ # Modify tokenizer to include <image> special token.
45
+ tokenizer_config_path = os.path.join(model_path, "tokenizer_config.json")
46
+ tokenizer_config = json.load(open(tokenizer_config_path, 'r'))
47
+ token_ids = list()
48
+ image_token_is_present = False
49
+ for k, v in tokenizer_config['added_tokens_decoder'].items():
50
+ token_ids.append(int(k))
51
+ if v["content"] == "<image>":
52
+ image_token_is_present = True
53
+ token_ids.pop()
54
+
55
+ # Append only if <image> token is not present
56
+ if not image_token_is_present:
57
+ tokenizer_config['added_tokens_decoder'][f'{max(token_ids) + 1}'] = copy.deepcopy(
58
+ tokenizer_config['added_tokens_decoder'][f'{token_ids[0]}'])
59
+ tokenizer_config['added_tokens_decoder'][f'{max(token_ids) + 1}']["content"] = "<image>"
60
+ json.dump(tokenizer_config, open(tokenizer_config_path, 'w'), indent=2)
61
+
62
+ # Modify config to contain token id for <image>
63
+ config_path = os.path.join(model_path, "config.json")
64
+ model_config = json.load(open(config_path, 'r'))
65
+ model_config["image_token_index"] = max(token_ids) + 1
66
+ json.dump(model_config, open(config_path, 'w'), indent=2)
67
+
68
+ # Export the vision encoder to ONNX
69
+ image_res = image_processor.to_dict()['size']['shortest_edge']
70
+ dummy_vision_input = torch.rand(1, 3, image_res, image_res).float() # Dummy input tensor
71
+
72
+ vision_model = model.get_vision_tower()
73
+ # Ensure model is on CPU, in float precision, and in evaluation mode for ONNX export
74
+ vision_model = vision_model.cpu().float().eval()
75
+
76
+ onnx_vision_model_path = os.path.join(model_path, "fastvithd.onnx")
77
+
78
+ print(f"Exporting vision model to {onnx_vision_model_path}...")
79
+ torch.onnx.export(
80
+ vision_model,
81
+ dummy_vision_input, # Pass the dummy input tensor
82
+ onnx_vision_model_path,
83
+ input_names=['pixel_values'], # ONNX图中输入节点的名称
84
+ output_names=['last_hidden_state'], # ONNX图中输出节点的名称
85
+ # dynamic_axes={
86
+ # 'pixel_values': {0: 'batch_size'}, # 输入'pixel_values'的第0维是动态的batch_size
87
+ # 'last_hidden_state': {0: 'batch_size'} # 输出'last_hidden_state'的第0维是动态的batch_size
88
+ # },
89
+ opset_version=17, # ONNX opset 版本
90
+ export_params=True, # 在模型文件中存储训练好的参数权重
91
+ do_constant_folding=True # 执行常量折叠优化
92
+ )
93
+ print(f"Vision model ONNX export complete: {onnx_vision_model_path}")
94
+
95
+ # Generate dummy input for mm_projector by passing dummy_vision_input through vision_model
96
+ # This ensures the mm_projector receives input with the correct shape and characteristics
97
+ with torch.no_grad():
98
+ dummy_mm_projector_input = vision_model(dummy_vision_input)
99
+
100
+ # Ensure the input is on CPU and in float32 precision for the projector
101
+ dummy_mm_projector_input = dummy_mm_projector_input.cpu().float()
102
+
103
+ # Export the mm_projector to ONNX
104
+ # model.get_model() gives the underlying base model (e.g., LlavaLlamaModel)
105
+ # which contains the mm_projector attribute.
106
+ mm_projector = model.get_model().mm_projector
107
+ mm_projector = mm_projector.cpu().float().eval()
108
+
109
+ onnx_mm_projector_path = os.path.join(model_path, "mm_projector.onnx")
110
+
111
+ print(f"Exporting mm_projector to {onnx_mm_projector_path}...")
112
+ torch.onnx.export(
113
+ mm_projector,
114
+ dummy_mm_projector_input,
115
+ onnx_mm_projector_path,
116
+ input_names=['last_hidden_state'],
117
+ output_names=['projected_image_features'],
118
+ opset_version=17,
119
+ export_params=True,
120
+ do_constant_folding=True
121
+ )
122
+ print(f"mm_projector ONNX export complete: {onnx_mm_projector_path}")
123
+
124
+ # Removed CoreML specific code and intermediate .pt file handling
125
+ # No need for os.remove(pt_name) as pt_name is no longer created
126
+
127
+
128
+ if __name__ == "__main__":
129
+ parser = argparse.ArgumentParser()
130
+ parser.add_argument("--model-path", type=str, required=True)
131
+ parser.add_argument("--model-base", type=str, default=None)
132
+ parser.add_argument("--conv-mode", type=str, default="qwen_2")
133
+
134
+ args = parser.parse_args()
135
+
136
+ export(args)
fastvithd.rknn ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3d65f07758bd5fd610c76d28aae1d07a87bc65de0687f4fee5ef5c5e0f61d52a
3
+ size 372732105
librkllmrt.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f6a9c2de93cf94bb524eb071c27190ad4c83401e01b562534f265dff4cb40da2
3
+ size 6710712
mm_projector.rknn ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d3f6d9b06589c9aa7b70d28ef22703579d5d3c07c53e1b6ee72be85ac4ae7ee5
3
+ size 14272722
qwen_f16.rkllm ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7e13cc4849405b7338b5d21cef0f50a1776ea7c21594685e52665837cfec123c
3
+ size 3580141646
rkllm-convert.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rkllm.api import RKLLM
2
+
3
+ modelpath = '.'
4
+ llm = RKLLM()
5
+
6
+ ret = llm.load_huggingface(model=modelpath, model_lora=None, device='cpu')
7
+ if ret != 0:
8
+ print('Load model failed!')
9
+ exit(ret)
10
+
11
+ qparams = None
12
+ ret = llm.build(do_quantization=False, optimization_level=1, quantized_dtype='w8a8_g128',
13
+ quantized_algorithm='normal', target_platform='rk3588', num_npu_core=3, extra_qparams=qparams, dataset='calibration_dataset.json')
14
+
15
+ if ret != 0:
16
+ print('Build model failed!')
17
+ exit(ret)
18
+
19
+ # Export rkllm model
20
+ ret = llm.export_rkllm("./qwen_f16.rkllm")
21
+ if ret != 0:
22
+ print('Export model failed!')
23
+ exit(ret)
rkllm_binding.py ADDED
@@ -0,0 +1,510 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import enum
3
+ import os
4
+
5
+ # Define constants from the header
6
+ CPU0 = (1 << 0) # 0x01
7
+ CPU1 = (1 << 1) # 0x02
8
+ CPU2 = (1 << 2) # 0x04
9
+ CPU3 = (1 << 3) # 0x08
10
+ CPU4 = (1 << 4) # 0x10
11
+ CPU5 = (1 << 5) # 0x20
12
+ CPU6 = (1 << 6) # 0x40
13
+ CPU7 = (1 << 7) # 0x80
14
+
15
+ # --- Enums ---
16
+ class LLMCallState(enum.IntEnum):
17
+ RKLLM_RUN_NORMAL = 0
18
+ RKLLM_RUN_WAITING = 1
19
+ RKLLM_RUN_FINISH = 2
20
+ RKLLM_RUN_ERROR = 3
21
+
22
+ class RKLLMInputType(enum.IntEnum):
23
+ RKLLM_INPUT_PROMPT = 0
24
+ RKLLM_INPUT_TOKEN = 1
25
+ RKLLM_INPUT_EMBED = 2
26
+ RKLLM_INPUT_MULTIMODAL = 3
27
+
28
+ class RKLLMInferMode(enum.IntEnum):
29
+ RKLLM_INFER_GENERATE = 0
30
+ RKLLM_INFER_GET_LAST_HIDDEN_LAYER = 1
31
+ RKLLM_INFER_GET_LOGITS = 2
32
+
33
+ # --- Structures ---
34
+ class RKLLMExtendParam(ctypes.Structure):
35
+ _fields_ = [
36
+ ("base_domain_id", ctypes.c_int32),
37
+ ("embed_flash", ctypes.c_int8),
38
+ ("enabled_cpus_num", ctypes.c_int8),
39
+ ("enabled_cpus_mask", ctypes.c_uint32),
40
+ ("reserved", ctypes.c_uint8 * 106)
41
+ ]
42
+
43
+ class RKLLMParam(ctypes.Structure):
44
+ _fields_ = [
45
+ ("model_path", ctypes.c_char_p),
46
+ ("max_context_len", ctypes.c_int32),
47
+ ("max_new_tokens", ctypes.c_int32),
48
+ ("top_k", ctypes.c_int32),
49
+ ("n_keep", ctypes.c_int32),
50
+ ("top_p", ctypes.c_float),
51
+ ("temperature", ctypes.c_float),
52
+ ("repeat_penalty", ctypes.c_float),
53
+ ("frequency_penalty", ctypes.c_float),
54
+ ("presence_penalty", ctypes.c_float), # Note: This was missing in the provided text but is in typical LLM params
55
+ ("mirostat", ctypes.c_int32),
56
+ ("mirostat_tau", ctypes.c_float),
57
+ ("mirostat_eta", ctypes.c_float),
58
+ ("skip_special_token", ctypes.c_bool),
59
+ ("is_async", ctypes.c_bool),
60
+ ("img_start", ctypes.c_char_p),
61
+ ("img_end", ctypes.c_char_p),
62
+ ("img_content", ctypes.c_char_p), # This seems like it should be more structured for actual image data
63
+ ("extend_param", RKLLMExtendParam)
64
+ ]
65
+
66
+ class RKLLMLoraAdapter(ctypes.Structure):
67
+ _fields_ = [
68
+ ("lora_adapter_path", ctypes.c_char_p),
69
+ ("lora_adapter_name", ctypes.c_char_p),
70
+ ("scale", ctypes.c_float)
71
+ ]
72
+
73
+ class RKLLMEmbedInput(ctypes.Structure):
74
+ _fields_ = [
75
+ ("embed", ctypes.POINTER(ctypes.c_float)),
76
+ ("n_tokens", ctypes.c_size_t)
77
+ ]
78
+
79
+ class RKLLMTokenInput(ctypes.Structure):
80
+ _fields_ = [
81
+ ("input_ids", ctypes.POINTER(ctypes.c_int32)),
82
+ ("n_tokens", ctypes.c_size_t)
83
+ ]
84
+
85
+ class RKLLMMultiModelInput(ctypes.Structure):
86
+ _fields_ = [
87
+ ("prompt", ctypes.c_char_p),
88
+ ("image_embed", ctypes.POINTER(ctypes.c_float)),
89
+ ("n_image_tokens", ctypes.c_size_t),
90
+ ("n_image", ctypes.c_size_t),
91
+ ("image_width", ctypes.c_size_t),
92
+ ("image_height", ctypes.c_size_t)
93
+ ]
94
+
95
+ class _RKLLMInputUnion(ctypes.Union):
96
+ _fields_ = [
97
+ ("prompt_input", ctypes.c_char_p),
98
+ ("embed_input", RKLLMEmbedInput),
99
+ ("token_input", RKLLMTokenInput),
100
+ ("multimodal_input", RKLLMMultiModelInput)
101
+ ]
102
+
103
+ class RKLLMInput(ctypes.Structure):
104
+ _fields_ = [
105
+ ("input_type", ctypes.c_int), # Enum will be passed as int, changed RKLLMInputType to ctypes.c_int
106
+ ("_union_data", _RKLLMInputUnion)
107
+ ]
108
+ # Properties to make accessing union members easier
109
+ @property
110
+ def prompt_input(self):
111
+ if self.input_type == RKLLMInputType.RKLLM_INPUT_PROMPT:
112
+ return self._union_data.prompt_input
113
+ raise AttributeError("Not a prompt input")
114
+ @prompt_input.setter
115
+ def prompt_input(self, value):
116
+ if self.input_type == RKLLMInputType.RKLLM_INPUT_PROMPT:
117
+ self._union_data.prompt_input = value
118
+ else:
119
+ raise AttributeError("Not a prompt input")
120
+
121
+ # Similar properties can be added for embed_input, token_input, multimodal_input
122
+
123
+ class RKLLMLoraParam(ctypes.Structure): # For inference
124
+ _fields_ = [
125
+ ("lora_adapter_name", ctypes.c_char_p)
126
+ ]
127
+
128
+ class RKLLMPromptCacheParam(ctypes.Structure): # For inference
129
+ _fields_ = [
130
+ ("save_prompt_cache", ctypes.c_int), # bool-like
131
+ ("prompt_cache_path", ctypes.c_char_p)
132
+ ]
133
+
134
+ class RKLLMInferParam(ctypes.Structure):
135
+ _fields_ = [
136
+ ("mode", ctypes.c_int), # Enum will be passed as int, changed RKLLMInferMode to ctypes.c_int
137
+ ("lora_params", ctypes.POINTER(RKLLMLoraParam)),
138
+ ("prompt_cache_params", ctypes.POINTER(RKLLMPromptCacheParam)),
139
+ ("keep_history", ctypes.c_int) # bool-like
140
+ ]
141
+
142
+ class RKLLMResultLastHiddenLayer(ctypes.Structure):
143
+ _fields_ = [
144
+ ("hidden_states", ctypes.POINTER(ctypes.c_float)),
145
+ ("embd_size", ctypes.c_int),
146
+ ("num_tokens", ctypes.c_int)
147
+ ]
148
+
149
+ class RKLLMResultLogits(ctypes.Structure):
150
+ _fields_ = [
151
+ ("logits", ctypes.POINTER(ctypes.c_float)),
152
+ ("vocab_size", ctypes.c_int),
153
+ ("num_tokens", ctypes.c_int)
154
+ ]
155
+
156
+ class RKLLMResult(ctypes.Structure):
157
+ _fields_ = [
158
+ ("text", ctypes.c_char_p),
159
+ ("token_id", ctypes.c_int32),
160
+ ("last_hidden_layer", RKLLMResultLastHiddenLayer),
161
+ ("logits", RKLLMResultLogits)
162
+ ]
163
+
164
+ # --- Typedefs ---
165
+ LLMHandle = ctypes.c_void_p
166
+
167
+ # --- Callback Function Type ---
168
+ LLMResultCallback = ctypes.CFUNCTYPE(
169
+ None, # return type: void
170
+ ctypes.POINTER(RKLLMResult),
171
+ ctypes.c_void_p, # userdata
172
+ ctypes.c_int # enum, will be passed as int. Changed LLMCallState to ctypes.c_int
173
+ )
174
+
175
+
176
+ class RKLLMRuntime:
177
+ def __init__(self, library_path="./librkllmrt.so"):
178
+ try:
179
+ self.lib = ctypes.CDLL(library_path)
180
+ except OSError as e:
181
+ raise OSError(f"Failed to load RKLLM library from {library_path}. "
182
+ f"Ensure it's in your LD_LIBRARY_PATH or provide the full path. Error: {e}")
183
+ self._setup_functions()
184
+ self.llm_handle = LLMHandle()
185
+ self._c_callback = None # To keep the callback object alive
186
+
187
+ def _setup_functions(self):
188
+ # RKLLMParam rkllm_createDefaultParam();
189
+ self.lib.rkllm_createDefaultParam.restype = RKLLMParam
190
+ self.lib.rkllm_createDefaultParam.argtypes = []
191
+
192
+ # int rkllm_init(LLMHandle* handle, RKLLMParam* param, LLMResultCallback callback);
193
+ self.lib.rkllm_init.restype = ctypes.c_int
194
+ self.lib.rkllm_init.argtypes = [
195
+ ctypes.POINTER(LLMHandle),
196
+ ctypes.POINTER(RKLLMParam),
197
+ LLMResultCallback
198
+ ]
199
+
200
+ # int rkllm_load_lora(LLMHandle handle, RKLLMLoraAdapter* lora_adapter);
201
+ self.lib.rkllm_load_lora.restype = ctypes.c_int
202
+ self.lib.rkllm_load_lora.argtypes = [LLMHandle, ctypes.POINTER(RKLLMLoraAdapter)]
203
+
204
+ # int rkllm_load_prompt_cache(LLMHandle handle, const char* prompt_cache_path);
205
+ self.lib.rkllm_load_prompt_cache.restype = ctypes.c_int
206
+ self.lib.rkllm_load_prompt_cache.argtypes = [LLMHandle, ctypes.c_char_p]
207
+
208
+ # int rkllm_release_prompt_cache(LLMHandle handle);
209
+ self.lib.rkllm_release_prompt_cache.restype = ctypes.c_int
210
+ self.lib.rkllm_release_prompt_cache.argtypes = [LLMHandle]
211
+
212
+ # int rkllm_destroy(LLMHandle handle);
213
+ self.lib.rkllm_destroy.restype = ctypes.c_int
214
+ self.lib.rkllm_destroy.argtypes = [LLMHandle]
215
+
216
+ # int rkllm_run(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata);
217
+ self.lib.rkllm_run.restype = ctypes.c_int
218
+ self.lib.rkllm_run.argtypes = [
219
+ LLMHandle,
220
+ ctypes.POINTER(RKLLMInput),
221
+ ctypes.POINTER(RKLLMInferParam),
222
+ ctypes.c_void_p # userdata
223
+ ]
224
+
225
+ # int rkllm_run_async(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata);
226
+ # Assuming async also takes userdata for the callback context
227
+ self.lib.rkllm_run_async.restype = ctypes.c_int
228
+ self.lib.rkllm_run_async.argtypes = [
229
+ LLMHandle,
230
+ ctypes.POINTER(RKLLMInput),
231
+ ctypes.POINTER(RKLLMInferParam),
232
+ ctypes.c_void_p # userdata
233
+ ]
234
+
235
+ # int rkllm_abort(LLMHandle handle);
236
+ self.lib.rkllm_abort.restype = ctypes.c_int
237
+ self.lib.rkllm_abort.argtypes = [LLMHandle]
238
+
239
+ # int rkllm_is_running(LLMHandle handle);
240
+ self.lib.rkllm_is_running.restype = ctypes.c_int # 0 if running, non-zero otherwise
241
+ self.lib.rkllm_is_running.argtypes = [LLMHandle]
242
+
243
+ # int rkllm_clear_kv_cache(LLMHandle handle, int keep_system_prompt);
244
+ self.lib.rkllm_clear_kv_cache.restype = ctypes.c_int
245
+ self.lib.rkllm_clear_kv_cache.argtypes = [LLMHandle, ctypes.c_int]
246
+
247
+ # int rkllm_set_chat_template(LLMHandle handle, const char* system_prompt, const char* prompt_prefix, const char* prompt_postfix);
248
+ self.lib.rkllm_set_chat_template.restype = ctypes.c_int
249
+ self.lib.rkllm_set_chat_template.argtypes = [
250
+ LLMHandle,
251
+ ctypes.c_char_p,
252
+ ctypes.c_char_p,
253
+ ctypes.c_char_p
254
+ ]
255
+
256
+ def create_default_param(self) -> RKLLMParam:
257
+ """Creates a default RKLLMParam structure."""
258
+ return self.lib.rkllm_createDefaultParam()
259
+
260
+ def init(self, param: RKLLMParam, callback_func) -> int:
261
+ """
262
+ Initializes the LLM.
263
+ :param param: RKLLMParam structure.
264
+ :param callback_func: A Python function that matches the signature:
265
+ def my_callback(result_ptr, userdata_ptr, state_enum):
266
+ result = result_ptr.contents # RKLLMResult
267
+ # Process result
268
+ # userdata can be retrieved if passed during run, or ignored
269
+ # state = LLMCallState(state_enum)
270
+ :return: 0 for success, non-zero for failure.
271
+ """
272
+ if not callable(callback_func):
273
+ raise ValueError("callback_func must be a callable Python function.")
274
+
275
+ # Keep a reference to the ctypes callback object to prevent it from being garbage collected
276
+ self._c_callback = LLMResultCallback(callback_func)
277
+
278
+ ret = self.lib.rkllm_init(ctypes.byref(self.llm_handle), ctypes.byref(param), self._c_callback)
279
+ if ret != 0:
280
+ raise RuntimeError(f"rkllm_init failed with error code {ret}")
281
+ return ret
282
+
283
+ def load_lora(self, lora_adapter: RKLLMLoraAdapter) -> int:
284
+ """Loads a Lora adapter."""
285
+ ret = self.lib.rkllm_load_lora(self.llm_handle, ctypes.byref(lora_adapter))
286
+ if ret != 0:
287
+ raise RuntimeError(f"rkllm_load_lora failed with error code {ret}")
288
+ return ret
289
+
290
+ def load_prompt_cache(self, prompt_cache_path: str) -> int:
291
+ """Loads a prompt cache from a file."""
292
+ c_path = prompt_cache_path.encode('utf-8')
293
+ ret = self.lib.rkllm_load_prompt_cache(self.llm_handle, c_path)
294
+ if ret != 0:
295
+ raise RuntimeError(f"rkllm_load_prompt_cache failed for {prompt_cache_path} with error code {ret}")
296
+ return ret
297
+
298
+ def release_prompt_cache(self) -> int:
299
+ """Releases the prompt cache from memory."""
300
+ ret = self.lib.rkllm_release_prompt_cache(self.llm_handle)
301
+ if ret != 0:
302
+ raise RuntimeError(f"rkllm_release_prompt_cache failed with error code {ret}")
303
+ return ret
304
+
305
+ def destroy(self) -> int:
306
+ """Destroys the LLM instance and releases resources."""
307
+ if self.llm_handle and self.llm_handle.value: # Check if handle is not NULL
308
+ ret = self.lib.rkllm_destroy(self.llm_handle)
309
+ self.llm_handle = LLMHandle() # Reset handle
310
+ if ret != 0:
311
+ # Don't raise here as it might be called in __del__
312
+ print(f"Warning: rkllm_destroy failed with error code {ret}")
313
+ return ret
314
+ return 0 # Already destroyed or not initialized
315
+
316
+ def run(self, rkllm_input: RKLLMInput, rkllm_infer_params: RKLLMInferParam, userdata=None) -> int:
317
+ """Runs an LLM inference task synchronously."""
318
+ # userdata can be a ctypes.py_object if you want to pass Python objects,
319
+ # then cast to c_void_p. Or simply None.
320
+ c_userdata = ctypes.cast(ctypes.py_object(userdata), ctypes.c_void_p) if userdata is not None else None
321
+ ret = self.lib.rkllm_run(self.llm_handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), c_userdata)
322
+ if ret != 0:
323
+ raise RuntimeError(f"rkllm_run failed with error code {ret}")
324
+ return ret
325
+
326
+ def run_async(self, rkllm_input: RKLLMInput, rkllm_infer_params: RKLLMInferParam, userdata=None) -> int:
327
+ """Runs an LLM inference task asynchronously."""
328
+ c_userdata = ctypes.cast(ctypes.py_object(userdata), ctypes.c_void_p) if userdata is not None else None
329
+ ret = self.lib.rkllm_run_async(self.llm_handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), c_userdata)
330
+ if ret != 0:
331
+ raise RuntimeError(f"rkllm_run_async failed with error code {ret}")
332
+ return ret
333
+
334
+ def abort(self) -> int:
335
+ """Aborts an ongoing LLM task."""
336
+ ret = self.lib.rkllm_abort(self.llm_handle)
337
+ if ret != 0:
338
+ raise RuntimeError(f"rkllm_abort failed with error code {ret}")
339
+ return ret
340
+
341
+ def is_running(self) -> bool:
342
+ """Checks if an LLM task is currently running. Returns True if running."""
343
+ # The C API returns 0 if running, non-zero otherwise.
344
+ # This is a bit counter-intuitive for a boolean "is_running".
345
+ return self.lib.rkllm_is_running(self.llm_handle) == 0
346
+
347
+ def clear_kv_cache(self, keep_system_prompt: bool) -> int:
348
+ """Clears the key-value cache."""
349
+ ret = self.lib.rkllm_clear_kv_cache(self.llm_handle, ctypes.c_int(1 if keep_system_prompt else 0))
350
+ if ret != 0:
351
+ raise RuntimeError(f"rkllm_clear_kv_cache failed with error code {ret}")
352
+ return ret
353
+
354
+ def set_chat_template(self, system_prompt: str, prompt_prefix: str, prompt_postfix: str) -> int:
355
+ """Sets the chat template for the LLM."""
356
+ c_system = system_prompt.encode('utf-8') if system_prompt else None
357
+ c_prefix = prompt_prefix.encode('utf-8') if prompt_prefix else None
358
+ c_postfix = prompt_postfix.encode('utf-8') if prompt_postfix else None
359
+
360
+ ret = self.lib.rkllm_set_chat_template(self.llm_handle, c_system, c_prefix, c_postfix)
361
+ if ret != 0:
362
+ raise RuntimeError(f"rkllm_set_chat_template failed with error code {ret}")
363
+ return ret
364
+
365
+ def __enter__(self):
366
+ return self
367
+
368
+ def __exit__(self, exc_type, exc_val, exc_tb):
369
+ self.destroy()
370
+
371
+ def __del__(self):
372
+ self.destroy() # Ensure resources are freed if object is garbage collected
373
+
374
+ # --- Example Usage (Illustrative) ---
375
+ if __name__ == "__main__":
376
+ # This is a placeholder for how you might use it.
377
+ # You'll need a valid .rkllm model and librkllmrt.so in your path.
378
+
379
+ # Global list to store results from callback for demonstration
380
+ results_buffer = []
381
+
382
+ def my_python_callback(result_ptr, userdata_ptr, state_enum):
383
+ """
384
+ Callback function to be called by the C library.
385
+ """
386
+ global results_buffer
387
+ state = LLMCallState(state_enum)
388
+ result = result_ptr.contents
389
+
390
+ current_text = ""
391
+ if result.text: # Check if the char_p is not NULL
392
+ current_text = result.text.decode('utf-8', errors='ignore')
393
+
394
+ print(f"Callback: State={state.name}, TokenID={result.token_id}, Text='{current_text}'")
395
+ results_buffer.append(current_text)
396
+
397
+ if state == LLMCallState.RKLLM_RUN_FINISH:
398
+ print("Inference finished.")
399
+ elif state == LLMCallState.RKLLM_RUN_ERROR:
400
+ print("Inference error.")
401
+
402
+ # Example: Accessing logits if available (and if mode was set to get logits)
403
+ # if result.logits.logits and result.logits.vocab_size > 0:
404
+ # print(f" Logits (first 5 of vocab_size {result.logits.vocab_size}):")
405
+ # for i in range(min(5, result.logits.vocab_size)):
406
+ # print(f" {result.logits.logits[i]:.4f}", end=" ")
407
+ # print()
408
+
409
+
410
+ # --- Attempt to use the wrapper ---
411
+ try:
412
+ print("Initializing RKLLMRuntime...")
413
+ # Adjust library_path if librkllmrt.so is not in default search paths
414
+ # e.g., library_path="./path/to/librkllmrt.so"
415
+ rk_llm = RKLLMRuntime()
416
+
417
+ print("Creating default parameters...")
418
+ params = rk_llm.create_default_param()
419
+
420
+ # --- Configure parameters ---
421
+ # THIS IS CRITICAL: model_path must point to an actual .rkllm file
422
+ # For this example to run, you need a model file.
423
+ # Let's assume a dummy path for now, this will fail at init if not valid.
424
+ model_file = "dummy_model.rkllm"
425
+ if not os.path.exists(model_file):
426
+ print(f"Warning: Model file '{model_file}' does not exist. Init will likely fail.")
427
+ # Create a dummy file for the example to proceed further, though init will still fail
428
+ # with a real library unless it's a valid model.
429
+ with open(model_file, "w") as f:
430
+ f.write("dummy content")
431
+
432
+ params.model_path = model_file.encode('utf-8')
433
+ params.max_context_len = 512
434
+ params.max_new_tokens = 128
435
+ params.top_k = 1 # Greedy
436
+ params.temperature = 0.7
437
+ params.repeat_penalty = 1.1
438
+ # ... set other params as needed
439
+
440
+ print(f"Initializing LLM with model: {params.model_path.decode()}...")
441
+ # This will likely fail if dummy_model.rkllm is not a valid model recognized by the library
442
+ try:
443
+ rk_llm.init(params, my_python_callback)
444
+ print("LLM Initialized.")
445
+ except RuntimeError as e:
446
+ print(f"Error during LLM initialization: {e}")
447
+ print("This is expected if 'dummy_model.rkllm' is not a valid model.")
448
+ print("Replace 'dummy_model.rkllm' with a real model path to test further.")
449
+ exit()
450
+
451
+
452
+ # --- Prepare input ---
453
+ print("Preparing input...")
454
+ rk_input = RKLLMInput()
455
+ rk_input.input_type = RKLLMInputType.RKLLM_INPUT_PROMPT
456
+
457
+ prompt_text = "Translate the following English text to French: 'Hello, world!'"
458
+ c_prompt = prompt_text.encode('utf-8')
459
+ rk_input._union_data.prompt_input = c_prompt # Accessing union member directly
460
+
461
+ # --- Prepare inference parameters ---
462
+ print("Preparing inference parameters...")
463
+ infer_params = RKLLMInferParam()
464
+ infer_params.mode = RKLLMInferMode.RKLLM_INFER_GENERATE
465
+ infer_params.keep_history = 1 # True
466
+ # infer_params.lora_params = None # or set up RKLLMLoraParam if using LoRA
467
+ # infer_params.prompt_cache_params = None # or set up RKLLMPromptCacheParam
468
+
469
+ # --- Run inference ---
470
+ print(f"Running inference with prompt: '{prompt_text}'")
471
+ results_buffer.clear()
472
+ try:
473
+ rk_llm.run(rk_input, infer_params) # Userdata is None by default
474
+ print("\n--- Full Response ---")
475
+ print("".join(results_buffer))
476
+ print("---------------------\n")
477
+ except RuntimeError as e:
478
+ print(f"Error during LLM run: {e}")
479
+
480
+
481
+ # --- Example: Set chat template (if model supports it) ---
482
+ # print("Setting chat template...")
483
+ # try:
484
+ # rk_llm.set_chat_template("You are a helpful assistant.", "<user>: ", "<assistant>: ")
485
+ # print("Chat template set.")
486
+ # except RuntimeError as e:
487
+ # print(f"Error setting chat template: {e}")
488
+
489
+ # --- Example: Clear KV Cache ---
490
+ # print("Clearing KV cache (keeping system prompt if any)...")
491
+ # try:
492
+ # rk_llm.clear_kv_cache(keep_system_prompt=True)
493
+ # print("KV cache cleared.")
494
+ # except RuntimeError as e:
495
+ # print(f"Error clearing KV cache: {e}")
496
+
497
+ except OSError as e:
498
+ print(f"OSError: {e}. Could not load the RKLLM library.")
499
+ print("Please ensure 'librkllmrt.so' is in your LD_LIBRARY_PATH or provide the full path.")
500
+ except Exception as e:
501
+ print(f"An unexpected error occurred: {e}")
502
+ finally:
503
+ if 'rk_llm' in locals() and rk_llm.llm_handle and rk_llm.llm_handle.value:
504
+ print("Destroying LLM instance...")
505
+ rk_llm.destroy()
506
+ print("LLM instance destroyed.")
507
+ if os.path.exists(model_file) and model_file == "dummy_model.rkllm":
508
+ os.remove(model_file) # Clean up dummy file
509
+
510
+ print("Example finished.")
run_rknn.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import faulthandler
2
+ faulthandler.enable()
3
+ import os
4
+ import time
5
+ import numpy as np
6
+ from rkllm_binding import *
7
+ import ztu_somemodelruntime_rknnlite2 as ort
8
+ import signal
9
+ import cv2
10
+ import ctypes
11
+
12
+ # --- Configuration ---
13
+ # These paths should point to the directory containing all model files
14
+ # or be absolute paths.
15
+ MODEL_DIR = "." # Assuming models are in the current directory or provide a specific path
16
+ LLM_MODEL_NAME = "qwen_f16.rkllm"
17
+ VISION_ENCODER_ONNX_NAME = "fastvithd.onnx"
18
+ MM_PROJECTOR_ONNX_NAME = "mm_projector.onnx"
19
+ PREPROCESSOR_CONFIG_NAME = "preprocessor_config.json" # Generated by export_onnx.py
20
+
21
+ LLM_MODEL_PATH = os.path.join(MODEL_DIR, LLM_MODEL_NAME)
22
+ VISION_ENCODER_PATH = os.path.join(MODEL_DIR, VISION_ENCODER_ONNX_NAME)
23
+ MM_PROJECTOR_PATH = os.path.join(MODEL_DIR, MM_PROJECTOR_ONNX_NAME)
24
+ PREPROCESSOR_CONFIG_PATH = os.path.join(MODEL_DIR, PREPROCESSOR_CONFIG_NAME)
25
+
26
+ IMAGE_PATH = "test.jpg" # Replace with your test image
27
+ # user_prompt = "Describe this image in detail."
28
+ user_prompt = "仔细描述一下这张图片。"
29
+
30
+ # Global RKLLMRuntime instance
31
+ rk_runtime = None
32
+
33
+ # Exit on Ctrl-C
34
+ def signal_handler(signal, frame):
35
+ print("Ctrl-C pressed, exiting...")
36
+ global rk_runtime
37
+ if rk_runtime:
38
+ try:
39
+ print("Attempting to abort RKLLM task...")
40
+ rk_runtime.abort()
41
+ print("RKLLM task aborted.")
42
+ except RuntimeError as e:
43
+ print(f"Note: RKLLM abort failed or task was not running: {e}")
44
+ except Exception as e:
45
+ print(f"Unexpected error during RKLLM abort in signal handler: {e}")
46
+
47
+ try:
48
+ print("Attempting to destroy RKLLM instance...")
49
+ rk_runtime.destroy()
50
+ print("RKLLM instance destroyed via signal handler.")
51
+ except RuntimeError as e:
52
+ print(f"Error during RKLLM destroy in signal handler: {e}")
53
+ except Exception as e: # Catch any other unexpected errors
54
+ print(f"Unexpected error during RKLLM destroy in signal handler: {e}")
55
+ exit(0)
56
+
57
+ signal.signal(signal.SIGINT, signal_handler)
58
+
59
+ # Set RKLLM log level if desired
60
+ os.environ["RKLLM_LOG_LEVEL"] = "1"
61
+
62
+ inference_count = 0
63
+ inference_start_time = 0
64
+ first_token_received = False
65
+
66
+ def result_callback(result_ptr, userdata, state_enum):
67
+ global inference_start_time, inference_count, first_token_received
68
+ state = LLMCallState(state_enum) # Convert int to enum
69
+ if result_ptr is None:
70
+ return
71
+ result = result_ptr.contents # Dereference the pointer
72
+
73
+ if state == LLMCallState.RKLLM_RUN_NORMAL:
74
+ if not first_token_received:
75
+ first_token_time = time.time()
76
+ print(f"\nTime to first token: {first_token_time - inference_start_time:.2f} seconds")
77
+ first_token_received = True
78
+
79
+ current_text = ""
80
+ if result.text: # Check if char_p is not NULL
81
+ current_text = result.text.decode('utf-8', errors='ignore')
82
+ print(current_text, end="", flush=True)
83
+ inference_count += 1
84
+ elif state == LLMCallState.RKLLM_RUN_FINISH:
85
+ print("\n\n(finished)")
86
+ elif state == LLMCallState.RKLLM_RUN_ERROR:
87
+ print("\nError occurred during LLM call")
88
+ # Add other states if needed, e.g., RKLLM_RUN_WAITING
89
+
90
+ def load_and_preprocess_image(image_path, config_path):
91
+ img_size = 1024
92
+ image_mean = [0.0, 0.0, 0.0]
93
+ image_std = [1.0, 1.0, 1.0]
94
+
95
+ print(f"Target image size from config: {img_size}x{img_size}")
96
+ print(f"Using image_mean: {image_mean}, image_std: {image_std}")
97
+
98
+ img = cv2.imread(image_path)
99
+ if img is None:
100
+ raise FileNotFoundError(f"Image not found: {image_path}")
101
+
102
+ # 计算缩放比例,保持宽高比
103
+ h, w = img.shape[:2]
104
+ scale = min(img_size / w, img_size / h)
105
+ new_w, new_h = int(w * scale), int(h * scale)
106
+
107
+ # 保持比例缩放
108
+ img_resized = cv2.resize(img, (new_w, new_h))
109
+
110
+ # 创建目标大小的黑色背景
111
+ img_padded = np.zeros((img_size, img_size, 3), dtype=np.uint8)
112
+
113
+ # 将缩放后的图像放在中心位置
114
+ y_offset = (img_size - new_h) // 2
115
+ x_offset = (img_size - new_w) // 2
116
+ img_padded[y_offset:y_offset+new_h, x_offset:x_offset+new_w] = img_resized
117
+
118
+ img_rgb = cv2.cvtColor(img_padded, cv2.COLOR_BGR2RGB)
119
+ img_fp32 = img_rgb.astype(np.float32)
120
+
121
+ # Normalize
122
+ img_normalized = (img_fp32 / 255.0 - image_mean) / image_std
123
+
124
+ # Transpose to NCHW format
125
+ img_nchw = img_normalized.transpose(2, 0, 1) # HWC to CHW
126
+ img_batch = img_nchw[np.newaxis, :, :, :] # Add batch dimension -> NCHW
127
+
128
+ return img_batch.astype(np.float32), img_size
129
+ def main():
130
+ global rk_runtime, inference_start_time, inference_count, first_token_received, user_prompt
131
+
132
+ # --- 1. Initialize ONNX Runtime for Vision Models ---
133
+ print("Loading ONNX vision encoder model...")
134
+ vision_session = ort.InferenceSession(VISION_ENCODER_PATH)
135
+ vision_input_name = vision_session.get_inputs()[0].name
136
+ vision_output_name = vision_session.get_outputs()[0].name
137
+ print(f"ONNX vision encoder loaded. Input: '{vision_input_name}', Output: '{vision_output_name}'")
138
+
139
+ print("Loading ONNX mm_projector model...")
140
+ mm_projector_session = ort.InferenceSession(MM_PROJECTOR_PATH)
141
+ mm_projector_input_name = mm_projector_session.get_inputs()[0].name
142
+ mm_projector_output_name = mm_projector_session.get_outputs()[0].name
143
+ print(f"ONNX mm_projector loaded. Input: '{mm_projector_input_name}', Output: '{mm_projector_output_name}'")
144
+
145
+ # --- 2. Initialize RKLLM ---
146
+ print("Initializing RKLLM...")
147
+ rk_runtime = RKLLMRuntime()
148
+
149
+ param = rk_runtime.create_default_param()
150
+ param.model_path = LLM_MODEL_PATH.encode('utf-8')
151
+ param.img_start = "<image>".encode('utf-8')
152
+ param.img_end = "".encode('utf-8')
153
+ param.img_content = "<unk>".encode('utf-8')
154
+
155
+ extend_param = RKLLMExtendParam()
156
+ extend_param.base_domain_id = 1
157
+ extend_param.embed_flash = 1
158
+ extend_param.enabled_cpus_num = 8
159
+ extend_param.enabled_cpus_mask = 0xffffffff
160
+ param.extend_param = extend_param
161
+
162
+ model_size_llm = os.path.getsize(LLM_MODEL_PATH)
163
+ print(f"Start loading language model (size: {model_size_llm / 1024 / 1024:.2f} MB)")
164
+ start_time_llm_load = time.time()
165
+
166
+ try:
167
+ rk_runtime.init(param, result_callback)
168
+ except RuntimeError as e:
169
+ print(f"RKLLM init failed: {e}")
170
+ if rk_runtime:
171
+ try:
172
+ rk_runtime.destroy()
173
+ except Exception as e_destroy:
174
+ print(f"Error destroying RKLLM after init failure: {e_destroy}")
175
+ return
176
+
177
+ end_time_llm_load = time.time()
178
+ print(f"Language model loaded in {end_time_llm_load - start_time_llm_load:.2f} seconds")
179
+
180
+ # --- 3. Load and Preprocess Image ---
181
+ print(f"Loading and preprocessing image: {IMAGE_PATH}")
182
+ preprocessed_image, original_img_dim = load_and_preprocess_image(IMAGE_PATH, PREPROCESSOR_CONFIG_PATH)
183
+ print(f"Input image shape for ONNX vision model: {preprocessed_image.shape}")
184
+
185
+ # --- 4. Vision Encoder Inference (ONNX) ---
186
+ start_time_vision = time.time()
187
+ vision_outputs = vision_session.run([vision_output_name], {vision_input_name: preprocessed_image})
188
+ image_features_from_vision = vision_outputs[0]
189
+ end_time_vision = time.time()
190
+ print(f"ONNX Vision encoder inference time: {end_time_vision - start_time_vision:.2f} seconds")
191
+ print(f"Vision encoder output shape: {image_features_from_vision.shape}")
192
+
193
+ # --- 5. MM Projector Inference (ONNX) ---
194
+ start_time_projector = time.time()
195
+ projector_outputs = mm_projector_session.run([mm_projector_output_name], {mm_projector_input_name: image_features_from_vision})
196
+ projected_image_embeddings_np = projector_outputs[0]
197
+ end_time_projector = time.time()
198
+ print(f"ONNX MM projector inference time: {end_time_projector - start_time_projector:.2f} seconds")
199
+ print(f"Projected image embeddings shape: {projected_image_embeddings_np.shape}")
200
+
201
+ # Ensure C-contiguous and float32 for ctypes
202
+ projected_image_embeddings_np = np.ascontiguousarray(projected_image_embeddings_np, dtype=np.float32)
203
+
204
+ # --- 6. Prepare Prompt and RKLLMInput ---
205
+ # The prompt should contain the <image> placeholder where the image features will be inserted.
206
+ # prompt = f"""<|im_start|>system
207
+ # You are a helpful assistant.<|im_end|>
208
+ # <|im_start|>user
209
+ # {param.img_start.decode()}
210
+ # {user_prompt}<|im_end|>
211
+ # <|im_start|>assistant
212
+ # """
213
+
214
+ # RKLLM now loads its own chat template, so we don't need to include that.
215
+ prompt = f"""{param.img_start.decode()}
216
+ {user_prompt}"""
217
+
218
+ print(f"\nUsing prompt:\n{prompt}")
219
+
220
+ rkllm_input = RKLLMInput()
221
+ rkllm_input.input_type = RKLLMInputType.RKLLM_INPUT_MULTIMODAL
222
+
223
+ multimodal_payload = RKLLMMultiModelInput()
224
+ multimodal_payload.prompt = prompt.encode('utf-8')
225
+
226
+ # projected_image_embeddings_np has shape (1, num_tokens, hidden_dim)
227
+ num_image_tokens = projected_image_embeddings_np.shape[1]
228
+ # The C API expects a flat pointer to the embedding data.
229
+ embedding_data_flat = projected_image_embeddings_np.flatten()
230
+
231
+ multimodal_payload.image_embed = embedding_data_flat.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
232
+ multimodal_payload.n_image_tokens = num_image_tokens
233
+ multimodal_payload.n_image = 1 # Number of images processed
234
+ multimodal_payload.image_width = original_img_dim # Width of the (resized before processing) image
235
+ multimodal_payload.image_height = original_img_dim # Height of the (resized before processing) image
236
+
237
+ rkllm_input._union_data.multimodal_input = multimodal_payload
238
+
239
+ # --- 7. Create Inference Parameters ---
240
+ infer_param = RKLLMInferParam()
241
+ infer_param.mode = RKLLMInferMode.RKLLM_INFER_GENERATE.value # Ensure this is an int for C API
242
+ # infer_param.keep_history = 1 # Or 0, default is usually 0 (false) in create_default_param or C struct.
243
+ # Check rkllm.h or binding for default if not setting explicitly.
244
+ # RKLLMInferParam from binding has keep_history as c_int.
245
+
246
+ # --- 8. Run RKLLM Inference ---
247
+ print("Starting RKLLM inference...")
248
+ inference_start_time = time.time()
249
+ inference_count = 0
250
+ first_token_received = False
251
+
252
+ try:
253
+ # The RKLLMRuntime.run method takes input and infer_param objects directly.
254
+ rk_runtime.run(rkllm_input, infer_param, None) # Userdata is None
255
+ except RuntimeError as e:
256
+ print(f"RKLLM run failed: {e}")
257
+
258
+ # --- 9. Clean up ---
259
+ # Normal cleanup if not interrupted by Ctrl-C.
260
+ # The signal handler also attempts to destroy the instance.
261
+ if rk_runtime and rk_runtime.llm_handle and rk_runtime.llm_handle.value:
262
+ try:
263
+ rk_runtime.destroy()
264
+ print("RKLLM instance destroyed at script end.")
265
+ except RuntimeError as e:
266
+ print(f"Error during RKLLM destroy at script end: {e}")
267
+ except Exception as e:
268
+ print(f"Unexpected error during RKLLM destroy at script end: {e}")
269
+
270
+ print("Script finished.")
271
+
272
+ if __name__ == "__main__":
273
+ # rk_runtime (global) will be initialized inside main()
274
+ main()
test.jpg ADDED

Git LFS Details

  • SHA256: 6e0ba9e46ff16d0583aa286130a7275c35d05d9600c8cea8a4df4b2e0c46c27b
  • Pointer size: 131 Bytes
  • Size of remote file: 309 kB
ztu_somemodelruntime_rknnlite2.py ADDED
@@ -0,0 +1,569 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 模块级常量和函数
2
+ from rknnlite.api import RKNNLite
3
+ import numpy as np
4
+ import os
5
+ import warnings
6
+ import logging
7
+ from typing import List, Dict, Union, Optional
8
+
9
+ try:
10
+ import onnxruntime as ort
11
+ HAS_ORT = True
12
+ except ImportError:
13
+ HAS_ORT = False
14
+ warnings.warn("onnxruntime未安装,只能使用RKNN后端", ImportWarning)
15
+
16
+ # 配置日志
17
+ logger = logging.getLogger("somemodelruntime_rknnlite2")
18
+ logger.setLevel(logging.ERROR) # 默认只输出错误信息
19
+ if not logger.handlers:
20
+ handler = logging.StreamHandler()
21
+ handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
22
+ logger.addHandler(handler)
23
+
24
+ # ONNX Runtime日志级别到Python logging级别的映射
25
+ _LOGGING_LEVEL_MAP = {
26
+ 0: logging.DEBUG, # Verbose
27
+ 1: logging.INFO, # Info
28
+ 2: logging.WARNING, # Warning
29
+ 3: logging.ERROR, # Error
30
+ 4: logging.CRITICAL # Fatal
31
+ }
32
+
33
+ # 检查环境变量中的日志级别设置
34
+ try:
35
+ env_log_level = os.getenv('ZTU_MODELRT_RKNNL2_LOG_LEVEL')
36
+ if env_log_level is not None:
37
+ log_level = int(env_log_level)
38
+ if log_level in _LOGGING_LEVEL_MAP:
39
+ logger.setLevel(_LOGGING_LEVEL_MAP[log_level])
40
+ logger.info(f"从环境变量设置日志级别: {log_level}")
41
+ else:
42
+ logger.warning(f"环境变量ZTU_MODELRT_RKNNL2_LOG_LEVEL的值无效: {log_level}, 应该是0-4之间的整数")
43
+ except ValueError:
44
+ logger.warning(f"环境变量ZTU_MODELRT_RKNNL2_LOG_LEVEL的值无效: {env_log_level}, 应该是0-4之间的整数")
45
+
46
+
47
+ def set_default_logger_severity(level: int) -> None:
48
+ """
49
+ Sets the default logging severity. 0:Verbose, 1:Info, 2:Warning, 3:Error, 4:Fatal
50
+
51
+ Args:
52
+ level: 日志级别(0-4)
53
+ """
54
+ if level not in _LOGGING_LEVEL_MAP:
55
+ raise ValueError(f"无效的日志级别: {level}, 应该是0-4之间的整数")
56
+ logger.setLevel(_LOGGING_LEVEL_MAP[level])
57
+
58
+ def set_default_logger_verbosity(level: int) -> None:
59
+ """
60
+ Sets the default logging verbosity level. To activate the verbose log,
61
+ you need to set the default logging severity to 0:Verbose level.
62
+
63
+ Args:
64
+ level: 日志级别(0-4)
65
+ """
66
+ set_default_logger_severity(level)
67
+
68
+ # RKNN tensor type到numpy dtype的映射
69
+ RKNN_DTYPE_MAP = {
70
+ 0: np.float32, # RKNN_TENSOR_FLOAT32
71
+ 1: np.float16, # RKNN_TENSOR_FLOAT16
72
+ 2: np.int8, # RKNN_TENSOR_INT8
73
+ 3: np.uint8, # RKNN_TENSOR_UINT8
74
+ 4: np.int16, # RKNN_TENSOR_INT16
75
+ 5: np.uint16, # RKNN_TENSOR_UINT16
76
+ 6: np.int32, # RKNN_TENSOR_INT32
77
+ 7: np.uint32, # RKNN_TENSOR_UINT32
78
+ 8: np.int64, # RKNN_TENSOR_INT64
79
+ 9: bool, # RKNN_TENSOR_BOOL
80
+ 10: np.int8, # RKNN_TENSOR_INT4 (用int8表示)
81
+ }
82
+
83
+ def get_available_providers() -> List[str]:
84
+ """
85
+ 获取可用的设备提供者列表(为保持接口兼容性的占位函数)
86
+
87
+ Returns:
88
+ list: 可用的设备提供者列表,总是返回["CPUExecutionProvider", "somemodelruntime_rknnlite2_ExecutionProvider"]
89
+ """
90
+ return ["CPUExecutionProvider", "somemodelruntime_rknnlite2_ExecutionProvider"]
91
+
92
+
93
+ def get_device() -> str:
94
+ """
95
+ 获取当前设备
96
+
97
+ Returns:
98
+ str: 当前设备
99
+ """
100
+ return "RKNN2"
101
+
102
+ def get_version_info() -> Dict[str, str]:
103
+ """
104
+ 获取版本信息
105
+
106
+ Returns:
107
+ dict: 包含API和驱动版本信息的字典
108
+ """
109
+ runtime = RKNNLite()
110
+ version = runtime.get_sdk_version()
111
+ return {
112
+ "api_version": version.split('\n')[2].split(': ')[1].split(' ')[0],
113
+ "driver_version": version.split('\n')[3].split(': ')[1]
114
+ }
115
+
116
+ class IOTensor:
117
+ """输入/输出张量的信息封装类"""
118
+ def __init__(self, name, shape, type=None):
119
+ self.name = name.decode() if isinstance(name, bytes) else name
120
+ self.shape = shape
121
+ self.type = type
122
+
123
+ def __str__(self):
124
+ return f"IOTensor(name='{self.name}', shape={self.shape}, type={self.type})"
125
+
126
+ class SessionOptions:
127
+ """会话选项类"""
128
+ def __init__(self):
129
+ self.enable_profiling = False # 是否使用性能分析
130
+ self.intra_op_num_threads = 1 # 设置RKNN的线程数, 对应rknn的core_mask
131
+ self.log_severity_level = -1 # 另一个设置日志级别的参数
132
+ self.log_verbosity_level = -1 # 另一个设置日志级别的参数
133
+
134
+
135
+ class InferenceSession:
136
+ """
137
+ RKNNLite运行时封装类,API风格类似ONNX Runtime
138
+ """
139
+
140
+ def __new__(cls, model_path: str, sess_options: Optional[SessionOptions] = None, **kwargs):
141
+ processed_path = InferenceSession._process_model_path(model_path, sess_options)
142
+ if isinstance(processed_path, str) and processed_path.lower().endswith('.onnx'):
143
+ logger.info("使用ONNX Runtime加载模型")
144
+ if not HAS_ORT:
145
+ raise RuntimeError("未安装onnxruntime,无法加载ONNX模型")
146
+ return ort.InferenceSession(processed_path, sess_options=sess_options, **kwargs)
147
+ else:
148
+ # 如果不是 ONNX 模型,则调用父类的 __new__ 创建 InferenceSession 实例
149
+ instance = super().__new__(cls)
150
+ # 保存处理后的路径
151
+ instance._processed_path = processed_path
152
+ return instance
153
+
154
+ def __init__(self, model_path: str, sess_options: Optional[SessionOptions] = None, **kwargs):
155
+ """
156
+ 初始化运行时并加载模型
157
+
158
+ Args:
159
+ model_path: 模型文件路径(.rknn或.onnx)
160
+ sess_options: 会话选项
161
+ **kwargs: 其他初始化参数
162
+ """
163
+ options = sess_options or SessionOptions()
164
+
165
+ # 只在未设置环境变量时使用SessionOptions中的日志级别
166
+ if os.getenv('ZTU_MODELRT_RKNNL2_LOG_LEVEL') is None:
167
+ if options.log_severity_level != -1:
168
+ set_default_logger_severity(options.log_severity_level)
169
+ if options.log_verbosity_level != -1:
170
+ set_default_logger_verbosity(options.log_verbosity_level)
171
+
172
+ # 使用__new__中处理好的路径
173
+ model_path = getattr(self, '_processed_path', model_path)
174
+ if isinstance(model_path, str) and model_path.lower().endswith('.onnx'):
175
+ # 避免重复加载 ONNX 模型
176
+ return
177
+
178
+ # ... 现有的 RKNN 模型加载和初始化代码 ...
179
+ self.model_path = model_path
180
+ if not os.path.exists(self.model_path):
181
+ logger.error(f"模型文件不存在: {self.model_path}")
182
+ raise FileNotFoundError(f"模型文件不存在: {self.model_path}")
183
+
184
+ self.runtime = RKNNLite(verbose=options.enable_profiling)
185
+
186
+ logger.debug(f"正在加载模型: {self.model_path}")
187
+ ret = self.runtime.load_rknn(self.model_path)
188
+ if ret != 0:
189
+ logger.error(f"加载RKNN模型失败: {self.model_path}")
190
+ raise RuntimeError(f'加载RKNN模型失败: {self.model_path}')
191
+ logger.debug("模型加载成功")
192
+
193
+
194
+ if options.intra_op_num_threads == 1:
195
+ core_mask = RKNNLite.NPU_CORE_AUTO
196
+ elif options.intra_op_num_threads == 2:
197
+ core_mask = RKNNLite.NPU_CORE_0_1
198
+ elif options.intra_op_num_threads == 3:
199
+ core_mask = RKNNLite.NPU_CORE_0_1_2
200
+ else:
201
+ raise ValueError(f"intra_op_num_threads的值无效: {options.intra_op_num_threads}, 只能是1,2或3")
202
+
203
+ logger.debug("正在初始化运行时环境")
204
+ ret = self.runtime.init_runtime(core_mask=core_mask)
205
+ if ret != 0:
206
+ logger.error("初始化运行时环境失败")
207
+ raise RuntimeError('初始化运行时环境失败')
208
+ logger.debug("运行时环境初始化成功")
209
+
210
+ self._init_io_info()
211
+ self.options = options
212
+
213
+ def get_performance_info(self) -> Dict[str, float]:
214
+ """
215
+ 获取性能信息
216
+
217
+ Returns:
218
+ dict: 包含性能信息的字典
219
+ """
220
+ if not self.options.perf_debug:
221
+ raise RuntimeError("性能分析未启用,请在SessionOptions中设置perf_debug=True")
222
+
223
+ perf = self.runtime.rknn_runtime.get_run_perf()
224
+ return {
225
+ "run_duration": perf.run_duration / 1000.0 # 转换为毫秒
226
+ }
227
+
228
+ def set_core_mask(self, core_mask: int) -> None:
229
+ """
230
+ 设置NPU核心使用模式
231
+
232
+ Args:
233
+ core_mask: NPU核心掩码,使用NPU_CORE_*常量
234
+ """
235
+ ret = self.runtime.rknn_runtime.set_core_mask(core_mask)
236
+ if ret != 0:
237
+ raise RuntimeError("设置NPU核心模式失败")
238
+
239
+ @staticmethod
240
+ def _process_model_path(model_path, sess_options):
241
+ """
242
+ 处理模型路径,支持.onnx和.rknn文件
243
+
244
+ Args:
245
+ model_path: 模型文件路径
246
+ """
247
+ # 如果是ONNX文件,检查是否需要自动加载RKNN
248
+ if model_path.lower().endswith('.onnx'):
249
+ logger.info("检测到ONNX模型文件")
250
+
251
+ # 获取需要跳过自动加载的模型列表
252
+ skip_models = os.getenv('ZTU_MODELRT_RKNNL2_SKIP', '').strip()
253
+ if skip_models:
254
+ skip_list = [m.strip() for m in skip_models.split(',')]
255
+ # 获取模型文件名(不含路径)用于匹配
256
+ model_name = os.path.basename(model_path)
257
+ if model_name.lower() in [m.lower() for m in skip_list]:
258
+ logger.info(f"模型{model_name}在跳过列表中,将使用ONNX Runtime")
259
+ return model_path
260
+
261
+ # 构造RKNN文件路径
262
+ rknn_path = os.path.splitext(model_path)[0] + '.rknn'
263
+ if os.path.exists(rknn_path):
264
+ logger.info(f"找到对应的RKNN模型,将使用RKNN: {rknn_path}")
265
+ return rknn_path
266
+ else:
267
+ logger.info("未找到对应的RKNN模型,将使用ONNX Runtime")
268
+ return model_path
269
+
270
+ return model_path
271
+
272
+ def _convert_nhwc_to_nchw(self, shape):
273
+ """将NHWC格式的shape转换为NCHW格式"""
274
+ if len(shape) == 4:
275
+ # NHWC -> NCHW
276
+ n, h, w, c = shape
277
+ return [n, c, h, w]
278
+ return shape
279
+
280
+ def _init_io_info(self):
281
+ """初始化模型的输入输出信息"""
282
+ runtime = self.runtime.rknn_runtime
283
+
284
+ # 获取输入输出数量
285
+ n_input, n_output = runtime.get_in_out_num()
286
+
287
+ # 获取输入信息
288
+ self.input_tensors = []
289
+ for i in range(n_input):
290
+ attr = runtime.get_tensor_attr(i)
291
+ shape = [attr.dims[j] for j in range(attr.n_dims)]
292
+ # 对四维输入进行NHWC到NCHW的转换
293
+ shape = self._convert_nhwc_to_nchw(shape)
294
+ # 获取dtype
295
+ dtype = RKNN_DTYPE_MAP.get(attr.type, None)
296
+ tensor = IOTensor(attr.name, shape, dtype)
297
+ self.input_tensors.append(tensor)
298
+
299
+ # 获取输出信息
300
+ self.output_tensors = []
301
+ for i in range(n_output):
302
+ attr = runtime.get_tensor_attr(i, is_output=True)
303
+ shape = runtime.get_output_shape(i)
304
+ # 获取dtype
305
+ dtype = RKNN_DTYPE_MAP.get(attr.type, None)
306
+ tensor = IOTensor(attr.name, shape, dtype)
307
+ self.output_tensors.append(tensor)
308
+
309
+ def get_inputs(self):
310
+ """
311
+ 获取模型输入信息
312
+
313
+ Returns:
314
+ list: 包含输入信息的列表
315
+ """
316
+ return self.input_tensors
317
+
318
+ def get_outputs(self):
319
+ """
320
+ 获取模型输出信息
321
+
322
+ Returns:
323
+ list: 包含输出信息的列表
324
+ """
325
+ return self.output_tensors
326
+
327
+ def run(self, output_names=None, input_feed=None, data_format="nchw", **kwargs):
328
+ """
329
+ 执行模型推理
330
+
331
+ Args:
332
+ output_names: 输出节点名称列表,指定需要返回哪些输出
333
+ input_feed: 输入数据字典或列表
334
+ data_format: 输入数据格式,"nchw"或"nhwc"
335
+ **kwargs: 其他运行时参数
336
+
337
+ Returns:
338
+ list: 模型输出结果列表,如果指定了output_names则只返回指定的输出
339
+ """
340
+ if input_feed is None:
341
+ logger.error("input_feed不能为None")
342
+ raise ValueError("input_feed不能为None")
343
+
344
+ # 准备输入数据
345
+ if isinstance(input_feed, dict):
346
+ # 如果是字典,按照模型输入顺序排列
347
+ inputs = []
348
+ input_map = {tensor.name: i for i, tensor in enumerate(self.input_tensors)}
349
+ for tensor in self.input_tensors:
350
+ if tensor.name not in input_feed:
351
+ raise ValueError(f"缺少输入: {tensor.name}")
352
+ inputs.append(input_feed[tensor.name])
353
+ elif isinstance(input_feed, (list, tuple)):
354
+ # 如果是列表,确保长度匹配
355
+ if len(input_feed) != len(self.input_tensors):
356
+ raise ValueError(f"输入数量不匹配: 期望{len(self.input_tensors)}, 实际{len(input_feed)}")
357
+ inputs = list(input_feed)
358
+ else:
359
+ logger.error("input_feed必须是字典或列表类型")
360
+ raise ValueError("input_feed必须是字典或列表类型")
361
+
362
+ # 执行推理
363
+ try:
364
+ logger.debug("开始执行推理")
365
+ all_outputs = self.runtime.inference(inputs=inputs, data_format=data_format)
366
+
367
+ # 如果没有指定output_names,返回所有输出
368
+ if output_names is None:
369
+ return all_outputs
370
+
371
+ # 获取指定的输出
372
+ output_map = {tensor.name: i for i, tensor in enumerate(self.output_tensors)}
373
+ selected_outputs = []
374
+ for name in output_names:
375
+ if name not in output_map:
376
+ raise ValueError(f"未找到输出节点: {name}")
377
+ selected_outputs.append(all_outputs[output_map[name]])
378
+
379
+ return selected_outputs
380
+
381
+ except Exception as e:
382
+ logger.error(f"推理执行失败: {str(e)}")
383
+ raise RuntimeError(f"推理执行失败: {str(e)}")
384
+
385
+ def close(self):
386
+ """
387
+ 关闭会话,释放资源
388
+ """
389
+ if self.runtime is not None:
390
+ logger.info("正在释放运行时资源")
391
+ self.runtime.release()
392
+ self.runtime = None
393
+
394
+ def __enter__(self):
395
+ return self
396
+
397
+ def __exit__(self, exc_type, exc_val, exc_tb):
398
+ self.close()
399
+
400
+ def end_profiling(self) -> Optional[str]:
401
+ """
402
+ 结束性能分析的存根方法
403
+
404
+ Returns:
405
+ Optional[str]: None
406
+ """
407
+ warnings.warn("end_profiling()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
408
+ return None
409
+
410
+ def get_profiling_start_time_ns(self) -> int:
411
+ """
412
+ 获取性能分析开始时间的存根方法
413
+
414
+ Returns:
415
+ int: 0
416
+ """
417
+ warnings.warn("get_profiling_start_time_ns()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
418
+ return 0
419
+
420
+ def get_modelmeta(self) -> Dict[str, str]:
421
+ """
422
+ 获取模型元数据的存根方法
423
+
424
+ Returns:
425
+ Dict[str, str]: 空字典
426
+ """
427
+ warnings.warn("get_modelmeta()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
428
+ return {}
429
+
430
+ def get_session_options(self) -> SessionOptions:
431
+ """
432
+ 获取会话选项
433
+
434
+ Returns:
435
+ SessionOptions: 当前会话选项
436
+ """
437
+ return self.options
438
+
439
+ def get_providers(self) -> List[str]:
440
+ """
441
+ 获取当前使用的providers的存根方法
442
+
443
+ Returns:
444
+ List[str]: ["CPUExecutionProvider"]
445
+ """
446
+ warnings.warn("get_providers()是存根方法,始终返回CPUExecutionProvider", RuntimeWarning, stacklevel=2)
447
+ return ["CPUExecutionProvider"]
448
+
449
+ def get_provider_options(self) -> Dict[str, Dict[str, str]]:
450
+ """
451
+ 获取provider选项的存根方法
452
+
453
+ Returns:
454
+ Dict[str, Dict[str, str]]: 空字典
455
+ """
456
+ warnings.warn("get_provider_options()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
457
+ return {}
458
+
459
+ def get_session_config(self) -> Dict[str, str]:
460
+ """
461
+ 获取会话配置的存根方法
462
+
463
+ Returns:
464
+ Dict[str, str]: 空字典
465
+ """
466
+ warnings.warn("get_session_config()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
467
+ return {}
468
+
469
+ def get_session_state(self) -> Dict[str, str]:
470
+ """
471
+ 获取会话状态的存根方法
472
+
473
+ Returns:
474
+ Dict[str, str]: 空字典
475
+ """
476
+ warnings.warn("get_session_state()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
477
+ return {}
478
+
479
+ def set_session_config(self, config: Dict[str, str]) -> None:
480
+ """
481
+ 设置会话配置的存根方法
482
+
483
+ Args:
484
+ config: 会话配置字典
485
+ """
486
+ warnings.warn("set_session_config()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
487
+
488
+ def get_memory_info(self) -> Dict[str, int]:
489
+ """
490
+ 获取内存使用信息的存根方法
491
+
492
+ Returns:
493
+ Dict[str, int]: 空字典
494
+ """
495
+ warnings.warn("get_memory_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
496
+ return {}
497
+
498
+ def set_memory_pattern(self, enable: bool) -> None:
499
+ """
500
+ 设置内存模式的存根方法
501
+
502
+ Args:
503
+ enable: 是否启用内存模式
504
+ """
505
+ warnings.warn("set_memory_pattern()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
506
+
507
+ def disable_memory_pattern(self) -> None:
508
+ """
509
+ 禁用内存模式的存根方法
510
+ """
511
+ warnings.warn("disable_memory_pattern()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
512
+
513
+ def get_optimization_level(self) -> int:
514
+ """
515
+ 获取优化级别的存根方法
516
+
517
+ Returns:
518
+ int: 0
519
+ """
520
+ warnings.warn("get_optimization_level()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
521
+ return 0
522
+
523
+ def set_optimization_level(self, level: int) -> None:
524
+ """
525
+ 设置优化级别的存根方法
526
+
527
+ Args:
528
+ level: 优化级别
529
+ """
530
+ warnings.warn("set_optimization_level()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
531
+
532
+ def get_model_metadata(self) -> Dict[str, str]:
533
+ """
534
+ 获取模型元数据的存根方法(与get_modelmeta不同的接口)
535
+
536
+ Returns:
537
+ Dict[str, str]: 空字典
538
+ """
539
+ warnings.warn("get_model_metadata()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
540
+ return {}
541
+
542
+ def get_model_path(self) -> str:
543
+ """
544
+ 获取模型路径
545
+
546
+ Returns:
547
+ str: 模型文件路径
548
+ """
549
+ return self.model_path
550
+
551
+ def get_input_type_info(self) -> List[Dict[str, str]]:
552
+ """
553
+ 获取输入类型信息的存根方法
554
+
555
+ Returns:
556
+ List[Dict[str, str]]: 空列表
557
+ """
558
+ warnings.warn("get_input_type_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
559
+ return []
560
+
561
+ def get_output_type_info(self) -> List[Dict[str, str]]:
562
+ """
563
+ 获取输出类型信息的存根方法
564
+
565
+ Returns:
566
+ List[Dict[str, str]]: 空列表
567
+ """
568
+ warnings.warn("get_output_type_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
569
+ return []