diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 377b26846b6..0182767c2b3 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -73,6 +73,8 @@ add_library(${TARGET} STATIC ngram-cache.h peg-parser.cpp peg-parser.h + preset.cpp + preset.h regex-partial.cpp regex-partial.h sampling.cpp diff --git a/common/arg.cpp b/common/arg.cpp index 210ef8d6214..b333f45c96a 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -47,6 +47,7 @@ #define LLAMA_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083 using json = nlohmann::ordered_json; +using namespace common_arg_utils; static std::initializer_list mmproj_examples = { LLAMA_EXAMPLE_MTMD, @@ -64,6 +65,15 @@ static std::string read_file(const std::string & fname) { return content; } +static const std::vector & get_common_arg_defs() { + static const std::vector options = [] { + common_params params; + auto ctx = common_params_parser_init(params, LLAMA_EXAMPLE_SERVER, nullptr); + return ctx.options; + }(); + return options; +} + common_arg & common_arg::set_examples(std::initializer_list examples) { this->examples = examples; return *this; @@ -134,7 +144,7 @@ static std::vector break_str_into_lines(std::string input, size_t m return result; } -std::string common_arg::to_string() { +std::string common_arg::to_string() const { // params for printing to console const static int n_leading_spaces = 40; const static int n_char_per_line_help = 70; // TODO: detect this based on current console @@ -647,6 +657,53 @@ static void add_rpc_devices(const std::string & servers) { } } +bool common_params_parse(int argc, char ** argv, llama_example ex, std::map & out_map) { + common_params dummy_params; + common_params_context ctx_arg = common_params_parser_init(dummy_params, ex, nullptr); + + std::unordered_map arg_to_options; + for (auto & opt : ctx_arg.options) { + for (const auto & arg : opt.args) { + arg_to_options[arg] = &opt; + } + } + + // TODO @ngxson : find a way to deduplicate this code + + // handle command line arguments + auto check_arg = [&](int i) { + if (i+1 >= argc) { + throw std::invalid_argument("expected value for argument"); + } + }; + + for (int i = 1; i < argc; i++) { + const std::string arg_prefix = "--"; + + std::string arg = argv[i]; + if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) { + std::replace(arg.begin(), arg.end(), '_', '-'); + } + if (arg_to_options.find(arg) == arg_to_options.end()) { + throw std::invalid_argument(string_format("error: invalid argument: %s", arg.c_str())); + } + auto opt = *arg_to_options[arg]; + std::string val; + if (opt.value_hint != nullptr) { + // arg with single value + check_arg(i); + val = argv[++i]; + } + if (opt.value_hint_2 != nullptr) { + // TODO: support arg with 2 values + throw std::invalid_argument("error: argument with 2 values is not yet supported\n"); + } + out_map[opt] = val; + } + + return true; +} + bool common_params_parse(int argc, char ** argv, common_params & params, llama_example ex, void(*print_usage)(int, char **)) { auto ctx_arg = common_params_parser_init(params, ex, print_usage); const common_params params_org = ctx_arg.params; // the example can modify the default params @@ -692,15 +749,15 @@ static std::string list_builtin_chat_templates() { return msg.str(); } -static bool is_truthy(const std::string & value) { +bool common_arg_utils::is_truthy(const std::string & value) { return value == "on" || value == "enabled" || value == "1"; } -static bool is_falsey(const std::string & value) { +bool common_arg_utils::is_falsey(const std::string & value) { return value == "off" || value == "disabled" || value == "0"; } -static bool is_autoy(const std::string & value) { +bool common_arg_utils::is_autoy(const std::string & value) { return value == "auto" || value == "-1"; } @@ -2543,6 +2600,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.models_dir = value; } ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MODELS_DIR")); + add_opt(common_arg( + {"--models-preset"}, "PATH", + "path to INI file containing model presets for the router server (default: disabled)", + [](common_params & params, const std::string & value) { + params.models_preset = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MODELS_PRESET")); add_opt(common_arg( {"--models-max"}, "N", string_format("for router server, maximum number of models to load simultaneously (default: %d, 0 = unlimited)", params.models_max), diff --git a/common/arg.h b/common/arg.h index 7ab7e2cea43..219c115e635 100644 --- a/common/arg.h +++ b/common/arg.h @@ -3,8 +3,10 @@ #include "common.h" #include +#include #include #include +#include // // CLI argument parsing @@ -24,6 +26,8 @@ struct common_arg { void (*handler_str_str)(common_params & params, const std::string &, const std::string &) = nullptr; void (*handler_int) (common_params & params, int) = nullptr; + common_arg() = default; + common_arg( const std::initializer_list & args, const char * value_hint, @@ -61,9 +65,29 @@ struct common_arg { bool is_exclude(enum llama_example ex); bool get_value_from_env(std::string & output) const; bool has_value_from_env() const; - std::string to_string(); + std::string to_string() const; + + // for using as key in std::map + bool operator<(const common_arg& other) const { + if (args.empty() || other.args.empty()) { + return false; + } + return strcmp(args[0], other.args[0]) < 0; + } + bool operator==(const common_arg& other) const { + if (args.empty() || other.args.empty()) { + return false; + } + return strcmp(args[0], other.args[0]) == 0; + } }; +namespace common_arg_utils { + bool is_truthy(const std::string & value); + bool is_falsey(const std::string & value); + bool is_autoy(const std::string & value); +} + struct common_params_context { enum llama_example ex = LLAMA_EXAMPLE_COMMON; common_params & params; @@ -76,7 +100,11 @@ struct common_params_context { // if one argument has invalid value, it will automatically display usage of the specific argument (and not the full usage message) bool common_params_parse(int argc, char ** argv, common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr); -// function to be used by test-arg-parser +// parse input arguments from CLI into a map +// TODO: support repeated args in the future +bool common_params_parse(int argc, char ** argv, llama_example ex, std::map & out_map); + +// initialize argument parser context - used by test-arg-parser and preset common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr); struct common_remote_params { diff --git a/common/common.h b/common/common.h index ad79f5b425c..6119adcc0f8 100644 --- a/common/common.h +++ b/common/common.h @@ -484,9 +484,10 @@ struct common_params { bool endpoint_metrics = false; // router server configs - std::string models_dir = ""; // directory containing models for the router server - int models_max = 4; // maximum number of models to load simultaneously - bool models_autoload = true; // automatically load models when requested via the router server + std::string models_dir = ""; // directory containing models for the router server + std::string models_preset = ""; // directory containing model presets for the router server + int models_max = 4; // maximum number of models to load simultaneously + bool models_autoload = true; // automatically load models when requested via the router server bool log_json = false; diff --git a/common/preset.cpp b/common/preset.cpp new file mode 100644 index 00000000000..09ac171b720 --- /dev/null +++ b/common/preset.cpp @@ -0,0 +1,180 @@ +#include "arg.h" +#include "preset.h" +#include "peg-parser.h" +#include "log.h" + +#include +#include +#include + +static std::string rm_leading_dashes(const std::string & str) { + size_t pos = 0; + while (pos < str.size() && str[pos] == '-') { + ++pos; + } + return str.substr(pos); +} + +std::vector common_preset::to_args() const { + std::vector args; + + for (const auto & [opt, value] : options) { + args.push_back(opt.args.back()); // use the last arg as the main arg + if (opt.value_hint == nullptr && opt.value_hint_2 == nullptr) { + // flag option, no value + if (common_arg_utils::is_falsey(value)) { + // skip the flag + args.pop_back(); + } + } + if (opt.value_hint != nullptr) { + // single value + args.push_back(value); + } + if (opt.value_hint != nullptr && opt.value_hint_2 != nullptr) { + throw std::runtime_error(string_format( + "common_preset::to_args(): option '%s' has two values, which is not supported yet", + opt.args.back() + )); + } + } + + return args; +} + +std::string common_preset::to_ini() const { + std::ostringstream ss; + + ss << "[" << name << "]\n"; + for (const auto & [opt, value] : options) { + auto espaced_value = value; + string_replace_all(espaced_value, "\n", "\\\n"); + ss << rm_leading_dashes(opt.args.back()) << " = "; + ss << espaced_value << "\n"; + } + ss << "\n"; + + return ss.str(); +} + +static std::map> parse_ini_from_file(const std::string & path) { + std::map> parsed; + + if (!std::filesystem::exists(path)) { + throw std::runtime_error("preset file does not exist: " + path); + } + + std::ifstream file(path); + if (!file.good()) { + throw std::runtime_error("failed to open server preset file: " + path); + } + + std::string contents((std::istreambuf_iterator(file)), std::istreambuf_iterator()); + + static const auto parser = build_peg_parser([](auto & p) { + // newline ::= "\r\n" / "\n" / "\r" + auto newline = p.rule("newline", p.literal("\r\n") | p.literal("\n") | p.literal("\r")); + + // ws ::= [ \t]* + auto ws = p.rule("ws", p.chars("[ \t]", 0, -1)); + + // comment ::= [;#] (!newline .)* + auto comment = p.rule("comment", p.chars("[;#]", 1, 1) + p.zero_or_more(p.negate(newline) + p.any())); + + // eol ::= ws comment? (newline / EOF) + auto eol = p.rule("eol", ws + p.optional(comment) + (newline | p.end())); + + // ident ::= [a-zA-Z_] [a-zA-Z0-9_.-]* + auto ident = p.rule("ident", p.chars("[a-zA-Z_]", 1, 1) + p.chars("[a-zA-Z0-9_.-]", 0, -1)); + + // value ::= (!eol-start .)* + auto eol_start = p.rule("eol-start", ws + (p.chars("[;#]", 1, 1) | newline | p.end())); + auto value = p.rule("value", p.zero_or_more(p.negate(eol_start) + p.any())); + + // header-line ::= "[" ws ident ws "]" eol + auto header_line = p.rule("header-line", "[" + ws + p.tag("section-name", p.chars("[^]]")) + ws + "]" + eol); + + // kv-line ::= ident ws "=" ws value eol + auto kv_line = p.rule("kv-line", p.tag("key", ident) + ws + "=" + ws + p.tag("value", value) + eol); + + // comment-line ::= ws comment (newline / EOF) + auto comment_line = p.rule("comment-line", ws + comment + (newline | p.end())); + + // blank-line ::= ws (newline / EOF) + auto blank_line = p.rule("blank-line", ws + (newline | p.end())); + + // line ::= header-line / kv-line / comment-line / blank-line + auto line = p.rule("line", header_line | kv_line | comment_line | blank_line); + + // ini ::= line* EOF + auto ini = p.rule("ini", p.zero_or_more(line) + p.end()); + + return ini; + }); + + common_peg_parse_context ctx(contents); + const auto result = parser.parse(ctx); + if (!result.success()) { + throw std::runtime_error("failed to parse server config file: " + path); + } + + std::string current_section = COMMON_PRESET_DEFAULT_NAME; + std::string current_key; + + ctx.ast.visit(result, [&](const auto & node) { + if (node.tag == "section-name") { + const std::string section = std::string(node.text); + current_section = section; + parsed[current_section] = {}; + } else if (node.tag == "key") { + const std::string key = std::string(node.text); + current_key = key; + } else if (node.tag == "value" && !current_key.empty() && !current_section.empty()) { + parsed[current_section][current_key] = std::string(node.text); + current_key.clear(); + } + }); + + return parsed; +} + +static std::map get_map_key_opt(common_params_context & ctx_params) { + std::map mapping; + for (const auto & opt : ctx_params.options) { + if (opt.env != nullptr) { + mapping[opt.env] = opt; + } + for (const auto & arg : opt.args) { + mapping[rm_leading_dashes(arg)] = opt; + } + } + return mapping; +} + +common_presets common_presets_load(const std::string & path, common_params_context & ctx_params) { + common_presets out; + auto key_to_opt = get_map_key_opt(ctx_params); + auto ini_data = parse_ini_from_file(path); + + for (auto section : ini_data) { + common_preset preset; + if (section.first.empty()) { + preset.name = COMMON_PRESET_DEFAULT_NAME; + } else { + preset.name = section.first; + } + LOG_DBG("loading preset: %s\n", preset.name.c_str()); + for (const auto & [key, value] : section.second) { + LOG_DBG("option: %s = %s\n", key.c_str(), value.c_str()); + if (key_to_opt.find(key) != key_to_opt.end()) { + preset.options[key_to_opt[key]] = value; + LOG_DBG("accepted option: %s = %s\n", key.c_str(), value.c_str()); + } else { + // TODO: maybe warn about unknown key? + } + } + out[preset.name] = preset; + } + + return out; +} diff --git a/common/preset.h b/common/preset.h new file mode 100644 index 00000000000..dceb849eb81 --- /dev/null +++ b/common/preset.h @@ -0,0 +1,32 @@ +#pragma once + +#include "common.h" +#include "arg.h" + +#include +#include +#include + +// +// INI preset parser and writer +// + +constexpr const char * COMMON_PRESET_DEFAULT_NAME = "default"; + +struct common_preset { + std::string name; + // TODO: support repeated args in the future + std::map options; + + // convert preset to CLI argument list + std::vector to_args() const; + + // convert preset to INI format string + std::string to_ini() const; + + // TODO: maybe implement to_env() if needed +}; + +// interface for multiple presets in one file +using common_presets = std::map; +common_presets common_presets_load(const std::string & path, common_params_context & ctx_params); diff --git a/tools/server/CMakeLists.txt b/tools/server/CMakeLists.txt index a39b4c5b35f..ae1a497be6d 100644 --- a/tools/server/CMakeLists.txt +++ b/tools/server/CMakeLists.txt @@ -38,6 +38,14 @@ set(TARGET_SRCS server-http.h server-models.cpp server-models.h + server-task.cpp + server-task.h + server-queue.cpp + server-queue.h + server-common.cpp + server-common.h + server-context.cpp + server-context.h ) set(PUBLIC_ASSETS index.html.gz diff --git a/tools/server/README.md b/tools/server/README.md index f98fb44c7bc..d6b9b87dcf7 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -1369,6 +1369,11 @@ llama-server ### Model sources +There are 3 possible sources for model files: +1. Cached models (controlled by the `LLAMA_CACHE` environment variable) +2. Custom model directory (set via the `--models-dir` argument) +3. Custom preset (set via the `--models-preset` argument) + By default, the router looks for models in the cache. You can add Hugging Face models to the cache with: ```sh @@ -1413,6 +1418,51 @@ llama-server -ctx 8192 -n 1024 -np 2 Note: model instances inherit both command line arguments and environment variables from the router server. +Alternatively, you can also add GGUF based preset (see next section) + +### Model presets + +Model presets allow advanced users to define custom configurations using an `.ini` file: + +```sh +llama-server --models-preset ./my-models.ini +``` + +Each section in the file defines a new preset. Keys within a section correspond to command-line arguments (without leading dashes). For example, the argument `--n-gpu-layer 123` is written as `n-gpu-layer = 123`. + +Short argument forms (e.g., `c`, `ngl`) and environment variable names (e.g., `LLAMA_ARG_N_GPU_LAYERS`) are also supported as keys. + +Example: + +```ini +version = 1 + +; If the key corresponds to an existing model on the server, +; this will be used as the default config for that model +[ggml-org/MY-MODEL-GGUF:Q8_0] +; string value +chat-template = chatml +; numeric value +n-gpu-layer = 123 +; flag value (for certain flags, you need to use the "no-" prefix for negation) +jinja = true +; shorthand argument (for example, context size) +c = 4096 +; environment variable name +LLAMA_ARG_CACHE_RAM = 0 +; file paths are relative to server's CWD +model-draft = ./my-models/draft.gguf +; but it's RECOMMENDED to use absolute path +model-draft = /Users/abc/my-models/draft.gguf + +; If the key does NOT correspond to an existing model, +; you need to specify at least the model path +[custom_model] +model = /Users/abc/my-awesome-model-Q4_K_M.gguf +``` + +Note: some arguments are controlled by router (e.g., host, port, API key, HF repo, model alias). They will be removed or overwritten upload loading. + ### Routing requests Requests are routed according to the requested model name. diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp index 6f88e93c4bb..6c618a673c9 100644 --- a/tools/server/server-models.cpp +++ b/tools/server/server-models.cpp @@ -1,6 +1,7 @@ #include "server-common.h" #include "server-models.h" +#include "preset.h" #include "download.h" #include // TODO: remove this once we use HTTP client from download.h @@ -33,6 +34,10 @@ #define CMD_EXIT "exit" +// address for child process, this is needed because router may run on 0.0.0.0 +// ref: https://github.com/ggml-org/llama.cpp/issues/17862 +#define CHILD_ADDR "127.0.0.1" + static std::filesystem::path get_server_exec_path() { #if defined(_WIN32) wchar_t buf[32768] = { 0 }; // Large buffer to handle long paths @@ -132,6 +137,93 @@ static std::vector list_local_models(const std::string & dir) { return models; } +// +// server_presets +// + + +server_presets::server_presets(int argc, char ** argv, common_params & base_params, const std::string & presets_path) + : ctx_params(common_params_parser_init(base_params, LLAMA_EXAMPLE_SERVER)) { + if (!presets_path.empty()) { + presets = common_presets_load(presets_path, ctx_params); + SRV_INF("Loaded %zu presets from %s\n", presets.size(), presets_path.c_str()); + } + + // populate reserved args (will be appended by the router) + for (auto & opt : ctx_params.options) { + if (opt.env == nullptr) { + continue; + } + std::string env = opt.env; + if (env == "LLAMA_ARG_PORT" || + env == "LLAMA_ARG_HOST" || + env == "LLAMA_ARG_ALIAS" || + env == "LLAMA_ARG_API_KEY" || + env == "LLAMA_ARG_MODELS_DIR" || + env == "LLAMA_ARG_MODELS_MAX" || + env == "LLAMA_ARG_MODELS_PRESET" || + env == "LLAMA_ARG_MODEL" || + env == "LLAMA_ARG_MMPROJ" || + env == "LLAMA_ARG_HF_REPO" || + env == "LLAMA_ARG_NO_MODELS_AUTOLOAD") { + control_args[env] = opt; + } + } + + // read base args from router's argv + common_params_parse(argc, argv, LLAMA_EXAMPLE_SERVER, base_args); + + // remove any router-controlled args from base_args + for (const auto & cargs : control_args) { + auto it = base_args.find(cargs.second); + if (it != base_args.end()) { + base_args.erase(it); + } + } +} + +common_preset server_presets::get_preset(const std::string & name) { + auto it = presets.find(name); + if (it != presets.end()) { + return it->second; + } + return common_preset(); +} + +void server_presets::render_args(server_model_meta & meta) { + common_preset preset = meta.preset; // copy + // merging 3 kinds of args: + // 1. model-specific args (from preset) + // force removing control args if any + for (auto & cargs : control_args) { + if (preset.options.find(cargs.second) != preset.options.end()) { + SRV_WRN("Preset '%s' contains reserved arg '%s', removing it\n", preset.name.c_str(), cargs.second.args[0]); + preset.options.erase(cargs.second); + } + } + // 2. base args (from router) + // inherit from base args + for (const auto & [arg, value] : base_args) { + preset.options[arg] = value; + } + // 3. control args (from router) + // set control values + preset.options[control_args["LLAMA_ARG_HOST"]] = CHILD_ADDR; + preset.options[control_args["LLAMA_ARG_PORT"]] = std::to_string(meta.port); + preset.options[control_args["LLAMA_ARG_ALIAS"]] = meta.name; + if (meta.in_cache) { + preset.options[control_args["LLAMA_ARG_HF_REPO"]] = meta.name; + } else { + preset.options[control_args["LLAMA_ARG_MODEL"]] = meta.path; + if (!meta.path_mmproj.empty()) { + preset.options[control_args["LLAMA_ARG_MMPROJ"]] = meta.path_mmproj; + } + } + meta.args = preset.to_args(); + // add back the binary path at the front + meta.args.insert(meta.args.begin(), get_server_exec_path().string()); +} + // // server_models // @@ -140,7 +232,7 @@ server_models::server_models( const common_params & params, int argc, char ** argv, - char ** envp) : base_params(params) { + char ** envp) : base_params(params), presets(argc, argv, base_params, params.models_preset) { for (int i = 0; i < argc; i++) { base_args.push_back(std::string(argv[i])); } @@ -155,11 +247,58 @@ server_models::server_models( LOG_WRN("failed to get server executable path: %s\n", e.what()); LOG_WRN("using original argv[0] as fallback: %s\n", base_args[0].c_str()); } - // TODO: allow refreshing cached model list - // add cached models + load_models(); +} + +void server_models::add_model(server_model_meta && meta) { + if (mapping.find(meta.name) != mapping.end()) { + throw std::runtime_error(string_format("model '%s' appears multiple times", meta.name.c_str())); + } + presets.render_args(meta); // populate meta.args + std::string name = meta.name; + mapping[name] = instance_t{ + /* subproc */ std::make_shared(), + /* th */ std::thread(), + /* meta */ std::move(meta) + }; +} + +static std::vector list_custom_path_models(server_presets & presets) { + // detect any custom-path models in presets + std::vector custom_models; + for (auto & [model_name, preset] : presets.presets) { + local_model model; + model.name = model_name; + std::vector to_erase; + for (auto & [arg, value] : preset.options) { + std::string env(arg.env ? arg.env : ""); + if (env == "LLAMA_ARG_MODEL") { + model.path = value; + to_erase.push_back(arg); + } + if (env == "LLAMA_ARG_MMPROJ") { + model.path_mmproj = value; + to_erase.push_back(arg); + } + } + for (auto & arg : to_erase) { + preset.options.erase(arg); + } + if (!model.name.empty() && !model.path.empty()) { + custom_models.push_back(model); + } + } + return custom_models; +} + +// TODO: allow refreshing cached model list +void server_models::load_models() { + // loading models from 3 sources: + // 1. cached models auto cached_models = common_list_cached_models(); for (const auto & model : cached_models) { server_model_meta meta{ + /* preset */ presets.get_preset(model.to_string()), /* name */ model.to_string(), /* path */ model.manifest_path, /* path_mmproj */ "", // auto-detected when loading @@ -170,21 +309,18 @@ server_models::server_models( /* args */ std::vector(), /* exit_code */ 0 }; - mapping[meta.name] = instance_t{ - /* subproc */ std::make_shared(), - /* th */ std::thread(), - /* meta */ meta - }; + add_model(std::move(meta)); } - // add local models specificed via --models-dir - if (!params.models_dir.empty()) { - auto local_models = list_local_models(params.models_dir); + // 2. local models specificed via --models-dir + if (!base_params.models_dir.empty()) { + auto local_models = list_local_models(base_params.models_dir); for (const auto & model : local_models) { if (mapping.find(model.name) != mapping.end()) { // already exists in cached models, skip continue; } server_model_meta meta{ + /* preset */ presets.get_preset(model.name), /* name */ model.name, /* path */ model.path, /* path_mmproj */ model.path_mmproj, @@ -195,13 +331,31 @@ server_models::server_models( /* args */ std::vector(), /* exit_code */ 0 }; - mapping[meta.name] = instance_t{ - /* subproc */ std::make_shared(), - /* th */ std::thread(), - /* meta */ meta - }; + add_model(std::move(meta)); } } + // 3. custom-path models specified in presets + auto custom_models = list_custom_path_models(presets); + for (const auto & model : custom_models) { + server_model_meta meta{ + /* preset */ presets.get_preset(model.name), + /* name */ model.name, + /* path */ model.path, + /* path_mmproj */ model.path_mmproj, + /* in_cache */ false, + /* port */ 0, + /* status */ SERVER_MODEL_STATUS_UNLOADED, + /* last_used */ 0, + /* args */ std::vector(), + /* exit_code */ 0 + }; + add_model(std::move(meta)); + } + // log available models + SRV_INF("Available models (%zu) (*: custom preset)\n", mapping.size()); + for (const auto & [name, inst] : mapping) { + SRV_INF(" %c %s\n", inst.meta.preset.name.empty() ? ' ' : '*', name.c_str()); + } } void server_models::update_meta(const std::string & name, const server_model_meta & meta) { @@ -335,19 +489,7 @@ void server_models::unload_lru() { } } -static void add_or_replace_arg(std::vector & args, const std::string & key, const std::string & value) { - for (size_t i = 0; i < args.size(); i++) { - if (args[i] == key && i + 1 < args.size()) { - args[i + 1] = value; - return; - } - } - // not found, append - args.push_back(key); - args.push_back(value); -} - -void server_models::load(const std::string & name, bool auto_load) { +void server_models::load(const std::string & name) { if (!has_model(name)) { throw std::runtime_error("model name=" + name + " is not found"); } @@ -376,26 +518,10 @@ void server_models::load(const std::string & name, bool auto_load) { { SRV_INF("spawning server instance with name=%s on port %d\n", inst.meta.name.c_str(), inst.meta.port); - std::vector child_args; - if (auto_load && !meta.args.empty()) { - child_args = meta.args; // copy previous args - } else { - child_args = base_args; // copy - if (inst.meta.in_cache) { - add_or_replace_arg(child_args, "-hf", inst.meta.name); - } else { - add_or_replace_arg(child_args, "-m", inst.meta.path); - if (!inst.meta.path_mmproj.empty()) { - add_or_replace_arg(child_args, "--mmproj", inst.meta.path_mmproj); - } - } - } - - // set model args - add_or_replace_arg(child_args, "--port", std::to_string(inst.meta.port)); - add_or_replace_arg(child_args, "--alias", inst.meta.name); + presets.render_args(inst.meta); // update meta.args - std::vector child_env = base_env; // copy + std::vector child_args = inst.meta.args; // copy + std::vector child_env = base_env; // copy child_env.push_back("LLAMA_SERVER_ROUTER_PORT=" + std::to_string(base_params.port)); SRV_INF("%s", "spawning server instance with args:\n"); @@ -541,7 +667,7 @@ bool server_models::ensure_model_loaded(const std::string & name) { } if (meta->status == SERVER_MODEL_STATUS_UNLOADED) { SRV_INF("model name=%s is not loaded, loading...\n", name.c_str()); - load(name, true); + load(name); } SRV_INF("waiting until model name=%s is fully loaded...\n", name.c_str()); @@ -571,7 +697,7 @@ server_http_res_ptr server_models::proxy_request(const server_http_req & req, co SRV_INF("proxying request to model %s on port %d\n", name.c_str(), meta->port); auto proxy = std::make_unique( method, - base_params.hostname, + CHILD_ADDR, meta->port, req.path, req.headers, @@ -724,38 +850,6 @@ void server_models_routes::init_routes() { return models.proxy_request(req, method, name, true); // update last usage for POST request only }; - this->get_router_models = [this](const server_http_req &) { - auto res = std::make_unique(); - json models_json = json::array(); - auto all_models = models.get_all_meta(); - std::time_t t = std::time(0); - for (const auto & meta : all_models) { - json status { - {"value", server_model_status_to_string(meta.status)}, - {"args", meta.args}, - }; - if (meta.is_failed()) { - status["exit_code"] = meta.exit_code; - status["failed"] = true; - } - models_json.push_back(json { - {"id", meta.name}, - {"object", "model"}, // for OAI-compat - {"owned_by", "llamacpp"}, // for OAI-compat - {"created", t}, // for OAI-compat - {"in_cache", meta.in_cache}, - {"path", meta.path}, - {"status", status}, - // TODO: add other fields, may require reading GGUF metadata - }); - } - res_ok(res, { - {"data", models_json}, - {"object", "list"}, - }); - return res; - }; - this->post_router_models_load = [this](const server_http_req & req) { auto res = std::make_unique(); json body = json::parse(req.body); @@ -769,7 +863,7 @@ void server_models_routes::init_routes() { res_err(res, format_error_response("model is already loaded", ERROR_TYPE_INVALID_REQUEST)); return res; } - models.load(name, false); + models.load(name); res_ok(res, {{"success", true}}); return res; }; @@ -793,9 +887,12 @@ void server_models_routes::init_routes() { std::time_t t = std::time(0); for (const auto & meta : all_models) { json status { - {"value", server_model_status_to_string(meta.status)}, - {"args", meta.args}, + {"value", server_model_status_to_string(meta.status)}, + {"args", meta.args}, }; + if (!meta.preset.name.empty()) { + status["preset"] = meta.preset.to_ini(); + } if (meta.is_failed()) { status["exit_code"] = meta.exit_code; status["failed"] = true; diff --git a/tools/server/server-models.h b/tools/server/server-models.h index 526e7488dc9..9cdbbad9b6a 100644 --- a/tools/server/server-models.h +++ b/tools/server/server-models.h @@ -1,6 +1,7 @@ #pragma once #include "common.h" +#include "preset.h" #include "server-http.h" #include @@ -47,6 +48,7 @@ static std::string server_model_status_to_string(server_model_status status) { } struct server_model_meta { + common_preset preset; std::string name; std::string path; std::string path_mmproj; // only available if in_cache=false @@ -54,7 +56,7 @@ struct server_model_meta { int port = 0; server_model_status status = SERVER_MODEL_STATUS_UNLOADED; int64_t last_used = 0; // for LRU unloading - std::vector args; // additional args passed to the model instance (used for debugging) + std::vector args; // args passed to the model instance, will be populated by render_args() int exit_code = 0; // exit code of the model instance process (only valid if status == FAILED) bool is_active() const { @@ -66,6 +68,19 @@ struct server_model_meta { } }; +// the server_presets struct holds the presets read from presets.ini +// as well as base args from the router server +struct server_presets { + common_presets presets; + common_params_context ctx_params; + std::map base_args; + std::map control_args; // args reserved for server control + + server_presets(int argc, char ** argv, common_params & base_params, const std::string & models_dir); + common_preset get_preset(const std::string & name); + void render_args(server_model_meta & meta); +}; + struct subprocess_s; struct server_models { @@ -85,14 +100,21 @@ struct server_models { std::vector base_args; std::vector base_env; + server_presets presets; + void update_meta(const std::string & name, const server_model_meta & meta); // unload least recently used models if the limit is reached void unload_lru(); + // not thread-safe, caller must hold mutex + void add_model(server_model_meta && meta); + public: server_models(const common_params & params, int argc, char ** argv, char ** envp); + void load_models(); + // check if a model instance exists bool has_model(const std::string & name); @@ -102,8 +124,7 @@ struct server_models { // return a copy of all model metadata std::vector get_all_meta(); - // if auto_load is true, load the model with previous args if any - void load(const std::string & name, bool auto_load); + void load(const std::string & name); void unload(const std::string & name); void unload_all();