From 86f68a86e0bd3b7a8cf0662d07386caf9f4185d2 Mon Sep 17 00:00:00 2001 From: Michael Engel Date: Mon, 17 Feb 2025 09:10:50 +0100 Subject: [PATCH] Added --chat-template-file to llama-run Relates to: https://github.com/ggml-org/llama.cpp/issues/11178 Added --chat-template-file CLI option to llama-run. If specified, the file will be read and the content passed for overwriting the chat template of the model to common_chat_templates_from_model. Signed-off-by: Michael Engel --- examples/run/run.cpp | 75 +++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 70 insertions(+), 5 deletions(-) diff --git a/examples/run/run.cpp b/examples/run/run.cpp index ed8644ef78d97..900e159cbd7e7 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -113,6 +113,7 @@ class Opt { llama_context_params ctx_params; llama_model_params model_params; std::string model_; + std::string chat_template_file; std::string user; bool use_jinja = false; int context_size = -1, ngl = -1; @@ -148,6 +149,16 @@ class Opt { return 0; } + int handle_option_with_value(int argc, const char ** argv, int & i, std::string & option_value) { + if (i + 1 >= argc) { + return 1; + } + + option_value = argv[++i]; + + return 0; + } + int parse(int argc, const char ** argv) { bool options_parsing = true; for (int i = 1, positional_args_i = 0; i < argc; ++i) { @@ -169,6 +180,11 @@ class Opt { verbose = true; } else if (options_parsing && strcmp(argv[i], "--jinja") == 0) { use_jinja = true; + } else if (options_parsing && strcmp(argv[i], "--chat-template-file") == 0){ + if (handle_option_with_value(argc, argv, i, chat_template_file) == 1) { + return 1; + } + use_jinja = true; } else if (options_parsing && parse_flag(argv, i, "-h", "--help")) { help = true; return 0; @@ -207,6 +223,11 @@ class Opt { "Options:\n" " -c, --context-size \n" " Context size (default: %d)\n" + " --chat-template-file \n" + " Path to the file containing the chat template to use with the model.\n" + " Only supports jinja templates and implicitly sets the --jinja flag.\n" + " --jinja\n" + " Use jinja templating for the chat template of the model\n" " -n, -ngl, --ngl \n" " Number of GPU layers (default: %d)\n" " --temp \n" @@ -261,13 +282,16 @@ static int get_terminal_width() { #endif } -#ifdef LLAMA_USE_CURL class File { public: FILE * file = nullptr; FILE * open(const std::string & filename, const char * mode) { - file = fopen(filename.c_str(), mode); + FILE* file = ggml_fopen(filename.c_str(), mode); + if (!file) { + printe("Error opening file '%s': %s", filename.c_str(), strerror(errno)); + return NULL; + } return file; } @@ -303,6 +327,25 @@ class File { return 0; } + std::string read_all(const std::string & filename){ + file = open(filename, "r"); + if (!file) { + return ""; + } + + fseek(file, 0, SEEK_END); + size_t size = ftell(file); + fseek(file, 0, SEEK_SET); + + std::string out; + size_t read_size = fread(&out, 1, size, file); + if (read_size != size) { + printe("Error reading file '%s': %s", filename.c_str(), strerror(errno)); + return ""; + } + return out; + } + ~File() { if (fd >= 0) { # ifdef _WIN32 @@ -327,6 +370,7 @@ class File { # endif }; +#ifdef LLAMA_USE_CURL class HttpClient { public: int init(const std::string & url, const std::vector & headers, const std::string & output_file, @@ -1053,11 +1097,32 @@ static int get_user_input(std::string & user_input, const std::string & user) { return 0; } +// Reads a chat template file to be used +static std::string read_chat_template_file(const std::string & chat_template_file) { + if(chat_template_file.empty()){ + return ""; + } + + File file; + std::string chat_template = file.read_all(chat_template_file); + if(chat_template.empty()){ + printe("Error opening chat template file '%s': %s", chat_template_file.c_str(), strerror(errno)); + return ""; + } + return chat_template; +} + // Main chat loop function -static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_jinja) { +static int chat_loop(LlamaData & llama_data, const std::string & user, const std::string & chat_template_file, bool use_jinja) { int prev_len = 0; llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get())); - auto chat_templates = common_chat_templates_init(llama_data.model.get(), ""); + + std::string chat_template = ""; + if(!chat_template_file.empty()){ + chat_template = read_chat_template_file(chat_template_file); + } + auto chat_templates = common_chat_templates_init(llama_data.model.get(), chat_template); + static const bool stdout_a_terminal = is_stdout_a_terminal(); while (true) { // Get user input @@ -1143,7 +1208,7 @@ int main(int argc, const char ** argv) { return 1; } - if (chat_loop(llama_data, opt.user, opt.use_jinja)) { + if (chat_loop(llama_data, opt.user, opt.chat_template_file, opt.use_jinja)) { return 1; }