import faulthandler faulthandler.enable() import os import time import numpy as np from rkllm_binding import * import ztu_somemodelruntime_rknnlite2 as ort import signal import cv2 import ctypes # --- Configuration --- # These paths should point to the directory containing all model files # or be absolute paths. MODEL_DIR = "." # Assuming models are in the current directory or provide a specific path LLM_MODEL_NAME = "qwen_f16.rkllm" VISION_ENCODER_ONNX_NAME = "fastvithd.onnx" MM_PROJECTOR_ONNX_NAME = "mm_projector.onnx" PREPROCESSOR_CONFIG_NAME = "preprocessor_config.json" # Generated by export_onnx.py LLM_MODEL_PATH = os.path.join(MODEL_DIR, LLM_MODEL_NAME) VISION_ENCODER_PATH = os.path.join(MODEL_DIR, VISION_ENCODER_ONNX_NAME) MM_PROJECTOR_PATH = os.path.join(MODEL_DIR, MM_PROJECTOR_ONNX_NAME) PREPROCESSOR_CONFIG_PATH = os.path.join(MODEL_DIR, PREPROCESSOR_CONFIG_NAME) IMAGE_PATH = "test.jpg" # Replace with your test image # user_prompt = "Describe this image in detail." user_prompt = "仔细描述一下这张图片。" # Global RKLLMRuntime instance rk_runtime = None # Exit on Ctrl-C def signal_handler(signal, frame): print("Ctrl-C pressed, exiting...") global rk_runtime if rk_runtime: try: print("Attempting to abort RKLLM task...") rk_runtime.abort() print("RKLLM task aborted.") except RuntimeError as e: print(f"Note: RKLLM abort failed or task was not running: {e}") except Exception as e: print(f"Unexpected error during RKLLM abort in signal handler: {e}") try: print("Attempting to destroy RKLLM instance...") rk_runtime.destroy() print("RKLLM instance destroyed via signal handler.") except RuntimeError as e: print(f"Error during RKLLM destroy in signal handler: {e}") except Exception as e: # Catch any other unexpected errors print(f"Unexpected error during RKLLM destroy in signal handler: {e}") exit(0) signal.signal(signal.SIGINT, signal_handler) # Set RKLLM log level if desired os.environ["RKLLM_LOG_LEVEL"] = "1" inference_count = 0 inference_start_time = 0 first_token_received = False def result_callback(result_ptr, userdata, state_enum): global inference_start_time, inference_count, first_token_received state = LLMCallState(state_enum) # Convert int to enum if result_ptr is None: return result = result_ptr.contents # Dereference the pointer if state == LLMCallState.RKLLM_RUN_NORMAL: if not first_token_received: first_token_time = time.time() print(f"\nTime to first token: {first_token_time - inference_start_time:.2f} seconds") first_token_received = True current_text = "" if result.text: # Check if char_p is not NULL current_text = result.text.decode('utf-8', errors='ignore') print(current_text, end="", flush=True) inference_count += 1 elif state == LLMCallState.RKLLM_RUN_FINISH: print("\n\n(finished)") elif state == LLMCallState.RKLLM_RUN_ERROR: print("\nError occurred during LLM call") # Add other states if needed, e.g., RKLLM_RUN_WAITING def load_and_preprocess_image(image_path, config_path): img_size = 1024 image_mean = [0.0, 0.0, 0.0] image_std = [1.0, 1.0, 1.0] print(f"Target image size from config: {img_size}x{img_size}") print(f"Using image_mean: {image_mean}, image_std: {image_std}") img = cv2.imread(image_path) if img is None: raise FileNotFoundError(f"Image not found: {image_path}") # 计算缩放比例,保持宽高比 h, w = img.shape[:2] scale = min(img_size / w, img_size / h) new_w, new_h = int(w * scale), int(h * scale) # 保持比例缩放 img_resized = cv2.resize(img, (new_w, new_h)) # 创建目标大小的黑色背景 img_padded = np.zeros((img_size, img_size, 3), dtype=np.uint8) # 将缩放后的图像放在中心位置 y_offset = (img_size - new_h) // 2 x_offset = (img_size - new_w) // 2 img_padded[y_offset:y_offset+new_h, x_offset:x_offset+new_w] = img_resized img_rgb = cv2.cvtColor(img_padded, cv2.COLOR_BGR2RGB) img_fp32 = img_rgb.astype(np.float32) # Normalize img_normalized = (img_fp32 / 255.0 - image_mean) / image_std # Transpose to NCHW format img_nchw = img_normalized.transpose(2, 0, 1) # HWC to CHW img_batch = img_nchw[np.newaxis, :, :, :] # Add batch dimension -> NCHW return img_batch.astype(np.float32), img_size def main(): global rk_runtime, inference_start_time, inference_count, first_token_received, user_prompt # --- 1. Initialize ONNX Runtime for Vision Models --- print("Loading ONNX vision encoder model...") vision_session = ort.InferenceSession(VISION_ENCODER_PATH) vision_input_name = vision_session.get_inputs()[0].name vision_output_name = vision_session.get_outputs()[0].name print(f"ONNX vision encoder loaded. Input: '{vision_input_name}', Output: '{vision_output_name}'") print("Loading ONNX mm_projector model...") mm_projector_session = ort.InferenceSession(MM_PROJECTOR_PATH) mm_projector_input_name = mm_projector_session.get_inputs()[0].name mm_projector_output_name = mm_projector_session.get_outputs()[0].name print(f"ONNX mm_projector loaded. Input: '{mm_projector_input_name}', Output: '{mm_projector_output_name}'") # --- 2. Initialize RKLLM --- print("Initializing RKLLM...") rk_runtime = RKLLMRuntime() param = rk_runtime.create_default_param() param.model_path = LLM_MODEL_PATH.encode('utf-8') param.img_start = "".encode('utf-8') param.img_end = "".encode('utf-8') param.img_content = "".encode('utf-8') extend_param = RKLLMExtendParam() extend_param.base_domain_id = 1 extend_param.embed_flash = 1 extend_param.enabled_cpus_num = 8 extend_param.enabled_cpus_mask = 0xffffffff param.extend_param = extend_param model_size_llm = os.path.getsize(LLM_MODEL_PATH) print(f"Start loading language model (size: {model_size_llm / 1024 / 1024:.2f} MB)") start_time_llm_load = time.time() try: rk_runtime.init(param, result_callback) except RuntimeError as e: print(f"RKLLM init failed: {e}") if rk_runtime: try: rk_runtime.destroy() except Exception as e_destroy: print(f"Error destroying RKLLM after init failure: {e_destroy}") return end_time_llm_load = time.time() print(f"Language model loaded in {end_time_llm_load - start_time_llm_load:.2f} seconds") # --- 3. Load and Preprocess Image --- print(f"Loading and preprocessing image: {IMAGE_PATH}") preprocessed_image, original_img_dim = load_and_preprocess_image(IMAGE_PATH, PREPROCESSOR_CONFIG_PATH) print(f"Input image shape for ONNX vision model: {preprocessed_image.shape}") # --- 4. Vision Encoder Inference (ONNX) --- start_time_vision = time.time() vision_outputs = vision_session.run([vision_output_name], {vision_input_name: preprocessed_image}) image_features_from_vision = vision_outputs[0] end_time_vision = time.time() print(f"ONNX Vision encoder inference time: {end_time_vision - start_time_vision:.2f} seconds") print(f"Vision encoder output shape: {image_features_from_vision.shape}") # --- 5. MM Projector Inference (ONNX) --- start_time_projector = time.time() projector_outputs = mm_projector_session.run([mm_projector_output_name], {mm_projector_input_name: image_features_from_vision}) projected_image_embeddings_np = projector_outputs[0] end_time_projector = time.time() print(f"ONNX MM projector inference time: {end_time_projector - start_time_projector:.2f} seconds") print(f"Projected image embeddings shape: {projected_image_embeddings_np.shape}") # Ensure C-contiguous and float32 for ctypes projected_image_embeddings_np = np.ascontiguousarray(projected_image_embeddings_np, dtype=np.float32) # --- 6. Prepare Prompt and RKLLMInput --- # The prompt should contain the placeholder where the image features will be inserted. # prompt = f"""<|im_start|>system # You are a helpful assistant.<|im_end|> # <|im_start|>user # {param.img_start.decode()} # {user_prompt}<|im_end|> # <|im_start|>assistant # """ # RKLLM now loads its own chat template, so we don't need to include that. prompt = f"""{param.img_start.decode()} {user_prompt}""" print(f"\nUsing prompt:\n{prompt}") rkllm_input = RKLLMInput() rkllm_input.input_type = RKLLMInputType.RKLLM_INPUT_MULTIMODAL multimodal_payload = RKLLMMultiModelInput() multimodal_payload.prompt = prompt.encode('utf-8') # projected_image_embeddings_np has shape (1, num_tokens, hidden_dim) num_image_tokens = projected_image_embeddings_np.shape[1] # The C API expects a flat pointer to the embedding data. embedding_data_flat = projected_image_embeddings_np.flatten() multimodal_payload.image_embed = embedding_data_flat.ctypes.data_as(ctypes.POINTER(ctypes.c_float)) multimodal_payload.n_image_tokens = num_image_tokens multimodal_payload.n_image = 1 # Number of images processed multimodal_payload.image_width = original_img_dim # Width of the (resized before processing) image multimodal_payload.image_height = original_img_dim # Height of the (resized before processing) image rkllm_input._union_data.multimodal_input = multimodal_payload # --- 7. Create Inference Parameters --- infer_param = RKLLMInferParam() infer_param.mode = RKLLMInferMode.RKLLM_INFER_GENERATE.value # Ensure this is an int for C API # infer_param.keep_history = 1 # Or 0, default is usually 0 (false) in create_default_param or C struct. # Check rkllm.h or binding for default if not setting explicitly. # RKLLMInferParam from binding has keep_history as c_int. # --- 8. Run RKLLM Inference --- print("Starting RKLLM inference...") inference_start_time = time.time() inference_count = 0 first_token_received = False try: # The RKLLMRuntime.run method takes input and infer_param objects directly. rk_runtime.run(rkllm_input, infer_param, None) # Userdata is None except RuntimeError as e: print(f"RKLLM run failed: {e}") # --- 9. Clean up --- # Normal cleanup if not interrupted by Ctrl-C. # The signal handler also attempts to destroy the instance. if rk_runtime and rk_runtime.llm_handle and rk_runtime.llm_handle.value: try: rk_runtime.destroy() print("RKLLM instance destroyed at script end.") except RuntimeError as e: print(f"Error during RKLLM destroy at script end: {e}") except Exception as e: print(f"Unexpected error during RKLLM destroy at script end: {e}") print("Script finished.") if __name__ == "__main__": # rk_runtime (global) will be initialized inside main() main()