|
#include "common.h" |
|
|
|
#include "llama.h" |
|
|
|
#include <cstdio> |
|
#include <cstring> |
|
#include <fstream> |
|
#include <string> |
|
#include <vector> |
|
#include <iostream> |
|
|
|
#if defined(_WIN32) |
|
#define WIN32_LEAN_AND_MEAN |
|
#include <windows.h> |
|
#include <shellapi.h> |
|
#endif |
|
|
|
static void print_usage_information(const char * argv0) { |
|
printf("usage: %s [options]\n\n", argv0); |
|
printf("The tokenize program tokenizes a prompt using a given model,\n"); |
|
printf("and prints the resulting tokens to standard output.\n\n"); |
|
printf("It needs a model file, a prompt, and optionally other flags\n"); |
|
printf("to control the behavior of the tokenizer.\n\n"); |
|
printf(" The possible options are:\n"); |
|
printf("\n"); |
|
printf(" -h, --help print this help and exit\n"); |
|
printf(" -m MODEL_PATH, --model MODEL_PATH path to model.\n"); |
|
printf(" --ids if given, only print numerical token IDs, and not token strings.\n"); |
|
printf(" The output format looks like [1, 2, 3], i.e. parseable by Python.\n"); |
|
printf(" -f PROMPT_FNAME, --file PROMPT_FNAME read prompt from a file.\n"); |
|
printf(" -p PROMPT, --prompt PROMPT read prompt from the argument.\n"); |
|
printf(" --stdin read prompt from standard input.\n"); |
|
printf(" --no-bos do not ever add a BOS token to the prompt, even if normally the model uses a BOS token.\n"); |
|
printf(" --no-escape do not escape input (such as \\n, \\t, etc.).\n"); |
|
printf(" --no-parse-special do not parse control tokens.\n"); |
|
printf(" --log-disable disable logs. Makes stderr quiet when loading the model.\n"); |
|
printf(" --show-count print the total number of tokens.\n"); |
|
} |
|
|
|
static void llama_log_callback_null(ggml_log_level level, const char * text, void * user_data) { |
|
(void) level; |
|
(void) text; |
|
(void) user_data; |
|
} |
|
|
|
static std::string read_prompt_from_file(const char * filepath, bool & success) { |
|
success = false; |
|
|
|
std::ifstream in(filepath, std::ios::binary); |
|
if (!in) { |
|
fprintf(stderr, "%s: could not open file '%s' for reading: %s\n", __func__, filepath, strerror(errno)); |
|
return std::string(); |
|
} |
|
|
|
std::stringstream buffer; |
|
buffer << in.rdbuf(); |
|
if (in.fail()) { |
|
fprintf(stderr, "%s: could not read the entire file '%s': %s\n", __func__, filepath, strerror(errno)); |
|
return std::string(); |
|
} |
|
|
|
success = true; |
|
return buffer.str(); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
static std::vector<std::string> ingest_args(int raw_argc, char ** raw_argv) { |
|
std::vector<std::string> argv; |
|
|
|
|
|
|
|
|
|
|
|
#if defined(_WIN32) |
|
int argc; |
|
const LPWSTR cmdline_wargv = GetCommandLineW(); |
|
LPWSTR * wargv = CommandLineToArgvW(cmdline_wargv, &argc); |
|
|
|
|
|
(void) raw_argc; |
|
(void) raw_argv; |
|
|
|
for (int i = 0; i < argc; ++i) { |
|
int length_needed = WideCharToMultiByte(CP_UTF8, 0, wargv[i], wcslen(wargv[i]), 0, 0, NULL, NULL); |
|
char * output_buf = (char *) calloc(length_needed+1, sizeof(char)); |
|
GGML_ASSERT(output_buf); |
|
|
|
WideCharToMultiByte(CP_UTF8, 0, wargv[i], wcslen(wargv[i]), output_buf, length_needed, NULL, NULL); |
|
output_buf[length_needed] = '\0'; |
|
|
|
argv.push_back(output_buf); |
|
free(output_buf); |
|
} |
|
|
|
LocalFree((HLOCAL) wargv); |
|
#else |
|
int argc = raw_argc; |
|
for (int i = 0; i < argc; ++i) { |
|
argv.push_back(raw_argv[i]); |
|
} |
|
#endif |
|
|
|
GGML_ASSERT((unsigned int) argc == argv.size()); |
|
|
|
return argv; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
static void write_utf8_cstr_to_stdout(const char * str, bool & invalid_utf8) { |
|
invalid_utf8 = false; |
|
|
|
#if defined(_WIN32) |
|
|
|
HANDLE hConsole = GetStdHandle(STD_OUTPUT_HANDLE); |
|
DWORD dwMode = 0; |
|
|
|
|
|
|
|
|
|
if (hConsole == INVALID_HANDLE_VALUE || !GetConsoleMode(hConsole, &dwMode)) { |
|
printf("%s", str); |
|
return; |
|
} |
|
|
|
|
|
|
|
if (*str == 0) { |
|
return; |
|
} |
|
int length_needed = MultiByteToWideChar(CP_UTF8, MB_ERR_INVALID_CHARS, str, strlen(str), NULL, 0); |
|
if (length_needed == 0) { |
|
DWORD err = GetLastError(); |
|
if (err == ERROR_NO_UNICODE_TRANSLATION) { |
|
invalid_utf8 = true; |
|
int len = strlen(str); |
|
printf("<"); |
|
for (int i = 0; i < len; ++i) { |
|
if (i > 0) { |
|
printf(" "); |
|
} |
|
printf("%02x", (uint8_t) str[i]); |
|
} |
|
printf(">"); |
|
return; |
|
} |
|
GGML_ABORT("MultiByteToWideChar() failed in an unexpected way."); |
|
} |
|
|
|
LPWSTR wstr = (LPWSTR) calloc(length_needed+1, sizeof(*wstr)); |
|
GGML_ASSERT(wstr); |
|
|
|
MultiByteToWideChar(CP_UTF8, 0, str, strlen(str), wstr, length_needed); |
|
WriteConsoleW(hConsole, wstr, length_needed, NULL, NULL); |
|
|
|
free(wstr); |
|
#else |
|
|
|
|
|
printf("%s", str); |
|
#endif |
|
} |
|
|
|
int main(int raw_argc, char ** raw_argv) { |
|
const std::vector<std::string> argv = ingest_args(raw_argc, raw_argv); |
|
const int argc = argv.size(); |
|
|
|
if (argc <= 1) { |
|
print_usage_information(argv[0].c_str()); |
|
return 1; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
bool printing_ids = false; |
|
bool no_bos = false; |
|
bool no_escape = false; |
|
bool no_parse_special = false; |
|
bool disable_logging = false; |
|
bool show_token_count = false; |
|
const char * model_path = NULL; |
|
const char * prompt_path = NULL; |
|
const char * prompt_arg = NULL; |
|
|
|
|
|
|
|
bool model_path_set = false; |
|
bool prompt_path_set = false; |
|
bool prompt_set = false; |
|
bool stdin_set = false; |
|
|
|
int iarg = 1; |
|
for (; iarg < argc; ++iarg) { |
|
std::string arg{argv[iarg]}; |
|
if (arg == "-h" || arg == "--help") { |
|
print_usage_information(argv[0].c_str()); |
|
return 0; |
|
} |
|
else if (arg == "--ids") { |
|
printing_ids = true; |
|
} |
|
else if (arg == "-m" || arg == "--model") { |
|
if (model_path_set) { |
|
fprintf(stderr, "Error: -m or --model specified multiple times.\n"); |
|
return 1; |
|
} |
|
model_path = argv[++iarg].c_str(); |
|
model_path_set = true; |
|
} |
|
else if (arg == "--no-bos") { |
|
no_bos = true; |
|
} |
|
else if (arg == "--no-escape") { |
|
no_escape = true; |
|
} |
|
else if (arg == "--no-parse-special") { |
|
no_parse_special = true; |
|
} |
|
else if (arg == "-p" || arg == "--prompt") { |
|
if (prompt_set) { |
|
fprintf(stderr, "Error: -p or --prompt specified multiple times.\n"); |
|
return 1; |
|
} |
|
prompt_arg = argv[++iarg].c_str(); |
|
prompt_set = true; |
|
} |
|
else if (arg == "-f" || arg == "--file") { |
|
if (prompt_path_set) { |
|
fprintf(stderr, "Error: -f or --file specified multiple times.\n"); |
|
return 1; |
|
} |
|
prompt_path = argv[++iarg].c_str(); |
|
prompt_path_set = true; |
|
} |
|
else if (arg == "--stdin") { |
|
stdin_set = true; |
|
} |
|
else if (arg == "--log-disable") { |
|
disable_logging = true; |
|
} |
|
else if (arg == "--show-count") { |
|
show_token_count = true; |
|
} |
|
else { |
|
fprintf(stderr, "Error: unknown option '%s'\n", argv[iarg].c_str()); |
|
return 1; |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
if (model_path_set && model_path == NULL) { |
|
fprintf(stderr, "Error: --model requires an argument.\n"); |
|
return 1; |
|
} |
|
if (!model_path_set) { |
|
fprintf(stderr, "Error: must specify --model.\n"); |
|
return 1; |
|
} |
|
if (prompt_path_set && prompt_path == NULL) { |
|
fprintf(stderr, "Error: --file requires an argument.\n"); |
|
return 1; |
|
} |
|
if (prompt_set && prompt_arg == NULL) { |
|
fprintf(stderr, "Error: --prompt requires an argument.\n"); |
|
return 1; |
|
} |
|
const int prompts_set = !!(prompt_path_set) + !!(prompt_set) + !!(stdin_set); |
|
if (prompts_set > 1) { |
|
fprintf(stderr, "Error: --stdin, --file and --prompt are mutually exclusive.\n"); |
|
return 1; |
|
} |
|
|
|
if (prompts_set == 0) { |
|
fprintf(stderr, "Error: must specify one of: --stdin, --file or --prompt.\n"); |
|
return 1; |
|
} |
|
|
|
GGML_ASSERT(model_path); |
|
GGML_ASSERT(prompt_path || prompt_arg || stdin_set); |
|
|
|
|
|
|
|
|
|
|
|
std::string prompt; |
|
if (prompt_path_set) { |
|
bool success = false; |
|
prompt = read_prompt_from_file(prompt_path, success); |
|
if (!success) { |
|
return 1; |
|
} |
|
} else if (prompt_set) { |
|
prompt = prompt_arg; |
|
} else { |
|
GGML_ASSERT(stdin_set); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if (disable_logging) { |
|
llama_log_set(llama_log_callback_null, NULL); |
|
} |
|
|
|
llama_backend_init(); |
|
|
|
llama_model_params model_params = llama_model_default_params(); |
|
model_params.vocab_only = true; |
|
llama_model * model = llama_model_load_from_file(model_path, model_params); |
|
if (!model) { |
|
fprintf(stderr, "Error: could not load model from file '%s'.\n", model_path); |
|
return 1; |
|
} |
|
|
|
const llama_vocab * vocab = llama_model_get_vocab(model); |
|
|
|
llama_context_params ctx_params = llama_context_default_params(); |
|
llama_context * ctx = llama_init_from_model(model, ctx_params); |
|
if (!ctx) { |
|
fprintf(stderr, "Error: could not create context.\n"); |
|
return 1; |
|
} |
|
|
|
|
|
if (stdin_set) { |
|
GGML_ASSERT(!prompt_path_set && !prompt_set); |
|
|
|
std::stringstream stdin_buffer; |
|
stdin_buffer << std::cin.rdbuf(); |
|
if (std::cin.fail()) { |
|
fprintf(stderr, "Error: could not read the entire standard input.\n"); |
|
return 1; |
|
} |
|
|
|
prompt = stdin_buffer.str(); |
|
} |
|
|
|
const bool model_wants_add_bos = llama_vocab_get_add_bos(vocab); |
|
const bool add_bos = model_wants_add_bos && !no_bos; |
|
const bool parse_special = !no_parse_special; |
|
const bool escape = !no_escape; |
|
|
|
if (escape) { |
|
string_process_escapes(prompt); |
|
} |
|
|
|
std::vector<llama_token> tokens; |
|
tokens = common_tokenize(vocab, prompt, add_bos, parse_special); |
|
|
|
if (printing_ids) { |
|
printf("["); |
|
} |
|
|
|
for (int i = 0; i < (int) tokens.size(); i++) { |
|
if (printing_ids) { |
|
if (i > 0) { |
|
printf(", "); |
|
} |
|
printf("%d", tokens[i]); |
|
} else { |
|
bool invalid_utf8 = false; |
|
printf("%6d -> '", tokens[i]); |
|
write_utf8_cstr_to_stdout(common_token_to_piece(ctx, tokens[i]).c_str(), invalid_utf8); |
|
if (invalid_utf8) { |
|
printf("' (utf-8 decode failure)\n"); |
|
} else { |
|
printf("'\n"); |
|
} |
|
} |
|
} |
|
|
|
if (printing_ids) { |
|
printf("]\n"); |
|
} |
|
|
|
if (show_token_count) { |
|
printf("Total number of tokens: %zu\n", tokens.size()); |
|
} |
|
|
|
llama_free(ctx); |
|
llama_model_free(model); |
|
|
|
return 0; |
|
} |
|
|