diff --git a/common/arg.cpp b/common/arg.cpp index f6a775fc4a804..4e89a92976483 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1721,18 +1721,14 @@ static void add_rpc_devices(const std::string & servers) { if (!rpc_reg) { throw std::invalid_argument("failed to find RPC backend"); } - typedef ggml_backend_dev_t (*ggml_backend_rpc_add_device_t)(const char * endpoint); - ggml_backend_rpc_add_device_t ggml_backend_rpc_add_device_fn = (ggml_backend_rpc_add_device_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device"); - if (!ggml_backend_rpc_add_device_fn) { - throw std::invalid_argument("failed to find RPC device add function"); + typedef ggml_backend_reg_t (*ggml_backend_rpc_add_server_t)(const char * endpoint); + ggml_backend_rpc_add_server_t ggml_backend_rpc_add_server_fn = (ggml_backend_rpc_add_server_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_server"); + if (!ggml_backend_rpc_add_server_fn) { + throw std::invalid_argument("failed to find RPC add server function"); } for (const auto & server : rpc_servers) { - ggml_backend_dev_t dev = ggml_backend_rpc_add_device_fn(server.c_str()); - if (dev) { - ggml_backend_device_register(dev); - } else { - throw std::invalid_argument("failed to register RPC device"); - } + auto reg = ggml_backend_rpc_add_server_fn(server.c_str()); + ggml_backend_register(reg); } } diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index 62b6d65e51445..f1b740785914e 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -215,6 +215,8 @@ extern "C" { // Backend registry // + GGML_API void ggml_backend_register(ggml_backend_reg_t reg); + GGML_API void ggml_backend_device_register(ggml_backend_dev_t device); // Backend (reg) enumeration diff --git a/ggml/include/ggml-rpc.h b/ggml/include/ggml-rpc.h index 1e674112767c9..5bf0a2b95f5e1 100644 --- a/ggml/include/ggml-rpc.h +++ b/ggml/include/ggml-rpc.h @@ -7,26 +7,25 @@ extern "C" { #endif -#define RPC_PROTO_MAJOR_VERSION 2 +#define RPC_PROTO_MAJOR_VERSION 3 #define RPC_PROTO_MINOR_VERSION 0 #define RPC_PROTO_PATCH_VERSION 0 #define GGML_RPC_MAX_SERVERS 16 // backend API -GGML_BACKEND_API ggml_backend_t ggml_backend_rpc_init(const char * endpoint); +GGML_BACKEND_API ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device); GGML_BACKEND_API bool ggml_backend_is_rpc(ggml_backend_t backend); -GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint); +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, uint32_t device); -GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total); +GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total); -GGML_BACKEND_API void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, - const char * cache_dir, - size_t free_mem, size_t total_mem); +GGML_BACKEND_API void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir, + size_t n_devices, size_t n_threads, + ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem); GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void); - -GGML_BACKEND_API ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint); +GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint); #ifdef __cplusplus } diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h index 07784d6f66ce6..6792ba986e8ed 100644 --- a/ggml/src/ggml-backend-impl.h +++ b/ggml/src/ggml-backend-impl.h @@ -209,9 +209,6 @@ extern "C" { void * context; }; - // Internal backend registry API - GGML_API void ggml_backend_register(ggml_backend_reg_t reg); - // Add backend dynamic loading support to the backend // Initialize the backend diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index f99681c84cbab..14a6186a0afb0 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -90,7 +90,8 @@ static_assert(sizeof(rpc_tensor) % 8 == 0, "rpc_tensor size must be multiple of // RPC commands enum rpc_cmd { - RPC_CMD_ALLOC_BUFFER = 0, + RPC_CMD_DEVICE_COUNT = 0, + RPC_CMD_ALLOC_BUFFER, RPC_CMD_GET_ALIGNMENT, RPC_CMD_GET_MAX_SIZE, RPC_CMD_BUFFER_GET_BASE, @@ -117,7 +118,12 @@ struct rpc_msg_hello_rsp { uint8_t patch; }; +struct rpc_msg_device_count_rsp { + uint32_t device_count; +}; + struct rpc_msg_get_alloc_size_req { + uint32_t device; rpc_tensor tensor; }; @@ -130,6 +136,7 @@ struct rpc_msg_init_tensor_req { }; struct rpc_msg_alloc_buffer_req { + uint32_t device; uint64_t size; }; @@ -138,10 +145,18 @@ struct rpc_msg_alloc_buffer_rsp { uint64_t remote_size; }; +struct rpc_msg_get_alignment_req { + uint32_t device; +}; + struct rpc_msg_get_alignment_rsp { uint64_t alignment; }; +struct rpc_msg_get_max_size_req { + uint32_t device; +}; + struct rpc_msg_get_max_size_rsp { uint64_t max_size; }; @@ -192,6 +207,10 @@ struct rpc_msg_graph_compute_rsp { uint8_t result; }; +struct rpc_msg_get_device_memory_req { + uint32_t device; +}; + struct rpc_msg_get_device_memory_rsp { uint64_t free_mem; uint64_t total_mem; @@ -207,13 +226,15 @@ static ggml_guid_t ggml_backend_rpc_guid() { struct ggml_backend_rpc_buffer_type_context { std::string endpoint; + uint32_t device; std::string name; - size_t alignment; - size_t max_size; + size_t alignment; + size_t max_size; }; struct ggml_backend_rpc_context { std::string endpoint; + uint32_t device; std::string name; }; @@ -653,7 +674,7 @@ static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; - rpc_msg_alloc_buffer_req request = {size}; + rpc_msg_alloc_buffer_req request = {buft_ctx->device, size}; rpc_msg_alloc_buffer_rsp response; auto sock = get_socket(buft_ctx->endpoint); bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof(request), &response, sizeof(response)); @@ -669,9 +690,10 @@ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_back } } -static size_t get_alignment(const std::shared_ptr & sock) { +static size_t get_alignment(const std::shared_ptr & sock, uint32_t device) { + rpc_msg_get_alignment_req request = {device}; rpc_msg_get_alignment_rsp response; - bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, nullptr, 0, &response, sizeof(response)); + bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, &request, sizeof(request), &response, sizeof(response)); RPC_STATUS_ASSERT(status); return response.alignment; } @@ -681,9 +703,10 @@ static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_typ return buft_ctx->alignment; } -static size_t get_max_size(const std::shared_ptr & sock) { +static size_t get_max_size(const std::shared_ptr & sock, uint32_t device) { + rpc_msg_get_max_size_req request = {device}; rpc_msg_get_max_size_rsp response; - bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, nullptr, 0, &response, sizeof(response)); + bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, &request, sizeof(request), &response, sizeof(response)); RPC_STATUS_ASSERT(status); return response.max_size; } @@ -700,7 +723,7 @@ static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_ty auto sock = get_socket(buft_ctx->endpoint); rpc_msg_get_alloc_size_req request; - + request.device = buft_ctx->device; request.tensor = serialize_tensor(tensor); rpc_msg_get_alloc_size_rsp response; @@ -754,7 +777,7 @@ static void add_tensor(ggml_tensor * tensor, std::vector & tensors, tensors.push_back(serialize_tensor(tensor)); } -static void serialize_graph(const ggml_cgraph * cgraph, std::vector & output) { +static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::vector & output) { uint32_t n_nodes = cgraph->n_nodes; std::vector tensors; std::unordered_set visited; @@ -762,24 +785,29 @@ static void serialize_graph(const ggml_cgraph * cgraph, std::vector & o add_tensor(cgraph->nodes[i], tensors, visited); } // serialization format: - // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) | + // | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) | uint32_t n_tensors = tensors.size(); - int output_size = sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor); + int output_size = 2*sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor); output.resize(output_size, 0); - memcpy(output.data(), &n_nodes, sizeof(n_nodes)); + uint8_t * dest = output.data(); + memcpy(dest, &device, sizeof(device)); + dest += sizeof(device); + memcpy(dest, &n_nodes, sizeof(n_nodes)); + dest += sizeof(n_nodes); for (uint32_t i = 0; i < n_nodes; i++) { - memcpy(output.data() + sizeof(n_nodes) + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t)); + memcpy(dest + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t)); } - uint32_t * out_ntensors = (uint32_t *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t)); - *out_ntensors = n_tensors; - rpc_tensor * out_tensors = (rpc_tensor *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t)); + dest += n_nodes * sizeof(uint64_t); + memcpy(dest, &n_tensors, sizeof(n_tensors)); + dest += sizeof(n_tensors); + rpc_tensor * out_tensors = (rpc_tensor *)dest; memcpy(out_tensors, tensors.data(), n_tensors * sizeof(rpc_tensor)); } static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; std::vector input; - serialize_graph(cgraph, input); + serialize_graph(rpc_ctx->device, cgraph, input); rpc_msg_graph_compute_rsp response; auto sock = get_socket(rpc_ctx->endpoint); bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response)); @@ -804,12 +832,13 @@ static ggml_backend_i ggml_backend_rpc_interface = { /* .graph_optimize = */ NULL, }; -ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) { +ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, uint32_t device) { static std::mutex mutex; std::lock_guard lock(mutex); + std::string dev_name = "RPC" + std::to_string(device) + "[" + std::string(endpoint) + "]"; // NOTE: buffer types are allocated and never freed; this is by design static std::unordered_map buft_map; - auto it = buft_map.find(endpoint); + auto it = buft_map.find(dev_name); if (it != buft_map.end()) { return it->second; } @@ -818,34 +847,37 @@ ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) { GGML_LOG_ERROR("Failed to connect to %s\n", endpoint); return nullptr; } - size_t alignment = get_alignment(sock); - size_t max_size = get_max_size(sock); + size_t alignment = get_alignment(sock, device); + size_t max_size = get_max_size(sock, device); ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context { /* .endpoint = */ endpoint, - /* .name = */ "RPC[" + std::string(endpoint) + "]", + /* .device = */ device, + /* .name = */ dev_name, /* .alignment = */ alignment, /* .max_size = */ max_size }; - + auto reg = ggml_backend_rpc_add_server(endpoint); ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type { /* .iface = */ ggml_backend_rpc_buffer_type_interface, - /* .device = */ ggml_backend_rpc_add_device(endpoint), + /* .device = */ ggml_backend_reg_dev_get(reg, device), /* .context = */ buft_ctx }; - buft_map[endpoint] = buft; + buft_map[dev_name] = buft; return buft; } -ggml_backend_t ggml_backend_rpc_init(const char * endpoint) { +ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) { + std::string dev_name = "RPC" + std::to_string(device) + "[" + std::string(endpoint) + "]"; ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context { - /* .endpoint = */ endpoint, - /* .name = */ "RPC[" + std::string(endpoint) + "]", + /* .endpoint = */ endpoint, + /* .device = */ device, + /* .name = */ dev_name }; - + auto reg = ggml_backend_rpc_add_server(endpoint); ggml_backend_t backend = new ggml_backend { /* .guid = */ ggml_backend_rpc_guid(), /* .iface = */ ggml_backend_rpc_interface, - /* .device = */ ggml_backend_rpc_add_device(endpoint), + /* .device = */ ggml_backend_reg_dev_get(reg, device), /* .context = */ ctx }; return backend; @@ -855,37 +887,39 @@ bool ggml_backend_is_rpc(ggml_backend_t backend) { return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_rpc_guid()); } -static void get_device_memory(const std::shared_ptr & sock, size_t * free, size_t * total) { +static void get_device_memory(const std::shared_ptr & sock, uint32_t device, size_t * free, size_t * total) { + rpc_msg_get_device_memory_req request; + request.device = device; rpc_msg_get_device_memory_rsp response; - bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, nullptr, 0, &response, sizeof(response)); + bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, &request, sizeof(request), &response, sizeof(response)); RPC_STATUS_ASSERT(status); *free = response.free_mem; *total = response.total_mem; } -void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) { +void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total) { auto sock = get_socket(endpoint); if (sock == nullptr) { *free = 0; *total = 0; return; } - get_device_memory(sock, free, total); + get_device_memory(sock, device, free, total); } // RPC server-side implementation class rpc_server { public: - rpc_server(ggml_backend_t backend, const char * cache_dir) - : backend(backend), cache_dir(cache_dir) { + rpc_server(std::vector backends, const char * cache_dir) + : backends(std::move(backends)), cache_dir(cache_dir) { } ~rpc_server(); void hello(rpc_msg_hello_rsp & response); - void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response); - void get_alignment(rpc_msg_get_alignment_rsp & response); - void get_max_size(rpc_msg_get_max_size_rsp & response); + bool alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response); + bool get_alignment(const rpc_msg_get_alignment_req & request, rpc_msg_get_alignment_rsp & response); + bool get_max_size(const rpc_msg_get_max_size_req & request, rpc_msg_get_max_size_rsp & response); bool buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response); bool free_buffer(const rpc_msg_free_buffer_req & request); bool buffer_clear(const rpc_msg_buffer_clear_req & request); @@ -906,7 +940,7 @@ class rpc_server { std::unordered_map & tensor_map); - ggml_backend_t backend; + std::vector backends; const char * cache_dir; std::unordered_set buffers; }; @@ -919,6 +953,10 @@ void rpc_server::hello(rpc_msg_hello_rsp & response) { } bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) { + uint32_t dev_id = request.device; + if (dev_id >= backends.size()) { + return false; + } ggml_backend_buffer_type_t buft; struct ggml_init_params params { /*.mem_size =*/ ggml_tensor_overhead(), @@ -935,10 +973,10 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_ GGML_LOG_ERROR("Null tensor pointer passed to server get_alloc_size function.\n"); return false; } - LOG_DBG("[%s] buffer: %p, data: %p\n", __func__, (void*)tensor->buffer, tensor->data); + LOG_DBG("[%s] device: %d, buffer: %p, data: %p\n", __func__, dev_id, (void*)tensor->buffer, tensor->data); if (tensor->buffer == nullptr) { //No buffer allocated. - buft = ggml_backend_get_default_buffer_type(backend); + buft = ggml_backend_get_default_buffer_type(backends[dev_id]); } else { buft = tensor->buffer->buft; } @@ -948,33 +986,49 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_ return true; } -void rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) { - ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend); +bool rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) { + uint32_t dev_id = request.device; + if (dev_id >= backends.size()) { + return false; + } + ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]); ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, request.size); response.remote_ptr = 0; response.remote_size = 0; if (buffer != nullptr) { response.remote_ptr = reinterpret_cast(buffer); response.remote_size = buffer->size; - LOG_DBG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, request.size, response.remote_ptr, response.remote_size); + LOG_DBG("[%s] device: %d, size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", + __func__, dev_id, request.size, response.remote_ptr, response.remote_size); buffers.insert(buffer); } else { - LOG_DBG("[%s] size: %" PRIu64 " -> failed\n", __func__, request.size); + LOG_DBG("[%s] device: %d, size: %" PRIu64 " -> failed\n", __func__, dev_id, request.size); } + return true; } -void rpc_server::get_alignment(rpc_msg_get_alignment_rsp & response) { - ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend); +bool rpc_server::get_alignment(const rpc_msg_get_alignment_req & request, rpc_msg_get_alignment_rsp & response) { + uint32_t dev_id = request.device; + if (dev_id >= backends.size()) { + return false; + } + ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]); size_t alignment = ggml_backend_buft_get_alignment(buft); - LOG_DBG("[%s] alignment: %lu\n", __func__, alignment); + LOG_DBG("[%s] device: %d, alignment: %lu\n", __func__, dev_id, alignment); response.alignment = alignment; + return true; } -void rpc_server::get_max_size(rpc_msg_get_max_size_rsp & response) { - ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend); +bool rpc_server::get_max_size(const rpc_msg_get_max_size_req & request, rpc_msg_get_max_size_rsp & response) { + uint32_t dev_id = request.device; + if (dev_id >= backends.size()) { + return false; + } + ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]); size_t max_size = ggml_backend_buft_get_max_size(buft); - LOG_DBG("[%s] max_size: %lu\n", __func__, max_size); + LOG_DBG("[%s] device: %d, max_size: %lu\n", __func__, dev_id, max_size); response.max_size = max_size; + return true; } bool rpc_server::buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response) { @@ -1332,23 +1386,33 @@ ggml_tensor * rpc_server::create_node(uint64_t id, bool rpc_server::graph_compute(const std::vector & input, rpc_msg_graph_compute_rsp & response) { // serialization format: - // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) | - if (input.size() < sizeof(uint32_t)) { + // | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) | + if (input.size() < 2*sizeof(uint32_t)) { + return false; + } + const uint8_t * src = input.data(); + uint32_t device; + memcpy(&device, src, sizeof(device)); + src += sizeof(device); + if (device >= backends.size()) { return false; } uint32_t n_nodes; - memcpy(&n_nodes, input.data(), sizeof(n_nodes)); - if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t)) { + memcpy(&n_nodes, src, sizeof(n_nodes)); + src += sizeof(n_nodes); + if (input.size() < 2*sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t)) { return false; } - const uint64_t * nodes = (const uint64_t *)(input.data() + sizeof(n_nodes)); + const uint64_t * nodes = (const uint64_t *)src; + src += n_nodes*sizeof(uint64_t); uint32_t n_tensors; - memcpy(&n_tensors, input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t), sizeof(n_tensors)); - if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t) + n_tensors*sizeof(rpc_tensor)) { + memcpy(&n_tensors, src, sizeof(n_tensors)); + src += sizeof(n_tensors); + if (input.size() < 2*sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t) + n_tensors*sizeof(rpc_tensor)) { return false; } - const rpc_tensor * tensors = (const rpc_tensor *)(input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t) + sizeof(n_tensors)); - LOG_DBG("[%s] n_nodes: %u, n_tensors: %u\n", __func__, n_nodes, n_tensors); + const rpc_tensor * tensors = (const rpc_tensor *)src; + LOG_DBG("[%s] device: %u, n_nodes: %u, n_tensors: %u\n", __func__, device, n_nodes, n_tensors); size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false); @@ -1380,7 +1444,7 @@ bool rpc_server::graph_compute(const std::vector & input, rpc_msg_graph return false; } } - ggml_status status = ggml_backend_graph_compute(backend, graph); + ggml_status status = ggml_backend_graph_compute(backends[device], graph); response.result = status; return true; } @@ -1391,9 +1455,9 @@ rpc_server::~rpc_server() { } } -static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir, - sockfd_t sockfd, size_t free_mem, size_t total_mem) { - rpc_server server(backend, cache_dir); +static void rpc_serve_client(const std::vector & backends, const char * cache_dir, + sockfd_t sockfd, const std::vector & free_mem, const std::vector & total_mem) { + rpc_server server(backends, cache_dir); uint8_t cmd; if (!recv_data(sockfd, &cmd, 1)) { return; @@ -1425,13 +1489,26 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir, // HELLO command is handled above return; } + case RPC_CMD_DEVICE_COUNT: { + if (!recv_msg(sockfd, nullptr, 0)) { + return; + } + rpc_msg_device_count_rsp response; + response.device_count = backends.size(); + if (!send_msg(sockfd, &response, sizeof(response))) { + return; + } + break; + } case RPC_CMD_ALLOC_BUFFER: { rpc_msg_alloc_buffer_req request; if (!recv_msg(sockfd, &request, sizeof(request))) { return; } rpc_msg_alloc_buffer_rsp response; - server.alloc_buffer(request, response); + if (!server.alloc_buffer(request, response)) { + return; + } if (!send_msg(sockfd, &response, sizeof(response))) { return; } @@ -1452,22 +1529,28 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir, break; } case RPC_CMD_GET_ALIGNMENT: { - if (!recv_msg(sockfd, nullptr, 0)) { + rpc_msg_get_alignment_req request; + if (!recv_msg(sockfd, &request, sizeof(request))) { return; } rpc_msg_get_alignment_rsp response; - server.get_alignment(response); + if (!server.get_alignment(request, response)) { + return; + } if (!send_msg(sockfd, &response, sizeof(response))) { return; } break; } case RPC_CMD_GET_MAX_SIZE: { - if (!recv_msg(sockfd, nullptr, 0)) { + rpc_msg_get_max_size_req request; + if (!recv_msg(sockfd, &request, sizeof(request))) { return; } rpc_msg_get_max_size_rsp response; - server.get_max_size(response); + if (!server.get_max_size(request, response)) { + return; + } if (!send_msg(sockfd, &response, sizeof(response))) { return; } @@ -1593,12 +1676,19 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir, break; } case RPC_CMD_GET_DEVICE_MEMORY: { - if (!recv_msg(sockfd, nullptr, 0)) { + rpc_msg_get_device_memory_req request; + if (!recv_msg(sockfd, &request, sizeof(request))) { + return; + } + auto dev_id = request.device; + if (dev_id >= backends.size()) { return; } rpc_msg_get_device_memory_rsp response; - response.free_mem = free_mem; - response.total_mem = total_mem; + response.free_mem = free_mem[dev_id]; + response.total_mem = total_mem[dev_id]; + LOG_DBG("[get_device_mem] device: %u, free_mem: %" PRIu64 ", total_mem: %" PRIu64 "\n", dev_id, + response.free_mem, response.total_mem); if (!send_msg(sockfd, &response, sizeof(response))) { return; } @@ -1612,16 +1702,41 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir, } } -void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, - const char * cache_dir, - size_t free_mem, size_t total_mem) { +void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir, + size_t n_devices, size_t n_threads, + ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem) { + if (n_devices == 0 || devices == nullptr || free_mem == nullptr || total_mem == nullptr) { + fprintf(stderr, "Invalid arguments to ggml_backend_rpc_start_server\n"); + return; + } + std::vector backends; + std::vector free_mem_vec(free_mem, free_mem + n_devices); + std::vector total_mem_vec(total_mem, total_mem + n_devices); printf("Starting RPC server v%d.%d.%d\n", RPC_PROTO_MAJOR_VERSION, RPC_PROTO_MINOR_VERSION, RPC_PROTO_PATCH_VERSION); printf(" endpoint : %s\n", endpoint); printf(" local cache : %s\n", cache_dir ? cache_dir : "n/a"); - printf(" backend memory : %zu MB\n", free_mem / (1024 * 1024)); + printf("Devices:\n"); + for (size_t i = 0; i < n_devices; i++) { + auto dev = devices[i]; + printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), + total_mem[i] / 1024 / 1024, free_mem[i] / 1024 / 1024); + auto backend = ggml_backend_dev_init(dev, nullptr); + if (!backend) { + fprintf(stderr, "Failed to create backend for device %s\n", dev->iface.get_name(dev)); + return; + } + backends.push_back(backend); + ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr; + if (reg) { + auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads"); + if (ggml_backend_set_n_threads_fn) { + ggml_backend_set_n_threads_fn(backend, n_threads); + } + } + } std::string host; int port; @@ -1649,21 +1764,25 @@ void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint fprintf(stderr, "Failed to accept client connection\n"); return; } - printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem); + printf("Accepted client connection\n"); fflush(stdout); - rpc_serve_client(backend, cache_dir, client_socket->fd, free_mem, total_mem); + rpc_serve_client(backends, cache_dir, client_socket->fd, free_mem_vec, total_mem_vec); printf("Client connection closed\n"); fflush(stdout); } #ifdef _WIN32 WSACleanup(); #endif + for (auto backend : backends) { + ggml_backend_free(backend); + } } // device interface struct ggml_backend_rpc_device_context { std::string endpoint; + uint32_t device; std::string name; }; @@ -1682,9 +1801,7 @@ static const char * ggml_backend_rpc_device_get_description(ggml_backend_dev_t d static void ggml_backend_rpc_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context; - ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), free, total); - - GGML_UNUSED(dev); + ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), ctx->device, free, total); } static enum ggml_backend_dev_type ggml_backend_rpc_device_get_type(ggml_backend_dev_t dev) { @@ -1710,7 +1827,7 @@ static void ggml_backend_rpc_device_get_props(ggml_backend_dev_t dev, struct ggm static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const char * params) { ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context; - return ggml_backend_rpc_init(ctx->endpoint.c_str()); + return ggml_backend_rpc_init(ctx->endpoint.c_str(), ctx->device); GGML_UNUSED(params); } @@ -1718,7 +1835,7 @@ static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_backend_dev_t dev) { ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context; - return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str()); + return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str(), ctx->device); GGML_UNUSED(dev); } @@ -1736,7 +1853,7 @@ static bool ggml_backend_rpc_device_supports_buft(ggml_backend_dev_t dev, ggml_b } ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; ggml_backend_rpc_device_context * dev_ctx = (ggml_backend_rpc_device_context *)dev->context; - return buft_ctx->endpoint == dev_ctx->endpoint; + return buft_ctx->endpoint == dev_ctx->endpoint && buft_ctx->device == dev_ctx->device; } static const struct ggml_backend_device_i ggml_backend_rpc_device_i = { @@ -1759,28 +1876,34 @@ static const struct ggml_backend_device_i ggml_backend_rpc_device_i = { // backend reg interface -static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) { - return "RPC"; +struct ggml_backend_rpc_reg_context { + std::string name; + std::vector devices; +}; - GGML_UNUSED(reg); +static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) { + ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context; + return ctx ? ctx->name.c_str() : "RPC"; } static size_t ggml_backend_rpc_reg_get_device_count(ggml_backend_reg_t reg) { - return 0; - - GGML_UNUSED(reg); + ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context; + return ctx ? ctx->devices.size() : 0; } static ggml_backend_dev_t ggml_backend_rpc_reg_get_device(ggml_backend_reg_t reg, size_t index) { - GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_add_device instead"); - - GGML_UNUSED(reg); - GGML_UNUSED(index); + ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context; + if (ctx == nullptr) { + GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_rpc_add_server instead"); + } else { + GGML_ASSERT(index < ctx->devices.size()); + return ctx->devices[index]; + } } static void * ggml_backend_rpc_get_proc_address(ggml_backend_reg_t reg, const char * name) { - if (std::strcmp(name, "ggml_backend_rpc_add_device") == 0) { - return (void *)ggml_backend_rpc_add_device; + if (std::strcmp(name, "ggml_backend_rpc_add_server") == 0) { + return (void *)ggml_backend_rpc_add_server; } if (std::strcmp(name, "ggml_backend_rpc_start_server") == 0) { return (void *)ggml_backend_rpc_start_server; @@ -1807,30 +1930,57 @@ ggml_backend_reg_t ggml_backend_rpc_reg(void) { return &ggml_backend_rpc_reg; } -ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint) { - static std::unordered_map dev_map; +static uint32_t ggml_backend_rpc_get_device_count(const char * endpoint) { + auto sock = get_socket(endpoint); + rpc_msg_device_count_rsp response; + bool status = send_rpc_cmd(sock, RPC_CMD_DEVICE_COUNT, nullptr, 0, &response, sizeof(response)); + RPC_STATUS_ASSERT(status); + return response.device_count; +} +static const ggml_backend_reg_i ggml_backend_rpc_reg_interface = { + /* .get_name = */ ggml_backend_rpc_reg_get_name, + /* .get_device_count = */ ggml_backend_rpc_reg_get_device_count, + /* .get_device = */ ggml_backend_rpc_reg_get_device, + /* .get_proc_address = */ ggml_backend_rpc_get_proc_address, +}; + +ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint) { + static std::unordered_map reg_map; static std::mutex mutex; std::lock_guard lock(mutex); - - if (dev_map.find(endpoint) != dev_map.end()) { - return dev_map[endpoint]; + if (reg_map.find(endpoint) != reg_map.end()) { + return reg_map[endpoint]; } - - ggml_backend_rpc_device_context * ctx = new ggml_backend_rpc_device_context { - /* .endpoint = */ endpoint, - /* .name = */ "RPC[" + std::string(endpoint) + "]", - }; - - ggml_backend_dev_t dev = new ggml_backend_device { - /* .iface = */ ggml_backend_rpc_device_i, - /* .reg = */ ggml_backend_rpc_reg(), - /* .context = */ ctx, + uint32_t dev_count = ggml_backend_rpc_get_device_count(endpoint); + if (dev_count == 0) { + return nullptr; + } + ggml_backend_rpc_reg_context * ctx = new ggml_backend_rpc_reg_context; + ctx->name = "RPC[" + std::string(endpoint) + "]"; + for (uint32_t dev_id = 0; dev_id < dev_count; dev_id++) { + std::string dev_name = "RPC" + std::to_string(dev_id) + "[" + std::string(endpoint) + "]"; + ggml_backend_rpc_device_context * dev_ctx = new ggml_backend_rpc_device_context { + /* .endpoint = */ endpoint, + /* .device = */ dev_id, + /* .name = */ dev_name, + }; + + ggml_backend_dev_t dev = new ggml_backend_device { + /* .iface = */ ggml_backend_rpc_device_i, + /* .reg = */ ggml_backend_rpc_reg(), + /* .context = */ dev_ctx, + }; + ctx->devices.push_back(dev); + } + ggml_backend_reg_t reg = new ggml_backend_reg { + /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_rpc_reg_interface, + /* .context = */ ctx }; - - dev_map[endpoint] = dev; - - return dev; + reg_map[endpoint] = reg; + return reg; } + GGML_BACKEND_DL_IMPL(ggml_backend_rpc_reg) diff --git a/tools/llama-bench/llama-bench.cpp b/tools/llama-bench/llama-bench.cpp index 275ba367c02f1..f1a4ed7f9cf56 100644 --- a/tools/llama-bench/llama-bench.cpp +++ b/tools/llama-bench/llama-bench.cpp @@ -168,7 +168,7 @@ static std::vector parse_devices_arg(const std::string & val return devices; } -static std::vector register_rpc_device_list(const std::string & servers) { +static void register_rpc_server_list(const std::string & servers) { auto rpc_servers = string_split(servers, ','); if (rpc_servers.empty()) { throw std::invalid_argument("no RPC servers specified"); @@ -179,36 +179,15 @@ static std::vector register_rpc_device_list(const std::strin throw std::invalid_argument("failed to find RPC backend"); } - using add_rpc_device_fn = ggml_backend_dev_t (*)(const char * endpoint); - auto * ggml_backend_rpc_add_device_fn = (add_rpc_device_fn) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device"); - if (!ggml_backend_rpc_add_device_fn) { - throw std::invalid_argument("failed to find RPC device add function"); + using add_rpc_server_fn = ggml_backend_reg_t (*)(const char * endpoint); + auto * ggml_backend_rpc_add_server_fn = (add_rpc_server_fn) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_server"); + if (!ggml_backend_rpc_add_server_fn) { + throw std::invalid_argument("failed to find RPC add server function"); } - - static std::unordered_set registered; - std::vector devices; for (const auto & server : rpc_servers) { - ggml_backend_dev_t dev = nullptr; - - std::string name = string_format("RPC[%s]", server.c_str()); - - if (registered.find(server) != registered.end()) { - dev = ggml_backend_dev_by_name(name.c_str()); - } - - if (!dev) { - dev = ggml_backend_rpc_add_device_fn(server.c_str()); - if (!dev) { - throw std::invalid_argument(string_format("failed to add RPC device for server '%s'", server.c_str())); - } - ggml_backend_device_register(dev); - registered.insert(server); - } - - devices.push_back(dev); + auto reg = ggml_backend_rpc_add_server_fn(server.c_str()); + ggml_backend_register(reg); } - - return devices; } static std::string devices_to_string(const std::vector & devices) { @@ -714,7 +693,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { break; } try { - register_rpc_device_list(argv[i]); + register_rpc_server_list(argv[i]); } catch (const std::exception & e) { fprintf(stderr, "error: %s\n", e.what()); invalid_param = true; @@ -1368,12 +1347,19 @@ struct test { static std::string get_backend() { std::vector backends; + bool rpc_used = false; for (size_t i = 0; i < ggml_backend_reg_count(); i++) { auto * reg = ggml_backend_reg_get(i); std::string name = ggml_backend_reg_name(reg); - if (name != "CPU") { + if (name != "CPU" && !string_starts_with(name, "RPC")) { backends.push_back(ggml_backend_reg_name(reg)); } + if (string_starts_with(name, "RPC[")) { + rpc_used = true; + } + } + if (rpc_used) { + backends.push_back("RPC"); } return backends.empty() ? "CPU" : join(backends, ","); } diff --git a/tools/rpc/rpc-server.cpp b/tools/rpc/rpc-server.cpp index dc8e077f34a73..dc2c2c1548589 100644 --- a/tools/rpc/rpc-server.cpp +++ b/tools/rpc/rpc-server.cpp @@ -22,6 +22,7 @@ #include #include #include +#include namespace fs = std::filesystem; @@ -131,24 +132,24 @@ static std::string fs_get_cache_directory() { } struct rpc_server_params { - std::string host = "127.0.0.1"; - int port = 50052; - size_t backend_mem = 0; - bool use_cache = false; - int n_threads = std::max(1U, std::thread::hardware_concurrency()/2); - std::string device; + std::string host = "127.0.0.1"; + int port = 50052; + bool use_cache = false; + int n_threads = std::max(1U, std::thread::hardware_concurrency()/2); + std::vector devices; + std::vector dev_mem; }; static void print_usage(int /*argc*/, char ** argv, rpc_server_params params) { fprintf(stderr, "Usage: %s [options]\n\n", argv[0]); fprintf(stderr, "options:\n"); - fprintf(stderr, " -h, --help show this help message and exit\n"); - fprintf(stderr, " -t, --threads number of threads for the CPU backend (default: %d)\n", params.n_threads); - fprintf(stderr, " -d DEV, --device device to use\n"); - fprintf(stderr, " -H HOST, --host HOST host to bind to (default: %s)\n", params.host.c_str()); - fprintf(stderr, " -p PORT, --port PORT port to bind to (default: %d)\n", params.port); - fprintf(stderr, " -m MEM, --mem MEM backend memory size (in MB)\n"); - fprintf(stderr, " -c, --cache enable local file cache\n"); + fprintf(stderr, " -h, --help show this help message and exit\n"); + fprintf(stderr, " -t, --threads N number of threads for the CPU device (default: %d)\n", params.n_threads); + fprintf(stderr, " -d, --device comma-separated list of devices\n"); + fprintf(stderr, " -H, --host HOST host to bind to (default: %s)\n", params.host.c_str()); + fprintf(stderr, " -p, --port PORT port to bind to (default: %d)\n", params.port); + fprintf(stderr, " -m, --mem memory size for each device (in MB)\n"); + fprintf(stderr, " -c, --cache enable local file cache\n"); fprintf(stderr, "\n"); } @@ -174,17 +175,17 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params & if (++i >= argc) { return false; } - params.device = argv[i]; - if (ggml_backend_dev_by_name(params.device.c_str()) == nullptr) { - fprintf(stderr, "error: unknown device: %s\n", params.device.c_str()); - fprintf(stderr, "available devices:\n"); - for (size_t i = 0; i < ggml_backend_dev_count(); i++) { - auto * dev = ggml_backend_dev_get(i); - size_t free, total; - ggml_backend_dev_memory(dev, &free, &total); - printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024); + const std::regex regex{ R"([,/]+)" }; + std::string dev_str = argv[i]; + std::sregex_token_iterator iter(dev_str.begin(), dev_str.end(), regex, -1); + std::sregex_token_iterator end; + for ( ; iter != end; ++iter) { + try { + params.devices.push_back(*iter); + } catch (const std::exception & ) { + fprintf(stderr, "error: invalid device: %s\n", iter->str().c_str()); + return false; } - return false; } } else if (arg == "-p" || arg == "--port") { if (++i >= argc) { @@ -200,7 +201,19 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params & if (++i >= argc) { return false; } - params.backend_mem = std::stoul(argv[i]) * 1024 * 1024; + const std::regex regex{ R"([,/]+)" }; + std::string mem_str = argv[i]; + std::sregex_token_iterator iter(mem_str.begin(), mem_str.end(), regex, -1); + std::sregex_token_iterator end; + for ( ; iter != end; ++iter) { + try { + size_t mem = std::stoul(*iter) * 1024 * 1024; + params.dev_mem.push_back(mem); + } catch (const std::exception & ) { + fprintf(stderr, "error: invalid memory size: %s\n", iter->str().c_str()); + return false; + } + } } else if (arg == "-h" || arg == "--help") { print_usage(argc, argv, params); exit(0); @@ -213,45 +226,46 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params & return true; } -static ggml_backend_t create_backend(const rpc_server_params & params) { - ggml_backend_t backend = nullptr; - - if (!params.device.empty()) { - ggml_backend_dev_t dev = ggml_backend_dev_by_name(params.device.c_str()); - if (dev) { - backend = ggml_backend_dev_init(dev, nullptr); - if (!backend) { - fprintf(stderr, "Failed to create backend for device %s\n", params.device.c_str()); - return nullptr; +static std::vector get_devices(const rpc_server_params & params) { + std::vector devices; + if (!params.devices.empty()) { + for (auto device : params.devices) { + ggml_backend_dev_t dev = ggml_backend_dev_by_name(device.c_str()); + if (dev) { + devices.push_back(dev); + } else { + fprintf(stderr, "error: unknown device: %s\n", device.c_str()); + fprintf(stderr, "available devices:\n"); + for (size_t i = 0; i < ggml_backend_dev_count(); i++) { + auto * dev = ggml_backend_dev_get(i); + size_t free, total; + ggml_backend_dev_memory(dev, &free, &total); + printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024); + } + return {}; } } } - if (!backend) { - backend = ggml_backend_init_best(); - } - - if (backend) { - fprintf(stderr, "%s: using %s backend\n", __func__, ggml_backend_name(backend)); - - // set the number of threads - ggml_backend_dev_t dev = ggml_backend_get_device(backend); - ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr; - if (reg) { - auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads"); - if (ggml_backend_set_n_threads_fn) { - ggml_backend_set_n_threads_fn(backend, params.n_threads); + // Try non-CPU devices first + if (devices.empty()) { + for (size_t i = 0; i < ggml_backend_dev_count(); i++) { + ggml_backend_dev_t dev = ggml_backend_dev_get(i); + if (ggml_backend_dev_type(dev) != GGML_BACKEND_DEVICE_TYPE_CPU) { + devices.push_back(dev); } } } - return backend; -} + // If there are no accelerators, fallback to CPU device + if (devices.empty()) { + ggml_backend_dev_t dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (dev) { + devices.push_back(dev); + } + } -static void get_backend_memory(ggml_backend_t backend, size_t * free_mem, size_t * total_mem) { - ggml_backend_dev_t dev = ggml_backend_get_device(backend); - GGML_ASSERT(dev != nullptr); - ggml_backend_dev_memory(dev, free_mem, total_mem); + return devices; } int main(int argc, char * argv[]) { @@ -273,18 +287,23 @@ int main(int argc, char * argv[]) { fprintf(stderr, "\n"); } - ggml_backend_t backend = create_backend(params); - if (!backend) { - fprintf(stderr, "Failed to create backend\n"); + auto devices = get_devices(params); + if (devices.empty()) { + fprintf(stderr, "No devices found\n"); return 1; } std::string endpoint = params.host + ":" + std::to_string(params.port); - size_t free_mem, total_mem; - if (params.backend_mem > 0) { - free_mem = params.backend_mem; - total_mem = params.backend_mem; - } else { - get_backend_memory(backend, &free_mem, &total_mem); + std::vector free_mem, total_mem; + for (size_t i = 0; i < devices.size(); i++) { + if (i < params.dev_mem.size()) { + free_mem.push_back(params.dev_mem[i]); + total_mem.push_back(params.dev_mem[i]); + } else { + size_t free, total; + ggml_backend_dev_memory(devices[i], &free, &total); + free_mem.push_back(free); + total_mem.push_back(total); + } } const char * cache_dir = nullptr; std::string cache_dir_str; @@ -309,8 +328,7 @@ int main(int argc, char * argv[]) { return 1; } - start_server_fn(backend, endpoint.c_str(), cache_dir, free_mem, total_mem); - - ggml_backend_free(backend); + start_server_fn(endpoint.c_str(), cache_dir, devices.size(), params.n_threads, + devices.data(), free_mem.data(), total_mem.data()); return 0; }