|
import ctypes |
|
import enum |
|
import os |
|
|
|
|
|
CPU0 = (1 << 0) |
|
CPU1 = (1 << 1) |
|
CPU2 = (1 << 2) |
|
CPU3 = (1 << 3) |
|
CPU4 = (1 << 4) |
|
CPU5 = (1 << 5) |
|
CPU6 = (1 << 6) |
|
CPU7 = (1 << 7) |
|
|
|
|
|
class LLMCallState(enum.IntEnum): |
|
RKLLM_RUN_NORMAL = 0 |
|
RKLLM_RUN_WAITING = 1 |
|
RKLLM_RUN_FINISH = 2 |
|
RKLLM_RUN_ERROR = 3 |
|
|
|
class RKLLMInputType(enum.IntEnum): |
|
RKLLM_INPUT_PROMPT = 0 |
|
RKLLM_INPUT_TOKEN = 1 |
|
RKLLM_INPUT_EMBED = 2 |
|
RKLLM_INPUT_MULTIMODAL = 3 |
|
|
|
class RKLLMInferMode(enum.IntEnum): |
|
RKLLM_INFER_GENERATE = 0 |
|
RKLLM_INFER_GET_LAST_HIDDEN_LAYER = 1 |
|
RKLLM_INFER_GET_LOGITS = 2 |
|
|
|
|
|
class RKLLMExtendParam(ctypes.Structure): |
|
base_domain_id: ctypes.c_int32 |
|
embed_flash: ctypes.c_int8 |
|
enabled_cpus_num: ctypes.c_int8 |
|
enabled_cpus_mask: ctypes.c_uint32 |
|
n_batch: ctypes.c_uint8 |
|
use_cross_attn: ctypes.c_int8 |
|
reserved: ctypes.c_uint8 * 104 |
|
|
|
_fields_ = [ |
|
("base_domain_id", ctypes.c_int32), |
|
("embed_flash", ctypes.c_int8), |
|
("enabled_cpus_num", ctypes.c_int8), |
|
("enabled_cpus_mask", ctypes.c_uint32), |
|
("n_batch", ctypes.c_uint8), |
|
("use_cross_attn", ctypes.c_int8), |
|
("reserved", ctypes.c_uint8 * 104) |
|
] |
|
|
|
class RKLLMParam(ctypes.Structure): |
|
model_path: ctypes.c_char_p |
|
max_context_len: ctypes.c_int32 |
|
max_new_tokens: ctypes.c_int32 |
|
top_k: ctypes.c_int32 |
|
n_keep: ctypes.c_int32 |
|
top_p: ctypes.c_float |
|
temperature: ctypes.c_float |
|
repeat_penalty: ctypes.c_float |
|
frequency_penalty: ctypes.c_float |
|
presence_penalty: ctypes.c_float |
|
mirostat: ctypes.c_int32 |
|
mirostat_tau: ctypes.c_float |
|
mirostat_eta: ctypes.c_float |
|
skip_special_token: ctypes.c_bool |
|
is_async: ctypes.c_bool |
|
img_start: ctypes.c_char_p |
|
img_end: ctypes.c_char_p |
|
img_content: ctypes.c_char_p |
|
extend_param: RKLLMExtendParam |
|
|
|
_fields_ = [ |
|
("model_path", ctypes.c_char_p), |
|
("max_context_len", ctypes.c_int32), |
|
("max_new_tokens", ctypes.c_int32), |
|
("top_k", ctypes.c_int32), |
|
("n_keep", ctypes.c_int32), |
|
("top_p", ctypes.c_float), |
|
("temperature", ctypes.c_float), |
|
("repeat_penalty", ctypes.c_float), |
|
("frequency_penalty", ctypes.c_float), |
|
("presence_penalty", ctypes.c_float), |
|
("mirostat", ctypes.c_int32), |
|
("mirostat_tau", ctypes.c_float), |
|
("mirostat_eta", ctypes.c_float), |
|
("skip_special_token", ctypes.c_bool), |
|
("is_async", ctypes.c_bool), |
|
("img_start", ctypes.c_char_p), |
|
("img_end", ctypes.c_char_p), |
|
("img_content", ctypes.c_char_p), |
|
("extend_param", RKLLMExtendParam) |
|
] |
|
|
|
class RKLLMLoraAdapter(ctypes.Structure): |
|
lora_adapter_path: ctypes.c_char_p |
|
lora_adapter_name: ctypes.c_char_p |
|
scale: ctypes.c_float |
|
|
|
_fields_ = [ |
|
("lora_adapter_path", ctypes.c_char_p), |
|
("lora_adapter_name", ctypes.c_char_p), |
|
("scale", ctypes.c_float) |
|
] |
|
|
|
class RKLLMEmbedInput(ctypes.Structure): |
|
embed: ctypes.POINTER(ctypes.c_float) |
|
n_tokens: ctypes.c_size_t |
|
|
|
_fields_ = [ |
|
("embed", ctypes.POINTER(ctypes.c_float)), |
|
("n_tokens", ctypes.c_size_t) |
|
] |
|
|
|
class RKLLMTokenInput(ctypes.Structure): |
|
input_ids: ctypes.POINTER(ctypes.c_int32) |
|
n_tokens: ctypes.c_size_t |
|
|
|
_fields_ = [ |
|
("input_ids", ctypes.POINTER(ctypes.c_int32)), |
|
("n_tokens", ctypes.c_size_t) |
|
] |
|
|
|
class RKLLMMultiModelInput(ctypes.Structure): |
|
prompt: ctypes.c_char_p |
|
image_embed: ctypes.POINTER(ctypes.c_float) |
|
n_image_tokens: ctypes.c_size_t |
|
n_image: ctypes.c_size_t |
|
image_width: ctypes.c_size_t |
|
image_height: ctypes.c_size_t |
|
|
|
_fields_ = [ |
|
("prompt", ctypes.c_char_p), |
|
("image_embed", ctypes.POINTER(ctypes.c_float)), |
|
("n_image_tokens", ctypes.c_size_t), |
|
("n_image", ctypes.c_size_t), |
|
("image_width", ctypes.c_size_t), |
|
("image_height", ctypes.c_size_t) |
|
] |
|
|
|
class RKLLMCrossAttnParam(ctypes.Structure): |
|
""" |
|
交叉注意力参数结构体 |
|
|
|
该结构体用于在解码器中执行交叉注意力时使用。 |
|
它提供编码器输出(键/值缓存)、位置索引和注意力掩码。 |
|
|
|
- encoder_k_cache必须存储在连续内存中,布局为: |
|
[num_layers][num_tokens][num_kv_heads][head_dim] |
|
- encoder_v_cache必须存储在连续内存中,布局为: |
|
[num_layers][num_kv_heads][head_dim][num_tokens] |
|
""" |
|
encoder_k_cache: ctypes.POINTER(ctypes.c_float) |
|
encoder_v_cache: ctypes.POINTER(ctypes.c_float) |
|
encoder_mask: ctypes.POINTER(ctypes.c_float) |
|
encoder_pos: ctypes.POINTER(ctypes.c_int32) |
|
num_tokens: ctypes.c_int |
|
|
|
_fields_ = [ |
|
("encoder_k_cache", ctypes.POINTER(ctypes.c_float)), |
|
("encoder_v_cache", ctypes.POINTER(ctypes.c_float)), |
|
("encoder_mask", ctypes.POINTER(ctypes.c_float)), |
|
("encoder_pos", ctypes.POINTER(ctypes.c_int32)), |
|
("num_tokens", ctypes.c_int) |
|
] |
|
|
|
class RKLLMPerfStat(ctypes.Structure): |
|
""" |
|
性能统计结构体 |
|
|
|
用于保存预填充和生成阶段的性能统计信息。 |
|
""" |
|
prefill_time_ms: ctypes.c_float |
|
prefill_tokens: ctypes.c_int |
|
generate_time_ms: ctypes.c_float |
|
generate_tokens: ctypes.c_int |
|
memory_usage_mb: ctypes.c_float |
|
|
|
_fields_ = [ |
|
("prefill_time_ms", ctypes.c_float), |
|
("prefill_tokens", ctypes.c_int), |
|
("generate_time_ms", ctypes.c_float), |
|
("generate_tokens", ctypes.c_int), |
|
("memory_usage_mb", ctypes.c_float) |
|
] |
|
|
|
class _RKLLMInputUnion(ctypes.Union): |
|
prompt_input: ctypes.c_char_p |
|
embed_input: RKLLMEmbedInput |
|
token_input: RKLLMTokenInput |
|
multimodal_input: RKLLMMultiModelInput |
|
|
|
_fields_ = [ |
|
("prompt_input", ctypes.c_char_p), |
|
("embed_input", RKLLMEmbedInput), |
|
("token_input", RKLLMTokenInput), |
|
("multimodal_input", RKLLMMultiModelInput) |
|
] |
|
|
|
class RKLLMInput(ctypes.Structure): |
|
""" |
|
LLM输入结构体 |
|
|
|
通过联合体表示不同类型的LLM输入。 |
|
""" |
|
role: ctypes.c_char_p |
|
enable_thinking: ctypes.c_bool |
|
input_type: ctypes.c_int |
|
_union_data: _RKLLMInputUnion |
|
|
|
_fields_ = [ |
|
("role", ctypes.c_char_p), |
|
("enable_thinking", ctypes.c_bool), |
|
("input_type", ctypes.c_int), |
|
("_union_data", _RKLLMInputUnion) |
|
] |
|
|
|
@property |
|
def prompt_input(self) -> bytes: |
|
if self.input_type == RKLLMInputType.RKLLM_INPUT_PROMPT: |
|
return self._union_data.prompt_input |
|
raise AttributeError("Not a prompt input") |
|
@prompt_input.setter |
|
def prompt_input(self, value: bytes): |
|
if self.input_type == RKLLMInputType.RKLLM_INPUT_PROMPT: |
|
self._union_data.prompt_input = value |
|
else: |
|
raise AttributeError("Not a prompt input") |
|
@property |
|
def embed_input(self) -> RKLLMEmbedInput: |
|
if self.input_type == RKLLMInputType.RKLLM_INPUT_EMBED: |
|
return self._union_data.embed_input |
|
raise AttributeError("Not an embed input") |
|
@embed_input.setter |
|
def embed_input(self, value: RKLLMEmbedInput): |
|
if self.input_type == RKLLMInputType.RKLLM_INPUT_EMBED: |
|
self._union_data.embed_input = value |
|
else: |
|
raise AttributeError("Not an embed input") |
|
|
|
@property |
|
def token_input(self) -> RKLLMTokenInput: |
|
if self.input_type == RKLLMInputType.RKLLM_INPUT_TOKEN: |
|
return self._union_data.token_input |
|
raise AttributeError("Not a token input") |
|
@token_input.setter |
|
def token_input(self, value: RKLLMTokenInput): |
|
if self.input_type == RKLLMInputType.RKLLM_INPUT_TOKEN: |
|
self._union_data.token_input = value |
|
else: |
|
raise AttributeError("Not a token input") |
|
|
|
@property |
|
def multimodal_input(self) -> RKLLMMultiModelInput: |
|
if self.input_type == RKLLMInputType.RKLLM_INPUT_MULTIMODAL: |
|
return self._union_data.multimodal_input |
|
raise AttributeError("Not a multimodal input") |
|
@multimodal_input.setter |
|
def multimodal_input(self, value: RKLLMMultiModelInput): |
|
if self.input_type == RKLLMInputType.RKLLM_INPUT_MULTIMODAL: |
|
self._union_data.multimodal_input = value |
|
else: |
|
raise AttributeError("Not a multimodal input") |
|
|
|
class RKLLMLoraParam(ctypes.Structure): |
|
lora_adapter_name: ctypes.c_char_p |
|
|
|
_fields_ = [ |
|
("lora_adapter_name", ctypes.c_char_p) |
|
] |
|
|
|
class RKLLMPromptCacheParam(ctypes.Structure): |
|
save_prompt_cache: ctypes.c_int |
|
prompt_cache_path: ctypes.c_char_p |
|
|
|
_fields_ = [ |
|
("save_prompt_cache", ctypes.c_int), |
|
("prompt_cache_path", ctypes.c_char_p) |
|
] |
|
|
|
class RKLLMInferParam(ctypes.Structure): |
|
mode: ctypes.c_int |
|
lora_params: ctypes.POINTER(RKLLMLoraParam) |
|
prompt_cache_params: ctypes.POINTER(RKLLMPromptCacheParam) |
|
keep_history: ctypes.c_int |
|
|
|
_fields_ = [ |
|
("mode", ctypes.c_int), |
|
("lora_params", ctypes.POINTER(RKLLMLoraParam)), |
|
("prompt_cache_params", ctypes.POINTER(RKLLMPromptCacheParam)), |
|
("keep_history", ctypes.c_int) |
|
] |
|
|
|
class RKLLMResultLastHiddenLayer(ctypes.Structure): |
|
hidden_states: ctypes.POINTER(ctypes.c_float) |
|
embd_size: ctypes.c_int |
|
num_tokens: ctypes.c_int |
|
|
|
_fields_ = [ |
|
("hidden_states", ctypes.POINTER(ctypes.c_float)), |
|
("embd_size", ctypes.c_int), |
|
("num_tokens", ctypes.c_int) |
|
] |
|
|
|
class RKLLMResultLogits(ctypes.Structure): |
|
logits: ctypes.POINTER(ctypes.c_float) |
|
vocab_size: ctypes.c_int |
|
num_tokens: ctypes.c_int |
|
|
|
_fields_ = [ |
|
("logits", ctypes.POINTER(ctypes.c_float)), |
|
("vocab_size", ctypes.c_int), |
|
("num_tokens", ctypes.c_int) |
|
] |
|
|
|
class RKLLMResult(ctypes.Structure): |
|
""" |
|
LLM推理结果结构体 |
|
|
|
表示LLM推理的结果,包含生成的文本、token ID、隐藏层状态、logits和性能统计。 |
|
""" |
|
text: ctypes.c_char_p |
|
token_id: ctypes.c_int32 |
|
last_hidden_layer: RKLLMResultLastHiddenLayer |
|
logits: RKLLMResultLogits |
|
perf: RKLLMPerfStat |
|
|
|
_fields_ = [ |
|
("text", ctypes.c_char_p), |
|
("token_id", ctypes.c_int32), |
|
("last_hidden_layer", RKLLMResultLastHiddenLayer), |
|
("logits", RKLLMResultLogits), |
|
("perf", RKLLMPerfStat) |
|
] |
|
|
|
|
|
LLMHandle = ctypes.c_void_p |
|
|
|
|
|
LLMResultCallback = ctypes.CFUNCTYPE( |
|
ctypes.c_int, |
|
ctypes.POINTER(RKLLMResult), |
|
ctypes.c_void_p, |
|
ctypes.c_int |
|
) |
|
""" |
|
回调函数类型定义 |
|
|
|
用于处理LLM结果的回调函数。 |
|
|
|
参数: |
|
- result: 指向LLM结果的指针 |
|
- userdata: 回调的用户数据指针 |
|
- state: LLM调用状态(例如:完成、错误) |
|
|
|
返回值: |
|
- 0: 正常继续推理 |
|
- 1: 暂停推理。如果用户想要修改或干预结果(例如编辑输出、注入新提示), |
|
返回1以暂停当前推理。稍后,使用更新的内容调用rkllm_run来恢复推理。 |
|
""" |
|
|
|
class RKLLMRuntime: |
|
def __init__(self, library_path="./librkllmrt.so"): |
|
try: |
|
self.lib = ctypes.CDLL(library_path) |
|
except OSError as e: |
|
raise OSError(f"Failed to load RKLLM library from {library_path}. " |
|
f"Ensure it's in your LD_LIBRARY_PATH or provide the full path. Error: {e}") |
|
self._setup_functions() |
|
self.llm_handle = LLMHandle() |
|
self._c_callback = None |
|
|
|
def _setup_functions(self): |
|
|
|
self.lib.rkllm_createDefaultParam.restype = RKLLMParam |
|
self.lib.rkllm_createDefaultParam.argtypes = [] |
|
|
|
|
|
self.lib.rkllm_init.restype = ctypes.c_int |
|
self.lib.rkllm_init.argtypes = [ |
|
ctypes.POINTER(LLMHandle), |
|
ctypes.POINTER(RKLLMParam), |
|
LLMResultCallback |
|
] |
|
|
|
|
|
self.lib.rkllm_load_lora.restype = ctypes.c_int |
|
self.lib.rkllm_load_lora.argtypes = [LLMHandle, ctypes.POINTER(RKLLMLoraAdapter)] |
|
|
|
|
|
self.lib.rkllm_load_prompt_cache.restype = ctypes.c_int |
|
self.lib.rkllm_load_prompt_cache.argtypes = [LLMHandle, ctypes.c_char_p] |
|
|
|
|
|
self.lib.rkllm_release_prompt_cache.restype = ctypes.c_int |
|
self.lib.rkllm_release_prompt_cache.argtypes = [LLMHandle] |
|
|
|
|
|
self.lib.rkllm_destroy.restype = ctypes.c_int |
|
self.lib.rkllm_destroy.argtypes = [LLMHandle] |
|
|
|
|
|
self.lib.rkllm_run.restype = ctypes.c_int |
|
self.lib.rkllm_run.argtypes = [ |
|
LLMHandle, |
|
ctypes.POINTER(RKLLMInput), |
|
ctypes.POINTER(RKLLMInferParam), |
|
ctypes.c_void_p |
|
] |
|
|
|
|
|
|
|
self.lib.rkllm_run_async.restype = ctypes.c_int |
|
self.lib.rkllm_run_async.argtypes = [ |
|
LLMHandle, |
|
ctypes.POINTER(RKLLMInput), |
|
ctypes.POINTER(RKLLMInferParam), |
|
ctypes.c_void_p |
|
] |
|
|
|
|
|
self.lib.rkllm_abort.restype = ctypes.c_int |
|
self.lib.rkllm_abort.argtypes = [LLMHandle] |
|
|
|
|
|
self.lib.rkllm_is_running.restype = ctypes.c_int |
|
self.lib.rkllm_is_running.argtypes = [LLMHandle] |
|
|
|
|
|
self.lib.rkllm_clear_kv_cache.restype = ctypes.c_int |
|
self.lib.rkllm_clear_kv_cache.argtypes = [ |
|
LLMHandle, |
|
ctypes.c_int, |
|
ctypes.POINTER(ctypes.c_int), |
|
ctypes.POINTER(ctypes.c_int) |
|
] |
|
|
|
|
|
self.lib.rkllm_get_kv_cache_size.restype = ctypes.c_int |
|
self.lib.rkllm_get_kv_cache_size.argtypes = [LLMHandle, ctypes.POINTER(ctypes.c_int)] |
|
|
|
|
|
self.lib.rkllm_set_chat_template.restype = ctypes.c_int |
|
self.lib.rkllm_set_chat_template.argtypes = [ |
|
LLMHandle, |
|
ctypes.c_char_p, |
|
ctypes.c_char_p, |
|
ctypes.c_char_p |
|
] |
|
|
|
|
|
self.lib.rkllm_set_function_tools.restype = ctypes.c_int |
|
self.lib.rkllm_set_function_tools.argtypes = [ |
|
LLMHandle, |
|
ctypes.c_char_p, |
|
ctypes.c_char_p, |
|
ctypes.c_char_p |
|
] |
|
|
|
|
|
self.lib.rkllm_set_cross_attn_params.restype = ctypes.c_int |
|
self.lib.rkllm_set_cross_attn_params.argtypes = [LLMHandle, ctypes.POINTER(RKLLMCrossAttnParam)] |
|
|
|
def create_default_param(self) -> RKLLMParam: |
|
"""Creates a default RKLLMParam structure.""" |
|
return self.lib.rkllm_createDefaultParam() |
|
|
|
def init(self, param: RKLLMParam, callback_func) -> int: |
|
""" |
|
Initializes the LLM. |
|
:param param: RKLLMParam structure. |
|
:param callback_func: A Python function that matches the signature: |
|
def my_callback(result_ptr, userdata_ptr, state_enum): |
|
result = result_ptr.contents # RKLLMResult |
|
# Process result |
|
# userdata can be retrieved if passed during run, or ignored |
|
# state = LLMCallState(state_enum) |
|
:return: 0 for success, non-zero for failure. |
|
""" |
|
if not callable(callback_func): |
|
raise ValueError("callback_func must be a callable Python function.") |
|
|
|
|
|
self._c_callback = LLMResultCallback(callback_func) |
|
|
|
ret = self.lib.rkllm_init(ctypes.byref(self.llm_handle), ctypes.byref(param), self._c_callback) |
|
if ret != 0: |
|
raise RuntimeError(f"rkllm_init failed with error code {ret}") |
|
return ret |
|
|
|
def load_lora(self, lora_adapter: RKLLMLoraAdapter) -> int: |
|
"""Loads a Lora adapter.""" |
|
ret = self.lib.rkllm_load_lora(self.llm_handle, ctypes.byref(lora_adapter)) |
|
if ret != 0: |
|
raise RuntimeError(f"rkllm_load_lora failed with error code {ret}") |
|
return ret |
|
|
|
def load_prompt_cache(self, prompt_cache_path: str) -> int: |
|
"""Loads a prompt cache from a file.""" |
|
c_path = prompt_cache_path.encode('utf-8') |
|
ret = self.lib.rkllm_load_prompt_cache(self.llm_handle, c_path) |
|
if ret != 0: |
|
raise RuntimeError(f"rkllm_load_prompt_cache failed for {prompt_cache_path} with error code {ret}") |
|
return ret |
|
|
|
def release_prompt_cache(self) -> int: |
|
"""Releases the prompt cache from memory.""" |
|
ret = self.lib.rkllm_release_prompt_cache(self.llm_handle) |
|
if ret != 0: |
|
raise RuntimeError(f"rkllm_release_prompt_cache failed with error code {ret}") |
|
return ret |
|
|
|
def destroy(self) -> int: |
|
"""Destroys the LLM instance and releases resources.""" |
|
if self.llm_handle and self.llm_handle.value: |
|
ret = self.lib.rkllm_destroy(self.llm_handle) |
|
self.llm_handle = LLMHandle() |
|
if ret != 0: |
|
|
|
print(f"Warning: rkllm_destroy failed with error code {ret}") |
|
return ret |
|
return 0 |
|
|
|
def run(self, rkllm_input: RKLLMInput, rkllm_infer_params: RKLLMInferParam, userdata=None) -> int: |
|
"""Runs an LLM inference task synchronously.""" |
|
|
|
|
|
if userdata is not None: |
|
|
|
self._userdata_ref = userdata |
|
c_userdata = ctypes.cast(ctypes.pointer(ctypes.py_object(userdata)), ctypes.c_void_p) |
|
else: |
|
c_userdata = None |
|
ret = self.lib.rkllm_run(self.llm_handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), c_userdata) |
|
if ret != 0: |
|
raise RuntimeError(f"rkllm_run failed with error code {ret}") |
|
return ret |
|
|
|
def run_async(self, rkllm_input: RKLLMInput, rkllm_infer_params: RKLLMInferParam, userdata=None) -> int: |
|
"""Runs an LLM inference task asynchronously.""" |
|
if userdata is not None: |
|
|
|
self._userdata_ref = userdata |
|
c_userdata = ctypes.cast(ctypes.pointer(ctypes.py_object(userdata)), ctypes.c_void_p) |
|
else: |
|
c_userdata = None |
|
ret = self.lib.rkllm_run_async(self.llm_handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), c_userdata) |
|
if ret != 0: |
|
raise RuntimeError(f"rkllm_run_async failed with error code {ret}") |
|
return ret |
|
|
|
def abort(self) -> int: |
|
"""Aborts an ongoing LLM task.""" |
|
ret = self.lib.rkllm_abort(self.llm_handle) |
|
if ret != 0: |
|
raise RuntimeError(f"rkllm_abort failed with error code {ret}") |
|
return ret |
|
|
|
def is_running(self) -> bool: |
|
"""Checks if an LLM task is currently running. Returns True if running.""" |
|
|
|
|
|
return self.lib.rkllm_is_running(self.llm_handle) == 0 |
|
|
|
def clear_kv_cache(self, keep_system_prompt: bool, start_pos: list = None, end_pos: list = None) -> int: |
|
""" |
|
清除键值缓存 |
|
|
|
此函数用于清除部分或全部KV缓存。 |
|
|
|
参数: |
|
- keep_system_prompt: 是否在缓存中保留系统提示(True保留,False清除) |
|
如果提供了特定范围[start_pos, end_pos),此标志将被忽略 |
|
- start_pos: 要清除的KV缓存范围的起始位置数组(包含),每个批次一个 |
|
- end_pos: 要清除的KV缓存范围的结束位置数组(不包含),每个批次一个 |
|
如果start_pos和end_pos都设置为None,将清除整个缓存,keep_system_prompt将生效 |
|
如果start_pos[i] < end_pos[i],只有指定的范围会被清除,keep_system_prompt将被忽略 |
|
|
|
注意:start_pos或end_pos只有在keep_history == 0且生成已通过在回调中返回1暂停时才有效 |
|
|
|
返回:0表示缓存清除成功,非零表示失败 |
|
""" |
|
|
|
c_start_pos = None |
|
c_end_pos = None |
|
|
|
if start_pos is not None and end_pos is not None: |
|
if len(start_pos) != len(end_pos): |
|
raise ValueError("start_pos和end_pos数组长度必须相同") |
|
|
|
|
|
c_start_pos = (ctypes.c_int * len(start_pos))(*start_pos) |
|
c_end_pos = (ctypes.c_int * len(end_pos))(*end_pos) |
|
|
|
ret = self.lib.rkllm_clear_kv_cache( |
|
self.llm_handle, |
|
ctypes.c_int(1 if keep_system_prompt else 0), |
|
c_start_pos, |
|
c_end_pos |
|
) |
|
if ret != 0: |
|
raise RuntimeError(f"rkllm_clear_kv_cache失败,错误代码:{ret}") |
|
return ret |
|
|
|
def set_chat_template(self, system_prompt: str, prompt_prefix: str, prompt_postfix: str) -> int: |
|
"""Sets the chat template for the LLM.""" |
|
c_system = system_prompt.encode('utf-8') if system_prompt else b"" |
|
c_prefix = prompt_prefix.encode('utf-8') if prompt_prefix else b"" |
|
c_postfix = prompt_postfix.encode('utf-8') if prompt_postfix else b"" |
|
|
|
ret = self.lib.rkllm_set_chat_template(self.llm_handle, c_system, c_prefix, c_postfix) |
|
if ret != 0: |
|
raise RuntimeError(f"rkllm_set_chat_template failed with error code {ret}") |
|
return ret |
|
|
|
def get_kv_cache_size(self, n_batch: int) -> list: |
|
""" |
|
获取给定LLM句柄的键值缓存当前大小 |
|
|
|
此函数返回当前存储在模型KV缓存中的位置总数。 |
|
|
|
参数: |
|
- n_batch: 批次数量,用于确定返回数组的大小 |
|
|
|
返回: |
|
- list: 每个批次的缓存大小列表 |
|
""" |
|
|
|
cache_sizes = (ctypes.c_int * n_batch)() |
|
|
|
ret = self.lib.rkllm_get_kv_cache_size(self.llm_handle, cache_sizes) |
|
if ret != 0: |
|
raise RuntimeError(f"rkllm_get_kv_cache_size失败,错误代码:{ret}") |
|
|
|
|
|
return [cache_sizes[i] for i in range(n_batch)] |
|
|
|
def set_function_tools(self, system_prompt: str, tools: str, tool_response_str: str) -> int: |
|
""" |
|
为LLM设置函数调用配置,包括系统提示、工具定义和工具响应token |
|
|
|
参数: |
|
- system_prompt: 定义语言模型上下文或行为的系统提示 |
|
- tools: JSON格式的字符串,定义可用的函数,包括它们的名称、描述和参数 |
|
- tool_response_str: 用于识别对话中函数调用结果的唯一标签。它作为标记标签, |
|
允许分词器将工具输出与正常对话轮次分开识别 |
|
|
|
返回:0表示配置设置成功,非零表示错误 |
|
""" |
|
c_system = system_prompt.encode('utf-8') if system_prompt else b"" |
|
c_tools = tools.encode('utf-8') if tools else b"" |
|
c_tool_response = tool_response_str.encode('utf-8') if tool_response_str else b"" |
|
|
|
ret = self.lib.rkllm_set_function_tools(self.llm_handle, c_system, c_tools, c_tool_response) |
|
if ret != 0: |
|
raise RuntimeError(f"rkllm_set_function_tools失败,错误代码:{ret}") |
|
return ret |
|
|
|
def set_cross_attn_params(self, cross_attn_params: RKLLMCrossAttnParam) -> int: |
|
""" |
|
为LLM解码器设置交叉注意力参数 |
|
|
|
参数: |
|
- cross_attn_params: 包含用于交叉注意力的编码器相关输入数据的结构体 |
|
(详见RKLLMCrossAttnParam说明) |
|
|
|
返回:0表示参数设置成功,非零表示错误 |
|
""" |
|
ret = self.lib.rkllm_set_cross_attn_params(self.llm_handle, ctypes.byref(cross_attn_params)) |
|
if ret != 0: |
|
raise RuntimeError(f"rkllm_set_cross_attn_params失败,错误代码:{ret}") |
|
return ret |
|
|
|
def __enter__(self): |
|
return self |
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
|
self.destroy() |
|
|
|
def __del__(self): |
|
self.destroy() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
results_buffer = [] |
|
|
|
def my_python_callback(result_ptr, userdata_ptr, state_enum): |
|
""" |
|
回调函数,由C库调用来处理LLM结果 |
|
|
|
参数: |
|
- result_ptr: 指向LLM结果的指针 |
|
- userdata_ptr: 用户数据指针 |
|
- state_enum: LLM调用状态枚举值 |
|
|
|
返回: |
|
- 0: 继续推理 |
|
- 1: 暂停推理 |
|
""" |
|
global results_buffer |
|
state = LLMCallState(state_enum) |
|
result = result_ptr.contents |
|
|
|
current_text = "" |
|
if result.text: |
|
current_text = result.text.decode('utf-8', errors='ignore') |
|
|
|
print(f"回调: State={state.name}, TokenID={result.token_id}, Text='{current_text}'") |
|
|
|
|
|
if result.perf.prefill_tokens > 0 or result.perf.generate_tokens > 0: |
|
print(f" 性能统计: 预填充={result.perf.prefill_tokens}tokens/{result.perf.prefill_time_ms:.1f}ms, " |
|
f"生成={result.perf.generate_tokens}tokens/{result.perf.generate_time_ms:.1f}ms, " |
|
f"内存={result.perf.memory_usage_mb:.1f}MB") |
|
|
|
results_buffer.append(current_text) |
|
|
|
if state == LLMCallState.RKLLM_RUN_FINISH: |
|
print("推理完成。") |
|
elif state == LLMCallState.RKLLM_RUN_ERROR: |
|
print("推理错误。") |
|
|
|
|
|
return 0 |
|
|
|
|
|
try: |
|
print("Initializing RKLLMRuntime...") |
|
|
|
|
|
rk_llm = RKLLMRuntime() |
|
|
|
print("Creating default parameters...") |
|
params = rk_llm.create_default_param() |
|
|
|
|
|
|
|
|
|
|
|
model_file = "language_model.rkllm" |
|
if not os.path.exists(model_file): |
|
raise FileNotFoundError(f"Model file '{model_file}' does not exist.") |
|
|
|
params.model_path = model_file.encode('utf-8') |
|
params.max_context_len = 512 |
|
params.max_new_tokens = 128 |
|
|
|
params.temperature = 0.7 |
|
params.repeat_penalty = 1.1 |
|
|
|
|
|
print(f"Initializing LLM with model: {params.model_path.decode()}...") |
|
|
|
try: |
|
rk_llm.init(params, my_python_callback) |
|
print("LLM Initialized.") |
|
except RuntimeError as e: |
|
print(f"Error during LLM initialization: {e}") |
|
print("This is expected if 'dummy_model.rkllm' is not a valid model.") |
|
print("Replace 'dummy_model.rkllm' with a real model path to test further.") |
|
exit() |
|
|
|
|
|
|
|
print("准备输入...") |
|
rk_input = RKLLMInput() |
|
rk_input.role = b"user" |
|
rk_input.enable_thinking = False |
|
rk_input.input_type = RKLLMInputType.RKLLM_INPUT_PROMPT |
|
|
|
prompt_text = "将以下英文文本翻译成中文:'Hello, world!'" |
|
c_prompt = prompt_text.encode('utf-8') |
|
rk_input._union_data.prompt_input = c_prompt |
|
|
|
|
|
print("Preparing inference parameters...") |
|
infer_params = RKLLMInferParam() |
|
infer_params.mode = RKLLMInferMode.RKLLM_INFER_GENERATE |
|
infer_params.keep_history = 1 |
|
|
|
|
|
|
|
|
|
print(f"Running inference with prompt: '{prompt_text}'") |
|
results_buffer.clear() |
|
try: |
|
rk_llm.run(rk_input, infer_params) |
|
print("\n--- Full Response ---") |
|
print("".join(results_buffer)) |
|
print("---------------------\n") |
|
except RuntimeError as e: |
|
print(f"Error during LLM run: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
except OSError as e: |
|
print(f"OSError: {e}. Could not load the RKLLM library.") |
|
print("Please ensure 'librkllmrt.so' is in your LD_LIBRARY_PATH or provide the full path.") |
|
except Exception as e: |
|
print(f"An unexpected error occurred: {e}") |
|
finally: |
|
if 'rk_llm' in locals() and rk_llm.llm_handle and rk_llm.llm_handle.value: |
|
print("Destroying LLM instance...") |
|
rk_llm.destroy() |
|
print("LLM instance destroyed.") |
|
|
|
print("Example finished.") |