import ctypes import enum import os # Define constants from the header CPU0 = (1 << 0) # 0x01 CPU1 = (1 << 1) # 0x02 CPU2 = (1 << 2) # 0x04 CPU3 = (1 << 3) # 0x08 CPU4 = (1 << 4) # 0x10 CPU5 = (1 << 5) # 0x20 CPU6 = (1 << 6) # 0x40 CPU7 = (1 << 7) # 0x80 # --- Enums --- class LLMCallState(enum.IntEnum): RKLLM_RUN_NORMAL = 0 RKLLM_RUN_WAITING = 1 RKLLM_RUN_FINISH = 2 RKLLM_RUN_ERROR = 3 class RKLLMInputType(enum.IntEnum): RKLLM_INPUT_PROMPT = 0 RKLLM_INPUT_TOKEN = 1 RKLLM_INPUT_EMBED = 2 RKLLM_INPUT_MULTIMODAL = 3 class RKLLMInferMode(enum.IntEnum): RKLLM_INFER_GENERATE = 0 RKLLM_INFER_GET_LAST_HIDDEN_LAYER = 1 RKLLM_INFER_GET_LOGITS = 2 # --- Structures --- class RKLLMExtendParam(ctypes.Structure): _fields_ = [ ("base_domain_id", ctypes.c_int32), ("embed_flash", ctypes.c_int8), ("enabled_cpus_num", ctypes.c_int8), ("enabled_cpus_mask", ctypes.c_uint32), ("reserved", ctypes.c_uint8 * 106) ] class RKLLMParam(ctypes.Structure): _fields_ = [ ("model_path", ctypes.c_char_p), ("max_context_len", ctypes.c_int32), ("max_new_tokens", ctypes.c_int32), ("top_k", ctypes.c_int32), ("n_keep", ctypes.c_int32), ("top_p", ctypes.c_float), ("temperature", ctypes.c_float), ("repeat_penalty", ctypes.c_float), ("frequency_penalty", ctypes.c_float), ("presence_penalty", ctypes.c_float), # Note: This was missing in the provided text but is in typical LLM params ("mirostat", ctypes.c_int32), ("mirostat_tau", ctypes.c_float), ("mirostat_eta", ctypes.c_float), ("skip_special_token", ctypes.c_bool), ("is_async", ctypes.c_bool), ("img_start", ctypes.c_char_p), ("img_end", ctypes.c_char_p), ("img_content", ctypes.c_char_p), # This seems like it should be more structured for actual image data ("extend_param", RKLLMExtendParam) ] class RKLLMLoraAdapter(ctypes.Structure): _fields_ = [ ("lora_adapter_path", ctypes.c_char_p), ("lora_adapter_name", ctypes.c_char_p), ("scale", ctypes.c_float) ] class RKLLMEmbedInput(ctypes.Structure): _fields_ = [ ("embed", ctypes.POINTER(ctypes.c_float)), ("n_tokens", ctypes.c_size_t) ] class RKLLMTokenInput(ctypes.Structure): _fields_ = [ ("input_ids", ctypes.POINTER(ctypes.c_int32)), ("n_tokens", ctypes.c_size_t) ] class RKLLMMultiModelInput(ctypes.Structure): _fields_ = [ ("prompt", ctypes.c_char_p), ("image_embed", ctypes.POINTER(ctypes.c_float)), ("n_image_tokens", ctypes.c_size_t), ("n_image", ctypes.c_size_t), ("image_width", ctypes.c_size_t), ("image_height", ctypes.c_size_t) ] class _RKLLMInputUnion(ctypes.Union): _fields_ = [ ("prompt_input", ctypes.c_char_p), ("embed_input", RKLLMEmbedInput), ("token_input", RKLLMTokenInput), ("multimodal_input", RKLLMMultiModelInput) ] class RKLLMInput(ctypes.Structure): _fields_ = [ ("input_type", ctypes.c_int), # Enum will be passed as int, changed RKLLMInputType to ctypes.c_int ("_union_data", _RKLLMInputUnion) ] # Properties to make accessing union members easier @property def prompt_input(self): if self.input_type == RKLLMInputType.RKLLM_INPUT_PROMPT: return self._union_data.prompt_input raise AttributeError("Not a prompt input") @prompt_input.setter def prompt_input(self, value): if self.input_type == RKLLMInputType.RKLLM_INPUT_PROMPT: self._union_data.prompt_input = value else: raise AttributeError("Not a prompt input") # Similar properties can be added for embed_input, token_input, multimodal_input class RKLLMLoraParam(ctypes.Structure): # For inference _fields_ = [ ("lora_adapter_name", ctypes.c_char_p) ] class RKLLMPromptCacheParam(ctypes.Structure): # For inference _fields_ = [ ("save_prompt_cache", ctypes.c_int), # bool-like ("prompt_cache_path", ctypes.c_char_p) ] class RKLLMInferParam(ctypes.Structure): _fields_ = [ ("mode", ctypes.c_int), # Enum will be passed as int, changed RKLLMInferMode to ctypes.c_int ("lora_params", ctypes.POINTER(RKLLMLoraParam)), ("prompt_cache_params", ctypes.POINTER(RKLLMPromptCacheParam)), ("keep_history", ctypes.c_int) # bool-like ] class RKLLMResultLastHiddenLayer(ctypes.Structure): _fields_ = [ ("hidden_states", ctypes.POINTER(ctypes.c_float)), ("embd_size", ctypes.c_int), ("num_tokens", ctypes.c_int) ] class RKLLMResultLogits(ctypes.Structure): _fields_ = [ ("logits", ctypes.POINTER(ctypes.c_float)), ("vocab_size", ctypes.c_int), ("num_tokens", ctypes.c_int) ] class RKLLMResult(ctypes.Structure): _fields_ = [ ("text", ctypes.c_char_p), ("token_id", ctypes.c_int32), ("last_hidden_layer", RKLLMResultLastHiddenLayer), ("logits", RKLLMResultLogits) ] # --- Typedefs --- LLMHandle = ctypes.c_void_p # --- Callback Function Type --- LLMResultCallback = ctypes.CFUNCTYPE( None, # return type: void ctypes.POINTER(RKLLMResult), ctypes.c_void_p, # userdata ctypes.c_int # enum, will be passed as int. Changed LLMCallState to ctypes.c_int ) class RKLLMRuntime: def __init__(self, library_path="./librkllmrt.so"): try: self.lib = ctypes.CDLL(library_path) except OSError as e: raise OSError(f"Failed to load RKLLM library from {library_path}. " f"Ensure it's in your LD_LIBRARY_PATH or provide the full path. Error: {e}") self._setup_functions() self.llm_handle = LLMHandle() self._c_callback = None # To keep the callback object alive def _setup_functions(self): # RKLLMParam rkllm_createDefaultParam(); self.lib.rkllm_createDefaultParam.restype = RKLLMParam self.lib.rkllm_createDefaultParam.argtypes = [] # int rkllm_init(LLMHandle* handle, RKLLMParam* param, LLMResultCallback callback); self.lib.rkllm_init.restype = ctypes.c_int self.lib.rkllm_init.argtypes = [ ctypes.POINTER(LLMHandle), ctypes.POINTER(RKLLMParam), LLMResultCallback ] # int rkllm_load_lora(LLMHandle handle, RKLLMLoraAdapter* lora_adapter); self.lib.rkllm_load_lora.restype = ctypes.c_int self.lib.rkllm_load_lora.argtypes = [LLMHandle, ctypes.POINTER(RKLLMLoraAdapter)] # int rkllm_load_prompt_cache(LLMHandle handle, const char* prompt_cache_path); self.lib.rkllm_load_prompt_cache.restype = ctypes.c_int self.lib.rkllm_load_prompt_cache.argtypes = [LLMHandle, ctypes.c_char_p] # int rkllm_release_prompt_cache(LLMHandle handle); self.lib.rkllm_release_prompt_cache.restype = ctypes.c_int self.lib.rkllm_release_prompt_cache.argtypes = [LLMHandle] # int rkllm_destroy(LLMHandle handle); self.lib.rkllm_destroy.restype = ctypes.c_int self.lib.rkllm_destroy.argtypes = [LLMHandle] # int rkllm_run(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata); self.lib.rkllm_run.restype = ctypes.c_int self.lib.rkllm_run.argtypes = [ LLMHandle, ctypes.POINTER(RKLLMInput), ctypes.POINTER(RKLLMInferParam), ctypes.c_void_p # userdata ] # int rkllm_run_async(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata); # Assuming async also takes userdata for the callback context self.lib.rkllm_run_async.restype = ctypes.c_int self.lib.rkllm_run_async.argtypes = [ LLMHandle, ctypes.POINTER(RKLLMInput), ctypes.POINTER(RKLLMInferParam), ctypes.c_void_p # userdata ] # int rkllm_abort(LLMHandle handle); self.lib.rkllm_abort.restype = ctypes.c_int self.lib.rkllm_abort.argtypes = [LLMHandle] # int rkllm_is_running(LLMHandle handle); self.lib.rkllm_is_running.restype = ctypes.c_int # 0 if running, non-zero otherwise self.lib.rkllm_is_running.argtypes = [LLMHandle] # int rkllm_clear_kv_cache(LLMHandle handle, int keep_system_prompt); self.lib.rkllm_clear_kv_cache.restype = ctypes.c_int self.lib.rkllm_clear_kv_cache.argtypes = [LLMHandle, ctypes.c_int] # int rkllm_set_chat_template(LLMHandle handle, const char* system_prompt, const char* prompt_prefix, const char* prompt_postfix); self.lib.rkllm_set_chat_template.restype = ctypes.c_int self.lib.rkllm_set_chat_template.argtypes = [ LLMHandle, ctypes.c_char_p, ctypes.c_char_p, ctypes.c_char_p ] def create_default_param(self) -> RKLLMParam: """Creates a default RKLLMParam structure.""" return self.lib.rkllm_createDefaultParam() def init(self, param: RKLLMParam, callback_func) -> int: """ Initializes the LLM. :param param: RKLLMParam structure. :param callback_func: A Python function that matches the signature: def my_callback(result_ptr, userdata_ptr, state_enum): result = result_ptr.contents # RKLLMResult # Process result # userdata can be retrieved if passed during run, or ignored # state = LLMCallState(state_enum) :return: 0 for success, non-zero for failure. """ if not callable(callback_func): raise ValueError("callback_func must be a callable Python function.") # Keep a reference to the ctypes callback object to prevent it from being garbage collected self._c_callback = LLMResultCallback(callback_func) ret = self.lib.rkllm_init(ctypes.byref(self.llm_handle), ctypes.byref(param), self._c_callback) if ret != 0: raise RuntimeError(f"rkllm_init failed with error code {ret}") return ret def load_lora(self, lora_adapter: RKLLMLoraAdapter) -> int: """Loads a Lora adapter.""" ret = self.lib.rkllm_load_lora(self.llm_handle, ctypes.byref(lora_adapter)) if ret != 0: raise RuntimeError(f"rkllm_load_lora failed with error code {ret}") return ret def load_prompt_cache(self, prompt_cache_path: str) -> int: """Loads a prompt cache from a file.""" c_path = prompt_cache_path.encode('utf-8') ret = self.lib.rkllm_load_prompt_cache(self.llm_handle, c_path) if ret != 0: raise RuntimeError(f"rkllm_load_prompt_cache failed for {prompt_cache_path} with error code {ret}") return ret def release_prompt_cache(self) -> int: """Releases the prompt cache from memory.""" ret = self.lib.rkllm_release_prompt_cache(self.llm_handle) if ret != 0: raise RuntimeError(f"rkllm_release_prompt_cache failed with error code {ret}") return ret def destroy(self) -> int: """Destroys the LLM instance and releases resources.""" if self.llm_handle and self.llm_handle.value: # Check if handle is not NULL ret = self.lib.rkllm_destroy(self.llm_handle) self.llm_handle = LLMHandle() # Reset handle if ret != 0: # Don't raise here as it might be called in __del__ print(f"Warning: rkllm_destroy failed with error code {ret}") return ret return 0 # Already destroyed or not initialized def run(self, rkllm_input: RKLLMInput, rkllm_infer_params: RKLLMInferParam, userdata=None) -> int: """Runs an LLM inference task synchronously.""" # userdata can be a ctypes.py_object if you want to pass Python objects, # then cast to c_void_p. Or simply None. c_userdata = ctypes.cast(ctypes.py_object(userdata), ctypes.c_void_p) if userdata is not None else None ret = self.lib.rkllm_run(self.llm_handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), c_userdata) if ret != 0: raise RuntimeError(f"rkllm_run failed with error code {ret}") return ret def run_async(self, rkllm_input: RKLLMInput, rkllm_infer_params: RKLLMInferParam, userdata=None) -> int: """Runs an LLM inference task asynchronously.""" c_userdata = ctypes.cast(ctypes.py_object(userdata), ctypes.c_void_p) if userdata is not None else None ret = self.lib.rkllm_run_async(self.llm_handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), c_userdata) if ret != 0: raise RuntimeError(f"rkllm_run_async failed with error code {ret}") return ret def abort(self) -> int: """Aborts an ongoing LLM task.""" ret = self.lib.rkllm_abort(self.llm_handle) if ret != 0: raise RuntimeError(f"rkllm_abort failed with error code {ret}") return ret def is_running(self) -> bool: """Checks if an LLM task is currently running. Returns True if running.""" # The C API returns 0 if running, non-zero otherwise. # This is a bit counter-intuitive for a boolean "is_running". return self.lib.rkllm_is_running(self.llm_handle) == 0 def clear_kv_cache(self, keep_system_prompt: bool) -> int: """Clears the key-value cache.""" ret = self.lib.rkllm_clear_kv_cache(self.llm_handle, ctypes.c_int(1 if keep_system_prompt else 0)) if ret != 0: raise RuntimeError(f"rkllm_clear_kv_cache failed with error code {ret}") return ret def set_chat_template(self, system_prompt: str, prompt_prefix: str, prompt_postfix: str) -> int: """Sets the chat template for the LLM.""" c_system = system_prompt.encode('utf-8') if system_prompt else None c_prefix = prompt_prefix.encode('utf-8') if prompt_prefix else None c_postfix = prompt_postfix.encode('utf-8') if prompt_postfix else None ret = self.lib.rkllm_set_chat_template(self.llm_handle, c_system, c_prefix, c_postfix) if ret != 0: raise RuntimeError(f"rkllm_set_chat_template failed with error code {ret}") return ret def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.destroy() def __del__(self): self.destroy() # Ensure resources are freed if object is garbage collected # --- Example Usage (Illustrative) --- if __name__ == "__main__": # This is a placeholder for how you might use it. # You'll need a valid .rkllm model and librkllmrt.so in your path. # Global list to store results from callback for demonstration results_buffer = [] def my_python_callback(result_ptr, userdata_ptr, state_enum): """ Callback function to be called by the C library. """ global results_buffer state = LLMCallState(state_enum) result = result_ptr.contents current_text = "" if result.text: # Check if the char_p is not NULL current_text = result.text.decode('utf-8', errors='ignore') print(f"Callback: State={state.name}, TokenID={result.token_id}, Text='{current_text}'") results_buffer.append(current_text) if state == LLMCallState.RKLLM_RUN_FINISH: print("Inference finished.") elif state == LLMCallState.RKLLM_RUN_ERROR: print("Inference error.") # Example: Accessing logits if available (and if mode was set to get logits) # if result.logits.logits and result.logits.vocab_size > 0: # print(f" Logits (first 5 of vocab_size {result.logits.vocab_size}):") # for i in range(min(5, result.logits.vocab_size)): # print(f" {result.logits.logits[i]:.4f}", end=" ") # print() # --- Attempt to use the wrapper --- try: print("Initializing RKLLMRuntime...") # Adjust library_path if librkllmrt.so is not in default search paths # e.g., library_path="./path/to/librkllmrt.so" rk_llm = RKLLMRuntime() print("Creating default parameters...") params = rk_llm.create_default_param() # --- Configure parameters --- # THIS IS CRITICAL: model_path must point to an actual .rkllm file # For this example to run, you need a model file. # Let's assume a dummy path for now, this will fail at init if not valid. model_file = "dummy_model.rkllm" if not os.path.exists(model_file): print(f"Warning: Model file '{model_file}' does not exist. Init will likely fail.") # Create a dummy file for the example to proceed further, though init will still fail # with a real library unless it's a valid model. with open(model_file, "w") as f: f.write("dummy content") params.model_path = model_file.encode('utf-8') params.max_context_len = 512 params.max_new_tokens = 128 params.top_k = 1 # Greedy params.temperature = 0.7 params.repeat_penalty = 1.1 # ... set other params as needed print(f"Initializing LLM with model: {params.model_path.decode()}...") # This will likely fail if dummy_model.rkllm is not a valid model recognized by the library try: rk_llm.init(params, my_python_callback) print("LLM Initialized.") except RuntimeError as e: print(f"Error during LLM initialization: {e}") print("This is expected if 'dummy_model.rkllm' is not a valid model.") print("Replace 'dummy_model.rkllm' with a real model path to test further.") exit() # --- Prepare input --- print("Preparing input...") rk_input = RKLLMInput() rk_input.input_type = RKLLMInputType.RKLLM_INPUT_PROMPT prompt_text = "Translate the following English text to French: 'Hello, world!'" c_prompt = prompt_text.encode('utf-8') rk_input._union_data.prompt_input = c_prompt # Accessing union member directly # --- Prepare inference parameters --- print("Preparing inference parameters...") infer_params = RKLLMInferParam() infer_params.mode = RKLLMInferMode.RKLLM_INFER_GENERATE infer_params.keep_history = 1 # True # infer_params.lora_params = None # or set up RKLLMLoraParam if using LoRA # infer_params.prompt_cache_params = None # or set up RKLLMPromptCacheParam # --- Run inference --- print(f"Running inference with prompt: '{prompt_text}'") results_buffer.clear() try: rk_llm.run(rk_input, infer_params) # Userdata is None by default print("\n--- Full Response ---") print("".join(results_buffer)) print("---------------------\n") except RuntimeError as e: print(f"Error during LLM run: {e}") # --- Example: Set chat template (if model supports it) --- # print("Setting chat template...") # try: # rk_llm.set_chat_template("You are a helpful assistant.", ": ", ": ") # print("Chat template set.") # except RuntimeError as e: # print(f"Error setting chat template: {e}") # --- Example: Clear KV Cache --- # print("Clearing KV cache (keeping system prompt if any)...") # try: # rk_llm.clear_kv_cache(keep_system_prompt=True) # print("KV cache cleared.") # except RuntimeError as e: # print(f"Error clearing KV cache: {e}") except OSError as e: print(f"OSError: {e}. Could not load the RKLLM library.") print("Please ensure 'librkllmrt.so' is in your LD_LIBRARY_PATH or provide the full path.") except Exception as e: print(f"An unexpected error occurred: {e}") finally: if 'rk_llm' in locals() and rk_llm.llm_handle and rk_llm.llm_handle.value: print("Destroying LLM instance...") rk_llm.destroy() print("LLM instance destroyed.") if os.path.exists(model_file) and model_file == "dummy_model.rkllm": os.remove(model_file) # Clean up dummy file print("Example finished.")