|
|
import faulthandler |
|
|
faulthandler.enable() |
|
|
import sys |
|
|
import os |
|
|
os.environ["RKLLM_LOG_LEVEL"] = "1" |
|
|
import ctypes |
|
|
import argparse |
|
|
import cv2 |
|
|
import numpy as np |
|
|
import ztu_somemodelruntime_rknnlite2 as ort |
|
|
from rkllm_binding import ( |
|
|
RKLLMRuntime, |
|
|
RKLLMParam, |
|
|
RKLLMInput, |
|
|
RKLLMInferParam, |
|
|
LLMCallState, |
|
|
RKLLMInputType, |
|
|
RKLLMInferMode, |
|
|
RKLLMResult |
|
|
) |
|
|
|
|
|
|
|
|
IMAGE_HEIGHT = 476 |
|
|
IMAGE_WIDTH = 476 |
|
|
|
|
|
def expand2square(img, background_color): |
|
|
""" |
|
|
Expand the image into a square and fill it with the specified background color. |
|
|
""" |
|
|
height, width, _ = img.shape |
|
|
if width == height: |
|
|
return img.copy() |
|
|
|
|
|
size = max(width, height) |
|
|
square_img = np.full((size, size, 3), background_color, dtype=np.uint8) |
|
|
|
|
|
x_offset = (size - width) // 2 |
|
|
y_offset = (size - height) // 2 |
|
|
|
|
|
square_img[y_offset:y_offset+height, x_offset:x_offset+width] = img |
|
|
return square_img |
|
|
|
|
|
def llm_callback(result_ptr, userdata_ptr, state_enum): |
|
|
""" |
|
|
Callback function to handle LLM results. |
|
|
""" |
|
|
state = LLMCallState(state_enum) |
|
|
result = result_ptr.contents |
|
|
|
|
|
if state == LLMCallState.RKLLM_RUN_NORMAL: |
|
|
if result.text: |
|
|
print(result.text.decode('utf-8', errors='ignore'), end='', flush=True) |
|
|
elif state == LLMCallState.RKLLM_RUN_FINISH: |
|
|
print("\n", flush=True) |
|
|
elif state == LLMCallState.RKLLM_RUN_ERROR: |
|
|
print("\nrun error", flush=True) |
|
|
|
|
|
return 0 |
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Run RKLLM visual language model inference based on the C++ example." |
|
|
) |
|
|
parser.add_argument("image_path", type=str, help="Path to the input image.") |
|
|
parser.add_argument("encoder_model_path", type=str, help="Path to the ONNX vision encoder model.") |
|
|
parser.add_argument("llm_model_path", type=str, help="Path to the .rkllm language model.") |
|
|
parser.add_argument("max_new_tokens", type=int, help="Maximum number of new tokens to generate.") |
|
|
parser.add_argument("max_context_len", type=int, help="Maximum context length.") |
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument("rknn_core_num", type=int, help="Sets the number of npu cores used in vision encoder.") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
print("Initializing ONNX Runtime for vision encoder...") |
|
|
try: |
|
|
sess_options = ort.SessionOptions() |
|
|
sess_options.intra_op_num_threads = args.rknn_core_num |
|
|
ort_session = ort.InferenceSession(args.encoder_model_path, sess_options=sess_options) |
|
|
except Exception as e: |
|
|
print(f"Failed to load ONNX model: {e}") |
|
|
sys.exit(1) |
|
|
print("Vision encoder loaded successfully.") |
|
|
|
|
|
input_name = ort_session.get_inputs()[0].name |
|
|
output_name = ort_session.get_outputs()[0].name |
|
|
print(f"ONNX Input: {input_name}, ONNX Output: {output_name}") |
|
|
|
|
|
|
|
|
print("Initializing RKLLM Runtime...") |
|
|
rk_llm = RKLLMRuntime() |
|
|
param = rk_llm.create_default_param() |
|
|
|
|
|
param.model_path = args.llm_model_path.encode('utf-8') |
|
|
param.top_k = 1 |
|
|
param.max_new_tokens = args.max_new_tokens |
|
|
param.max_context_len = args.max_context_len |
|
|
param.skip_special_token = True |
|
|
param.img_start = b"<|vision_start|>" |
|
|
param.img_end = b"<|vision_end|>" |
|
|
param.img_content = b"<|image_pad|>" |
|
|
param.extend_param.base_domain_id = 1 |
|
|
|
|
|
try: |
|
|
rk_llm.init(param, llm_callback) |
|
|
print("RKLLM initialized successfully.") |
|
|
except RuntimeError as e: |
|
|
print(f"RKLLM init failed: {e}") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
print("Preprocessing image...") |
|
|
img = cv2.imread(args.image_path) |
|
|
if img is None: |
|
|
print(f"Failed to read image from {args.image_path}") |
|
|
sys.exit(1) |
|
|
|
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
background_color = (127.5, 127.5, 127.5) |
|
|
square_img = expand2square(img, background_color) |
|
|
resized_img = cv2.resize(square_img, (IMAGE_WIDTH, IMAGE_HEIGHT), interpolation=cv2.INTER_LINEAR) |
|
|
|
|
|
|
|
|
input_tensor = resized_img.astype(np.float32) |
|
|
|
|
|
input_tensor = (input_tensor / 255.0 - np.array([0.48145466, 0.4578275, 0.40821073])) / np.array([0.26862954, 0.26130258, 0.27577711]) |
|
|
|
|
|
input_tensor = np.transpose(input_tensor, (2, 0, 1)) |
|
|
input_tensor = np.expand_dims(input_tensor, axis=0) |
|
|
|
|
|
|
|
|
print("Running vision encoder...") |
|
|
import time |
|
|
start_time = time.time() |
|
|
try: |
|
|
img_vec_output = ort_session.run([output_name], {input_name: input_tensor.astype(np.float32)})[0] |
|
|
elapsed_time = time.time() - start_time |
|
|
print(f"视觉编码器推理耗时: {elapsed_time:.4f} 秒") |
|
|
|
|
|
img_vec = img_vec_output.flatten().astype(np.float32) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Failed to run vision encoder inference: {e}") |
|
|
rk_llm.destroy() |
|
|
sys.exit(1) |
|
|
|
|
|
print("Image encoded successfully.") |
|
|
|
|
|
|
|
|
rkllm_infer_params = RKLLMInferParam() |
|
|
rkllm_infer_params.mode = RKLLMInferMode.RKLLM_INFER_GENERATE |
|
|
rkllm_infer_params.keep_history = 0 |
|
|
|
|
|
|
|
|
rk_llm.set_chat_template( |
|
|
system_prompt="<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n", |
|
|
prompt_prefix="<|im_start|>user\n", |
|
|
prompt_postfix="<|im_end|>\n<|im_start|>assistant\n" |
|
|
) |
|
|
|
|
|
pre_input = [ |
|
|
"Picture 1: <image> What is in the image?", |
|
|
"Picture 1: <image> 这张图片中有什么?" |
|
|
] |
|
|
print("\n**********************可输入以下问题对应序号获取回答/或自定义输入********************\n") |
|
|
for i, p in enumerate(pre_input): |
|
|
print(f"[{i}] {p}") |
|
|
print("\n*************************************************************************\n") |
|
|
|
|
|
try: |
|
|
while True: |
|
|
print("\nuser: ", end="", flush=True) |
|
|
input_str = sys.stdin.readline().strip() |
|
|
|
|
|
if not input_str: |
|
|
continue |
|
|
if input_str == "exit": |
|
|
break |
|
|
if input_str == "clear": |
|
|
try: |
|
|
rk_llm.clear_kv_cache(keep_system_prompt=True) |
|
|
print("KV cache cleared.") |
|
|
except RuntimeError as e: |
|
|
print(f"Failed to clear KV cache: {e}") |
|
|
continue |
|
|
|
|
|
try: |
|
|
idx = int(input_str) |
|
|
if 0 <= idx < len(pre_input): |
|
|
input_str = pre_input[idx] |
|
|
print(input_str) |
|
|
except (ValueError, IndexError): |
|
|
pass |
|
|
|
|
|
rkllm_input = RKLLMInput() |
|
|
rkllm_input.role = b"user" |
|
|
|
|
|
print("robot: ", end="", flush=True) |
|
|
|
|
|
if "<image>" in input_str: |
|
|
rkllm_input.input_type = RKLLMInputType.RKLLM_INPUT_MULTIMODAL |
|
|
|
|
|
|
|
|
rkllm_input.multimodal_input.prompt = input_str.encode('utf-8') |
|
|
rkllm_input.multimodal_input.image_embed = img_vec.ctypes.data_as(ctypes.POINTER(ctypes.c_float)) |
|
|
rkllm_input.multimodal_input.n_image_tokens = img_vec_output.shape[0] |
|
|
print("n_image_tokens: ", rkllm_input.multimodal_input.n_image_tokens) |
|
|
rkllm_input.multimodal_input.n_image = 1 |
|
|
rkllm_input.multimodal_input.image_height = IMAGE_HEIGHT |
|
|
rkllm_input.multimodal_input.image_width = IMAGE_WIDTH |
|
|
else: |
|
|
rkllm_input.input_type = RKLLMInputType.RKLLM_INPUT_PROMPT |
|
|
rkllm_input.prompt_input = input_str.encode('utf-8') |
|
|
|
|
|
try: |
|
|
rk_llm.run(rkllm_input, rkllm_infer_params) |
|
|
except RuntimeError as e: |
|
|
print(f"\nError during rkllm_run: {e}") |
|
|
|
|
|
except KeyboardInterrupt: |
|
|
print("\nExiting...") |
|
|
finally: |
|
|
print("Releasing resources...") |
|
|
rk_llm.destroy() |
|
|
print("RKLLM instance destroyed.") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|