From 64ddda46d44d6255ccc13e2c9afa1802b4bdec42 Mon Sep 17 00:00:00 2001 From: Feng Ren Date: Thu, 2 Jan 2025 16:38:24 +0800 Subject: [PATCH] [TransferEngine] Refactor code to hide transport logics from user APIs (#51) Co-authored-by: doujiang24 --- doc/en/p2p-store.md | 2 +- doc/en/transfer-engine.md | 62 +- doc/zh/p2p-store.md | 2 +- doc/zh/transfer-engine.md | 72 +- mooncake-integration/vllm/vllm_adaptor.cpp | 51 +- mooncake-p2p-store/src/p2pstore/go.mod | 4 +- .../src/p2pstore/transfer_engine.go | 28 +- .../example/memory_pool.cpp | 12 +- .../example/transfer_engine_bench.cpp | 62 +- mooncake-transfer-engine/include/common.h | 4 +- mooncake-transfer-engine/include/error.h | 1 + .../include/multi_transport.h | 65 ++ mooncake-transfer-engine/include/topology.h | 100 ++- .../include/transfer_engine.h | 67 +- .../include/transfer_engine_c.h | 22 +- .../include/transfer_metadata.h | 39 +- .../include/transfer_metadata_plugin.h | 57 ++ .../transport/rdma_transport/rdma_transport.h | 11 +- .../transport/tcp_transport/tcp_transport.h | 8 +- .../include/transport/transport.h | 9 +- .../rust/src/transfer_engine.rs | 10 +- .../src/multi_transport.cpp | 182 ++++ mooncake-transfer-engine/src/topology.cpp | 182 +++- .../src/transfer_engine.cpp | 171 ++-- .../src/transfer_engine_c.cpp | 44 +- .../src/transfer_metadata.cpp | 799 +++--------------- .../src/transfer_metadata_plugin.cpp | 594 +++++++++++++ .../rdma_transport/endpoint_store.cpp | 4 +- .../transport/rdma_transport/rdma_context.cpp | 143 ++-- .../rdma_transport/rdma_endpoint.cpp | 79 +- .../rdma_transport/rdma_transport.cpp | 142 ++-- .../transport/rdma_transport/worker_pool.cpp | 40 +- .../transport/tcp_transport/tcp_transport.cpp | 44 +- .../tests/rdma_transport_test.cpp | 54 +- .../tests/rdma_transport_test2.cpp | 35 +- .../tests/tcp_transport_test.cpp | 110 ++- .../tests/topology_test.cpp | 12 +- 37 files changed, 1937 insertions(+), 1386 deletions(-) create mode 100644 mooncake-transfer-engine/include/multi_transport.h create mode 100644 mooncake-transfer-engine/include/transfer_metadata_plugin.h create mode 100644 mooncake-transfer-engine/src/multi_transport.cpp create mode 100644 mooncake-transfer-engine/src/transfer_metadata_plugin.cpp diff --git a/doc/en/p2p-store.md b/doc/en/p2p-store.md index ff2f8c7..934a70e 100644 --- a/doc/en/p2p-store.md +++ b/doc/en/p2p-store.md @@ -42,7 +42,7 @@ func NewP2PStore(metadataUri string, localSegmentName string, nicPriorityMatrix Creates an instance of `P2PStore`, which internally starts a Transfer Engine service. - `metadataUri`: The hostname or IP address of the metadata server/etcd service. - `localSegmentName`: The local server name (hostname/IP address:port), ensuring uniqueness within the cluster. -- `nicPriorityMatrix`: The network interface card priority order matrix, see the related description in the Transfer Engine API documentation (`TransferEngine::installOrGetTransport`). +- `nicPriorityMatrix`: The network interface card priority order matrix, see the related description in the Transfer Engine API documentation (`TransferEngine::installTransport`). - Return value: If successful, returns a pointer to the `P2PStore` instance, otherwise returns `error`. ```go diff --git a/doc/en/transfer-engine.md b/doc/en/transfer-engine.md index fa6f3fe..7b0f025 100644 --- a/doc/en/transfer-engine.md +++ b/doc/en/transfer-engine.md @@ -122,17 +122,18 @@ After successfully compiling Transfer Engine, the test program `transfer_engine_ # This is 10.0.0.2 export MC_GID_INDEX=n ./transfer_engine_bench --mode=target \ - --metadata_server=10.0.0.1:2379 \ + --metadata_server=etcd://10.0.0.1:2379 \ --local_server_name=10.0.0.2:12345 \ --device_name=erdma_0 ``` The meanings of the various parameters are as follows: - - The default value of the parameter corresponding to the environment variable `MC_GID_INDEX` is 0, which means that the Transfer Engine selects a GID that is most likely to be connected. - If the connection is hung, the user still needs to set the value of such a environment variable manually. + - The default value of the parameter corresponding to the environment variable `MC_GID_INDEX` is 0, which means that the Transfer Engine selects a GID that is most likely to be connected. Since this parameter depends on the specific network environment, the user has to set the value of the environment variable manually if the connection is hung. The environment variable `NCCL_IB_GID_INDEX` is equivalent to this function. - `--mode=target` indicates the start of the target node. The target node does not initiate read/write requests; it passively supplies or writes data as required by the initiator node. > Note: In actual applications, there is no need to distinguish between target nodes and initiator nodes; each node can freely initiate read/write requests to other nodes in the cluster. - - `--metadata_server` is the address of the metadata server (the full address of the etcd service). - > Change `--metadata_server` to `--metadata_server=http://10.0.0.1:8080/metadata` and add `--metadata_type=http` when using `http` as the `metadata` service. + - `--metadata_server` is the address of the metadata server. Its form is `[proto]://[hostname:port]`. For example, the following addresses are VALID: + - Use `etcd` as metadata storage: `"10.0.0.1:2379"`, `"etcd://10.0.0.1:2379"` or `"etcd://10.0.0.1:2379,10.0.0.2:2379"` + - Use `redis` as metadata storage: `"redis://10.0.0.1:6379"` + - Use `http` as metadata storage: `"http://10.0.0.1:8080/metadata"` - `--local_server_name` represents the address of this machine, which does not need to be set in most cases. If this option is not set, the value is equivalent to the hostname of this machine (i.e., `hostname(2)`). Other nodes in the cluster will use this address to attempt out-of-band communication with this node to establish RDMA connections. > Note: If out-of-band communication fails, the connection cannot be established. Therefore, if necessary, you need to modify the `/etc/hosts` file on all nodes in the cluster to locate the correct node through the hostname. - `--device_name` indicates the name of the RDMA network card used in the transfer process. @@ -168,11 +169,11 @@ Transfer Engine provides interfaces through the `TransferEngine` class (located ### Data Transfer -#### Transport::TransferRequest +#### TransferEngine::TransferRequest -The core API provided by Mooncake Transfer Engine is submitting a group of asynchronous `Transport::TransferRequest` tasks through the `Transport::submitTransfer` interface, and querying their status through the `Transport::getTransferStatus` interface. Each `Transport::TransferRequest` specifies reading or writing a continuous data space of `length` starting from the local starting address `source`, to the position starting at `target_offset` in the segment corresponding to `target_id`. +The core API provided by Mooncake Transfer Engine is submitting a group of asynchronous `TransferRequest` tasks through the `submitTransfer` interface, and querying their status through the `getTransferStatus` interface. Each `TransferRequest` specifies reading or writing a continuous data space of `length` starting from the local starting address `source`, to the position starting at `target_offset` in the segment corresponding to `target_id`. -The `Transport::TransferRequest` structure is defined as follows: +The `TransferRequest` structure is defined as follows: ```cpp using SegmentID = int32_t; @@ -194,7 +195,7 @@ struct TransferRequest - NVMeOF space type, where each file corresponds to a segment. In this case, the segment name passed to the `openSegment` interface is equivalent to the unique identifier of the file. `target_offset` is the offset of the target file. - `length` represents the amount of data transferred. TransferEngine may further split this into multiple read/write requests internally. -#### Transport::allocateBatchID +#### TransferEngine::allocateBatchID ```cpp BatchID allocateBatchID(size_t batch_size); @@ -205,7 +206,7 @@ Allocates a `BatchID`. A maximum of `batch_size` `TransferRequest`s can be submi - `batch_size`: The maximum number of `TransferRequest`s that can be submitted under the same `BatchID`; - Return value: If successful, returns `BatchID` (non-negative); otherwise, returns a negative value. -#### Transport::submitTransfer +#### TransferEngine::submitTransfer ```cpp int submitTransfer(BatchID batch_id, const std::vector &entries); @@ -217,7 +218,7 @@ Submits new `TransferRequest` tasks to `batch_id`. The task is asynchronously su - `entries`: Array of `TransferRequest`; - Return value: If successful, returns 0; otherwise, returns a negative value. -#### Transport::getTransferStatus +#### TransferEngine::getTransferStatus ```cpp enum TaskStatus @@ -244,7 +245,7 @@ Obtains the running status of the `TransferRequest` with `task_id` in `batch_id` - `status`: Output Transfer status; - Return value: If successful, returns 0; otherwise, returns a negative value. -#### Transport::freeBatchID +#### TransferEngine::freeBatchID ```cpp int freeBatchID(BatchID batch_id); @@ -258,9 +259,9 @@ Recycles `BatchID`, and subsequent operations on `submitTransfer` and `getTransf ### Multi-Transport Management The `TransferEngine` class internally manages multiple backend `Transport` classes, and users can load or unload `Transport` for different backends in `TransferEngine`. -#### TransferEngine::installOrGetTransport +#### TransferEngine::installTransport ```cpp -Transport* installOrGetTransport(const std::string& proto, void** args); +Transport* installTransport(const std::string& proto, void** args); ``` Registers `Transport` in `TransferEngine`. If a `Transport` for a certain protocol already exists, it returns that `Transport`. @@ -272,7 +273,7 @@ Registers `Transport` in `TransferEngine`. If a `Transport` for a certain protoc ##### TCP Transfer Mode For TCP transfer mode, there is no need to pass `args` objects when registering the `Transport` object. ```cpp -engine->installOrGetTransport("tcp", nullptr); +engine->installTransport("tcp", nullptr); ``` ##### RDMA Transfer Mode @@ -281,7 +282,7 @@ For RDMA transfer mode, the network card priority marrix must be specified throu void** args = (void**) malloc(2 * sizeof(void*)); args[0] = /* topology matrix */; args[1] = nullptr; -engine->installOrGetTransport("rdma", args); +engine->installTransport("rdma", args); ``` The network card priority marrix is a JSON string indicating the storage medium name and the list of network cards to be used preferentially, as shown in the example below: ```json @@ -302,7 +303,7 @@ For NVMeOF transfer mode, the file path must be specified through `args` during void** args = (void**) malloc(2 * sizeof(void*)); args[0] = /* topology matrix */; args[1] = nullptr; -engine->installOrGetTransport("nvmeof", args); +engine->installTransport("nvmeof", args); ``` #### TransferEngine::uninstallTransport @@ -328,7 +329,7 @@ Registers a space starting at address `addr` with a length of `size` on the loca - `addr`: The starting address of the registration space; - `size`: The length of the registration space; -- `location`: The `device` corresponding to this memory segment, such as `cuda:0` indicating the GPU device, `cpu:0` indicating the CPU socket, by matching with the network card priority order table (see `installOrGetTransport`), the preferred network card is identified. +- `location`: The `device` corresponding to this memory segment, such as `cuda:0` indicating the GPU device, `cpu:0` indicating the CPU socket, by matching with the network card priority order table (see `installTransport`), the preferred network card is identified. - `remote_accessible`: Indicates whether this memory can be accessed by remote nodes. - Return value: If successful, returns 0; otherwise, returns a negative value. @@ -442,20 +443,23 @@ For specific implementation, refer to the demo service implemented in Golang at ### Initialization +TransferEngine needs to initializing by calling the `init` method before further actions: ```cpp -TransferEngine(std::unique_ptr metadata_client); -TransferMetadata(const std::string &metadata_server, const std::string &protocol = "etcd"); -``` - -- Pointer to a `TransferMetadata` object, which abstracts the communication logic between the TransferEngine framework and the metadata server. We currently support `etcd`, `redis` and `http` protocols, while `metadata_server` represents the IP address or hostname of the etcd or redis server, or the base HTTP URI of http server. +TransferEngine(); -For easy exception handling, TransferEngine needs to call the init function for secondary construction after construction: -```cpp -int init(std::string& server_name, std::string& connectable_name, uint64_t rpc_port = 12345); +int init(const std::string &metadata_conn_string, + const std::string &local_server_name, + const std::string &ip_or_host_name, + uint64_t rpc_port = 12345); ``` - -- `server_name`: The local server name, ensuring uniqueness within the cluster. It also serves as the name of the RAM Segment that other nodes refer to the current instance (i.e., Segment Name). -- `connectable_name`: The name used for other clients to connect, which can be a hostname or IP address. +- `metadata_conn_string`: Connecting string of metadata storage servers, i.e., the IP address/hostname of `etcd`/`redis` or the URI of the http service. +The general form is `[proto]://[hostname:port]`. For example, the following metadata server addresses are legal: + - Using `etcd` as a metadata storage service: `“10.0.0.1:2379”` or `“etcd://10.0.0.1:2379”`. + - Using `redis` as a metadata storage service: `“redis://10.0.0.1:6379”` + - Using `http` as a metadata storage service: `“http://10.0.0.1:8080/metadata”` + +- `local_server_name`: The local server name, ensuring uniqueness within the cluster. It also serves as the name of the RAM Segment that other nodes refer to the current instance (i.e., Segment Name). +- `ip_or_host_name`: The name used for other clients to connect, which can be a hostname or IP address. - `rpc_port`: The rpc port used for interaction with other clients. - Return value: If successful, returns 0; if TransferEngine has already been init, returns -1. diff --git a/doc/zh/p2p-store.md b/doc/zh/p2p-store.md index c7b0f2c..ab0029d 100644 --- a/doc/zh/p2p-store.md +++ b/doc/zh/p2p-store.md @@ -53,7 +53,7 @@ func NewP2PStore(metadataUri string, localSegmentName string, nicPriorityMatrix 创建 P2PStore 实例,该实例内部会启动一个 Transfer Engine 服务。 - `metadataUri`:元数据服务器/etcd服务所在主机名或 IP 地址。 - `localSegmentName`:本地的服务器名称(主机名/IP地址:端口号),保证在集群内唯一。 -- `nicPriorityMatrix`:网卡优先级顺序表,参见位于 Transfer Engine API 文档的相关描述(`TransferEngine::installOrGetTransport`)。 +- `nicPriorityMatrix`:网卡优先级顺序表,参见位于 Transfer Engine API 文档的相关描述(`TransferEngine::installTransport`)。 - 返回值:若成功则返回 `P2PStore` 实例指针,否则返回 `error`。 ```go diff --git a/doc/zh/transfer-engine.md b/doc/zh/transfer-engine.md index edc805d..19e404b 100644 --- a/doc/zh/transfer-engine.md +++ b/doc/zh/transfer-engine.md @@ -11,7 +11,7 @@ Mooncake Transfer Engine 是一个围绕 Segment 和 BatchTransfer 两个核心 如上图所示,每个特定的客户端对应一个 TransferEngine,其中不仅包含一个 RAM Segment,还集成了对于多线程多网卡高速传输的管理。RAM Segment 原则上就对应这个 TransferEngine 的全部虚拟地址空间,但实际上仅仅会注册其中的部分区域(被称为一个 Buffer)供外部 (GPUDirect) RDMA Read/Write。每一段 Buffer 可以分别设置权限(对应 RDMA rkey 等)和网卡亲和性(比如基于拓扑优先从哪张卡读写等)。 -Mooncake Transfer Engine 通过 `TransferEngine` 类对外提供接口(位于 `mooncake-transfer-engine/include/transfer_engine.h`),其中对应不同后端的具体的数据传输功能由 `Transport` 类实现,目前支持 `TcpTransport`、`RdmaTransport` 和 `NVMeoFTransport`。 +Mooncake Transfer Engine 通过 `TransferEngine` 类对外提供接口(位于 `mooncake-transfer-engine/include/transfer_engine.h`),其中对应不同后端的具体的数据传输功能在内部由 `Transport` 类实现,包括`TcpTransport`、`RdmaTransport` 和 `NVMeoFTransport`。 ### Segment Segment 表示 Transfer Engine 实现数据传输过程期间可使用的源地址范围及目标地址范围集合。也就是说,所有 BatchTransfer 请求中涉及的本地与远程地址都需要位于合法的 Segment 区间里。Transfer Engine 支持以下两种类型的 Segment。 @@ -87,26 +87,26 @@ Transfer Engine 使用SIEVE算法来管理端点的逐出。如果由于链路 例如,可使用 `mooncake-transfer-engine/example/http-metadata-server` 示例中的 `http` 服务: ```bash # This is 10.0.0.1 - # cd mooncake-transfer-engine/example/http-metadata-server + cd mooncake-transfer-engine/example/http-metadata-server go run . --addr=:8080 ``` 2. **启动目标节点。** ```bash # This is 10.0.0.2 - export MC_GID_INDEX=n ./transfer_engine_bench --mode=target \ - --metadata_server=10.0.0.1:2379 \ + --metadata_server=etcd://10.0.0.1:2379 \ --local_server_name=10.0.0.2:12345 \ --device_name=erdma_0 ``` 各个参数的含义如下: - - 环境变量 `MC_GID_INDEX` 对应参数的默认值为 0,表示由 Transfer Engine 选取一个最可能连通的 GID。 - 如果连接被挂起,用户仍需手工设置改环境变量的值。 + - 环境变量 `MC_GID_INDEX` 对应参数的默认值为 0,表示由 Transfer Engine 选取一个最可能连通的 GID。由于该参数取决于具体的网络环境存在差异,如果连接被挂起,用户需手工设置环境变量的值。环境变量 `NCCL_IB_GID_INDEX` 与此功能等价。 - `--mode=target` 表示启动目标节点。目标节点不发起读写请求,只是被动按发起节点的要求供给或写入数据。 > 注意:实际应用中可不区分目标节点和发起节点,每个节点可以向集群内其他节点自由发起读写请求。 - - `--metadata_server` 为元数据服务器地址(etcd 服务的完整地址)。 - > 如果使用 `http` 作为 `metadata` 服务,需要将 `--metadata_server` 参数改为 `--metadata_server=http://10.0.0.1:8080/metadata`,并且指定 `--metadata_type=http`。 + - `--metadata_server` 为元数据服务器地址,一般形式是 `[proto]://[hostname:port]`。例如,下列元数据服务器地址是合法的: + - 使用 `etcd` 作为元数据存储服务:`"10.0.0.1:2379"` 或 `"etcd://10.0.0.1:2379"` 或 `"etcd://10.0.0.1:2379,10.0.0.2:2379"` + - 使用 `redis` 作为元数据存储服务:`"redis://10.0.0.1:6379"` + - 使用 `http` 作为元数据存储服务:`"http://10.0.0.1:8080/metadata"` - `--local_server_name` 表示本机器地址,大多数情况下无需设置。如果不设置该选项,则该值等同于本机的主机名(即 `hostname(2)` )。集群内的其它节点会使用此地址尝试与该节点进行带外通信,从而建立 RDMA 连接。 > 注意:若带外通信失败则连接无法建立。因此,若有必要需修改集群所有节点的 `/etc/hosts` 文件,使得可以通过主机名定位到正确的节点。 - `--device_name` 表示传输过程使用的 RDMA 网卡名称。 @@ -138,15 +138,15 @@ Transfer Engine 使用SIEVE算法来管理端点的逐出。如果由于链路 > 如果在执行期间发生异常,大多数情况是参数设置不正确所致,建议参考[故障排除文档](troubleshooting.md)先行排查。 ## C/C++ API -Transfer Engine 通过 `TransferEngine` 类对外提供接口(位于 `mooncake-transfer-engine/include/transfer_engine.h`),其中对应不同后端的具体的数据传输功能由 `Transport` 类实现,目前支持 `TcpTransport`,`RdmaTransport` 和 `NVMeoFTransport`。 +Transfer Engine 通过 `TransferEngine` 类统一对外提供接口(位于 `mooncake-transfer-engine/include/transfer_engine.h`),其中对应不同后端的具体的数据传输功能在内部由 `Transport` 类实现,目前支持 `TcpTransport`,`RdmaTransport` 和 `NVMeoFTransport`。 ### 数据传输 -#### Transport::TransferRequest +#### TransferEngine::TransferRequest -Mooncake Transfer Engine 提供的最核心 API 是:通过 `Transport::submitTransfer` 接口提交一组异步的 `Transport::TransferRequest` 任务,并通过 `Transport::getTransferStatus` 接口查询其状态。每个 `Transport::TransferRequest` 规定从本地的起始地址 `source` 开始,读取或写入长度为 `length` 的连续数据空间,到 `target_id` 对应的段、从 `target_offset` 开始的位置。 +Mooncake Transfer Engine 提供的最核心 API 是:通过 `submitTransfer()` 接口提交一组异步的由 `TransferRequest` 结构体表示的任务,并通过 `getTransferStatus()` 接口查询其状态。每个 `TransferRequest` 结构体规定从本地的起始地址 `source` 开始,读取或写入长度为 `length` 的连续数据空间,到 `target_id` 对应的段、从 `target_offset` 开始的位置。 -`Transport::TransferRequest` 结构体定义如下: +`TransferRequest` 结构体定义如下: ```cpp using SegmentID = int32_t; @@ -168,7 +168,7 @@ struct TransferRequest - NVMeOF 空间型,每个文件对应一个 Segment。此时 `openSegment` 接口传入的 Segment 名称等同于文件的唯一标识符。`target_offset` 为目标文件的偏移量。 - `length` 表示传输的数据量。TransferEngine 在内部可能会进一步拆分成多个读写请求。 -#### Transport::allocateBatchID +#### TransferEngine::allocateBatchID ```cpp BatchID allocateBatchID(size_t batch_size); @@ -179,7 +179,7 @@ BatchID allocateBatchID(size_t batch_size); - `batch_size`: 同一 `BatchID` 下最多可提交的 `TransferRequest` 数量; - 返回值:若成功,返回 `BatchID`(非负);否则返回负数值。 -#### Transport::submitTransfer +#### TransferEngine::submitTransfer ```cpp int submitTransfer(BatchID batch_id, const std::vector &entries); @@ -191,7 +191,7 @@ int submitTransfer(BatchID batch_id, const std::vector &entries - `entries`: `TransferRequest` 数组; - 返回值:若成功,返回 0;否则返回负数值。 -#### Transport::getTransferStatus +#### TransferEngine::getTransferStatus ```cpp enum TaskStatus @@ -218,7 +218,7 @@ int getTransferStatus(BatchID batch_id, size_t task_id, TransferStatus &status) - `status`: 输出 Transfer 状态; - 返回值:若成功,返回 0;否则返回负数值。 -#### Transport::freeBatchID +#### TransferEngine::freeBatchID ```cpp int freeBatchID(BatchID batch_id); @@ -232,9 +232,9 @@ int freeBatchID(BatchID batch_id); ### 多 Transport 管理 `TransferEngine` 类内部管理多后端的 `Transport` 类,用户可向 `TransferEngine` 中装载或卸载对不同后端进行传输的 `Transport`。 -#### TransferEngine::installOrGetTransport +#### TransferEngine::installTransport ```cpp -Transport* installOrGetTransport(const std::string& proto, void** args); +Transport* installTransport(const std::string& proto, void** args); ``` 在 `TransferEngine` 中注册 `Transport`。如果某个协议对应的 `Transport` 已存在,则返回该 `Transport`。 @@ -245,7 +245,7 @@ Transport* installOrGetTransport(const std::string& proto, void** args); **TCP 传输模式:** 对于 TCP 传输模式,注册 `Transport` 期间不需要传入 `args` 对象。 ```cpp -engine->installOrGetTransport("tcp", nullptr); +engine->installTransport("tcp", nullptr); ``` **RDMA 传输模式:** @@ -254,7 +254,7 @@ engine->installOrGetTransport("tcp", nullptr); void** args = (void**) malloc(2 * sizeof(void*)); args[0] = /* topology matrix */; args[1] = nullptr; -engine->installOrGetTransport("rdma", args); +engine->installTransport("rdma", args); ``` 网卡优先级顺序是一个 JSON 字符串,表示使用的存储介质名称及优先使用的网卡列表,样例如下: ```json @@ -274,7 +274,7 @@ engine->installOrGetTransport("rdma", args); void** args = (void**) malloc(2 * sizeof(void*)); args[0] = /* topology matrix */; args[1] = nullptr; -engine->installOrGetTransport("nvmeof", args); +engine->installTransport("nvmeof", args); ``` #### TransferEngine::uinstallTransport @@ -299,7 +299,7 @@ int registerLocalMemory(void *addr, size_t size, string location, bool remote_ac - `addr`: 注册空间起始地址; - `size`:注册空间长度; -- `location`: 这一段内存对应的 `device`,比如 `cuda:0` 表示对应 GPU 设备,`cpu:0` 表示对应 CPU socket,通过和网卡优先级顺序表(见`installOrGetTransport`) 匹配,识别优选的网卡。 +- `location`: 这一段内存对应的 `device`,比如 `cuda:0` 表示对应 GPU 设备,`cpu:0` 表示对应 CPU socket,通过和网卡优先级顺序表(见`installTransport`) 匹配,识别优选的网卡。 - `remote_accessible`: 标识这一块内存能否被远端节点访问。 - 返回值:若成功,返回 0;否则返回负数值。 @@ -410,26 +410,28 @@ Value = { 具体实现,可以参考 [mooncake-transfer-engine/example/http-metadata-server](../../mooncake-transfer-engine/example/http-metadata-server) 用 Golang 实现的 demo 服务。 ### 构造函数与初始化 - +TransferEngine 在完成构造后需要调用 `init` 函数进行初始化: ```cpp -TransferEngine(std::unique_ptr metadata_client); -TransferMetadata(const std::string &metadata_server, const std::string &protocol = "etcd"); +TransferEngine(); + +int init(const std::string &metadata_conn_string, + const std::string &local_server_name, + const std::string &ip_or_host_name, + uint64_t rpc_port = 12345); ``` +- metadata_conn_string: 元数据存储服务连接字符串,表示 `etcd`/`redis` 的 IP 地址/主机名,或者 http 服务的 URI。一般形式是 `[proto]://[hostname:port]`。例如,下列元数据服务器地址是合法的: -- TransferMetadata 对象指针,该对象将 TransferEngine 框架与元数据服务器等带外通信逻辑抽取出来,以方便用户将其部署到不同的环境中。 - 目前支持 `etcd`,`redis` 和 `http` 三种元数据服务。`metadata_server` 表示 `etcd`/`redis` 的 IP 地址/主机名,或者 http 服务的 URI。 + - 使用 `etcd` 作为元数据存储服务:`"10.0.0.1:2379"` 或 `"etcd://10.0.0.1:2379"` + - 使用 `redis` 作为元数据存储服务:`"redis://10.0.0.1:6379"` + - 使用 `http` 作为元数据存储服务:`"http://10.0.0.1:8080/metadata"` -为了便于异常处理,TransferEngine 在完成构造后需要调用init函数进行二次构造: -```cpp -int init(std::string& server_name, std::string& connectable_name, uint64_t rpc_port = 12345); -``` -- server_name: 本地的 server name,保证在集群内唯一。它同时作为其他节点引用当前实例所属 RAM Segment 的名称(即 Segment Name) -- connectable_name:用于被其它 client 连接的 name,可为 hostname 或 ip 地址。 -- rpc_port:用于与其它 client 交互的 rpc 端口。- +- local_server_name: 本地的 server name,保证在集群内唯一。它同时作为其他节点引用当前实例所属 RAM Segment 的名称(即 Segment Name) +- ip_or_host_name: 用于被其它 client 连接的 name,可为 hostname 或 ip 地址。 +- rpc_port:当前进程占用于与其它 client 交互的 rpc 端口。 - 返回值:若成功,返回 0;若 TransferEngine 已被 init 过,返回 -1。 ```cpp - ~TransferEngine(); +~TransferEngine(); ``` 回收分配的所有类型资源,同时也会删除掉全局 meta data server 上的信息。 diff --git a/mooncake-integration/vllm/vllm_adaptor.cpp b/mooncake-integration/vllm/vllm_adaptor.cpp index 2f0d922..5334db0 100644 --- a/mooncake-integration/vllm/vllm_adaptor.cpp +++ b/mooncake-integration/vllm/vllm_adaptor.cpp @@ -46,27 +46,44 @@ std::string formatDeviceNames(const std::string &device_names) { return formatted; } +std::pair parseConnectionString( + const std::string &conn_string) { + std::pair result; + std::string proto = "etcd"; + std::string domain; + std::size_t pos = conn_string.find("://"); + + if (pos != std::string::npos) { + proto = conn_string.substr(0, pos); + domain = conn_string.substr(pos + 3); + } else { + domain = conn_string; + } + + result.first = proto; + result.second = domain; + return result; +} + int VLLMAdaptor::initialize(const char *local_hostname, const char *metadata_server, const char *protocol, const char *device_name) { - return initializeExt(local_hostname, metadata_server, protocol, device_name, - "etcd"); + auto conn_string = parseConnectionString(metadata_server); + return initializeExt(local_hostname, conn_string.second.c_str(), protocol, + device_name, conn_string.first.c_str()); } int VLLMAdaptor::initializeExt(const char *local_hostname, const char *metadata_server, const char *protocol, const char *device_name, const char *metadata_type) { - auto metadata_client = - std::make_shared(metadata_server, metadata_type); - if (!metadata_client) return -1; - - engine_ = std::make_unique(metadata_client); - if (!engine_) return -1; + auto conn_string = + std::string(metadata_type) + "://" + std::string(metadata_server); + engine_ = std::make_unique(); auto hostname_port = parseHostNameWithPort(local_hostname); - int ret = engine_->init(local_hostname, hostname_port.first.c_str(), - hostname_port.second); + int ret = engine_->init(conn_string, local_hostname, + hostname_port.first.c_str(), hostname_port.second); if (ret) return -1; xport_ = nullptr; @@ -77,9 +94,9 @@ int VLLMAdaptor::initializeExt(const char *local_hostname, void **args = (void **)malloc(2 * sizeof(void *)); args[0] = (void *)nic_priority_matrix.c_str(); args[1] = nullptr; - xport_ = engine_->installOrGetTransport("rdma", args); + xport_ = engine_->installTransport("rdma", args); } else if (strcmp(protocol, "tcp") == 0) { - xport_ = engine_->installOrGetTransport("tcp", nullptr); + xport_ = engine_->installTransport("tcp", nullptr); } else { LOG(ERROR) << "Unsupported protocol"; return -1; @@ -171,7 +188,7 @@ int VLLMAdaptor::transferSync(const char *target_hostname, uintptr_t buffer, handle_map_[target_hostname] = handle; } - auto batch_id = xport_->allocateBatchID(1); + auto batch_id = engine_->allocateBatchID(1); TransferRequest entry; entry.opcode = TransferRequest::READ; entry.length = length; @@ -179,18 +196,18 @@ int VLLMAdaptor::transferSync(const char *target_hostname, uintptr_t buffer, entry.target_id = handle; entry.target_offset = peer_buffer_address; - int ret = xport_->submitTransfer(batch_id, {entry}); + int ret = engine_->submitTransfer(batch_id, {entry}); if (ret < 0) return -1; TransferStatus status; while (true) { - int ret = xport_->getTransferStatus(batch_id, 0, status); + int ret = engine_->getTransferStatus(batch_id, 0, status); LOG_ASSERT(!ret); if (status.s == TransferStatusEnum::COMPLETED) { - xport_->freeBatchID(batch_id); + engine_->freeBatchID(batch_id); return 0; } else if (status.s == TransferStatusEnum::FAILED) { - xport_->freeBatchID(batch_id); + engine_->freeBatchID(batch_id); return -1; } } diff --git a/mooncake-p2p-store/src/p2pstore/go.mod b/mooncake-p2p-store/src/p2pstore/go.mod index 545d359..1e80382 100644 --- a/mooncake-p2p-store/src/p2pstore/go.mod +++ b/mooncake-p2p-store/src/p2pstore/go.mod @@ -1,6 +1,8 @@ module github.com/kvcache-ai/Mooncake/mooncake-p2p-store/src/p2pstore -go 1.20 +go 1.21 + +toolchain go1.23.4 require go.etcd.io/etcd/client/v3 v3.5.15 diff --git a/mooncake-p2p-store/src/p2pstore/transfer_engine.go b/mooncake-p2p-store/src/p2pstore/transfer_engine.go index 1235e97..5776f24 100644 --- a/mooncake-p2p-store/src/p2pstore/transfer_engine.go +++ b/mooncake-p2p-store/src/p2pstore/transfer_engine.go @@ -46,27 +46,21 @@ func parseServerName(serverName string) (host string, port int) { } func NewTransferEngine(metadata_uri string, local_server_name string, nic_priority_matrix string) (*TransferEngine, error) { - native_engine := C.createTransferEngine(C.CString(metadata_uri)) - if native_engine == nil { - return nil, ErrTransferEngine - } - // For simplifiy, local_server_name must be a valid IP address or hostname connectable_name, rpc_port := parseServerName(local_server_name) - ret := C.initTransferEngine(native_engine, - C.CString(local_server_name), - C.CString(connectable_name), - C.uint64_t(rpc_port)) - - if ret < 0 { - C.destroyTransferEngine(native_engine) + + native_engine := C.createTransferEngine(C.CString(metadata_uri), + C.CString(local_server_name), + C.CString(connectable_name), + C.uint64_t(rpc_port)) + if native_engine == nil { return nil, ErrTransferEngine } var args [2]unsafe.Pointer args[0] = unsafe.Pointer(C.CString(nic_priority_matrix)) args[1] = nil - xport := C.installOrGetTransport(native_engine, C.CString("rdma"), &args[0]) + xport := C.installTransport(native_engine, C.CString("rdma"), &args[0]) if xport == nil { C.destroyTransferEngine(native_engine) return nil, ErrTransferEngine @@ -105,7 +99,7 @@ func (engine *TransferEngine) unregisterLocalMemory(addr uintptr) error { } func (engine *TransferEngine) allocateBatchID(batchSize int) (BatchID, error) { - ret := C.allocateBatchID(engine.xport, C.size_t(batchSize)) + ret := C.allocateBatchID(engine.engine, C.size_t(batchSize)) if ret == C.UINT64_MAX { return BatchID(-1), ErrTransferEngine } @@ -144,7 +138,7 @@ func (engine *TransferEngine) submitTransfer(batchID BatchID, requests []Transfe } } - ret := C.submitTransfer(engine.xport, C.batch_id_t(batchID), &requestSlice[0], C.size_t(len(requests))) + ret := C.submitTransfer(engine.engine, C.batch_id_t(batchID), &requestSlice[0], C.size_t(len(requests))) if ret < 0 { return ErrTransferEngine } @@ -153,7 +147,7 @@ func (engine *TransferEngine) submitTransfer(batchID BatchID, requests []Transfe func (engine *TransferEngine) getTransferStatus(batchID BatchID, taskID int) (int, uint64, error) { var status C.transfer_status_t - ret := C.getTransferStatus(engine.xport, C.batch_id_t(batchID), C.size_t(taskID), &status) + ret := C.getTransferStatus(engine.engine, C.batch_id_t(batchID), C.size_t(taskID), &status) if ret < 0 { return -1, 0, ErrTransferEngine } @@ -161,7 +155,7 @@ func (engine *TransferEngine) getTransferStatus(batchID BatchID, taskID int) (in } func (engine *TransferEngine) freeBatchID(batchID BatchID) error { - ret := C.freeBatchID(engine.xport, C.batch_id_t(batchID)) + ret := C.freeBatchID(engine.engine, C.batch_id_t(batchID)) if ret < 0 { return ErrTransferEngine } diff --git a/mooncake-transfer-engine/example/memory_pool.cpp b/mooncake-transfer-engine/example/memory_pool.cpp index f0806aa..2a39ea6 100644 --- a/mooncake-transfer-engine/example/memory_pool.cpp +++ b/mooncake-transfer-engine/example/memory_pool.cpp @@ -73,23 +73,19 @@ std::string loadNicPriorityMatrix() { } int target() { - auto metadata_client = - std::make_shared(FLAGS_metadata_server); - LOG_ASSERT(metadata_client); - auto nic_priority_matrix = loadNicPriorityMatrix(); const size_t dram_buffer_size = 1ull << 30; - auto engine = std::make_unique(metadata_client); + auto engine = std::make_unique(); void **args = (void **)malloc(2 * sizeof(void *)); args[0] = (void *)nic_priority_matrix.c_str(); args[1] = nullptr; const std::string &connectable_name = FLAGS_local_server_name; - engine->init(FLAGS_local_server_name.c_str(), connectable_name.c_str(), - 12345); - engine->installOrGetTransport("rdma", args); + engine->init(FLAGS_metadata_server, FLAGS_local_server_name.c_str(), + connectable_name.c_str(), 12345); + engine->installTransport("rdma", args); LOG_ASSERT(engine); diff --git a/mooncake-transfer-engine/example/transfer_engine_bench.cpp b/mooncake-transfer-engine/example/transfer_engine_bench.cpp index e31cf90..b3e6322 100644 --- a/mooncake-transfer-engine/example/transfer_engine_bench.cpp +++ b/mooncake-transfer-engine/example/transfer_engine_bench.cpp @@ -54,7 +54,6 @@ DEFINE_string(mode, "initiator", DEFINE_string(operation, "read", "Operation type: read or write"); DEFINE_string(protocol, "rdma", "Transfer protocol: rdma|tcp"); -DEFINE_string(metadata_type, "etcd", "Metadata type: etcd|redis|http"); DEFINE_string(device_name, "mlx5_2", "Device name to use, valid if protocol=rdma"); @@ -107,7 +106,8 @@ static void freeMemoryPool(void *addr, size_t size) { if (attributes.type == cudaMemoryTypeDevice) { cudaFree(addr); - } else if (attributes.type == cudaMemoryTypeHost || attributes.type == cudaMemoryTypeUnregistered) { + } else if (attributes.type == cudaMemoryTypeHost || + attributes.type == cudaMemoryTypeUnregistered) { numa_free(addr, size); } else { LOG(ERROR) << "Unknown memory type, " << addr << " " << attributes.type; @@ -120,7 +120,7 @@ static void freeMemoryPool(void *addr, size_t size) { volatile bool running = true; std::atomic total_batch_count(0); -int initiatorWorker(Transport *xport, SegmentID segment_id, int thread_id, +int initiatorWorker(TransferEngine *engine, SegmentID segment_id, int thread_id, void *addr) { bindToSocket(thread_id % NR_SOCKETS); TransferRequest::OpCode opcode; @@ -133,7 +133,7 @@ int initiatorWorker(Transport *xport, SegmentID segment_id, int thread_id, exit(EXIT_FAILURE); } - auto segment_desc = xport->meta()->getSegmentDescByID(segment_id); + auto segment_desc = engine->getMetadata()->getSegmentDescByID(segment_id); if (!segment_desc) { LOG(ERROR) << "Unable to get target segment ID, please recheck"; exit(EXIT_FAILURE); @@ -143,7 +143,7 @@ int initiatorWorker(Transport *xport, SegmentID segment_id, int thread_id, size_t batch_count = 0; while (running) { - auto batch_id = xport->allocateBatchID(FLAGS_batch_size); + auto batch_id = engine->allocateBatchID(FLAGS_batch_size); int ret = 0; std::vector requests; for (int i = 0; i < FLAGS_batch_size; ++i) { @@ -159,24 +159,25 @@ int initiatorWorker(Transport *xport, SegmentID segment_id, int thread_id, requests.emplace_back(entry); } - ret = xport->submitTransfer(batch_id, requests); + ret = engine->submitTransfer(batch_id, requests); LOG_ASSERT(!ret); for (int task_id = 0; task_id < FLAGS_batch_size; ++task_id) { bool completed = false; TransferStatus status; while (!completed) { - int ret = xport->getTransferStatus(batch_id, task_id, status); + int ret = engine->getTransferStatus(batch_id, task_id, status); LOG_ASSERT(!ret); if (status.s == TransferStatusEnum::COMPLETED) completed = true; else if (status.s == TransferStatusEnum::FAILED) { LOG(INFO) << "FAILED"; completed = true; + exit(EXIT_FAILURE); } } } - ret = xport->freeBatchID(batch_id); + ret = engine->freeBatchID(batch_id); LOG_ASSERT(!ret); batch_count++; } @@ -185,7 +186,7 @@ int initiatorWorker(Transport *xport, SegmentID segment_id, int thread_id, return 0; } -std::string formatDeviceNames(const std::string& device_names) { +std::string formatDeviceNames(const std::string &device_names) { std::stringstream ss(device_names); std::string item; std::vector tokens; @@ -215,22 +216,22 @@ std::string loadNicPriorityMatrix() { } // Build JSON Data auto device_names = formatDeviceNames(FLAGS_device_name); - return "{\"cpu:0\": [[" + device_names + "], []], " - " \"cpu:1\": [[" + device_names + "], []], " - " \"gpu:0\": [[" + device_names + "], []]}"; + return "{\"cpu:0\": [[" + device_names + + "], []], " + " \"cpu:1\": [[" + + device_names + + "], []], " + " \"gpu:0\": [[" + + device_names + "], []]}"; } int initiator() { - auto metadata_client = - std::make_shared(FLAGS_metadata_server, FLAGS_metadata_type); - LOG_ASSERT(metadata_client); - const size_t ram_buffer_size = 1ull << 30; - auto engine = std::make_unique(metadata_client); + auto engine = std::make_unique(); auto hostname_port = parseHostNameWithPort(FLAGS_local_server_name); - engine->init(FLAGS_local_server_name.c_str(), hostname_port.first.c_str(), - hostname_port.second); + engine->init(FLAGS_metadata_server, FLAGS_local_server_name.c_str(), + hostname_port.first.c_str(), hostname_port.second); Transport *xport = nullptr; if (FLAGS_protocol == "rdma") { @@ -238,9 +239,9 @@ int initiator() { void **args = (void **)malloc(2 * sizeof(void *)); args[0] = (void *)nic_priority_matrix.c_str(); args[1] = nullptr; - xport = engine->installOrGetTransport("rdma", args); + xport = engine->installTransport("rdma", args); } else if (FLAGS_protocol == "tcp") { - xport = engine->installOrGetTransport("tcp", nullptr); + xport = engine->installTransport("tcp", nullptr); } else { LOG(ERROR) << "Unsupported protocol"; } @@ -252,8 +253,7 @@ int initiator() { #ifdef USE_CUDA buffer_num = FLAGS_use_vram ? 1 : NR_SOCKETS; - if (FLAGS_use_vram) - LOG(INFO) << "VRAM is used"; + if (FLAGS_use_vram) LOG(INFO) << "VRAM is used"; for (int i = 0; i < buffer_num; ++i) { addr[i] = allocateMemoryPool(ram_buffer_size, i, FLAGS_use_vram); std::string name_prefix = FLAGS_use_vram ? "gpu:" : "cpu:"; @@ -278,7 +278,7 @@ int initiator() { gettimeofday(&start_tv, nullptr); for (int i = 0; i < FLAGS_threads; ++i) - workers[i] = std::thread(initiatorWorker, xport, segment_id, i, + workers[i] = std::thread(initiatorWorker, engine.get(), segment_id, i, addr[i % buffer_num]); sleep(FLAGS_duration); @@ -306,25 +306,21 @@ int initiator() { } int target() { - auto metadata_client = - std::make_shared(FLAGS_metadata_server, FLAGS_metadata_type); - LOG_ASSERT(metadata_client); - const size_t ram_buffer_size = 1ull << 30; - auto engine = std::make_unique(metadata_client); + auto engine = std::make_unique(); auto hostname_port = parseHostNameWithPort(FLAGS_local_server_name); - engine->init(FLAGS_local_server_name.c_str(), hostname_port.first.c_str(), - hostname_port.second); + engine->init(FLAGS_metadata_server, FLAGS_local_server_name.c_str(), + hostname_port.first.c_str(), hostname_port.second); if (FLAGS_protocol == "rdma") { auto nic_priority_matrix = loadNicPriorityMatrix(); void **args = (void **)malloc(2 * sizeof(void *)); args[0] = (void *)nic_priority_matrix.c_str(); args[1] = nullptr; - engine->installOrGetTransport("rdma", args); + engine->installTransport("rdma", args); } else if (FLAGS_protocol == "tcp") { - engine->installOrGetTransport("tcp", nullptr); + engine->installTransport("tcp", nullptr); } else { LOG(ERROR) << "Unsupported protocol"; } diff --git a/mooncake-transfer-engine/include/common.h b/mooncake-transfer-engine/include/common.h index dfbc90e..471229d 100644 --- a/mooncake-transfer-engine/include/common.h +++ b/mooncake-transfer-engine/include/common.h @@ -91,8 +91,8 @@ static inline std::pair parseHostNameWithPort( auto port_str = server_name.substr(pos + 1); int val = std::atoi(port_str.c_str()); if (val <= 0 || val > 65535) - PLOG(WARNING) << "Illegal port number in " << server_name - << ". Use default port " << port << " instead"; + LOG(WARNING) << "Illegal port number in " << server_name + << ". Use default port " << port << " instead"; else port = (uint16_t)val; return std::make_pair(trimmed_server_name, port); diff --git a/mooncake-transfer-engine/include/error.h b/mooncake-transfer-engine/include/error.h index e1978e5..af1e501 100644 --- a/mooncake-transfer-engine/include/error.h +++ b/mooncake-transfer-engine/include/error.h @@ -34,5 +34,6 @@ #define ERR_NUMA (-300) #define ERR_CLOCK (-301) #define ERR_MEMORY (-302) +#define ERR_NOT_IMPLEMENTED (-303) #endif // ERROR_H \ No newline at end of file diff --git a/mooncake-transfer-engine/include/multi_transport.h b/mooncake-transfer-engine/include/multi_transport.h new file mode 100644 index 0000000..42266da --- /dev/null +++ b/mooncake-transfer-engine/include/multi_transport.h @@ -0,0 +1,65 @@ +// Copyright 2024 KVCache.AI +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MULTI_TRANSPORT_H_ +#define MULTI_TRANSPORT_H_ + +#include + +#include "transport/transport.h" + +namespace mooncake { +class MultiTransport { + public: + using BatchID = Transport::BatchID; + using TransferRequest = Transport::TransferRequest; + using TransferStatus = Transport::TransferStatus; + using BatchDesc = Transport::BatchDesc; + + const static BatchID INVALID_BATCH_ID = Transport::INVALID_BATCH_ID; + + MultiTransport(std::shared_ptr metadata, + std::string &local_server_name); + + ~MultiTransport(); + + BatchID allocateBatchID(size_t batch_size); + + int freeBatchID(BatchID batch_id); + + int submitTransfer(BatchID batch_id, + const std::vector &entries); + + int getTransferStatus(BatchID batch_id, size_t task_id, + TransferStatus &status); + + Transport *installTransport(const std::string &proto, void **args); + + Transport *getTransport(const std::string &proto); + + std::vector listTransports(); + + private: + Transport *selectTransport(const TransferRequest &entry); + + private: + std::shared_ptr metadata_; + std::string local_server_name_; + std::map transport_map_; + RWSpinlock batch_desc_lock_; + std::unordered_map> batch_desc_set_; +}; +} // namespace mooncake + +#endif // MULTI_TRANSPORT_H_ \ No newline at end of file diff --git a/mooncake-transfer-engine/include/topology.h b/mooncake-transfer-engine/include/topology.h index e31ba50..60c7cf6 100644 --- a/mooncake-transfer-engine/include/topology.h +++ b/mooncake-transfer-engine/include/topology.h @@ -1,5 +1,101 @@ +// Copyright 2024 KVCache.AI +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TOPOLOGY_H +#define TOPOLOGY_H + +#include +#include +#include + +#include +#include +#include +#include +#include #include +#include +#include + +#include "common.h" namespace mooncake { -std::string discoverTopologyMatrix(); -} +struct TopologyEntry { + std::string name; + std::vector preferred_hca; + std::vector avail_hca; + + Json::Value toJson() const { + Json::Value matrix(Json::arrayValue); + Json::Value hca_list(Json::arrayValue); + for (auto &hca : preferred_hca) { + hca_list.append(hca); + } + matrix.append(hca_list); + hca_list.clear(); + for (auto &hca : avail_hca) { + hca_list.append(hca); + } + matrix.append(hca_list); + return matrix; + } +}; + +using TopologyMatrix = + std::unordered_map; + +class Topology { + public: + Topology(); + + ~Topology(); + + bool empty() const; + + void clear(); + + int discover(); + + int parse(const std::string &topology_json); + + int disableDevice(const std::string &device_name); + + std::string toString() const; + + Json::Value toJson() const; + + int selectDevice(const std::string storage_type, int retry_count = 0); + + TopologyMatrix getMatrix() const { return matrix_; } + + const std::vector &getHcaList() const { return hca_list_; } + + private: + int resolve(); + + private: + TopologyMatrix matrix_; + std::vector hca_list_; + + struct ResolvedTopologyEntry { + std::vector preferred_hca; + std::vector avail_hca; + }; + std::unordered_map + resolved_matrix_; +}; + +} // namespace mooncake + +#endif // TOPOLOGY_H \ No newline at end of file diff --git a/mooncake-transfer-engine/include/transfer_engine.h b/mooncake-transfer-engine/include/transfer_engine.h index 6677fd4..4abfdd9 100644 --- a/mooncake-transfer-engine/include/transfer_engine.h +++ b/mooncake-transfer-engine/include/transfer_engine.h @@ -28,28 +28,39 @@ #include #include +#include "multi_transport.h" #include "transfer_metadata.h" #include "transport/transport.h" namespace mooncake { +using TransferRequest = Transport::TransferRequest; +using TransferStatus = Transport::TransferStatus; +using TransferStatusEnum = Transport::TransferStatusEnum; +using SegmentHandle = Transport::SegmentHandle; +using SegmentID = Transport::SegmentID; +using BatchID = Transport::BatchID; +using BufferEntry = Transport::BufferEntry; + class TransferEngine { public: - TransferEngine(std::shared_ptr meta) : metadata_(meta) {} + TransferEngine() : metadata_(nullptr) {} ~TransferEngine() { freeEngine(); } - int init(const char *server_name, const char *connectable_name, + int init(const std::string &metadata_conn_string, + const std::string &local_server_name, + const std::string &ip_or_host_name, uint64_t rpc_port = 12345); int freeEngine(); - Transport *installOrGetTransport(const char *proto, void **args); + Transport *installTransport(const std::string &proto, void **args); - int uninstallTransport(const char *proto); + int uninstallTransport(const std::string &proto); - Transport::SegmentHandle openSegment(const char *segment_name); + SegmentHandle openSegment(const std::string &segment_name); - int closeSegment(Transport::SegmentHandle seg_id); + int closeSegment(SegmentHandle handle); int registerLocalMemory(void *addr, size_t length, const std::string &location, @@ -58,38 +69,48 @@ class TransferEngine { int unregisterLocalMemory(void *addr, bool update_metadata = true); - int registerLocalMemoryBatch( - const std::vector &buffer_list, - const std::string &location); + int registerLocalMemoryBatch(const std::vector &buffer_list, + const std::string &location); int unregisterLocalMemoryBatch(const std::vector &addr_list); + BatchID allocateBatchID(size_t batch_size) { + return multi_transports_->allocateBatchID(batch_size); + } + + int freeBatchID(BatchID batch_id) { + return multi_transports_->freeBatchID(batch_id); + } + + int submitTransfer(BatchID batch_id, + const std::vector &entries) { + return multi_transports_->submitTransfer(batch_id, entries); + } + + int getTransferStatus(BatchID batch_id, size_t task_id, + TransferStatus &status) { + return multi_transports_->getTransferStatus(batch_id, task_id, status); + } + int syncSegmentCache() { return metadata_->syncSegmentCache(); } + std::shared_ptr getMetadata() { return metadata_; } + + bool checkOverlap(void *addr, uint64_t length); + private: struct MemoryRegion { void *addr; uint64_t length; - const char *location; + std::string location; bool remote_accessible; }; - Transport *findName(const char *name, size_t n = SIZE_MAX); - - Transport *initTransport(const char *proto); - - std::vector installed_transports_; - std::string local_server_name_; std::shared_ptr metadata_; + std::string local_server_name_; + std::shared_ptr multi_transports_; std::vector local_memory_regions_; }; - -using TransferRequest = Transport::TransferRequest; -using TransferStatus = Transport::TransferStatus; -using TransferStatusEnum = Transport::TransferStatusEnum; -using SegmentID = Transport::SegmentID; -using BatchID = Transport::BatchID; -using BufferEntry = Transport::BufferEntry; } // namespace mooncake #endif \ No newline at end of file diff --git a/mooncake-transfer-engine/include/transfer_engine_c.h b/mooncake-transfer-engine/include/transfer_engine_c.h index 62fb3c1..d669037 100644 --- a/mooncake-transfer-engine/include/transfer_engine_c.h +++ b/mooncake-transfer-engine/include/transfer_engine_c.h @@ -85,13 +85,13 @@ typedef struct segment_desc segment_desc_t; typedef void *transfer_engine_t; typedef void *transport_t; -transfer_engine_t createTransferEngine(const char *metadata_uri); +transfer_engine_t createTransferEngine(const char *metadata_conn_string, + const char *local_server_name, + const char *ip_or_host_name, + uint64_t rpc_port); -int initTransferEngine(transfer_engine_t engine, const char *local_server_name, - const char *connectable_name, uint64_t rpc_port); - -transport_t installOrGetTransport(transfer_engine_t engine, const char *proto, - void **args); +transport_t installTransport(transfer_engine_t engine, const char *proto, + void **args); int uninstallTransport(transfer_engine_t engine, const char *proto); @@ -113,15 +113,15 @@ int registerLocalMemoryBatch(transfer_engine_t engine, int unregisterLocalMemoryBatch(transfer_engine_t engine, void **addr_list, size_t addr_len); -batch_id_t allocateBatchID(transport_t xport, size_t batch_size); +batch_id_t allocateBatchID(transfer_engine_t engine, size_t batch_size); -int submitTransfer(transport_t xport, batch_id_t batch_id, +int submitTransfer(transfer_engine_t engine, batch_id_t batch_id, struct transfer_request *entries, size_t count); -int getTransferStatus(transport_t xport, batch_id_t batch_id, size_t task_id, - struct transfer_status *status); +int getTransferStatus(transfer_engine_t engine, batch_id_t batch_id, + size_t task_id, struct transfer_status *status); -int freeBatchID(transport_t xport, batch_id_t batch_id); +int freeBatchID(transfer_engine_t engine, batch_id_t batch_id); int syncSegmentCache(transfer_engine_t engine); diff --git a/mooncake-transfer-engine/include/transfer_metadata.h b/mooncake-transfer-engine/include/transfer_metadata.h index 059a204..19c6c32 100644 --- a/mooncake-transfer-engine/include/transfer_metadata.h +++ b/mooncake-transfer-engine/include/transfer_metadata.h @@ -29,9 +29,11 @@ #include #include "common.h" +#include "topology.h" namespace mooncake { -struct TransferMetadataImpl; +struct MetadataStoragePlugin; +struct HandShakePlugin; class TransferMetadata { public: @@ -55,14 +57,6 @@ class TransferMetadata { std::unordered_map local_path_map; }; - struct PriorityItem { - std::vector preferred_rnic_list; - std::vector available_rnic_list; - std::vector preferred_rnic_id_list; - std::vector available_rnic_id_list; - }; - - using PriorityMatrix = std::unordered_map; using SegmentID = uint64_t; struct SegmentDesc { @@ -70,7 +64,7 @@ class TransferMetadata { std::string protocol; // this is for rdma std::vector devices; - PriorityMatrix priority_matrix; + Topology topology; std::vector buffers; // this is for nvmeof. std::vector nvmeof_buffers; @@ -90,7 +84,7 @@ class TransferMetadata { }; public: - TransferMetadata(const std::string &metadata_uri, const std::string &protocol = "etcd"); + TransferMetadata(const std::string &conn_string); ~TransferMetadata(); @@ -105,7 +99,8 @@ class TransferMetadata { int updateSegmentDesc(const std::string &segment_name, const SegmentDesc &desc); - std::shared_ptr getSegmentDesc(const std::string &segment_name); + std::shared_ptr getSegmentDesc( + const std::string &segment_name); SegmentID getSegmentID(const std::string &segment_name); @@ -120,7 +115,7 @@ class TransferMetadata { int addLocalSegment(SegmentID segment_id, const std::string &segment_name, std::shared_ptr &&desc); - + int addRpcMetaEntry(const std::string &server_name, RpcMetaDesc &desc); int removeRpcMetaEntry(const std::string &server_name); @@ -138,22 +133,7 @@ class TransferMetadata { const HandShakeDesc &local_desc, HandShakeDesc &peer_desc); - static int parseNicPriorityMatrix(const std::string &nic_priority_matrix, - PriorityMatrix &priority_map, - std::vector &rnic_list); - - private: - int doSendHandshake(struct addrinfo *addr, const HandShakeDesc &local_desc, - HandShakeDesc &peer_desc); - - std::string encode(const HandShakeDesc &desc); - - int decode(const std::string &ser, HandShakeDesc &desc); - private: - std::atomic listener_running_; - std::thread listener_; - OnReceiveHandShake on_receive_handshake_; // local cache RWSpinlock segment_lock_; std::unordered_map> @@ -166,7 +146,8 @@ class TransferMetadata { std::atomic next_segment_id_; - std::shared_ptr impl_; + std::shared_ptr handshake_plugin_; + std::shared_ptr storage_plugin_; }; } // namespace mooncake diff --git a/mooncake-transfer-engine/include/transfer_metadata_plugin.h b/mooncake-transfer-engine/include/transfer_metadata_plugin.h new file mode 100644 index 0000000..53fa1d5 --- /dev/null +++ b/mooncake-transfer-engine/include/transfer_metadata_plugin.h @@ -0,0 +1,57 @@ +// Copyright 2024 KVCache.AI +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TRANSFER_METADATA_PLUGIN +#define TRANSFER_METADATA_PLUGIN + +#include "transfer_metadata.h" + +namespace mooncake { +struct MetadataStoragePlugin { + static std::shared_ptr Create( + const std::string &conn_string); + + MetadataStoragePlugin() {} + virtual ~MetadataStoragePlugin() {} + + virtual bool get(const std::string &key, Json::Value &value) = 0; + virtual bool set(const std::string &key, const Json::Value &value) = 0; + virtual bool remove(const std::string &key) = 0; +}; + +struct HandShakePlugin { + static std::shared_ptr Create( + const std::string &conn_string); + + HandShakePlugin() {} + virtual ~HandShakePlugin() {} + + // When accept a new connection, this function will be called. + // The first param represents peer endpoint's attributes, while + // the second param represents local endpoint's attributes + using OnReceiveCallBack = + std::function; + + virtual int startDaemon(OnReceiveCallBack on_recv_callback, + uint16_t listen_port) = 0; + + // Connect to peer endpoint, and wait for receiving + // peer endpoint's attributes + virtual int send(std::string ip_or_host_name, uint16_t rpc_port, + const Json::Value &local, Json::Value &peer) = 0; +}; + +} // namespace mooncake + +#endif // TRANSFER_METADATA_PLUGIN \ No newline at end of file diff --git a/mooncake-transfer-engine/include/transport/rdma_transport/rdma_transport.h b/mooncake-transfer-engine/include/transport/rdma_transport/rdma_transport.h index 87cd829..8b798cc 100644 --- a/mooncake-transfer-engine/include/transport/rdma_transport/rdma_transport.h +++ b/mooncake-transfer-engine/include/transport/rdma_transport/rdma_transport.h @@ -27,6 +27,7 @@ #include #include +#include "topology.h" #include "transfer_metadata.h" #include "transport/transport.h" @@ -74,6 +75,10 @@ class RdmaTransport : public Transport { int submitTransfer(BatchID batch_id, const std::vector &entries) override; + int submitTransferTask( + const std::vector &request_list, + const std::vector &task_list) override; + int getTransferStatus(BatchID batch_id, std::vector &status); @@ -83,8 +88,7 @@ class RdmaTransport : public Transport { SegmentID getSegmentID(const std::string &segment_name); private: - int allocateLocalSegmentID( - TransferMetadata::PriorityMatrix &priority_matrix); + int allocateLocalSegmentID(); public: int onSetupRdmaConnections(const HandShakeDesc &peer_desc, @@ -107,10 +111,9 @@ class RdmaTransport : public Transport { int &buffer_id, int &device_id, int retry_cnt = 0); private: - std::vector device_name_list_; std::vector> context_list_; - std::unordered_map device_name_to_index_map_; std::atomic next_segment_id_; + Topology local_topology_; }; using TransferRequest = Transport::TransferRequest; diff --git a/mooncake-transfer-engine/include/transport/tcp_transport/tcp_transport.h b/mooncake-transfer-engine/include/transport/tcp_transport/tcp_transport.h index 69ef92f..62aba16 100644 --- a/mooncake-transfer-engine/include/transport/tcp_transport/tcp_transport.h +++ b/mooncake-transfer-engine/include/transport/tcp_transport/tcp_transport.h @@ -49,6 +49,9 @@ class TcpTransport : public Transport { int submitTransfer(BatchID batch_id, const std::vector &entries) override; + int submitTransferTask(const std::vector &request_list, + const std::vector &task_list) override; + int getTransferStatus(BatchID batch_id, size_t task_id, TransferStatus &status) override; @@ -58,8 +61,9 @@ class TcpTransport : public Transport { int allocateLocalSegmentID(); - int registerLocalMemory(void *addr, size_t length, const std::string &location, - bool remote_accessible, bool update_metadata); + int registerLocalMemory(void *addr, size_t length, + const std::string &location, bool remote_accessible, + bool update_metadata); int unregisterLocalMemory(void *addr, bool update_metadata = false); diff --git a/mooncake-transfer-engine/include/transport/transport.h b/mooncake-transfer-engine/include/transport/transport.h index 0224f25..c59ee40 100644 --- a/mooncake-transfer-engine/include/transport/transport.h +++ b/mooncake-transfer-engine/include/transport/transport.h @@ -35,6 +35,7 @@ class TransferMetadata; /// failure. class Transport { friend class TransferEngine; + friend class MultiTransport; public: using SegmentID = uint64_t; @@ -163,6 +164,12 @@ class Transport { virtual int submitTransfer(BatchID batch_id, const std::vector &entries) = 0; + virtual int submitTransferTask( + const std::vector &request_list, + const std::vector &task_list) { + return ERR_NOT_IMPLEMENTED; + } + /// @brief Get the status of a submitted transfer. This function shall not /// be called again after completion. /// @return Return 1 on completed (either success or failure); 0 if still in @@ -207,4 +214,4 @@ class Transport { }; } // namespace mooncake -#endif // TRANSPORT_H_ \ No newline at end of file +#endif // TRANSPORT_H_ \ No newline at end of file diff --git a/mooncake-transfer-engine/rust/src/transfer_engine.rs b/mooncake-transfer-engine/rust/src/transfer_engine.rs index 5fde333..c9466fb 100644 --- a/mooncake-transfer-engine/rust/src/transfer_engine.rs +++ b/mooncake-transfer-engine/rust/src/transfer_engine.rs @@ -105,7 +105,7 @@ impl TransferEngine { std::ptr::null_mut(), ]; let xport = - unsafe { bindings::installOrGetTransport(engine, proto_c.as_ptr(), args.as_mut_ptr()) }; + unsafe { bindings::installTransport(engine, proto_c.as_ptr(), args.as_mut_ptr()) }; if xport.is_null() { unsafe { @@ -197,7 +197,7 @@ impl TransferEngine { } pub fn allocate_batch_id(&self, batch_size: usize) -> Result { - let ret = unsafe { bindings::allocateBatchID(self.xport, batch_size) }; + let ret = unsafe { bindings::allocateBatchID(self.engine, batch_size) }; if ret == u64::MAX { bail!("Failed to allocate batch ID") } else { @@ -221,7 +221,7 @@ impl TransferEngine { }) } let ret = unsafe { - bindings::submitTransfer(self.xport, batch_id, requests_c.as_mut_ptr(), requests.len()) + bindings::submitTransfer(self.engine, batch_id, requests_c.as_mut_ptr(), requests.len()) }; if ret < 0 { bail!("Failed to submit transfer") @@ -236,7 +236,7 @@ impl TransferEngine { transferred_bytes: 0, }; let ret = - unsafe { bindings::getTransferStatus(self.xport, batch_id, task_id as usize, &mut status) }; + unsafe { bindings::getTransferStatus(self.engine, batch_id, task_id as usize, &mut status) }; if ret < 0 { bail!("Failed to get transfer status") } else { @@ -245,7 +245,7 @@ impl TransferEngine { } pub fn free_batch_id(&self, batch_id: BatchID) -> Result<()> { - let ret = unsafe { bindings::freeBatchID(self.xport, batch_id) }; + let ret = unsafe { bindings::freeBatchID(self.engine, batch_id) }; if ret < 0 { bail!("Failed to free batch ID") } else { diff --git a/mooncake-transfer-engine/src/multi_transport.cpp b/mooncake-transfer-engine/src/multi_transport.cpp new file mode 100644 index 0000000..524db1a --- /dev/null +++ b/mooncake-transfer-engine/src/multi_transport.cpp @@ -0,0 +1,182 @@ +// Copyright 2024 KVCache.AI +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "multi_transport.h" + +#include "transport/rdma_transport/rdma_transport.h" +#include "transport/tcp_transport/tcp_transport.h" +#include "transport/transport.h" +#ifdef USE_CUDA +#include "transport/nvmeof_transport/nvmeof_transport.h" +#endif + +namespace mooncake { +MultiTransport::MultiTransport(std::shared_ptr metadata, + std::string &local_server_name) + : metadata_(metadata), local_server_name_(local_server_name) { + // ... +} + +MultiTransport::~MultiTransport() { + // ... +} + +MultiTransport::BatchID MultiTransport::allocateBatchID(size_t batch_size) { + auto batch_desc = new BatchDesc(); + if (!batch_desc) return ERR_MEMORY; + batch_desc->id = BatchID(batch_desc); + batch_desc->batch_size = batch_size; + batch_desc->task_list.reserve(batch_size); + batch_desc->context = NULL; +#ifdef CONFIG_USE_BATCH_DESC_SET + batch_desc_lock_.lock(); + batch_desc_set_[batch_desc->id] = batch_desc; + batch_desc_lock_.unlock(); +#endif + return batch_desc->id; +} + +int MultiTransport::freeBatchID(BatchID batch_id) { + auto &batch_desc = *((BatchDesc *)(batch_id)); + const size_t task_count = batch_desc.task_list.size(); + for (size_t task_id = 0; task_id < task_count; task_id++) { + if (!batch_desc.task_list[task_id].is_finished) { + LOG(ERROR) << "BatchID cannot be freed until all tasks are done"; + return ERR_BATCH_BUSY; + } + } + delete &batch_desc; +#ifdef CONFIG_USE_BATCH_DESC_SET + RWSpinlock::WriteGuard guard(batch_desc_lock_); + batch_desc_set_.erase(batch_id); +#endif + return 0; +} + +int MultiTransport::submitTransfer( + BatchID batch_id, const std::vector &entries) { + auto &batch_desc = *((BatchDesc *)(batch_id)); + if (batch_desc.task_list.size() + entries.size() > batch_desc.batch_size) { + LOG(ERROR) << "MultiTransport: Exceed the limitation of batch capacity"; + return ERR_TOO_MANY_REQUESTS; + } + + size_t task_id = batch_desc.task_list.size(); + batch_desc.task_list.resize(task_id + entries.size()); + struct SubmitTasks { + std::vector request_list; + std::vector task_list; + }; + std::unordered_map submit_tasks; + for (auto &request : entries) { + auto transport = selectTransport(request); + if (!transport) return ERR_INVALID_ARGUMENT; + auto &task = batch_desc.task_list[task_id]; + ++task_id; + submit_tasks[transport].request_list.push_back( + (TransferRequest *)&request); + submit_tasks[transport].task_list.push_back(&task); + } + for (auto &entry : submit_tasks) { + int ret = entry.first->submitTransferTask(entry.second.request_list, + entry.second.task_list); + if (ret) { + LOG(ERROR) << "MultiTransport: Failed to submit transfer task to " + << entry.first->getName(); + return ret; + } + } + return 0; +} + +int MultiTransport::getTransferStatus(BatchID batch_id, size_t task_id, + TransferStatus &status) { + auto &batch_desc = *((BatchDesc *)(batch_id)); + const size_t task_count = batch_desc.task_list.size(); + if (task_id >= task_count) return ERR_INVALID_ARGUMENT; + auto &task = batch_desc.task_list[task_id]; + status.transferred_bytes = task.transferred_bytes; + uint64_t success_slice_count = task.success_slice_count; + uint64_t failed_slice_count = task.failed_slice_count; + if (success_slice_count + failed_slice_count == + (uint64_t)task.slices.size()) { + if (failed_slice_count) { + status.s = Transport::TransferStatusEnum::FAILED; + } else { + status.s = Transport::TransferStatusEnum::COMPLETED; + } + task.is_finished = true; + } else { + status.s = Transport::TransferStatusEnum::WAITING; + } + return 0; +} + +Transport *MultiTransport::installTransport(const std::string &proto, + void **args) { + Transport *transport = nullptr; + if (std::string(proto) == "rdma") { + transport = new RdmaTransport(); + } else if (std::string(proto) == "tcp") { + transport = new TcpTransport(); + } +#ifdef USE_CUDA + else if (std::string(proto) == "nvmeof") { + transport = new NVMeoFTransport(); + } +#endif + + if (!transport) { + LOG(ERROR) << "MultiTransport: Failed to initialize transport " + << proto; + return nullptr; + } + + if (transport->install(local_server_name_, metadata_, args)) { + return nullptr; + } + + transport_map_[proto] = transport; + return transport; +} + +Transport *MultiTransport::selectTransport(const TransferRequest &entry) { + if (entry.target_id == LOCAL_SEGMENT_ID && transport_map_.count("local")) + return transport_map_["local"]; + auto target_segment_desc = metadata_->getSegmentDescByID(entry.target_id); + if (!target_segment_desc) { + LOG(ERROR) << "MultiTransport: Incorrect target segment id " + << entry.target_id; + return nullptr; + } + auto proto = target_segment_desc->protocol; + if (!transport_map_.count(proto)) { + LOG(ERROR) << "MultiTransport: Transport " << proto << " not installed"; + return nullptr; + } + return transport_map_[proto]; +} + +Transport *MultiTransport::getTransport(const std::string &proto) { + if (!transport_map_.count(proto)) return nullptr; + return transport_map_[proto]; +} + +std::vector MultiTransport::listTransports() { + std::vector transport_list; + for (auto &entry : transport_map_) transport_list.push_back(entry.second); + return transport_list; +} + +} // namespace mooncake \ No newline at end of file diff --git a/mooncake-transfer-engine/src/topology.cpp b/mooncake-transfer-engine/src/topology.cpp index 3827689..41f5627 100644 --- a/mooncake-transfer-engine/src/topology.cpp +++ b/mooncake-transfer-engine/src/topology.cpp @@ -1,9 +1,24 @@ +// Copyright 2024 KVCache.AI +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include #include #include #include #include +#include #include #include #include @@ -21,40 +36,20 @@ #include "topology.h" +namespace mooncake { struct InfinibandDevice { std::string name; std::string pci_bus_id; int numa_node; }; -struct TopologyEntry { - std::string name; - std::vector preferred_hca; - std::vector avail_hca; - - Json::Value to_json() { - Json::Value matrix(Json::arrayValue); - Json::Value hca_list(Json::arrayValue); - for (auto &hca : preferred_hca) { - hca_list.append(hca); - } - matrix.append(hca_list); - hca_list.clear(); - for (auto &hca : avail_hca) { - hca_list.append(hca); - } - matrix.append(hca_list); - return matrix; - } -}; - -static std::vector list_infiniband_devices() { +static std::vector listInfiniBandDevices() { DIR *dir = opendir("/sys/class/infiniband"); struct dirent *entry; std::vector devices; if (dir == NULL) { - LOG(WARNING) << "failed to list /sys/class/infiniband"; + LOG(WARNING) << "Failed to list /sys/class/infiniband"; return {}; } while ((entry = readdir(dir))) { @@ -64,7 +59,7 @@ static std::vector list_infiniband_devices() { std::string device_name = entry->d_name; - char path[PATH_MAX]; + char path[PATH_MAX + 32]; char resolved_path[PATH_MAX]; // Get the PCI bus id for the infiniband device. Note that // "/sys/class/infiniband/mlx5_X/" is a symlink to @@ -72,7 +67,7 @@ static std::vector list_infiniband_devices() { snprintf(path, sizeof(path), "/sys/class/infiniband/%s/../..", entry->d_name); if (realpath(path, resolved_path) == NULL) { - LOG(ERROR) << "realpath: " << strerror(errno); + PLOG(ERROR) << "Failed to parse realpath"; continue; } std::string pci_bus_id = basename(resolved_path); @@ -89,14 +84,14 @@ static std::vector list_infiniband_devices() { return devices; } -static std::vector discover_cpu_topology( +static std::vector discoverCpuTopology( const std::vector &all_hca) { DIR *dir = opendir("/sys/devices/system/node"); struct dirent *entry; std::vector topology; if (dir == NULL) { - LOG(WARNING) << "failed to list /sys/devices/system/node"; + LOG(WARNING) << "Failed to list /sys/devices/system/node"; return {}; } while ((entry = readdir(dir))) { @@ -127,7 +122,7 @@ static std::vector discover_cpu_topology( #ifdef USE_CUDA -static int get_pci_distance(const char *bus1, const char *bus2) { +static int getPciDistance(const char *bus1, const char *bus2) { char buf[PATH_MAX]; char path1[PATH_MAX]; char path2[PATH_MAX]; @@ -157,7 +152,7 @@ static int get_pci_distance(const char *bus1, const char *bus2) { return distance; } -static std::vector discover_cuda_topology( +static std::vector discoverCudaTopology( const std::vector &all_hca) { std::vector topology; int device_count; @@ -170,15 +165,14 @@ static std::vector discover_cuda_topology( cudaSuccess) { continue; } - for (char *ch = pci_bus_id; (*ch = tolower(*ch)); ch++) - ; + for (char *ch = pci_bus_id; (*ch = tolower(*ch)); ch++); std::vector preferred_hca; std::vector avail_hca; for (const auto &hca : all_hca) { // FIXME: currently we only identify the NICs connected to the same // PCIe switch/RC with GPU as preferred. - if (get_pci_distance(hca.pci_bus_id.c_str(), pci_bus_id) == 0) { + if (getPciDistance(hca.pci_bus_id.c_str(), pci_bus_id) == 0) { preferred_hca.push_back(hca.name); } else { avail_hca.push_back(hca.name); @@ -194,19 +188,125 @@ static std::vector discover_cuda_topology( #endif // USE_CUDA -namespace mooncake { -// TODO: add black/white lists for devices. -std::string discoverTopologyMatrix() { - auto all_hca = list_infiniband_devices(); - Json::Value value(Json::objectValue); - for (auto &ent : discover_cpu_topology(all_hca)) { - value[ent.name] = ent.to_json(); +Topology::Topology() {} + +Topology::~Topology() {} + +bool Topology::empty() const { return matrix_.empty(); } + +void Topology::clear() { + matrix_.clear(); + hca_list_.clear(); + resolved_matrix_.clear(); +} + +int Topology::discover() { + matrix_.clear(); + auto all_hca = listInfiniBandDevices(); + for (auto &ent : discoverCpuTopology(all_hca)) { + matrix_[ent.name] = ent; } #ifdef USE_CUDA - for (auto &ent : discover_cuda_topology(all_hca)) { - value[ent.name] = ent.to_json(); + for (auto &ent : discoverCudaTopology(all_hca)) { + matrix_[ent.name] = ent; } #endif + return resolve(); +} + +int Topology::parse(const std::string &topology_json) { + std::set rnic_set; + Json::Value root; + Json::Reader reader; + + if (topology_json.empty() || !reader.parse(topology_json, root)) { + LOG(ERROR) << "Topology: malformed json format"; + return ERR_MALFORMED_JSON; + } + + matrix_.clear(); + for (const auto &key : root.getMemberNames()) { + const Json::Value &value = root[key]; + if (value.isArray() && value.size() == 2) { + TopologyEntry topo_entry; + topo_entry.name = key; + for (const auto &array : value[0]) { + auto device_name = array.asString(); + topo_entry.preferred_hca.push_back(device_name); + } + for (const auto &array : value[1]) { + auto device_name = array.asString(); + topo_entry.avail_hca.push_back(device_name); + } + matrix_[key] = topo_entry; + } else { + LOG(ERROR) << "Topology: malformed json format"; + return ERR_MALFORMED_JSON; + } + } + + return resolve(); +} + +std::string Topology::toString() const { + Json::Value value(Json::objectValue); + for (auto &entry : matrix_) { + value[entry.first] = entry.second.toJson(); + } return value.toStyledString(); } + +Json::Value Topology::toJson() const { + Json::Value root; + Json::Reader reader; + reader.parse(toString(), root); + return root; +} + +int Topology::selectDevice(const std::string storage_type, int retry_count) { + if (!resolved_matrix_.count(storage_type)) return ERR_DEVICE_NOT_FOUND; + auto &entry = resolved_matrix_[storage_type]; + if (retry_count == 0) { + int rand_value = SimpleRandom::Get().next(); + if (!entry.preferred_hca.empty()) + return entry.preferred_hca[rand_value % entry.preferred_hca.size()]; + else + return entry.avail_hca[rand_value % entry.avail_hca.size()]; + } else { + size_t index = (retry_count - 1) % + (entry.preferred_hca.size() + entry.avail_hca.size()); + if (index < entry.preferred_hca.size()) + return entry.preferred_hca[index]; + else { + index -= entry.preferred_hca.size(); + return entry.avail_hca[index]; + } + } + return 0; +} + +int Topology::resolve() { + std::map hca_id_map; + int next_hca_map_index = 0; + for (auto &entry : matrix_) { + for (auto &hca : entry.second.preferred_hca) { + if (!hca_id_map.count(hca)) { + hca_list_.push_back(hca); + hca_id_map[hca] = next_hca_map_index; + next_hca_map_index++; + } + resolved_matrix_[entry.first].preferred_hca.push_back( + hca_id_map[hca]); + } + for (auto &hca : entry.second.avail_hca) { + if (!hca_id_map.count(hca)) { + hca_list_.push_back(hca); + hca_id_map[hca] = next_hca_map_index; + next_hca_map_index++; + } + resolved_matrix_[entry.first].avail_hca.push_back(hca_id_map[hca]); + } + } + return 0; +} } // namespace mooncake diff --git a/mooncake-transfer-engine/src/transfer_engine.cpp b/mooncake-transfer-engine/src/transfer_engine.cpp index 8b88351..6b3768b 100644 --- a/mooncake-transfer-engine/src/transfer_engine.cpp +++ b/mooncake-transfer-engine/src/transfer_engine.cpp @@ -14,115 +14,100 @@ #include "transfer_engine.h" -#include "transport/rdma_transport/rdma_transport.h" -#include "transport/tcp_transport/tcp_transport.h" #include "transport/transport.h" -#ifdef USE_CUDA -#include "transport/nvmeof_transport/nvmeof_transport.h" -#endif namespace mooncake { -int TransferEngine::init(const char *server_name, const char *connectable_name, +int TransferEngine::init(const std::string &metadata_conn_string, + const std::string &local_server_name, + const std::string &ip_or_host_name, uint64_t rpc_port) { - local_server_name_ = server_name; - assert(metadata_); + local_server_name_ = local_server_name; + metadata_ = std::make_shared(metadata_conn_string); + multi_transports_ = + std::make_shared(metadata_, local_server_name_); TransferMetadata::RpcMetaDesc desc; - desc.ip_or_host_name = connectable_name; + desc.ip_or_host_name = ip_or_host_name; desc.rpc_port = rpc_port; - return metadata_->addRpcMetaEntry(server_name, desc); + return metadata_->addRpcMetaEntry(local_server_name_, desc); } int TransferEngine::freeEngine() { - while (!installed_transports_.empty()) { - auto proto = installed_transports_.back()->getName(); - if (uninstallTransport(proto) < 0) - LOG(ERROR) << "Failed to uninstall transport " << proto; + if (metadata_) { + metadata_->removeRpcMetaEntry(local_server_name_); + metadata_.reset(); } - metadata_->removeRpcMetaEntry(local_server_name_); return 0; } -Transport *TransferEngine::installOrGetTransport(const char *proto, - void **args) { - Transport *xport = initTransport(proto); - if (!xport) { - LOG(ERROR) << "Failed to initialize transport " << proto; - return nullptr; +Transport *TransferEngine::installTransport(const std::string &proto, + void **args) { + Transport *transport = multi_transports_->getTransport(proto); + if (transport) { + LOG(INFO) << "Transport " << proto << " already installed"; + return transport; } - - if (xport->install(local_server_name_, metadata_, args) < 0) goto fail; - - installed_transports_.emplace_back(xport); - for (const auto &mem : local_memory_regions_) - if (xport->registerLocalMemory(mem.addr, mem.length, mem.location, - mem.remote_accessible) < 0) - goto fail; - - return xport; - -fail: - delete xport; - return NULL; -} - -int TransferEngine::uninstallTransport(const char *proto) { - for (auto it = installed_transports_.begin(); - it != installed_transports_.end(); ++it) { - if (strcmp((*it)->getName(), proto) == 0) { - delete *it; - installed_transports_.erase(it); - return 0; - } + transport = multi_transports_->installTransport(proto, args); + if (!transport) return nullptr; + for (auto &entry : local_memory_regions_) { + int ret = transport->registerLocalMemory( + entry.addr, entry.length, entry.location, entry.remote_accessible); + if (ret < 0) return nullptr; } - return ERR_INVALID_ARGUMENT; + return transport; } -Transport::SegmentHandle TransferEngine::openSegment(const char *segment_name) { - if (!segment_name) - return ERR_INVALID_ARGUMENT; +int TransferEngine::uninstallTransport(const std::string &proto) { return 0; } + +Transport::SegmentHandle TransferEngine::openSegment( + const std::string &segment_name) { + if (segment_name.empty()) return ERR_INVALID_ARGUMENT; std::string trimmed_segment_name = segment_name; while (!trimmed_segment_name.empty() && trimmed_segment_name[0] == '/') trimmed_segment_name.erase(0, 1); - if (trimmed_segment_name.empty()) - return ERR_INVALID_ARGUMENT; + if (trimmed_segment_name.empty()) return ERR_INVALID_ARGUMENT; return metadata_->getSegmentID(trimmed_segment_name); } -int TransferEngine::closeSegment(Transport::SegmentHandle seg_id) { - // Not used - return 0; +int TransferEngine::closeSegment(Transport::SegmentHandle handle) { return 0; } + +bool TransferEngine::checkOverlap(void *addr, uint64_t length) { + for (auto &local_memory_region : local_memory_regions_) { + if (overlap(addr, length, local_memory_region.addr, + local_memory_region.length)) { + return true; + } + } + return false; } int TransferEngine::registerLocalMemory(void *addr, size_t length, const std::string &location, bool remote_accessible, bool update_metadata) { - for (auto &local_memory_region : local_memory_regions_) { - if (overlap(addr, length, local_memory_region.addr, - local_memory_region.length)) { - LOG(ERROR) << "Memory region overlap"; - return ERR_ADDRESS_OVERLAPPED; - } + if (checkOverlap(addr, length)) { + LOG(ERROR) + << "Transfer Engine does not support overlapped memory region"; + return ERR_ADDRESS_OVERLAPPED; } - for (auto &xport : installed_transports_) { - int ret = xport->registerLocalMemory( + for (auto transport : multi_transports_->listTransports()) { + int ret = transport->registerLocalMemory( addr, length, location, remote_accessible, update_metadata); if (ret < 0) return ret; } local_memory_regions_.push_back( - {addr, length, location.c_str(), remote_accessible}); + {addr, length, location, remote_accessible}); return 0; } int TransferEngine::unregisterLocalMemory(void *addr, bool update_metadata) { + for (auto &transport : multi_transports_->listTransports()) { + int ret = transport->unregisterLocalMemory(addr, update_metadata); + if (ret) return ret; + } for (auto it = local_memory_regions_.begin(); it != local_memory_regions_.end(); ++it) { if (it->addr == addr) { - for (auto &xport : installed_transports_) { - if (xport->unregisterLocalMemory(addr, update_metadata) < 0) - return ERR_MEMORY; - } local_memory_regions_.erase(it); break; } @@ -132,43 +117,39 @@ int TransferEngine::unregisterLocalMemory(void *addr, bool update_metadata) { int TransferEngine::registerLocalMemoryBatch( const std::vector &buffer_list, const std::string &location) { - for (auto &xport : installed_transports_) { - int ret = xport->registerLocalMemoryBatch(buffer_list, location); + for (auto &buffer : buffer_list) { + if (checkOverlap(buffer.addr, buffer.length)) { + LOG(ERROR) + << "Transfer Engine does not support overlapped memory region"; + return ERR_ADDRESS_OVERLAPPED; + } + } + for (auto transport : multi_transports_->listTransports()) { + int ret = transport->registerLocalMemoryBatch(buffer_list, location); if (ret < 0) return ret; } + for (auto &buffer : buffer_list) { + local_memory_regions_.push_back( + {buffer.addr, buffer.length, location, true}); + } return 0; } int TransferEngine::unregisterLocalMemoryBatch( const std::vector &addr_list) { - for (auto &xport : installed_transports_) { - int ret = xport->unregisterLocalMemoryBatch(addr_list); + for (auto transport : multi_transports_->listTransports()) { + int ret = transport->unregisterLocalMemoryBatch(addr_list); if (ret < 0) return ret; } - return 0; -} - -Transport *TransferEngine::findName(const char *name, size_t n) { - for (const auto &xport : installed_transports_) { - if (strncmp(xport->getName(), name, n) == 0) return xport; - } - return nullptr; -} - -Transport *TransferEngine::initTransport(const char *proto) { - if (std::string(proto) == "rdma") { - return new RdmaTransport(); - } else if (std::string(proto) == "tcp") { - return new TcpTransport(); - } -#ifdef USE_CUDA - else if (std::string(proto) == "nvmeof") { - return new NVMeoFTransport(); - } -#endif - else { - LOG(ERROR) << "Unsupported Transport Protocol: " << proto; - return NULL; + for (auto &addr : addr_list) { + for (auto it = local_memory_regions_.begin(); + it != local_memory_regions_.end(); ++it) { + if (it->addr == addr) { + local_memory_regions_.erase(it); + break; + } + } } + return 0; } } // namespace mooncake diff --git a/mooncake-transfer-engine/src/transfer_engine_c.cpp b/mooncake-transfer-engine/src/transfer_engine_c.cpp index e22cd31..95f9920 100644 --- a/mooncake-transfer-engine/src/transfer_engine_c.cpp +++ b/mooncake-transfer-engine/src/transfer_engine_c.cpp @@ -22,22 +22,24 @@ using namespace mooncake; -transfer_engine_t createTransferEngine(const char *metadata_uri) { - auto metadata_client = std::make_shared(metadata_uri); - TransferEngine *native = new TransferEngine(metadata_client); +transfer_engine_t createTransferEngine(const char *metadata_conn_string, + const char *local_server_name, + const char *ip_or_host_name, + uint64_t rpc_port) { + TransferEngine *native = new TransferEngine(); + int ret = native->init(metadata_conn_string, local_server_name, + ip_or_host_name, rpc_port); + if (ret) { + delete native; + return nullptr; + } return (transfer_engine_t)native; } -int initTransferEngine(transfer_engine_t engine, const char *local_server_name, - const char *connectable_name, uint64_t rpc_port) { - TransferEngine *native = (TransferEngine *)engine; - return native->init(local_server_name, connectable_name, rpc_port); -} - -transport_t installOrGetTransport(transfer_engine_t engine, const char *proto, - void **args) { +transport_t installTransport(transfer_engine_t engine, const char *proto, + void **args) { TransferEngine *native = (TransferEngine *)engine; - return (transport_t)native->installOrGetTransport(proto, args); + return (transport_t)native->installTransport(proto, args); } int uninstallTransport(transfer_engine_t engine, const char *proto) { @@ -95,14 +97,14 @@ int unregisterLocalMemoryBatch(transfer_engine_t engine, void **addr_list, return native->unregisterLocalMemoryBatch(native_addr_list); } -batch_id_t allocateBatchID(transport_t xport, size_t batch_size) { - Transport *native = (Transport *)xport; +batch_id_t allocateBatchID(transfer_engine_t engine, size_t batch_size) { + TransferEngine *native = (TransferEngine *)engine; return (batch_id_t)native->allocateBatchID(batch_size); } -int submitTransfer(transport_t xport, batch_id_t batch_id, +int submitTransfer(transfer_engine_t engine, batch_id_t batch_id, struct transfer_request *entries, size_t count) { - Transport *native = (Transport *)xport; + TransferEngine *native = (TransferEngine *)engine; std::vector native_entries; native_entries.resize(count); for (size_t index = 0; index < count; index++) { @@ -116,9 +118,9 @@ int submitTransfer(transport_t xport, batch_id_t batch_id, return native->submitTransfer((Transport::BatchID)batch_id, native_entries); } -int getTransferStatus(transport_t xport, batch_id_t batch_id, size_t task_id, - struct transfer_status *status) { - Transport *native = (Transport *)xport; +int getTransferStatus(transfer_engine_t engine, batch_id_t batch_id, + size_t task_id, struct transfer_status *status) { + TransferEngine *native = (TransferEngine *)engine; Transport::TransferStatus native_status; int rc = native->getTransferStatus((Transport::BatchID)batch_id, task_id, native_status); @@ -129,8 +131,8 @@ int getTransferStatus(transport_t xport, batch_id_t batch_id, size_t task_id, return rc; } -int freeBatchID(transport_t xport, batch_id_t batch_id) { - Transport *native = (Transport *)xport; +int freeBatchID(transfer_engine_t engine, batch_id_t batch_id) { + TransferEngine *native = (TransferEngine *)engine; return native->freeBatchID(batch_id); } diff --git a/mooncake-transfer-engine/src/transfer_metadata.cpp b/mooncake-transfer-engine/src/transfer_metadata.cpp index 0b5b968..eafe88f 100644 --- a/mooncake-transfer-engine/src/transfer_metadata.cpp +++ b/mooncake-transfer-engine/src/transfer_metadata.cpp @@ -14,19 +14,7 @@ #include "transfer_metadata.h" -#include -#include #include -#include -#include - -#ifdef USE_REDIS -#include -#endif - -#ifdef USE_HTTP -#include -#endif #include #include @@ -34,331 +22,62 @@ #include "common.h" #include "config.h" #include "error.h" +#include "transfer_metadata_plugin.h" namespace mooncake { -const static std::string kRpcMetaPrefix = "mooncake/rpc_meta/"; +const static std::string kCommonKeyPrefix = "mooncake/"; +const static std::string kRpcMetaPrefix = kCommonKeyPrefix + "rpc_meta/"; + // mooncake/segments/[...] static inline std::string getFullMetadataKey(const std::string &segment_name) { - const static std::string keyPrefix = "mooncake/"; auto pos = segment_name.find("/"); if (pos == segment_name.npos) - return keyPrefix + "ram/" + segment_name; + return kCommonKeyPrefix + "ram/" + segment_name; else - return keyPrefix + segment_name; + return kCommonKeyPrefix + segment_name; } -struct TransferMetadataImpl { - TransferMetadataImpl() {} - virtual ~TransferMetadataImpl() {} - virtual bool get(const std::string &key, Json::Value &value) = 0; - virtual bool set(const std::string &key, const Json::Value &value) = 0; - virtual bool remove(const std::string &key) = 0; -}; - -#ifdef USE_REDIS -struct TransferMetadataImpl4Redis : public TransferMetadataImpl { - TransferMetadataImpl4Redis(const std::string &metadata_uri) - : client_(nullptr), metadata_uri_(metadata_uri) { - auto hostname_port = parseHostNameWithPort(metadata_uri); - client_ = redisConnect(hostname_port.first.c_str(), hostname_port.second); - if (!client_ || client_->err) { - LOG(ERROR) << "redis error: " << client_->errstr; - } - } - - virtual ~TransferMetadataImpl4Redis() {} - - virtual bool get(const std::string &key, Json::Value &value) { - Json::Reader reader; - redisReply *resp = (redisReply *) redisCommand(client_, "GET %s", key.c_str()); - if (!resp) { - LOG(ERROR) << "Error from redis client uri: " << metadata_uri_; - return false; - } - auto json_file = std::string(resp->str); - freeReplyObject(resp); - if (!reader.parse(json_file, value)) return false; - if (globalConfig().verbose) - LOG(INFO) << "Get segment desc, key=" << key - << ", value=" << json_file; - return true; - } - - virtual bool set(const std::string &key, const Json::Value &value) { - Json::FastWriter writer; - const std::string json_file = writer.write(value); - if (globalConfig().verbose) - LOG(INFO) << "Put segment desc, key=" << key - << ", value=" << json_file; - - redisReply *resp = (redisReply *) redisCommand(client_,"SET %s %s", key.c_str(), json_file.c_str()); - if (!resp) { - LOG(ERROR) << "Error from redis client uri: " << metadata_uri_; - return false; - } - freeReplyObject(resp); - return true; - } - - virtual bool remove(const std::string &key) { - redisReply *resp = (redisReply *) redisCommand(client_,"DEL %s", key.c_str()); - if (!resp) { - LOG(ERROR) << "Error from redis client uri: " << metadata_uri_; - return false; - } - freeReplyObject(resp); - return true; - } - - redisContext *client_; - const std::string metadata_uri_; -}; -#endif // USE_REDIS - -#ifdef USE_HTTP -struct TransferMetadataImpl4Http: public TransferMetadataImpl { - TransferMetadataImpl4Http(const std::string &metadata_uri) - : client_(nullptr), metadata_uri_(metadata_uri) { - curl_global_init(CURL_GLOBAL_ALL); - client_ = curl_easy_init(); - if (!client_) { - LOG(ERROR) << "Cannot allocate CURL objects"; - exit(EXIT_FAILURE); - } - } - - virtual ~TransferMetadataImpl4Http() { - curl_easy_cleanup(client_); - curl_global_cleanup(); - } - - static size_t writeCallback(void* contents, size_t size, size_t nmemb, std::string* userp) { - userp->append(static_cast(contents), size * nmemb); - return size * nmemb; - } - - std::string encodeUrl(const std::string &key) { - char *newkey = curl_easy_escape(client_, key.c_str(), key.size()); - std::string encodedKey(newkey); - std::string url = metadata_uri_ + "?key=" + encodedKey; - curl_free(newkey); - return url; - } - - virtual bool get(const std::string &key, Json::Value &value) { - curl_easy_reset(client_); - curl_easy_setopt(client_, CURLOPT_TIMEOUT_MS, 3000); // 3s timeout - - std::string url = encodeUrl(key); - curl_easy_setopt(client_, CURLOPT_URL, url.c_str()); - curl_easy_setopt(client_, CURLOPT_WRITEFUNCTION, writeCallback); - - // get response body - std::string readBuffer; - curl_easy_setopt(client_, CURLOPT_WRITEDATA, &readBuffer); - CURLcode res = curl_easy_perform(client_); - if (res != CURLE_OK) { - LOG(ERROR) << "Error from http client, GET " << url - << " error: " << curl_easy_strerror(res); - return false; - } - - // Get the HTTP response code - long responseCode; - curl_easy_getinfo(client_, CURLINFO_RESPONSE_CODE, &responseCode); - if (responseCode != 200) { - LOG(ERROR) << "Unexpected code in http response, GET " << url - << " response code: " << responseCode - << " response body: " << readBuffer; - return false; - } - - if (globalConfig().verbose) - LOG(INFO) << "Get segment desc, key=" << key - << ", value=" << readBuffer; - - Json::Reader reader; - if (!reader.parse(readBuffer, value)) return false; - return true; - } - - virtual bool set(const std::string &key, const Json::Value &value) { - curl_easy_reset(client_); - curl_easy_setopt(client_, CURLOPT_TIMEOUT_MS, 3000); // 3s timeout - - Json::FastWriter writer; - const std::string json_file = writer.write(value); - if (globalConfig().verbose) - LOG(INFO) << "Put segment desc, key=" << key - << ", value=" << json_file; - - std::string url = encodeUrl(key); - curl_easy_setopt(client_, CURLOPT_URL, url.c_str()); - curl_easy_setopt(client_, CURLOPT_WRITEFUNCTION, writeCallback); - curl_easy_setopt(client_, CURLOPT_POSTFIELDS, json_file.c_str()); - curl_easy_setopt(client_, CURLOPT_POSTFIELDSIZE, json_file.size()); - curl_easy_setopt(client_, CURLOPT_CUSTOMREQUEST, "PUT"); - - // get response body - std::string readBuffer; - curl_easy_setopt(client_, CURLOPT_WRITEDATA, &readBuffer); - - // set content-type to application/json - struct curl_slist *headers = NULL; - headers = curl_slist_append(headers, "Content-Type: application/json"); - curl_easy_setopt(client_, CURLOPT_HTTPHEADER, headers); - CURLcode res = curl_easy_perform(client_); - curl_slist_free_all(headers); // Free headers - if (res != CURLE_OK) { - LOG(ERROR) << "Error from http client, PUT " << url - << " error: " << curl_easy_strerror(res); - return false; - } - - // Get the HTTP response code - long responseCode; - curl_easy_getinfo(client_, CURLINFO_RESPONSE_CODE, &responseCode); - if (responseCode != 200) { - LOG(ERROR) << "Unexpected code in http response, PUT " << url - << " response code: " << responseCode - << " response body: " << readBuffer; - return false; - } - - return true; +struct TransferHandshakeUtil { + static Json::Value encode(const TransferMetadata::HandShakeDesc &desc) { + Json::Value root; + root["local_nic_path"] = desc.local_nic_path; + root["peer_nic_path"] = desc.peer_nic_path; + Json::Value qpNums(Json::arrayValue); + for (const auto &qp : desc.qp_num) qpNums.append(qp); + root["qp_num"] = qpNums; + root["reply_msg"] = desc.reply_msg; + return root; } - virtual bool remove(const std::string &key) { - curl_easy_reset(client_); - curl_easy_setopt(client_, CURLOPT_TIMEOUT_MS, 3000); // 3s timeout - - if (globalConfig().verbose) - LOG(INFO) << "Remove segment desc, key=" << key; - - std::string url = encodeUrl(key); - curl_easy_setopt(client_, CURLOPT_URL, url.c_str()); - curl_easy_setopt(client_, CURLOPT_WRITEFUNCTION, writeCallback); - curl_easy_setopt(client_, CURLOPT_CUSTOMREQUEST, "DELETE"); - - // get response body - std::string readBuffer; - curl_easy_setopt(client_, CURLOPT_WRITEDATA, &readBuffer); - CURLcode res = curl_easy_perform(client_); - if (res != CURLE_OK) { - LOG(ERROR) << "Error from http client, DELETE " << url - << " error: " << curl_easy_strerror(res); - return false; - } - - // Get the HTTP response code - long responseCode; - curl_easy_getinfo(client_, CURLINFO_RESPONSE_CODE, &responseCode); - if (responseCode != 200) { - LOG(ERROR) << "Unexpected code in http response, DELETE " << url - << " response code: " << responseCode - << " response body: " << readBuffer; - return false; - } - return true; - } - - CURL *client_; - const std::string metadata_uri_; -}; -#endif // USE_HTTP - -struct TransferMetadataImpl4Etcd : public TransferMetadataImpl { - TransferMetadataImpl4Etcd(const std::string &metadata_uri) - : client_(metadata_uri), metadata_uri_(metadata_uri) {} - - virtual ~TransferMetadataImpl4Etcd() {} - - virtual bool get(const std::string &key, Json::Value &value) { + static int decode(Json::Value root, TransferMetadata::HandShakeDesc &desc) { Json::Reader reader; - auto resp = client_.get(key); - if (!resp.is_ok()) { - LOG(ERROR) << "Error from etcd client, etcd uri: " << metadata_uri_ - << " error: " << resp.error_code() - << " message: " << resp.error_message(); - return false; + desc.local_nic_path = root["local_nic_path"].asString(); + desc.peer_nic_path = root["peer_nic_path"].asString(); + for (const auto &qp : root["qp_num"]) + desc.qp_num.push_back(qp.asUInt()); + desc.reply_msg = root["reply_msg"].asString(); + if (globalConfig().verbose) { + LOG(INFO) << "TransferHandshakeUtil::decode: local_nic_path " + << desc.local_nic_path << " peer_nic_path " + << desc.peer_nic_path << " qp_num count " + << desc.qp_num.size(); } - auto json_file = resp.value().as_string(); - if (!reader.parse(json_file, value)) return false; - if (globalConfig().verbose) - LOG(INFO) << "Get segment desc, key=" << key - << ", value=" << json_file; - return true; - } - - virtual bool set(const std::string &key, const Json::Value &value) { - Json::FastWriter writer; - const std::string json_file = writer.write(value); - if (globalConfig().verbose) - LOG(INFO) << "Put segment desc, key=" << key - << ", value=" << json_file; - auto resp = client_.put(key, json_file); - if (!resp.is_ok()) { - LOG(ERROR) << "Error from etcd client, etcd uri: " << metadata_uri_ - << " error: " << resp.error_code() - << " message: " << resp.error_message(); - return false; - } - return true; - } - - virtual bool remove(const std::string &key) { - auto resp = client_.rm(key); - if (!resp.is_ok()) { - LOG(ERROR) << "Error from etcd client, etcd uri: " << metadata_uri_ - << " error: " << resp.error_code() - << " message: " << resp.error_message(); - return false; - } - return true; + return 0; } - - etcd::SyncClient client_; - const std::string metadata_uri_; }; -TransferMetadata::TransferMetadata(const std::string &metadata_uri, const std::string &protocol) - : listener_running_(false) { - if (protocol == "etcd") { - impl_ = std::make_shared(metadata_uri); - if (!impl_) { - LOG(ERROR) << "Cannot allocate TransferMetadataImpl objects"; - exit(EXIT_FAILURE); - } -#ifdef USE_REDIS - } else if (protocol == "redis") { - impl_ = std::make_shared(metadata_uri); - if (!impl_) { - LOG(ERROR) << "Cannot allocate TransferMetadataImpl objects"; - exit(EXIT_FAILURE); - } -#endif // USE_REDIS -#ifdef USE_HTTP - } else if (protocol == "http") { - impl_ = std::make_shared(metadata_uri); - if (!impl_) { - LOG(ERROR) << "Cannot allocate TransferMetadataImpl objects"; - exit(EXIT_FAILURE); - } -#endif // USE_HTTP - } else { - LOG(ERROR) << "Unsupported metdata protocol " << protocol; - exit(EXIT_FAILURE); +TransferMetadata::TransferMetadata(const std::string &conn_string) { + handshake_plugin_ = HandShakePlugin::Create(conn_string); + storage_plugin_ = MetadataStoragePlugin::Create(conn_string); + if (!handshake_plugin_ || !storage_plugin_) { + LOG(ERROR) << "Unable to create metadata plugins with conn string " + << conn_string; } next_segment_id_.store(1); } -TransferMetadata::~TransferMetadata() { - if (listener_running_) { - listener_running_ = false; - listener_.join(); - } -} +TransferMetadata::~TransferMetadata() { handshake_plugin_.reset(); } int TransferMetadata::updateSegmentDesc(const std::string &segment_name, const SegmentDesc &desc) { @@ -392,21 +111,7 @@ int TransferMetadata::updateSegmentDesc(const std::string &segment_name, buffersJSON.append(bufferJSON); } segmentJSON["buffers"] = buffersJSON; - - Json::Value priorityMatrixJSON; - for (auto &entry : desc.priority_matrix) { - Json::Value priorityItemJSON(Json::arrayValue); - Json::Value preferredRnicListJSON(Json::arrayValue); - for (auto &device_name : entry.second.preferred_rnic_list) - preferredRnicListJSON.append(device_name); - priorityItemJSON.append(preferredRnicListJSON); - Json::Value availableRnicListJSON(Json::arrayValue); - for (auto &device_name : entry.second.available_rnic_list) - availableRnicListJSON.append(device_name); - priorityItemJSON.append(availableRnicListJSON); - priorityMatrixJSON[entry.first] = priorityItemJSON; - } - segmentJSON["priority_matrix"] = priorityMatrixJSON; + segmentJSON["priority_matrix"] = desc.topology.toJson(); } else if (segmentJSON["protocol"] == "tcp") { Json::Value buffersJSON(Json::arrayValue); for (const auto &buffer : desc.buffers) { @@ -418,13 +123,14 @@ int TransferMetadata::updateSegmentDesc(const std::string &segment_name, } segmentJSON["buffers"] = buffersJSON; } else { - LOG(FATAL) - << "For NVMeoF, the transfer engine should not modify the metadata"; + LOG(ERROR) << "Unsupported segment descriptor for register, name " + << desc.name << " protocol " << desc.protocol; return ERR_METADATA; } - if (!impl_->set(getFullMetadataKey(segment_name), segmentJSON)) { - LOG(ERROR) << "Failed to put segment description: " << segment_name; + if (!storage_plugin_->set(getFullMetadataKey(segment_name), segmentJSON)) { + LOG(ERROR) << "Failed to register segment descriptor, name " + << desc.name << " protocol " << desc.protocol; return ERR_METADATA; } @@ -432,8 +138,9 @@ int TransferMetadata::updateSegmentDesc(const std::string &segment_name, } int TransferMetadata::removeSegmentDesc(const std::string &segment_name) { - if (!impl_->remove(getFullMetadataKey(segment_name))) { - LOG(ERROR) << "Failed to remove segment description: " << segment_name; + if (!storage_plugin_->remove(getFullMetadataKey(segment_name))) { + LOG(ERROR) << "Failed to unregister segment descriptor, name " + << segment_name; return ERR_METADATA; } return 0; @@ -442,16 +149,13 @@ int TransferMetadata::removeSegmentDesc(const std::string &segment_name) { std::shared_ptr TransferMetadata::getSegmentDesc( const std::string &segment_name) { Json::Value segmentJSON; - if (!impl_->get(getFullMetadataKey(segment_name), segmentJSON)) { - LOG(ERROR) << "Failed to get segment description: " << segment_name; + if (!storage_plugin_->get(getFullMetadataKey(segment_name), segmentJSON)) { + LOG(WARNING) << "Failed to retrieve segment descriptor, name " + << segment_name; return nullptr; } auto desc = std::make_shared(); - if (!desc) { - LOG(ERROR) << "Failed to allocate SegmentDesc object"; - return nullptr; - } desc->name = segmentJSON["name"].asString(); desc->protocol = segmentJSON["protocol"].asString(); @@ -461,6 +165,11 @@ std::shared_ptr TransferMetadata::getSegmentDesc( device.name = deviceJSON["name"].asString(); device.lid = deviceJSON["lid"].asUInt(); device.gid = deviceJSON["gid"].asString(); + if (device.name.empty() || device.gid.empty()) { + LOG(WARNING) << "Corrupted segment descriptor, name " + << segment_name << " protocol " << desc->protocol; + return nullptr; + } desc->devices.push_back(device); } @@ -473,42 +182,21 @@ std::shared_ptr TransferMetadata::getSegmentDesc( buffer.rkey.push_back(rkeyJSON.asUInt()); for (const auto &lkeyJSON : bufferJSON["lkey"]) buffer.lkey.push_back(lkeyJSON.asUInt()); + if (buffer.name.empty() || !buffer.addr || !buffer.length || + buffer.rkey.empty() || + buffer.rkey.size() != buffer.lkey.size()) { + LOG(WARNING) << "Corrupted segment descriptor, name " + << segment_name << " protocol " << desc->protocol; + return nullptr; + } desc->buffers.push_back(buffer); } - auto priorityMatrixJSON = segmentJSON["priority_matrix"]; - for (const auto &key : priorityMatrixJSON.getMemberNames()) { - const Json::Value &value = priorityMatrixJSON[key]; - if (value.isArray() && value.size() == 2) { - PriorityItem item; - for (const auto &array : value[0]) { - auto device_name = array.asString(); - item.preferred_rnic_list.push_back(device_name); - int device_index = 0; - for (auto &entry : desc->devices) { - if (entry.name == device_name) { - item.preferred_rnic_id_list.push_back(device_index); - break; - } - device_index++; - } - LOG_ASSERT(device_index != (int)desc->devices.size()); - } - for (const auto &array : value[1]) { - auto device_name = array.asString(); - item.available_rnic_list.push_back(device_name); - int device_index = 0; - for (auto &entry : desc->devices) { - if (entry.name == device_name) { - item.available_rnic_id_list.push_back(device_index); - break; - } - device_index++; - } - LOG_ASSERT(device_index != (int)desc->devices.size()); - } - desc->priority_matrix[key] = item; - } + int ret = desc->topology.parse( + segmentJSON["priority_matrix"].toStyledString()); + if (ret) { + LOG(WARNING) << "Corrupted segment descriptor, name " + << segment_name << " protocol " << desc->protocol; } } else if (desc->protocol == "tcp") { for (const auto &bufferJSON : segmentJSON["buffers"]) { @@ -516,6 +204,11 @@ std::shared_ptr TransferMetadata::getSegmentDesc( buffer.name = bufferJSON["name"].asString(); buffer.addr = bufferJSON["addr"].asUInt64(); buffer.length = bufferJSON["length"].asUInt64(); + if (buffer.name.empty() || !buffer.addr || !buffer.length) { + LOG(WARNING) << "Corrupted segment descriptor, name " + << segment_name << " protocol " << desc->protocol; + return nullptr; + } desc->buffers.push_back(buffer); } } else if (desc->protocol == "nvmeof") { @@ -529,12 +222,18 @@ std::shared_ptr TransferMetadata::getSegmentDesc( } desc->nvmeof_buffers.push_back(buffer); } + } else { + LOG(ERROR) << "Unsupported segment descriptor, name " << segment_name + << " protocol " << desc->protocol; + return nullptr; } + return desc; } int TransferMetadata::syncSegmentCache() { RWSpinlock::WriteGuard guard(segment_lock_); + LOG(INFO) << "Invalidate segment descriptor cache"; for (auto &entry : segment_id_to_desc_map_) { if (entry.first == LOCAL_SEGMENT_ID) continue; auto segment_desc = getSegmentDesc(entry.second->name); @@ -598,7 +297,6 @@ TransferMetadata::SegmentID TransferMetadata::getSegmentID( auto segment_desc = this->getSegmentDesc(segment_name); if (!segment_desc) return -1; SegmentID id = next_segment_id_.fetch_add(1); - // LOG(INFO) << "put " << id; segment_id_to_desc_map_[id] = segment_desc; segment_name_to_id_map_[segment_name] = id; return id; @@ -624,10 +322,6 @@ int TransferMetadata::addLocalMemoryBuffer(const BufferDesc &buffer_desc, { RWSpinlock::WriteGuard guard(segment_lock_); auto new_segment_desc = std::make_shared(); - if (!new_segment_desc) { - LOG(ERROR) << "Failed to allocate segment description"; - return ERR_MEMORY; - } auto &segment_desc = segment_id_to_desc_map_[LOCAL_SEGMENT_ID]; *new_segment_desc = *segment_desc; segment_desc = new_segment_desc; @@ -643,10 +337,6 @@ int TransferMetadata::removeLocalMemoryBuffer(void *addr, { RWSpinlock::WriteGuard guard(segment_lock_); auto new_segment_desc = std::make_shared(); - if (!new_segment_desc) { - LOG(ERROR) << "Failed to allocate segment description"; - return ERR_MEMORY; - } auto &segment_desc = segment_id_to_desc_map_[LOCAL_SEGMENT_ID]; *new_segment_desc = *segment_desc; segment_desc = new_segment_desc; @@ -671,8 +361,8 @@ int TransferMetadata::addRpcMetaEntry(const std::string &server_name, Json::Value rpcMetaJSON; rpcMetaJSON["ip_or_host_name"] = desc.ip_or_host_name; rpcMetaJSON["rpc_port"] = static_cast(desc.rpc_port); - if (!impl_->set(kRpcMetaPrefix + server_name, rpcMetaJSON)) { - LOG(ERROR) << "Failed to insert rpc meta of " << server_name; + if (!storage_plugin_->set(kRpcMetaPrefix + server_name, rpcMetaJSON)) { + LOG(ERROR) << "Failed to set location of " << server_name; return ERR_METADATA; } local_rpc_meta_ = desc; @@ -680,8 +370,8 @@ int TransferMetadata::addRpcMetaEntry(const std::string &server_name, } int TransferMetadata::removeRpcMetaEntry(const std::string &server_name) { - if (!impl_->remove(kRpcMetaPrefix + server_name)) { - LOG(ERROR) << "Failed to remove rpc meta of " << server_name; + if (!storage_plugin_->remove(kRpcMetaPrefix + server_name)) { + LOG(ERROR) << "Failed to remove location of " << server_name; return ERR_METADATA; } return 0; @@ -698,8 +388,8 @@ int TransferMetadata::getRpcMetaEntry(const std::string &server_name, } RWSpinlock::WriteGuard guard(rpc_meta_lock_); Json::Value rpcMetaJSON; - if (!impl_->get(kRpcMetaPrefix + server_name, rpcMetaJSON)) { - LOG(ERROR) << "Failed to get rpc meta of " << server_name; + if (!storage_plugin_->get(kRpcMetaPrefix + server_name, rpcMetaJSON)) { + LOG(ERROR) << "Failed to find location of " << server_name; return ERR_METADATA; } desc.ip_or_host_name = rpcMetaJSON["ip_or_host_name"].asString(); @@ -708,324 +398,39 @@ int TransferMetadata::getRpcMetaEntry(const std::string &server_name, return 0; } -std::string TransferMetadata::encode(const HandShakeDesc &desc) { - Json::Value root; - root["local_nic_path"] = desc.local_nic_path; - root["peer_nic_path"] = desc.peer_nic_path; - Json::Value qpNums(Json::arrayValue); - for (const auto &qp : desc.qp_num) qpNums.append(qp); - root["qp_num"] = qpNums; - root["reply_msg"] = desc.reply_msg; - Json::FastWriter writer; - auto serialized = writer.write(root); - if (globalConfig().verbose) - LOG(INFO) << "Send Endpoint Handshake Info: " << serialized; - return serialized; -} - -int TransferMetadata::decode(const std::string &serialized, - HandShakeDesc &desc) { - Json::Value root; - Json::Reader reader; - - if (serialized.empty() || !reader.parse(serialized, root)) - return ERR_MALFORMED_JSON; - - if (globalConfig().verbose) - LOG(INFO) << "Receive Endpoint Handshake Info: " << serialized; - desc.local_nic_path = root["local_nic_path"].asString(); - desc.peer_nic_path = root["peer_nic_path"].asString(); - for (const auto &qp : root["qp_num"]) desc.qp_num.push_back(qp.asUInt()); - desc.reply_msg = root["reply_msg"].asString(); - - return 0; -} - -static inline const std::string toString(struct sockaddr *addr) { - if (addr->sa_family == AF_INET) { - struct sockaddr_in *sock_addr = (struct sockaddr_in *)addr; - char ip[INET_ADDRSTRLEN]; - if (inet_ntop(addr->sa_family, &(sock_addr->sin_addr), ip, - INET_ADDRSTRLEN) != NULL) - return ip; - } else if (addr->sa_family == AF_INET6) { - struct sockaddr_in6 *sock_addr = (struct sockaddr_in6 *)addr; - char ip[INET6_ADDRSTRLEN]; - if (inet_ntop(addr->sa_family, &(sock_addr->sin6_addr), ip, - INET6_ADDRSTRLEN) != NULL) - return ip; - } - LOG(ERROR) << "Invalid address, cannot convert to string"; - return ""; -} - int TransferMetadata::startHandshakeDaemon( OnReceiveHandShake on_receive_handshake, uint16_t listen_port) { - sockaddr_in bind_address; - int on = 1, listen_fd = -1; - memset(&bind_address, 0, sizeof(sockaddr_in)); - bind_address.sin_family = AF_INET; - bind_address.sin_port = htons(listen_port); - bind_address.sin_addr.s_addr = INADDR_ANY; - - listen_fd = socket(AF_INET, SOCK_STREAM, 0); - if (listen_fd < 0) { - PLOG(ERROR) << "Failed to create socket"; - return ERR_SOCKET; - } - - struct timeval timeout; - timeout.tv_sec = 1; - timeout.tv_usec = 0; - if (setsockopt(listen_fd, SOL_SOCKET, SO_RCVTIMEO, &timeout, - sizeof(timeout))) { - PLOG(ERROR) << "Failed to set socket timeout"; - close(listen_fd); - return ERR_SOCKET; - } - - if (setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on))) { - PLOG(ERROR) << "Failed to set address reusable"; - close(listen_fd); - return ERR_SOCKET; - } - - if (bind(listen_fd, (sockaddr *)&bind_address, sizeof(sockaddr_in)) < 0) { - PLOG(ERROR) << "Failed to bind address, rpc port: " << listen_port; - close(listen_fd); - return ERR_SOCKET; - } - - if (listen(listen_fd, 5)) { - PLOG(ERROR) << "Failed to listen"; - close(listen_fd); - return ERR_SOCKET; - } - - listener_running_ = true; - listener_ = std::thread([this, listen_fd, on_receive_handshake]() { - while (listener_running_) { - sockaddr_in addr; - socklen_t addr_len = sizeof(sockaddr_in); - int conn_fd = accept(listen_fd, (sockaddr *)&addr, &addr_len); - if (conn_fd < 0) { - if (errno != EWOULDBLOCK) - PLOG(ERROR) << "Failed to accept socket connection"; - continue; - } - - if (addr.sin_family != AF_INET && addr.sin_family != AF_INET6) { - LOG(ERROR) - << "Unsupported socket type, should be AF_INET or AF_INET6"; - close(conn_fd); - continue; - } - - struct timeval timeout; - timeout.tv_sec = 60; - timeout.tv_usec = 0; - if (setsockopt(conn_fd, SOL_SOCKET, SO_RCVTIMEO, &timeout, - sizeof(timeout))) { - PLOG(ERROR) << "Failed to set socket timeout"; - close(conn_fd); - continue; - } - - if (globalConfig().verbose) - LOG(INFO) << "New connection: " - << toString((struct sockaddr *)&addr) << ":" - << ntohs(addr.sin_port); - + return handshake_plugin_->startDaemon( + [on_receive_handshake](const Json::Value &local, + Json::Value &peer) -> int { HandShakeDesc local_desc, peer_desc; - int ret = decode(readString(conn_fd), peer_desc); - if (ret) { - PLOG(ERROR) << "Failed to receive handshake message: malformed " - "json format, check tcp connection"; - close(conn_fd); - continue; - } - - on_receive_handshake(peer_desc, local_desc); - ret = writeString(conn_fd, encode(local_desc)); - if (ret) { - PLOG(ERROR) << "Failed to send handshake message: malformed " - "json format, check tcp connection"; - close(conn_fd); - continue; - } - - close(conn_fd); - } - return; - }); - - return 0; + TransferHandshakeUtil::decode(local, local_desc); + int ret = on_receive_handshake(local_desc, peer_desc); + if (ret) return ret; + peer = TransferHandshakeUtil::encode(peer_desc); + return 0; + }, + listen_port); } int TransferMetadata::sendHandshake(const std::string &peer_server_name, const HandShakeDesc &local_desc, HandShakeDesc &peer_desc) { - struct addrinfo hints; - struct addrinfo *result, *rp; - memset(&hints, 0, sizeof(hints)); - hints.ai_family = AF_INET; - hints.ai_socktype = SOCK_STREAM; - - RpcMetaDesc desc; - if (getRpcMetaEntry(peer_server_name, desc)) { - PLOG(ERROR) << "Cannot find rpc meta entry for " << peer_server_name; + RpcMetaDesc peer_location; + if (getRpcMetaEntry(peer_server_name, peer_location)) { return ERR_METADATA; } - - char service[16]; - sprintf(service, "%u", desc.rpc_port); - if (getaddrinfo(desc.ip_or_host_name.c_str(), service, &hints, &result)) { - PLOG(ERROR) - << "Failed to get IP address of peer server " << peer_server_name - << ", check DNS and /etc/hosts, or use IPv4 address instead"; - return ERR_DNS; - } - - int ret = 0; - for (rp = result; rp; rp = rp->ai_next) { - ret = doSendHandshake(rp, local_desc, peer_desc); - if (ret == 0) { - freeaddrinfo(result); - return 0; - } - if (ret == ERR_MALFORMED_JSON) { - return ret; - } - } - - freeaddrinfo(result); - return ret; -} - -int TransferMetadata::doSendHandshake(struct addrinfo *addr, - const HandShakeDesc &local_desc, - HandShakeDesc &peer_desc) { - if (globalConfig().verbose) - LOG(INFO) << "Try connecting " << toString(addr->ai_addr); - - int on = 1; - int conn_fd = socket(addr->ai_family, addr->ai_socktype, addr->ai_protocol); - if (conn_fd == -1) { - PLOG(ERROR) << "Failed to create socket"; - return ERR_SOCKET; - } - if (setsockopt(conn_fd, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on))) { - PLOG(ERROR) << "Failed to set address reusable"; - close(conn_fd); - return ERR_SOCKET; - } - - struct timeval timeout; - timeout.tv_sec = 60; - timeout.tv_usec = 0; - if (setsockopt(conn_fd, SOL_SOCKET, SO_RCVTIMEO, &timeout, - sizeof(timeout))) { - PLOG(ERROR) << "Failed to set socket timeout"; - close(conn_fd); - return ERR_SOCKET; - } - - if (connect(conn_fd, addr->ai_addr, addr->ai_addrlen)) { - PLOG(ERROR) << "Failed to connect " << toString(addr->ai_addr) - << " via socket"; - close(conn_fd); - return ERR_SOCKET; - } - - int ret = writeString(conn_fd, encode(local_desc)); - if (ret) { - LOG(ERROR) << "Failed to send handshake message: malformed json " - "format, check tcp connection"; - close(conn_fd); - return ret; - } - - ret = decode(readString(conn_fd), peer_desc); - if (ret) { - LOG(ERROR) << "Failed to receive handshake message: malformed json " - "format, check tcp connection"; - close(conn_fd); - return ret; - } - + auto local = TransferHandshakeUtil::encode(local_desc); + Json::Value peer; + int ret = handshake_plugin_->send(peer_location.ip_or_host_name, + peer_location.rpc_port, local, peer); + if (ret) return ret; + TransferHandshakeUtil::decode(peer, peer_desc); if (!peer_desc.reply_msg.empty()) { - LOG(ERROR) << "Handshake request is rejected by peer endpoint " - << peer_desc.local_nic_path - << ", message: " << peer_desc.reply_msg - << ". Please check peer's configuration."; - close(conn_fd); - return ERR_REJECT_HANDSHAKE; - } - - close(conn_fd); - return 0; -} - -int TransferMetadata::parseNicPriorityMatrix( - const std::string &nic_priority_matrix, PriorityMatrix &priority_map, - std::vector &rnic_list) { - std::set rnic_set; - Json::Value root; - Json::Reader reader; - - if (nic_priority_matrix.empty() || - !reader.parse(nic_priority_matrix, root)) { - LOG(ERROR) - << "Malformed format of NIC priority matrix: illegal JSON format"; - return ERR_MALFORMED_JSON; - } - - if (!root.isObject()) { - LOG(ERROR) - << "Malformed format of NIC priority matrix: root is not an object"; - return ERR_MALFORMED_JSON; - } - - priority_map.clear(); - for (const auto &key : root.getMemberNames()) { - const Json::Value &value = root[key]; - if (value.isArray() && value.size() == 2) { - PriorityItem item; - for (const auto &array : value[0]) { - auto device_name = array.asString(); - item.preferred_rnic_list.push_back(device_name); - auto iter = rnic_set.find(device_name); - if (iter == rnic_set.end()) { - item.preferred_rnic_id_list.push_back(rnic_set.size()); - rnic_set.insert(device_name); - } else { - item.preferred_rnic_id_list.push_back( - std::distance(rnic_set.begin(), iter)); - } - } - for (const auto &array : value[1]) { - auto device_name = array.asString(); - item.available_rnic_list.push_back(device_name); - auto iter = rnic_set.find(device_name); - if (iter == rnic_set.end()) { - item.available_rnic_id_list.push_back(rnic_set.size()); - rnic_set.insert(device_name); - } else { - item.available_rnic_id_list.push_back( - std::distance(rnic_set.begin(), iter)); - } - } - priority_map[key] = item; - } else { - LOG(ERROR) - << "Malformed format of NIC priority matrix: format error"; - return ERR_MALFORMED_JSON; - } + LOG(ERROR) << "Handshake rejected by " << peer_server_name << ": " + << peer_desc.reply_msg; + return ERR_METADATA; } - - rnic_list.clear(); - for (auto &entry : rnic_set) rnic_list.push_back(entry); - return 0; } diff --git a/mooncake-transfer-engine/src/transfer_metadata_plugin.cpp b/mooncake-transfer-engine/src/transfer_metadata_plugin.cpp new file mode 100644 index 0000000..df677df --- /dev/null +++ b/mooncake-transfer-engine/src/transfer_metadata_plugin.cpp @@ -0,0 +1,594 @@ +// Copyright 2024 KVCache.AI +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "transfer_metadata_plugin.h" + +#include +#include +#include +#include +#include + +#ifdef USE_REDIS +#include +#endif + +#ifdef USE_HTTP +#include +#endif + +#include +#include + +#include "common.h" +#include "config.h" +#include "error.h" + +namespace mooncake { +#ifdef USE_REDIS +struct RedisStoragePlugin : public MetadataStoragePlugin { + RedisStoragePlugin(const std::string &metadata_uri) + : client_(nullptr), metadata_uri_(metadata_uri) { + auto hostname_port = parseHostNameWithPort(metadata_uri); + client_ = + redisConnect(hostname_port.first.c_str(), hostname_port.second); + if (!client_ || client_->err) { + LOG(ERROR) << "RedisStoragePlugin: unable to connect " + << metadata_uri_ << ": " << client_->errstr; + client_ = nullptr; + } + } + + virtual ~RedisStoragePlugin() {} + + virtual bool get(const std::string &key, Json::Value &value) { + Json::Reader reader; + redisReply *resp = + (redisReply *)redisCommand(client_, "GET %s", key.c_str()); + if (!resp) { + LOG(ERROR) << "RedisStoragePlugin: unable to get " << key + << " from " << metadata_uri_; + return false; + } + auto json_file = std::string(resp->str); + freeReplyObject(resp); + if (!reader.parse(json_file, value)) return false; + if (globalConfig().verbose) + LOG(INFO) << "RedisStoragePlugin: get: key=" << key + << ", value=" << json_file; + return true; + } + + virtual bool set(const std::string &key, const Json::Value &value) { + Json::FastWriter writer; + const std::string json_file = writer.write(value); + if (globalConfig().verbose) + LOG(INFO) << "RedisStoragePlugin: set: key=" << key + << ", value=" << json_file; + redisReply *resp = (redisReply *)redisCommand( + client_, "SET %s %s", key.c_str(), json_file.c_str()); + if (!resp) { + LOG(ERROR) << "RedisStoragePlugin: unable to put " << key + << " from " << metadata_uri_; + return false; + } + freeReplyObject(resp); + return true; + } + + virtual bool remove(const std::string &key) { + redisReply *resp = + (redisReply *)redisCommand(client_, "DEL %s", key.c_str()); + if (!resp) { + LOG(ERROR) << "RedisStoragePlugin: unable to remove " << key + << " from " << metadata_uri_; + return false; + } + freeReplyObject(resp); + return true; + } + + redisContext *client_; + const std::string metadata_uri_; +}; +#endif // USE_REDIS + +#ifdef USE_HTTP +struct HTTPStoragePlugin : public MetadataStoragePlugin { + HTTPStoragePlugin(const std::string &metadata_uri) + : client_(nullptr), metadata_uri_(metadata_uri) { + curl_global_init(CURL_GLOBAL_ALL); + client_ = curl_easy_init(); + if (!client_) { + LOG(ERROR) << "Cannot allocate CURL objects"; + exit(EXIT_FAILURE); + } + } + + virtual ~HTTPStoragePlugin() { + curl_easy_cleanup(client_); + curl_global_cleanup(); + } + + static size_t writeCallback(void *contents, size_t size, size_t nmemb, + std::string *userp) { + userp->append(static_cast(contents), size * nmemb); + return size * nmemb; + } + + std::string encodeUrl(const std::string &key) { + char *newkey = curl_easy_escape(client_, key.c_str(), key.size()); + std::string encodedKey(newkey); + std::string url = metadata_uri_ + "?key=" + encodedKey; + curl_free(newkey); + return url; + } + + virtual bool get(const std::string &key, Json::Value &value) { + curl_easy_reset(client_); + curl_easy_setopt(client_, CURLOPT_TIMEOUT_MS, 3000); // 3s timeout + + std::string url = encodeUrl(key); + curl_easy_setopt(client_, CURLOPT_URL, url.c_str()); + curl_easy_setopt(client_, CURLOPT_WRITEFUNCTION, writeCallback); + + // get response body + std::string readBuffer; + curl_easy_setopt(client_, CURLOPT_WRITEDATA, &readBuffer); + CURLcode res = curl_easy_perform(client_); + if (res != CURLE_OK) { + LOG(ERROR) << "Error from http client, GET " << url + << " error: " << curl_easy_strerror(res); + return false; + } + + // Get the HTTP response code + long responseCode; + curl_easy_getinfo(client_, CURLINFO_RESPONSE_CODE, &responseCode); + if (responseCode != 200) { + LOG(ERROR) << "Unexpected code in http response, GET " << url + << " response code: " << responseCode + << " response body: " << readBuffer; + return false; + } + + if (globalConfig().verbose) + LOG(INFO) << "Get segment desc, key=" << key + << ", value=" << readBuffer; + + Json::Reader reader; + if (!reader.parse(readBuffer, value)) return false; + return true; + } + + virtual bool set(const std::string &key, const Json::Value &value) { + curl_easy_reset(client_); + curl_easy_setopt(client_, CURLOPT_TIMEOUT_MS, 3000); // 3s timeout + + Json::FastWriter writer; + const std::string json_file = writer.write(value); + if (globalConfig().verbose) + LOG(INFO) << "Put segment desc, key=" << key + << ", value=" << json_file; + + std::string url = encodeUrl(key); + curl_easy_setopt(client_, CURLOPT_URL, url.c_str()); + curl_easy_setopt(client_, CURLOPT_WRITEFUNCTION, writeCallback); + curl_easy_setopt(client_, CURLOPT_POSTFIELDS, json_file.c_str()); + curl_easy_setopt(client_, CURLOPT_POSTFIELDSIZE, json_file.size()); + curl_easy_setopt(client_, CURLOPT_CUSTOMREQUEST, "PUT"); + + // get response body + std::string readBuffer; + curl_easy_setopt(client_, CURLOPT_WRITEDATA, &readBuffer); + + // set content-type to application/json + struct curl_slist *headers = NULL; + headers = curl_slist_append(headers, "Content-Type: application/json"); + curl_easy_setopt(client_, CURLOPT_HTTPHEADER, headers); + CURLcode res = curl_easy_perform(client_); + curl_slist_free_all(headers); // Free headers + if (res != CURLE_OK) { + LOG(ERROR) << "Error from http client, PUT " << url + << " error: " << curl_easy_strerror(res); + return false; + } + + // Get the HTTP response code + long responseCode; + curl_easy_getinfo(client_, CURLINFO_RESPONSE_CODE, &responseCode); + if (responseCode != 200) { + LOG(ERROR) << "Unexpected code in http response, PUT " << url + << " response code: " << responseCode + << " response body: " << readBuffer; + return false; + } + + return true; + } + + virtual bool remove(const std::string &key) { + curl_easy_reset(client_); + curl_easy_setopt(client_, CURLOPT_TIMEOUT_MS, 3000); // 3s timeout + + if (globalConfig().verbose) + LOG(INFO) << "Remove segment desc, key=" << key; + + std::string url = encodeUrl(key); + curl_easy_setopt(client_, CURLOPT_URL, url.c_str()); + curl_easy_setopt(client_, CURLOPT_WRITEFUNCTION, writeCallback); + curl_easy_setopt(client_, CURLOPT_CUSTOMREQUEST, "DELETE"); + + // get response body + std::string readBuffer; + curl_easy_setopt(client_, CURLOPT_WRITEDATA, &readBuffer); + CURLcode res = curl_easy_perform(client_); + if (res != CURLE_OK) { + LOG(ERROR) << "Error from http client, DELETE " << url + << " error: " << curl_easy_strerror(res); + return false; + } + + // Get the HTTP response code + long responseCode; + curl_easy_getinfo(client_, CURLINFO_RESPONSE_CODE, &responseCode); + if (responseCode != 200) { + LOG(ERROR) << "Unexpected code in http response, DELETE " << url + << " response code: " << responseCode + << " response body: " << readBuffer; + return false; + } + return true; + } + + CURL *client_; + const std::string metadata_uri_; +}; +#endif // USE_HTTP + +struct EtcdStoragePlugin : public MetadataStoragePlugin { + EtcdStoragePlugin(const std::string &metadata_uri) + : client_(metadata_uri), metadata_uri_(metadata_uri) {} + + virtual ~EtcdStoragePlugin() {} + + virtual bool get(const std::string &key, Json::Value &value) { + Json::Reader reader; + auto resp = client_.get(key); + if (!resp.is_ok()) { + LOG(ERROR) << "EtcdStoragePlugin: unable to get " << key << " from " + << metadata_uri_ << ": " << resp.error_message(); + return false; + } + auto json_file = resp.value().as_string(); + if (!reader.parse(json_file, value)) return false; + if (globalConfig().verbose) + LOG(INFO) << "EtcdStoragePlugin: get: key=" << key + << ", value=" << json_file; + return true; + } + + virtual bool set(const std::string &key, const Json::Value &value) { + Json::FastWriter writer; + const std::string json_file = writer.write(value); + if (globalConfig().verbose) + LOG(INFO) << "EtcdStoragePlugin: set: key=" << key + << ", value=" << json_file; + auto resp = client_.put(key, json_file); + if (!resp.is_ok()) { + LOG(ERROR) << "EtcdStoragePlugin: unable to set " << key << " from " + << metadata_uri_ << ": " << resp.error_message(); + return false; + } + return true; + } + + virtual bool remove(const std::string &key) { + auto resp = client_.rm(key); + if (!resp.is_ok()) { + LOG(ERROR) << "EtcdStoragePlugin: unable to delete " << key + << " from " << metadata_uri_ << ": " + << resp.error_message(); + return false; + } + return true; + } + + etcd::SyncClient client_; + const std::string metadata_uri_; +}; + +std::pair parseConnectionString( + const std::string &conn_string) { + std::pair result; + std::string proto = "etcd"; + std::string domain; + std::size_t pos = conn_string.find("://"); + + if (pos != std::string::npos) { + proto = conn_string.substr(0, pos); + domain = conn_string.substr(pos + 3); + } else { + domain = conn_string; + } + + result.first = proto; + result.second = domain; + return result; +} + +std::shared_ptr MetadataStoragePlugin::Create( + const std::string &conn_string) { + auto parsed_conn_string = parseConnectionString(conn_string); + if (parsed_conn_string.first == "etcd") { + return std::make_shared(parsed_conn_string.second); +#ifdef USE_REDIS + } else if (parsed_conn_string.first == "redis") { + return std::make_shared(parsed_conn_string.second); +#endif // USE_REDIS +#ifdef USE_HTTP + } else if (parsed_conn_string.first == "http" || + parsed_conn_string.first == "https") { + return std::make_shared( + conn_string); // including prefix +#endif // USE_HTTP + } else { + LOG(FATAL) << "Unable to find metadata storage plugin " + << parsed_conn_string.first; + return nullptr; + } +} + +static inline const std::string getNetworkAddress(struct sockaddr *addr) { + if (addr->sa_family == AF_INET) { + struct sockaddr_in *sock_addr = (struct sockaddr_in *)addr; + char ip[INET_ADDRSTRLEN]; + if (inet_ntop(addr->sa_family, &(sock_addr->sin_addr), ip, + INET_ADDRSTRLEN) != NULL) + return std::string(ip) + ":" + + std::to_string(ntohs(sock_addr->sin_port)); + } else if (addr->sa_family == AF_INET6) { + struct sockaddr_in6 *sock_addr = (struct sockaddr_in6 *)addr; + char ip[INET6_ADDRSTRLEN]; + if (inet_ntop(addr->sa_family, &(sock_addr->sin6_addr), ip, + INET6_ADDRSTRLEN) != NULL) + return std::string(ip) + ":" + + std::to_string(ntohs(sock_addr->sin6_port)); + } + PLOG(ERROR) << "Failed to parse socket address"; + return ""; +} + +struct SocketHandShakePlugin : public HandShakePlugin { + SocketHandShakePlugin() : listener_running_(false) {} + + virtual ~SocketHandShakePlugin() { + if (listener_running_) { + listener_running_ = false; + listener_.join(); + } + } + + virtual int startDaemon(OnReceiveCallBack on_recv_callback, + uint16_t listen_port) { + sockaddr_in bind_address; + int on = 1, listen_fd = -1; + memset(&bind_address, 0, sizeof(sockaddr_in)); + bind_address.sin_family = AF_INET; + bind_address.sin_port = htons(listen_port); + bind_address.sin_addr.s_addr = INADDR_ANY; + + listen_fd = socket(AF_INET, SOCK_STREAM, 0); + if (listen_fd < 0) { + PLOG(ERROR) << "SocketHandShakePlugin: socket()"; + return ERR_SOCKET; + } + + struct timeval timeout; + timeout.tv_sec = 1; + timeout.tv_usec = 0; + if (setsockopt(listen_fd, SOL_SOCKET, SO_RCVTIMEO, &timeout, + sizeof(timeout))) { + PLOG(ERROR) << "SocketHandShakePlugin: setsockopt(SO_RCVTIMEO)"; + close(listen_fd); + return ERR_SOCKET; + } + + if (setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on))) { + PLOG(ERROR) << "SocketHandShakePlugin: setsockopt(SO_REUSEADDR)"; + close(listen_fd); + return ERR_SOCKET; + } + + if (bind(listen_fd, (sockaddr *)&bind_address, sizeof(sockaddr_in)) < + 0) { + PLOG(ERROR) << "SocketHandShakePlugin: bind (port " << listen_port + << ")"; + close(listen_fd); + return ERR_SOCKET; + } + + if (listen(listen_fd, 5)) { + PLOG(ERROR) << "SocketHandShakePlugin: listen()"; + close(listen_fd); + return ERR_SOCKET; + } + + listener_running_ = true; + listener_ = std::thread([this, listen_fd, on_recv_callback]() { + while (listener_running_) { + sockaddr_in addr; + socklen_t addr_len = sizeof(sockaddr_in); + int conn_fd = accept(listen_fd, (sockaddr *)&addr, &addr_len); + if (conn_fd < 0) { + if (errno != EWOULDBLOCK) + PLOG(ERROR) << "SocketHandShakePlugin: accept()"; + continue; + } + + if (addr.sin_family != AF_INET && addr.sin_family != AF_INET6) { + LOG(ERROR) << "SocketHandShakePlugin: unsupported socket " + "type, should be AF_INET or AF_INET6"; + close(conn_fd); + continue; + } + + struct timeval timeout; + timeout.tv_sec = 60; + timeout.tv_usec = 0; + if (setsockopt(conn_fd, SOL_SOCKET, SO_RCVTIMEO, &timeout, + sizeof(timeout))) { + PLOG(ERROR) + << "SocketHandShakePlugin: setsockopt(SO_RCVTIMEO)"; + close(conn_fd); + continue; + } + + auto peer_hostname = + getNetworkAddress((struct sockaddr *)&addr); + if (globalConfig().verbose) + LOG(INFO) << "SocketHandShakePlugin: new connection: " + << peer_hostname.c_str(); + + Json::Value local, peer; + Json::Reader reader; + if (!reader.parse(readString(conn_fd), peer)) { + LOG(ERROR) << "SocketHandShakePlugin: failed to receive " + "handshake message: " + "malformed json format, check tcp connection"; + close(conn_fd); + continue; + } + + on_recv_callback(peer, local); + int ret = writeString(conn_fd, Json::FastWriter{}.write(local)); + if (ret) { + LOG(ERROR) << "SocketHandShakePlugin: failed to send " + "handshake message: " + "malformed json format, check tcp connection"; + close(conn_fd); + continue; + } + + close(conn_fd); + } + return; + }); + + return 0; + } + + virtual int send(std::string ip_or_host_name, uint16_t rpc_port, + const Json::Value &local, Json::Value &peer) { + struct addrinfo hints; + struct addrinfo *result, *rp; + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_INET; + hints.ai_socktype = SOCK_STREAM; + + char service[16]; + sprintf(service, "%u", rpc_port); + if (getaddrinfo(ip_or_host_name.c_str(), service, &hints, &result)) { + PLOG(ERROR) + << "SocketHandShakePlugin: failed to get IP address of peer " + "server " + << ip_or_host_name << ":" << rpc_port + << ", check DNS and /etc/hosts, or use IPv4 address instead"; + return ERR_DNS; + } + + int ret = 0; + for (rp = result; rp; rp = rp->ai_next) { + ret = doSend(rp, local, peer); + if (ret == 0) { + freeaddrinfo(result); + return 0; + } + if (ret == ERR_MALFORMED_JSON) { + return ret; + } + } + + freeaddrinfo(result); + return ret; + } + + int doSend(struct addrinfo *addr, const Json::Value &local, + Json::Value &peer) { + if (globalConfig().verbose) + LOG(INFO) << "SocketHandShakePlugin: connecting " + << getNetworkAddress(addr->ai_addr); + + int on = 1; + int conn_fd = + socket(addr->ai_family, addr->ai_socktype, addr->ai_protocol); + if (conn_fd == -1) { + PLOG(ERROR) << "SocketHandShakePlugin: socket()"; + return ERR_SOCKET; + } + if (setsockopt(conn_fd, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on))) { + PLOG(ERROR) << "SocketHandShakePlugin: setsockopt(SO_REUSEADDR)"; + close(conn_fd); + return ERR_SOCKET; + } + + struct timeval timeout; + timeout.tv_sec = 60; + timeout.tv_usec = 0; + if (setsockopt(conn_fd, SOL_SOCKET, SO_RCVTIMEO, &timeout, + sizeof(timeout))) { + PLOG(ERROR) << "SocketHandShakePlugin: setsockopt(SO_RCVTIMEO)"; + close(conn_fd); + return ERR_SOCKET; + } + + if (connect(conn_fd, addr->ai_addr, addr->ai_addrlen)) { + PLOG(ERROR) << "SocketHandShakePlugin: connect()" + << getNetworkAddress(addr->ai_addr); + close(conn_fd); + return ERR_SOCKET; + } + + int ret = writeString(conn_fd, Json::FastWriter{}.write(local)); + if (ret) { + LOG(ERROR) + << "SocketHandShakePlugin: failed to send handshake message: " + "malformed json format, check tcp connection"; + close(conn_fd); + return ret; + } + + Json::Reader reader; + if (!reader.parse(readString(conn_fd), peer)) { + LOG(ERROR) << "SocketHandShakePlugin: failed to receive handshake " + "message: " + "malformed json format, check tcp connection"; + close(conn_fd); + return ERR_MALFORMED_JSON; + } + + close(conn_fd); + return 0; + } + + std::atomic listener_running_; + std::thread listener_; +}; + +std::shared_ptr HandShakePlugin::Create( + const std::string &conn_string) { + return std::make_shared(); +} + +} // namespace mooncake \ No newline at end of file diff --git a/mooncake-transfer-engine/src/transport/rdma_transport/endpoint_store.cpp b/mooncake-transfer-engine/src/transport/rdma_transport/endpoint_store.cpp index d2160bb..34819b5 100644 --- a/mooncake-transfer-engine/src/transport/rdma_transport/endpoint_store.cpp +++ b/mooncake-transfer-engine/src/transport/rdma_transport/endpoint_store.cpp @@ -45,7 +45,7 @@ std::shared_ptr FIFOEndpointStore::insertEndpoint( } auto endpoint = std::make_shared(*context); if (!endpoint) { - PLOG(ERROR) << "Failed to allocate memory for RdmaEndPoint"; + LOG(ERROR) << "Failed to allocate memory for RdmaEndPoint"; return nullptr; } auto &config = globalConfig(); @@ -131,7 +131,7 @@ std::shared_ptr SIEVEEndpointStore::insertEndpoint( } auto endpoint = std::make_shared(*context); if (!endpoint) { - PLOG(ERROR) << "Failed to allocate memory for RdmaEndPoint"; + LOG(ERROR) << "Failed to allocate memory for RdmaEndPoint"; return nullptr; } auto &config = globalConfig(); diff --git a/mooncake-transfer-engine/src/transport/rdma_transport/rdma_context.cpp b/mooncake-transfer-engine/src/transport/rdma_transport/rdma_context.cpp index 763b92e..68951fb 100644 --- a/mooncake-transfer-engine/src/transport/rdma_transport/rdma_context.cpp +++ b/mooncake-transfer-engine/src/transport/rdma_transport/rdma_context.cpp @@ -33,10 +33,9 @@ namespace mooncake { static int isNullGid(union ibv_gid *gid) { for (int i = 0; i < 16; ++i) { - if (gid->raw[i] != 0) - return 0; + if (gid->raw[i] != 0) return 0; } - return 1; + return 1; } RdmaContext::RdmaContext(RdmaTransport &engine, const std::string &device_name) @@ -49,8 +48,10 @@ RdmaContext::RdmaContext(RdmaTransport &engine, const std::string &device_name) active_(true) { static std::once_flag g_once_flag; auto fork_init = []() { - if (ibv_fork_init()) - PLOG(ERROR) << "RDMA context setup failed: fork compatibility"; + int ret = ibv_fork_init(); + if (ret) + LOG(ERROR) << "RDMA context setup failed: fork compatibility: " + << strerror(ret); }; std::call_once(g_once_flag, fork_init); } @@ -64,12 +65,12 @@ int RdmaContext::construct(size_t num_cq_list, size_t num_comp_channels, int max_endpoints) { endpoint_store_ = std::make_shared(max_endpoints); if (!endpoint_store_) { - PLOG(ERROR) << "RDMA context setup failed: endpoint store"; + LOG(ERROR) << "RDMA context setup failed: endpoint store"; return ERR_MEMORY; } if (openRdmaDevice(device_name_, port, gid_index)) { - PLOG(ERROR) << "RDMA context setup failed: open device"; + LOG(ERROR) << "RDMA context setup failed: open device"; return ERR_CONTEXT; } @@ -100,14 +101,14 @@ int RdmaContext::construct(size_t num_cq_list, size_t num_comp_channels, } if (joinNonblockingPollList(event_fd_, context_->async_fd)) { - PLOG(ERROR) + LOG(ERROR) << "RDMA context setup failed: register event file descriptor"; return ERR_CONTEXT; } for (size_t i = 0; i < num_comp_channel_; ++i) if (joinNonblockingPollList(event_fd_, comp_channel_[i]->fd)) { - PLOG(ERROR) + LOG(ERROR) << "RDMA context setup failed: register event file descriptor"; return ERR_CONTEXT; } @@ -124,7 +125,7 @@ int RdmaContext::construct(size_t num_cq_list, size_t num_comp_channels, worker_pool_ = std::make_shared(*this, socketId()); if (!worker_pool_) { - PLOG(ERROR) << "RDMA context setup failed: worker pool"; + LOG(ERROR) << "RDMA context setup failed: worker pool"; return ERR_MEMORY; } @@ -154,21 +155,24 @@ int RdmaContext::deconstruct() { endpoint_store_->destroyQPs(); for (auto &entry : memory_region_list_) { - if (ibv_dereg_mr(entry)) { - PLOG(ERROR) << "Fail to unregister memory region"; + int ret = ibv_dereg_mr(entry); + if (ret) { + LOG(ERROR) << "Fail to unregister memory region: " << strerror(ret); } } memory_region_list_.clear(); for (size_t i = 0; i < cq_list_.size(); ++i) { - if (ibv_destroy_cq(cq_list_[i])) { - PLOG(ERROR) << "Fail to destroy completion queue"; + int ret = ibv_destroy_cq(cq_list_[i]); + if (ret) { + PLOG(ERROR) << "Fail to destroy completion queue: " + << strerror(ret); } } cq_list_.clear(); if (event_fd_ >= 0) { - if (close(event_fd_)) PLOG(ERROR) << "Fail to close epoll fd"; + if (close(event_fd_)) LOG(ERROR) << "Fail to close epoll fd"; event_fd_ = -1; } @@ -176,20 +180,20 @@ int RdmaContext::deconstruct() { for (size_t i = 0; i < num_comp_channel_; ++i) if (comp_channel_[i]) if (ibv_destroy_comp_channel(comp_channel_[i])) - PLOG(ERROR) << "Fail to destroy completion channel"; + LOG(ERROR) << "Fail to destroy completion channel"; delete[] comp_channel_; comp_channel_ = nullptr; } if (pd_) { if (ibv_dealloc_pd(pd_)) - PLOG(ERROR) << "Fail to deallocate protection domain"; + LOG(ERROR) << "Fail to deallocate protection domain"; pd_ = nullptr; } if (context_) { if (ibv_close_device(context_)) - PLOG(ERROR) << "Fail to close device context"; + LOG(ERROR) << "Fail to close device context"; context_ = nullptr; } @@ -253,7 +257,7 @@ int RdmaContext::unregisterMemoryRegion(void *addr) { if ((*iter)->addr <= addr && addr < (char *)((*iter)->addr) + (*iter)->length) { if (ibv_dereg_mr(*iter)) { - PLOG(ERROR) << "Fail to unregister memory " << addr; + LOG(ERROR) << "Fail to unregister memory " << addr; return ERR_CONTEXT; } memory_region_list_.erase(iter); @@ -347,34 +351,32 @@ int RdmaContext::compVector() { return (next_comp_vector_index_++) % context_->num_comp_vectors; } -static inline int ipv6_addr_v4mapped(const struct in6_addr *a) -{ - return ((a->s6_addr32[0] | a->s6_addr32[1]) | - (a->s6_addr32[2] ^ htonl(0x0000ffff))) == 0UL || - /* IPv4 encoded multicast addresses */ - (a->s6_addr32[0] == htonl(0xff0e0000) && - ((a->s6_addr32[1] | - (a->s6_addr32[2] ^ htonl(0x0000ffff))) == 0UL)); +static inline int ipv6_addr_v4mapped(const struct in6_addr *a) { + return ((a->s6_addr32[0] | a->s6_addr32[1]) | + (a->s6_addr32[2] ^ htonl(0x0000ffff))) == 0UL || + /* IPv4 encoded multicast addresses */ + (a->s6_addr32[0] == htonl(0xff0e0000) && + ((a->s6_addr32[1] | (a->s6_addr32[2] ^ htonl(0x0000ffff))) == 0UL)); } -int RdmaContext::getBestGidIndex(struct ibv_context *context, ibv_port_attr &port_attr, uint8_t port) { - int gid_index = 0, i; - union ibv_gid temp_gid, temp_gid_rival; - int is_ipv4, is_ipv4_rival; +int RdmaContext::getBestGidIndex(struct ibv_context *context, + ibv_port_attr &port_attr, uint8_t port) { + int gid_index = 0, i; + union ibv_gid temp_gid, temp_gid_rival; + int is_ipv4, is_ipv4_rival; - if (ibv_query_gid(context, port, gid_index, &temp_gid)) - return -1; + if (ibv_query_gid(context, port, gid_index, &temp_gid)) return -1; is_ipv4 = ipv6_addr_v4mapped((struct in6_addr *)temp_gid.raw); - for (i = 1; i < port_attr.gid_tbl_len; i++) { - if (ibv_query_gid(context, port, i, &temp_gid_rival)) { - return -1; - } - is_ipv4_rival = ipv6_addr_v4mapped((struct in6_addr *)temp_gid_rival.raw); - if (is_ipv4_rival && !is_ipv4) - gid_index = i; - } - return gid_index; + for (i = 1; i < port_attr.gid_tbl_len; i++) { + if (ibv_query_gid(context, port, i, &temp_gid_rival)) { + return -1; + } + is_ipv4_rival = + ipv6_addr_v4mapped((struct in6_addr *)temp_gid_rival.raw); + if (is_ipv4_rival && !is_ipv4) gid_index = i; + } + return gid_index; } int RdmaContext::openRdmaDevice(const std::string &device_name, uint8_t port, @@ -383,7 +385,7 @@ int RdmaContext::openRdmaDevice(const std::string &device_name, uint8_t port, struct ibv_context *context = nullptr; struct ibv_device **devices = ibv_get_device_list(&num_devices); if (!devices || num_devices <= 0) { - PLOG(ERROR) << "ibv_get_device_list failed"; + LOG(ERROR) << "ibv_get_device_list failed"; return ERR_DEVICE_NOT_FOUND; } @@ -392,17 +394,18 @@ int RdmaContext::openRdmaDevice(const std::string &device_name, uint8_t port, context = ibv_open_device(devices[i]); if (!context) { - PLOG(ERROR) << "Failed to open device " << device_name; + LOG(ERROR) << "ibv_open_device(" << device_name << ") failed"; ibv_free_device_list(devices); return ERR_CONTEXT; } ibv_port_attr attr; - if (ibv_query_port(context, port, &attr)) { - PLOG(WARNING) << "Fail to query port " << port << " on " - << device_name; + int ret = ibv_query_port(context, port, &attr); + if (ret) { + LOG(WARNING) << "Fail to query port " << port << " on " + << device_name << ": " << strerror(ret); if (ibv_close_device(context)) { - PLOG(ERROR) << "Fail to close device " << device_name; + LOG(ERROR) << "ibv_close_device(" << device_name << ") failed"; } ibv_free_device_list(devices); return ERR_CONTEXT; @@ -411,46 +414,49 @@ int RdmaContext::openRdmaDevice(const std::string &device_name, uint8_t port, if (attr.state != IBV_PORT_ACTIVE) { LOG(WARNING) << "Device " << device_name << " port not active"; if (ibv_close_device(context)) { - PLOG(ERROR) << "Fail to close device " << device_name; + LOG(ERROR) << "ibv_close_device(" << device_name << ") failed"; } ibv_free_device_list(devices); return ERR_CONTEXT; } ibv_device_attr device_attr; - if (ibv_query_device(context, &device_attr)) { - PLOG(WARNING) << "Fail to query attributes on " << device_name; + ret = ibv_query_device(context, &device_attr); + if (ret) { + LOG(WARNING) << "Fail to query attributes on " << device_name + << ": " << strerror(ret); if (ibv_close_device(context)) { - PLOG(ERROR) << "Fail to close device " << device_name; + LOG(ERROR) << "ibv_close_device(" << device_name << ") failed"; } ibv_free_device_list(devices); return ERR_CONTEXT; } ibv_port_attr port_attr; - if (ibv_query_port(context, port, &port_attr)) { - PLOG(WARNING) << "Fail to query port attributes on " - << device_name << ":" << port; + ret = ibv_query_port(context, port, &port_attr); + if (ret) { + LOG(WARNING) << "Fail to query port attributes on " << device_name + << "/" << port << ": " << strerror(ret); if (ibv_close_device(context)) { - PLOG(ERROR) << "Fail to close device " << device_name; + LOG(ERROR) << "ibv_close_device(" << device_name << ") failed"; } ibv_free_device_list(devices); return ERR_CONTEXT; } updateGlobalConfig(device_attr); - if (gid_index == 0) - { + if (gid_index == 0) { int ret = getBestGidIndex(context, port_attr, port); - if (ret >= 0) - gid_index = ret; + if (ret >= 0) gid_index = ret; } - if (ibv_query_gid(context, port, gid_index, &gid_)) { - PLOG(WARNING) << "Device " << device_name << " GID " << gid_index - << " not available"; + ret = ibv_query_gid(context, port, gid_index, &gid_); + if (ret) { + LOG(WARNING) << "Fail to query GID " << gid_index + << " attributes on " << device_name << "/" << port + << ": " << strerror(ret); if (ibv_close_device(context)) { - PLOG(ERROR) << "Fail to close device " << device_name; + LOG(ERROR) << "ibv_close_device(" << device_name << ") failed"; } ibv_free_device_list(devices); return ERR_CONTEXT; @@ -458,14 +464,15 @@ int RdmaContext::openRdmaDevice(const std::string &device_name, uint8_t port, #ifndef CONFIG_SKIP_NULL_GID_CHECK if (isNullGid(&gid_)) { - LOG(WARNING) << "GID is NULL, please check your GID index by specifying MC_GID_INDEX"; + LOG(WARNING) << "GID is NULL, please check your GID index by " + "specifying MC_GID_INDEX"; if (ibv_close_device(context)) { - PLOG(ERROR) << "Fail to close device " << device_name; + LOG(ERROR) << "ibv_close_device(" << device_name << ") failed"; } ibv_free_device_list(devices); return ERR_CONTEXT; } -#endif // CONFIG_SKIP_NULL_GID_CHECK +#endif // CONFIG_SKIP_NULL_GID_CHECK context_ = context; port_ = port; @@ -511,7 +518,7 @@ int RdmaContext::joinNonblockingPollList(int event_fd, int data_fd) { int RdmaContext::poll(int num_entries, ibv_wc *wc, int cq_index) { int nr_poll = ibv_poll_cq(cq_list_[cq_index], num_entries, wc); if (nr_poll < 0) { - PLOG(ERROR) << "Failed to poll CQ #" << cq_index << " of device " + PLOG(ERROR) << "Failed to poll CQ " << cq_index << " of device " << device_name_; return ERR_CONTEXT; } diff --git a/mooncake-transfer-engine/src/transport/rdma_transport/rdma_endpoint.cpp b/mooncake-transfer-engine/src/transport/rdma_transport/rdma_endpoint.cpp index 87b0b1b..3863e7c 100644 --- a/mooncake-transfer-engine/src/transport/rdma_transport/rdma_endpoint.cpp +++ b/mooncake-transfer-engine/src/transport/rdma_transport/rdma_endpoint.cpp @@ -37,7 +37,7 @@ int RdmaEndPoint::construct(ibv_cq *cq, size_t num_qp_list, size_t max_sge_per_wr, size_t max_wr_depth, size_t max_inline_bytes) { if (status_.load(std::memory_order_relaxed) != INITIALIZING) { - PLOG(ERROR) << "Endpoint has already been constructed"; + LOG(ERROR) << "Endpoint has already been constructed"; return ERR_ENDPOINT; } @@ -46,7 +46,7 @@ int RdmaEndPoint::construct(ibv_cq *cq, size_t num_qp_list, max_wr_depth_ = (int)max_wr_depth; wr_depth_list_ = new volatile int[num_qp_list]; if (!wr_depth_list_) { - PLOG(ERROR) << "Failed to allocate memory for work request depth list"; + LOG(ERROR) << "Failed to allocate memory for work request depth list"; return ERR_MEMORY; } for (size_t i = 0; i < num_qp_list; ++i) { @@ -74,11 +74,11 @@ int RdmaEndPoint::construct(ibv_cq *cq, size_t num_qp_list, int RdmaEndPoint::deconstruct() { for (size_t i = 0; i < qp_list_.size(); ++i) { if (wr_depth_list_[i] != 0) - PLOG(WARNING) + LOG(WARNING) << "Outstanding work requests found, CQ will not be generated"; if (ibv_destroy_qp(qp_list_[i])) { - PLOG(ERROR) << "Failed to destroy QP"; + LOG(ERROR) << "Failed to destroy QP"; return ERR_ENDPOINT; } } @@ -92,7 +92,7 @@ int RdmaEndPoint::destroyQP() { return deconstruct(); } void RdmaEndPoint::setPeerNicPath(const std::string &peer_nic_path) { RWSpinlock::WriteGuard guard(lock_); if (connected()) { - LOG(WARNING) << "Previous connection is discarded"; + LOG(WARNING) << "Previous connection will be discarded"; disconnectUnlocked(); } peer_nic_path_ = peer_nic_path; @@ -101,7 +101,7 @@ void RdmaEndPoint::setPeerNicPath(const std::string &peer_nic_path) { int RdmaEndPoint::setupConnectionsByActive() { RWSpinlock::WriteGuard guard(lock_); if (connected()) { - // LOG(WARNING) << "Connection already connected"; + LOG(INFO) << "Connection has been established"; return 0; } HandShakeDesc local_desc, peer_desc; @@ -118,10 +118,7 @@ int RdmaEndPoint::setupConnectionsByActive() { int rc = context_.engine().sendHandshake(peer_server_name, local_desc, peer_desc); - if (rc) { - LOG(ERROR) << "Failed to exchange handshake description"; - return rc; - } + if (rc) return rc; if (peer_desc.local_nic_path != peer_nic_path_ || peer_desc.peer_nic_path != local_desc.local_nic_path) { @@ -190,15 +187,15 @@ void RdmaEndPoint::disconnect() { void RdmaEndPoint::disconnectUnlocked() { for (size_t i = 0; i < qp_list_.size(); ++i) { if (wr_depth_list_[i] != 0) - PLOG(WARNING) - << "Outstanding work requests found, CQ will not be generated"; + LOG(WARNING) << "Outstanding work requests will be dropped"; } ibv_qp_attr attr; memset(&attr, 0, sizeof(attr)); attr.qp_state = IBV_QPS_RESET; for (size_t i = 0; i < qp_list_.size(); ++i) { - if (ibv_modify_qp(qp_list_[i], &attr, IBV_QP_STATE)) - PLOG(ERROR) << "Failed to modity QP to RESET"; + int ret = ibv_modify_qp(qp_list_[i], &attr, IBV_QP_STATE); + if (ret) + LOG(ERROR) << "Failed to modity QP to RESET: " << strerror(ret); } peer_nic_path_.clear(); for (size_t i = 0; i < qp_list_.size(); ++i) wr_depth_list_[i] = 0; @@ -264,7 +261,7 @@ int RdmaEndPoint::submitPostSend( __sync_fetch_and_add(&wr_depth_list_[qp_index], wr_count); int rc = ibv_post_send(qp_list_[qp_index], wr_list, &bad_wr); if (rc) { - PLOG(ERROR) << "ibv_post_send failed"; + LOG(ERROR) << "ibv_post_send: " << strerror(rc); while (bad_wr) { int i = bad_wr - wr_list; failed_slice_list.push_back(slice_list[i]); @@ -291,7 +288,7 @@ int RdmaEndPoint::doSetupConnection(const std::string &peer_gid, std::string message = "QP count mismatch in peer and local endpoints, check " "MC_MAX_EP_PER_CTX"; - LOG(ERROR) << message; + LOG(ERROR) << "[Handshake] " << message; if (reply_msg) *reply_msg = message; return ERR_INVALID_ARGUMENT; } @@ -317,9 +314,11 @@ int RdmaEndPoint::doSetupConnection(int qp_index, const std::string &peer_gid, ibv_qp_attr attr; memset(&attr, 0, sizeof(attr)); attr.qp_state = IBV_QPS_RESET; - if (ibv_modify_qp(qp, &attr, IBV_QP_STATE)) { - std::string message = "Failed to modity QP to RESET"; - PLOG(ERROR) << message; + int ret = ibv_modify_qp(qp, &attr, IBV_QP_STATE); + if (ret) { + std::string message = "Failed to modity QP to RESET: "; + message += strerror(ret); + LOG(ERROR) << "[Handshake] " << message; if (reply_msg) *reply_msg = message; return ERR_ENDPOINT; } @@ -331,12 +330,14 @@ int RdmaEndPoint::doSetupConnection(int qp_index, const std::string &peer_gid, attr.pkey_index = 0; attr.qp_access_flags = IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_ATOMIC; - if (ibv_modify_qp(qp, &attr, - IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | - IBV_QP_ACCESS_FLAGS)) { + ret = ibv_modify_qp( + qp, &attr, + IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS); + if (ret) { std::string message = - "Failed to modity QP to INIT, check local context port num"; - PLOG(ERROR) << message; + "Failed to modity QP to INIT, check local context port num: "; + message += strerror(ret); + LOG(ERROR) << "[Handshake] " << message; if (reply_msg) *reply_msg = message; return ERR_ENDPOINT; } @@ -369,13 +370,16 @@ int RdmaEndPoint::doSetupConnection(int qp_index, const std::string &peer_gid, attr.rq_psn = 0; attr.max_dest_rd_atomic = 16; attr.min_rnr_timer = 12; // 12 in previous implementation - if (ibv_modify_qp(qp, &attr, - IBV_QP_STATE | IBV_QP_PATH_MTU | IBV_QP_MIN_RNR_TIMER | - IBV_QP_AV | IBV_QP_MAX_DEST_RD_ATOMIC | - IBV_QP_DEST_QPN | IBV_QP_RQ_PSN)) { + ret = ibv_modify_qp(qp, &attr, + IBV_QP_STATE | IBV_QP_PATH_MTU | IBV_QP_MIN_RNR_TIMER | + IBV_QP_AV | IBV_QP_MAX_DEST_RD_ATOMIC | + IBV_QP_DEST_QPN | IBV_QP_RQ_PSN); + if (ret) { std::string message = - "Failed to modity QP to RTR, check mtu, gid, peer lid, peer qp num"; - PLOG(ERROR) << message; + "Failed to modity QP to RTR, check mtu, gid, peer lid, peer qp " + "num: "; + message += strerror(ret); + LOG(ERROR) << "[Handshake] " << message; if (reply_msg) *reply_msg = message; return ERR_ENDPOINT; } @@ -388,13 +392,14 @@ int RdmaEndPoint::doSetupConnection(int qp_index, const std::string &peer_gid, attr.rnr_retry = 7; // or 7,RNR error attr.sq_psn = 0; attr.max_rd_atomic = 16; - - if (ibv_modify_qp(qp, &attr, - IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | - IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN | - IBV_QP_MAX_QP_RD_ATOMIC)) { - std::string message = "Failed to modity QP to RTS"; - PLOG(ERROR) << message; + ret = ibv_modify_qp(qp, &attr, + IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | + IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN | + IBV_QP_MAX_QP_RD_ATOMIC); + if (ret) { + std::string message = "Failed to modity QP to RTS: "; + message += strerror(ret); + LOG(ERROR) << "[Handshake] " << message; if (reply_msg) *reply_msg = message; return ERR_ENDPOINT; } diff --git a/mooncake-transfer-engine/src/transport/rdma_transport/rdma_transport.cpp b/mooncake-transfer-engine/src/transport/rdma_transport/rdma_transport.cpp index c9efe57..81165eb 100644 --- a/mooncake-transfer-engine/src/transport/rdma_transport/rdma_transport.cpp +++ b/mooncake-transfer-engine/src/transport/rdma_transport/rdma_transport.cpp @@ -25,6 +25,7 @@ #include "common.h" #include "config.h" +#include "topology.h" #include "transport/rdma_transport/rdma_context.h" #include "transport/rdma_transport/rdma_endpoint.h" @@ -46,34 +47,26 @@ int RdmaTransport::install(std::string &local_server_name, void **args) { const std::string nic_priority_matrix = static_cast(args[0]); bool dry_run = args[1] ? *static_cast(args[1]) : false; - TransferMetadata::PriorityMatrix local_priority_matrix; if (dry_run) return 0; metadata_ = meta; local_server_name_ = local_server_name; - int ret = metadata_->parseNicPriorityMatrix( - nic_priority_matrix, local_priority_matrix, device_name_list_); + int ret = local_topology_.parse(nic_priority_matrix); if (ret) { - LOG(ERROR) << "Transfer engine cannot be initialized: cannot parse " - "NIC priority matrix: " + LOG(ERROR) << "RdmaTransport: incorrect NIC priority matrix: " << nic_priority_matrix; return ret; } - int device_index = 0; - for (auto &device_name : device_name_list_) - device_name_to_index_map_[device_name] = device_index++; - ret = initializeRdmaResources(); if (ret) { - LOG(ERROR) << "Transfer engine cannot be initialized: cannot " - "initialize RDMA resources"; + LOG(ERROR) << "RdmaTransport: cannot initialize RDMA resources"; return ret; } - ret = allocateLocalSegmentID(local_priority_matrix); + ret = allocateLocalSegmentID(); if (ret) { LOG(ERROR) << "Transfer engine cannot be initialized: cannot " "allocate local segment"; @@ -82,16 +75,13 @@ int RdmaTransport::install(std::string &local_server_name, ret = startHandshakeDaemon(local_server_name); if (ret) { - LOG(ERROR) << "Transfer engine cannot be initialized: cannot start " - "handshake daemon"; + LOG(ERROR) << "RdmaTransport: cannot start handshake daemon"; return ret; } ret = metadata_->updateLocalSegmentDesc(); if (ret) { - LOG(ERROR) << "Transfer engine cannot be initialized: cannot " - "publish segments. Check the connectivity between this " - "server and etcd servers"; + LOG(ERROR) << "RdmaTransport: cannot publish segments"; return ret; } @@ -132,8 +122,7 @@ int RdmaTransport::unregisterLocalMemory(void *addr, bool update_metadata) { return 0; } -int RdmaTransport::allocateLocalSegmentID( - TransferMetadata::PriorityMatrix &priority_matrix) { +int RdmaTransport::allocateLocalSegmentID() { auto desc = std::make_shared(); if (!desc) return ERR_MEMORY; desc->name = local_server_name_; @@ -145,7 +134,7 @@ int RdmaTransport::allocateLocalSegmentID( device_desc.gid = entry->gid(); desc->devices.push_back(device_desc); } - desc->priority_matrix = priority_matrix; + desc->topology = local_topology_; metadata_->addLocalSegment(LOCAL_SEGMENT_ID, local_server_name_, std::move(desc)); return 0; @@ -165,7 +154,7 @@ int RdmaTransport::registerLocalMemoryBatch( for (size_t i = 0; i < buffer_list.size(); ++i) { if (results[i].get()) { - LOG(WARNING) << "Failed to register memory: addr " + LOG(WARNING) << "RdmaTransport: Failed to register memory: addr " << buffer_list[i].addr << " length " << buffer_list[i].length; } @@ -186,7 +175,7 @@ int RdmaTransport::unregisterLocalMemoryBatch( for (size_t i = 0; i < addr_list.size(); ++i) { if (results[i].get()) - LOG(WARNING) << "Failed to unregister memory: addr " + LOG(WARNING) << "RdmaTransport: Failed to unregister memory: addr " << addr_list[i]; } @@ -197,7 +186,8 @@ int RdmaTransport::submitTransfer(BatchID batch_id, const std::vector &entries) { auto &batch_desc = *((BatchDesc *)(batch_id)); if (batch_desc.task_list.size() + entries.size() > batch_desc.batch_size) { - LOG(ERROR) << "Exceed the limitation of current batch's capacity"; + LOG(ERROR) << "RdmaTransport: Exceed the limitation of current batch's " + "capacity"; return ERR_TOO_MANY_REQUESTS; } @@ -241,8 +231,61 @@ int RdmaTransport::submitTransfer(BatchID batch_id, break; } if (device_id < 0) { - LOG(ERROR) << "Address not registered by any device(s) " - << slice->source_addr; + LOG(ERROR) + << "RdmaTransport: Address not registered by any device(s) " + << slice->source_addr; + return ERR_ADDRESS_NOT_REGISTERED; + } + } + } + for (auto &entry : slices_to_post) + entry.first->submitPostSend(entry.second); + return 0; +} + +int RdmaTransport::submitTransferTask( + const std::vector &request_list, + const std::vector &task_list) { + std::unordered_map, std::vector> + slices_to_post; + auto local_segment_desc = metadata_->getSegmentDescByID(LOCAL_SEGMENT_ID); + const size_t kBlockSize = globalConfig().slice_size; + const int kMaxRetryCount = globalConfig().retry_cnt; + for (size_t index = 0; index < request_list.size(); ++index) { + auto &request = *request_list[index]; + auto &task = *task_list[index]; + for (uint64_t offset = 0; offset < request.length; + offset += kBlockSize) { + auto slice = new Slice(); + slice->source_addr = (char *)request.source + offset; + slice->length = std::min(request.length - offset, kBlockSize); + slice->opcode = request.opcode; + slice->rdma.dest_addr = request.target_offset + offset; + slice->rdma.retry_cnt = 0; + slice->rdma.max_retry_cnt = kMaxRetryCount; + slice->task = &task; + slice->target_id = request.target_id; + slice->status = Slice::PENDING; + + int buffer_id = -1, device_id = -1, retry_cnt = 0; + while (retry_cnt < kMaxRetryCount) { + if (selectDevice(local_segment_desc.get(), + (uint64_t)slice->source_addr, slice->length, + buffer_id, device_id, retry_cnt++)) + continue; + auto &context = context_list_[device_id]; + if (!context->active()) continue; + slice->rdma.source_lkey = + local_segment_desc->buffers[buffer_id].lkey[device_id]; + slices_to_post[context].push_back(slice); + task.total_bytes += slice->length; + task.slices.push_back(slice); + break; + } + if (device_id < 0) { + LOG(ERROR) + << "RdmaTransport: Address not registered by any device(s) " + << slice->source_addr; return ERR_ADDRESS_NOT_REGISTERED; } } @@ -307,23 +350,35 @@ int RdmaTransport::onSetupRdmaConnections(const HandShakeDesc &peer_desc, HandShakeDesc &local_desc) { auto local_nic_name = getNicNameFromNicPath(peer_desc.peer_nic_path); if (local_nic_name.empty()) return ERR_INVALID_ARGUMENT; - auto context = context_list_[device_name_to_index_map_[local_nic_name]]; + + std::shared_ptr context; + int index = 0; + for (auto &entry : local_topology_.getHcaList()) { + if (entry == local_nic_name) { + context = context_list_[index]; + break; + } + index++; + } + if (!context) return ERR_INVALID_ARGUMENT; + #ifdef CONFIG_ERDMA if (context->deleteEndpoint(peer_desc.local_nic_path)) return ERR_ENDPOINT; #endif + auto endpoint = context->endpoint(peer_desc.local_nic_path); if (!endpoint) return ERR_ENDPOINT; return endpoint->setupConnectionsByPassive(peer_desc, local_desc); } int RdmaTransport::initializeRdmaResources() { - if (device_name_list_.empty()) { - LOG(ERROR) << "No available RNIC!"; + if (local_topology_.empty()) { + LOG(ERROR) << "RdmaTransport: No available RNIC"; return ERR_DEVICE_NOT_FOUND; } std::vector device_speed_list; - for (auto &device_name : device_name_list_) { + for (auto &device_name : local_topology_.getHcaList()) { auto context = std::make_shared(*this, device_name); if (!context) return ERR_MEMORY; @@ -355,33 +410,8 @@ int RdmaTransport::selectDevice(SegmentDesc *desc, uint64_t offset, if (buffer_desc.addr > offset || offset + length > buffer_desc.addr + buffer_desc.length) continue; - - auto &priority = desc->priority_matrix[buffer_desc.name]; - size_t preferred_rnic_list_len = priority.preferred_rnic_list.size(); - size_t available_rnic_list_len = priority.available_rnic_list.size(); - size_t rnic_list_len = - preferred_rnic_list_len + available_rnic_list_len; - if (rnic_list_len == 0) return ERR_DEVICE_NOT_FOUND; - - if (retry_count == 0) { - int rand_value = SimpleRandom::Get().next(); - if (preferred_rnic_list_len) - device_id = - priority.preferred_rnic_id_list[rand_value % - preferred_rnic_list_len]; - else - device_id = - priority.available_rnic_id_list[rand_value % - available_rnic_list_len]; - } else { - size_t index = (retry_count - 1) % rnic_list_len; - if (index < preferred_rnic_list_len) - device_id = priority.preferred_rnic_id_list[index]; - else - device_id = priority.available_rnic_id_list[index - preferred_rnic_list_len]; - } - - return 0; + device_id = desc->topology.selectDevice(buffer_desc.name, retry_count); + if (device_id >= 0) return device_id; } return ERR_ADDRESS_NOT_REGISTERED; diff --git a/mooncake-transfer-engine/src/transport/rdma_transport/worker_pool.cpp b/mooncake-transfer-engine/src/transport/rdma_transport/worker_pool.cpp index 101c95b..f48d7f9 100644 --- a/mooncake-transfer-engine/src/transport/rdma_transport/worker_pool.cpp +++ b/mooncake-transfer-engine/src/transport/rdma_transport/worker_pool.cpp @@ -25,7 +25,7 @@ #ifdef USE_CUDA #include -#endif // USE_CUDA +#endif // USE_CUDA // Experimental: Per-thread SegmentDesc & EndPoint Caches // #define CONFIG_CACHE_SEGMENT_DESC @@ -82,6 +82,8 @@ int WorkerPool::submitPostSend( context_.engine().meta()->getSegmentDescByID(target_id); if (!segment_desc_map[target_id]) { segment_desc_map.clear(); + LOG(ERROR) << "Cannot get target segment description #" + << target_id; return ERR_INVALID_ARGUMENT; } } @@ -90,12 +92,14 @@ int WorkerPool::submitPostSend( std::unordered_map> segment_desc_map; for (auto &slice : slice_list) { - assert(slice); auto target_id = slice->target_id; if (!segment_desc_map.count(target_id)) segment_desc_map[target_id] = context_.engine().meta()->getSegmentDescByID(target_id); - if (!segment_desc_map[target_id]) return ERR_INVALID_ARGUMENT; + if (!segment_desc_map[target_id]) { + LOG(ERROR) << "Cannot get target segment #" << target_id; + return ERR_INVALID_ARGUMENT; + } } #endif // CONFIG_CACHE_SEGMENT_DESC @@ -107,11 +111,13 @@ int WorkerPool::submitPostSend( if (RdmaTransport::selectDevice(peer_segment_desc.get(), slice->rdma.dest_addr, slice->length, buffer_id, device_id)) { + LOG(WARNING) << "Reselect remote NIC for address " + << (void *)slice->rdma.dest_addr << " on segment #" + << slice->target_id; peer_segment_desc = context_.engine().meta()->getSegmentDescByID( slice->target_id, true); if (!peer_segment_desc) { - LOG(ERROR) << "Cannot get segment description for slice: " - << (void *)slice; + LOG(ERROR) << "Cannot get target segment #" << slice->target_id; slice->markFailed(); continue; } @@ -119,7 +125,8 @@ int WorkerPool::submitPostSend( peer_segment_desc.get(), slice->rdma.dest_addr, slice->length, buffer_id, device_id)) { LOG(ERROR) << "Failed to select remote NIC for address " - << (void *)slice->rdma.dest_addr; + << (void *)slice->rdma.dest_addr << " on segment #" + << slice->target_id; slice->markFailed(); continue; } @@ -313,9 +320,10 @@ void WorkerPool::redispatch(std::vector &slice_list, segment_desc_map; for (auto &slice : slice_list) { auto target_id = slice->target_id; - if (!segment_desc_map.count(target_id)) + if (!segment_desc_map.count(target_id)) { segment_desc_map[target_id] = context_.engine().meta()->getSegmentDescByID(target_id, true); + } } for (auto &slice : slice_list) { @@ -325,7 +333,8 @@ void WorkerPool::redispatch(std::vector &slice_list, } else { auto &peer_segment_desc = segment_desc_map[slice->target_id]; int buffer_id, device_id; - if (RdmaTransport::selectDevice(peer_segment_desc.get(), + if (!peer_segment_desc || + RdmaTransport::selectDevice(peer_segment_desc.get(), slice->rdma.dest_addr, slice->length, buffer_id, device_id, slice->rdma.retry_cnt)) { @@ -374,19 +383,19 @@ void WorkerPool::transferWorker(int thread_id) { int WorkerPool::doProcessContextEvents() { ibv_async_event event; if (ibv_get_async_event(context_.context(), &event) < 0) return ERR_CONTEXT; - LOG(INFO) << "Received context async event: " - << ibv_event_type_str(event.event_type) << " for context " - << context_.deviceName(); + LOG(WARNING) << "Worker: Received context async event " + << ibv_event_type_str(event.event_type) << " for context " + << context_.deviceName(); if (event.event_type == IBV_EVENT_DEVICE_FATAL || event.event_type == IBV_EVENT_CQ_ERR || event.event_type == IBV_EVENT_WQ_FATAL || event.event_type == IBV_EVENT_PORT_ERR || event.event_type == IBV_EVENT_LID_CHANGE) { context_.set_active(false); - LOG(INFO) << "Context " << context_.deviceName() << " is inactive"; + LOG(INFO) << "Worker: Context " << context_.deviceName() << " is now inactive"; } else if (event.event_type == IBV_EVENT_PORT_ACTIVE) { context_.set_active(true); - LOG(INFO) << "Context " << context_.deviceName() << " is active"; + LOG(INFO) << "Worker: Context " << context_.deviceName() << " is now active"; } ibv_ack_async_event(&event); return 0; @@ -398,15 +407,12 @@ void WorkerPool::monitorWorker() { struct epoll_event event; int num_events = epoll_wait(context_.eventFd(), &event, 1, 100); if (num_events < 0) { - PLOG(ERROR) << "Failed to call epoll wait"; + PLOG(ERROR) << "Worker: epoll_wait()"; continue; } if (num_events == 0) continue; - LOG(ERROR) << "Received event, fd: " << event.data.fd - << ", events: " << event.events; - if (!(event.events & EPOLLIN)) continue; if (event.data.fd == context_.context()->async_fd) diff --git a/mooncake-transfer-engine/src/transport/tcp_transport/tcp_transport.cpp b/mooncake-transfer-engine/src/transport/tcp_transport/tcp_transport.cpp index 442a652..bce4e1d 100644 --- a/mooncake-transfer-engine/src/transport/tcp_transport/tcp_transport.cpp +++ b/mooncake-transfer-engine/src/transport/tcp_transport/tcp_transport.cpp @@ -208,17 +208,14 @@ int TcpTransport::install(std::string &local_server_name, int ret = allocateLocalSegmentID(); if (ret) { - LOG(ERROR) << "*** Transfer engine cannot be initialized: cannot " - "allocate local segment"; + LOG(ERROR) << "TcpTransport: cannot allocate local segment"; return -1; } ret = metadata_->updateLocalSegmentDesc(); if (ret) { - LOG(ERROR) << "*** Transfer engine cannot be initialized: cannot " - "publish segments"; - LOG(ERROR) << "*** Check the connectivity between this server and " - "metadata server (etcd server)"; + LOG(ERROR) << "TcpTransport: cannot publish segments, " + "check the availability of metadata storage"; return -1; } @@ -295,7 +292,8 @@ int TcpTransport::submitTransfer(BatchID batch_id, const std::vector &entries) { auto &batch_desc = *((BatchDesc *)(batch_id)); if (batch_desc.task_list.size() + entries.size() > batch_desc.batch_size) { - LOG(ERROR) << "Exceed the limitation of current batch's capacity"; + LOG(ERROR) << "TcpTransport: Exceed the limitation of current batch's " + "capacity"; return ERR_TOO_MANY_REQUESTS; } @@ -321,13 +319,34 @@ int TcpTransport::submitTransfer(BatchID batch_id, return 0; } +int TcpTransport::submitTransferTask( + const std::vector &request_list, + const std::vector &task_list) { + for (size_t index = 0; index < request_list.size(); ++index) { + auto &request = *request_list[index]; + auto &task = *task_list[index]; + task.total_bytes = request.length; + auto slice = new Slice(); + slice->source_addr = (char *)request.source; + slice->length = request.length; + slice->opcode = request.opcode; + slice->tcp.dest_addr = request.target_offset; + slice->task = &task; + slice->target_id = request.target_id; + slice->status = Slice::PENDING; + task.slices.push_back(slice); + startTransfer(slice); + } + return 0; +} + void TcpTransport::worker() { while (running_) { try { context_->doAccept(); context_->io_context.run(); } catch (std::exception &e) { - LOG(ERROR) << "Exception: " << e.what(); + LOG(ERROR) << "TcpTransport: exception: " << e.what(); } } } @@ -348,10 +367,9 @@ void TcpTransport::startTransfer(Slice *slice) { return; } - auto endpoint_iterator = - resolver.resolve(boost::asio::ip::tcp::v4(), - meta_entry.ip_or_host_name, - std::to_string(meta_entry.rpc_port)); + auto endpoint_iterator = resolver.resolve( + boost::asio::ip::tcp::v4(), meta_entry.ip_or_host_name, + std::to_string(meta_entry.rpc_port)); boost::asio::connect(socket, endpoint_iterator); auto session = std::make_shared(std::move(socket)); session->on_finalize_ = [slice](TransferStatusEnum status) { @@ -363,7 +381,7 @@ void TcpTransport::startTransfer(Slice *slice) { session->initiate(slice->source_addr, slice->tcp.dest_addr, slice->length, slice->opcode); } catch (std::exception &e) { - LOG(ERROR) << "ASIO Exception: " << e.what(); + LOG(ERROR) << "TcpTransport: ASIO exception: " << e.what(); slice->markFailed(); } } diff --git a/mooncake-transfer-engine/tests/rdma_transport_test.cpp b/mooncake-transfer-engine/tests/rdma_transport_test.cpp index e74e2f8..a2c7a29 100644 --- a/mooncake-transfer-engine/tests/rdma_transport_test.cpp +++ b/mooncake-transfer-engine/tests/rdma_transport_test.cpp @@ -121,10 +121,10 @@ static void freeMemoryPool(void *addr, size_t size) { #endif } -int initiatorWorker(Transport *xport, SegmentID segment_id, int thread_id, +int initiatorWorker(TransferEngine *engine, SegmentID segment_id, int thread_id, void *addr) { bindToSocket(0); - auto segment_desc = xport->meta()->getSegmentDescByID(segment_id); + auto segment_desc = engine->getMetadata()->getSegmentDescByID(segment_id); uint64_t remote_base = (uint64_t)segment_desc->buffers[0].addr; const size_t kDataLength = 4096000; { @@ -134,7 +134,7 @@ int initiatorWorker(Transport *xport, SegmentID segment_id, int thread_id, LOG(INFO) << "Write Data: " << std::string((char *)(addr), 16) << "..."; - auto batch_id = xport->allocateBatchID(1); + auto batch_id = engine->allocateBatchID(1); int ret = 0; TransferRequest entry; @@ -143,12 +143,12 @@ int initiatorWorker(Transport *xport, SegmentID segment_id, int thread_id, entry.source = (uint8_t *)(addr); entry.target_id = segment_id; entry.target_offset = remote_base; - ret = xport->submitTransfer(batch_id, {entry}); + ret = engine->submitTransfer(batch_id, {entry}); LOG_ASSERT(!ret); bool completed = false; TransferStatus status; while (!completed) { - int ret = xport->getTransferStatus(batch_id, 0, status); + int ret = engine->getTransferStatus(batch_id, 0, status); LOG_ASSERT(!ret); if (status.s == TransferStatusEnum::COMPLETED) completed = true; @@ -157,13 +157,13 @@ int initiatorWorker(Transport *xport, SegmentID segment_id, int thread_id, completed = true; } } - ret = xport->freeBatchID(batch_id); + ret = engine->freeBatchID(batch_id); LOG_ASSERT(!ret); } { LOG(INFO) << "Stage 2: Read Data"; - auto batch_id = xport->allocateBatchID(1); + auto batch_id = engine->allocateBatchID(1); int ret = 0; TransferRequest entry; @@ -172,12 +172,12 @@ int initiatorWorker(Transport *xport, SegmentID segment_id, int thread_id, entry.source = (uint8_t *)(addr) + kDataLength; entry.target_id = segment_id; entry.target_offset = remote_base; - ret = xport->submitTransfer(batch_id, {entry}); + ret = engine->submitTransfer(batch_id, {entry}); LOG_ASSERT(!ret); bool completed = false; TransferStatus status; while (!completed) { - int ret = xport->getTransferStatus(batch_id, 0, status); + int ret = engine->getTransferStatus(batch_id, 0, status); LOG_ASSERT(!ret); if (status.s == TransferStatusEnum::COMPLETED) completed = true; @@ -186,7 +186,7 @@ int initiatorWorker(Transport *xport, SegmentID segment_id, int thread_id, completed = true; } } - ret = xport->freeBatchID(batch_id); + ret = engine->freeBatchID(batch_id); LOG_ASSERT(!ret); } @@ -239,16 +239,12 @@ std::string loadNicPriorityMatrix() { } int initiator() { - auto metadata_client = - std::make_shared(FLAGS_metadata_server); - LOG_ASSERT(metadata_client); - const size_t ram_buffer_size = 1ull << 30; - auto engine = std::make_unique(metadata_client); + auto engine = std::make_unique(); auto hostname_port = parseHostNameWithPort(FLAGS_local_server_name); - engine->init(FLAGS_local_server_name.c_str(), hostname_port.first.c_str(), - hostname_port.second); + engine->init(FLAGS_metadata_server, FLAGS_local_server_name.c_str(), + hostname_port.first.c_str(), hostname_port.second); Transport *xport = nullptr; if (FLAGS_protocol == "rdma") { @@ -256,11 +252,11 @@ int initiator() { void **args = (void **)malloc(2 * sizeof(void *)); args[0] = (void *)nic_priority_matrix.c_str(); args[1] = nullptr; - xport = engine->installOrGetTransport("rdma", args); + xport = engine->installTransport("rdma", args); } else if (FLAGS_protocol == "tcp") { - xport = engine->installOrGetTransport("tcp", nullptr); + xport = engine->installTransport("tcp", nullptr); } else if (FLAGS_protocol == "nvmeof") { - xport = engine->installOrGetTransport("nvmeof", nullptr); + xport = engine->installTransport("nvmeof", nullptr); } else { LOG(ERROR) << "Unsupported protocol"; } @@ -280,7 +276,7 @@ int initiator() { #endif auto segment_id = engine->openSegment(FLAGS_segment_id.c_str()); - std::thread workers(initiatorWorker, xport, segment_id, 0, addr); + std::thread workers(initiatorWorker, engine.get(), segment_id, 0, addr); workers.join(); engine->unregisterLocalMemory(addr); freeMemoryPool(addr, ram_buffer_size); @@ -288,27 +284,23 @@ int initiator() { } int target() { - auto metadata_client = - std::make_shared(FLAGS_metadata_server); - LOG_ASSERT(metadata_client); - const size_t ram_buffer_size = 1ull << 30; - auto engine = std::make_unique(metadata_client); + auto engine = std::make_unique(); auto hostname_port = parseHostNameWithPort(FLAGS_local_server_name); - engine->init(FLAGS_local_server_name.c_str(), hostname_port.first.c_str(), - hostname_port.second); + engine->init(FLAGS_metadata_server, FLAGS_local_server_name.c_str(), + hostname_port.first.c_str(), hostname_port.second); if (FLAGS_protocol == "rdma") { auto nic_priority_matrix = loadNicPriorityMatrix(); void **args = (void **)malloc(2 * sizeof(void *)); args[0] = (void *)nic_priority_matrix.c_str(); args[1] = nullptr; - engine->installOrGetTransport("rdma", args); + engine->installTransport("rdma", args); } else if (FLAGS_protocol == "tcp") { - engine->installOrGetTransport("tcp", nullptr); + engine->installTransport("tcp", nullptr); } else if (FLAGS_protocol == "nvmeof") { - engine->installOrGetTransport("nvmeof", nullptr); + engine->installTransport("nvmeof", nullptr); } else { LOG(ERROR) << "Unsupported protocol"; } diff --git a/mooncake-transfer-engine/tests/rdma_transport_test2.cpp b/mooncake-transfer-engine/tests/rdma_transport_test2.cpp index be91694..617fc52 100644 --- a/mooncake-transfer-engine/tests/rdma_transport_test2.cpp +++ b/mooncake-transfer-engine/tests/rdma_transport_test2.cpp @@ -121,12 +121,9 @@ class RDMATransportTest : public ::testing::Test { LOG(INFO) << "HERE \n"; google::InitGoogleLogging("RDMATransportTest"); FLAGS_logtostderr = 1; - metadata_client = - std::make_shared(FLAGS_metadata_server); - LOG_ASSERT(metadata_client); - engine = std::make_unique(metadata_client); + engine = std::make_unique(); hostname_port = parseHostNameWithPort(FLAGS_local_server_name); - engine->init(FLAGS_local_server_name.c_str(), + engine->init(FLAGS_metadata_server, FLAGS_local_server_name.c_str(), hostname_port.first.c_str(), hostname_port.second + offset++); xport = nullptr; @@ -134,14 +131,14 @@ class RDMATransportTest : public ::testing::Test { args = (void **)malloc(2 * sizeof(void *)); args[0] = (void *)nic_priority_matrix.c_str(); args[1] = nullptr; - xport = engine->installOrGetTransport("rdma", args); + xport = engine->installTransport("rdma", args); ASSERT_NE(xport, nullptr); addr = allocateMemoryPool(ram_buffer_size, 0, false); int rc = engine->registerLocalMemory(addr, ram_buffer_size, "cpu:0"); ASSERT_EQ(rc, 0); segment_id = engine->openSegment(FLAGS_segment_id.c_str()); bindToSocket(0); - segment_desc = xport->meta()->getSegmentDescByID(segment_id); + segment_desc = engine->getMetadata()->getSegmentDescByID(segment_id); remote_base = (uint64_t)segment_desc->buffers[0].addr; } @@ -158,7 +155,7 @@ TEST_F(RDMATransportTest, MultiWrite) { while (times--) { for (size_t offset = 0; offset < kDataLength; ++offset) *((char *)(addr) + offset) = 'a' + lrand48() % 26; - auto batch_id = xport->allocateBatchID(1); + auto batch_id = engine->allocateBatchID(1); int ret = 0; TransferRequest entry; entry.opcode = TransferRequest::WRITE; @@ -166,12 +163,12 @@ TEST_F(RDMATransportTest, MultiWrite) { entry.source = (uint8_t *)(addr); entry.target_id = segment_id; entry.target_offset = remote_base; - ret = xport->submitTransfer(batch_id, {entry}); + ret = engine->submitTransfer(batch_id, {entry}); LOG_ASSERT(!ret); bool completed = false; TransferStatus status; while (!completed) { - int ret = xport->getTransferStatus(batch_id, 0, status); + int ret = engine->getTransferStatus(batch_id, 0, status); ASSERT_EQ(ret, 0); if (status.s == TransferStatusEnum::COMPLETED) completed = true; @@ -180,7 +177,7 @@ TEST_F(RDMATransportTest, MultiWrite) { completed = true; } } - ret = xport->freeBatchID(batch_id); + ret = engine->freeBatchID(batch_id); ASSERT_EQ(ret, 0); } } @@ -192,7 +189,7 @@ TEST_F(RDMATransportTest, MultipleRead) { for (size_t offset = 0; offset < kDataLength; ++offset) *((char *)(addr) + offset) = 'a' + lrand48() % 26; - auto batch_id = xport->allocateBatchID(1); + auto batch_id = engine->allocateBatchID(1); int ret = 0; TransferRequest entry; entry.opcode = TransferRequest::WRITE; @@ -200,12 +197,12 @@ TEST_F(RDMATransportTest, MultipleRead) { entry.source = (uint8_t *)(addr); entry.target_id = segment_id; entry.target_offset = remote_base; - ret = xport->submitTransfer(batch_id, {entry}); + ret = engine->submitTransfer(batch_id, {entry}); LOG_ASSERT(!ret); bool completed = false; TransferStatus status; while (!completed) { - int ret = xport->getTransferStatus(batch_id, 0, status); + int ret = engine->getTransferStatus(batch_id, 0, status); ASSERT_EQ(ret, 0); if (status.s == TransferStatusEnum::COMPLETED) completed = true; @@ -214,12 +211,12 @@ TEST_F(RDMATransportTest, MultipleRead) { completed = true; } } - ret = xport->freeBatchID(batch_id); + ret = engine->freeBatchID(batch_id); ASSERT_EQ(ret, 0); } times = 10; while (times--) { - auto batch_id = xport->allocateBatchID(1); + auto batch_id = engine->allocateBatchID(1); int ret = 0; TransferRequest entry; entry.opcode = TransferRequest::READ; @@ -227,12 +224,12 @@ TEST_F(RDMATransportTest, MultipleRead) { entry.source = (uint8_t *)(addr) + kDataLength; entry.target_id = segment_id; entry.target_offset = remote_base; - ret = xport->submitTransfer(batch_id, {entry}); + ret = engine->submitTransfer(batch_id, {entry}); ASSERT_EQ(ret, 0); bool completed = false; TransferStatus status; while (!completed) { - int ret = xport->getTransferStatus(batch_id, 0, status); + int ret = engine->getTransferStatus(batch_id, 0, status); ASSERT_EQ(ret, 0); if (status.s == TransferStatusEnum::COMPLETED) completed = true; @@ -240,7 +237,7 @@ TEST_F(RDMATransportTest, MultipleRead) { completed = true; } } - ret = xport->freeBatchID(batch_id); + ret = engine->freeBatchID(batch_id); ASSERT_EQ(ret, 0); ret = memcmp((uint8_t *)(addr), (uint8_t *)(addr) + kDataLength, kDataLength); diff --git a/mooncake-transfer-engine/tests/tcp_transport_test.cpp b/mooncake-transfer-engine/tests/tcp_transport_test.cpp index 5c13262..71c827c 100644 --- a/mooncake-transfer-engine/tests/tcp_transport_test.cpp +++ b/mooncake-transfer-engine/tests/tcp_transport_test.cpp @@ -72,16 +72,12 @@ static void *allocateMemoryPool(size_t size, int socket_id, } TEST_F(TCPTransportTest, GetTcpTest) { - auto metadata_client = std::make_shared("127.0.0.1:2379"); - LOG_ASSERT(metadata_client); - - auto engine = std::make_unique(metadata_client); - + auto engine = std::make_unique(); auto hostname_port = parseHostNameWithPort("127.0.0.2:12345"); - engine->init("127.0.0.2:12345", hostname_port.first.c_str(), - hostname_port.second); + engine->init("127.0.0.1:2379", "127.0.0.2:12345", + hostname_port.first.c_str(), hostname_port.second); Transport *xport = nullptr; - xport = engine->installOrGetTransport("tcp", nullptr); + xport = engine->installTransport("tcp", nullptr); LOG_ASSERT(xport != nullptr); } @@ -89,16 +85,12 @@ TEST_F(TCPTransportTest, Writetest) { const size_t kDataLength = 4096000; void *addr = nullptr; const size_t ram_buffer_size = 1ull << 30; - auto metadata_client = std::make_shared("127.0.0.1:2379"); - LOG_ASSERT(metadata_client); - - auto engine = std::make_unique(metadata_client); - + auto engine = std::make_unique(); auto hostname_port = parseHostNameWithPort("127.0.0.2:12345"); - engine->init("127.0.0.2:12345", hostname_port.first.c_str(), - hostname_port.second); + engine->init("127.0.0.1:2379", "127.0.0.2:12345", + hostname_port.first.c_str(), hostname_port.second); Transport *xport = nullptr; - xport = engine->installOrGetTransport("tcp", nullptr); + xport = engine->installTransport("tcp", nullptr); LOG_ASSERT(xport != nullptr); addr = allocateMemoryPool(ram_buffer_size, 0, false); @@ -107,28 +99,28 @@ TEST_F(TCPTransportTest, Writetest) { for (size_t offset = 0; offset < kDataLength; ++offset) *((char *)(addr) + offset) = 'a' + lrand48() % 26; - auto batch_id = xport->allocateBatchID(1); + auto batch_id = engine->allocateBatchID(1); int ret = 0; auto segment_id = engine->openSegment("127.0.0.2:12345"); TransferRequest entry; - auto segment_desc = xport->meta()->getSegmentDescByID(segment_id); + auto segment_desc = engine->getMetadata()->getSegmentDescByID(segment_id); uint64_t remote_base = (uint64_t)segment_desc->buffers[0].addr; entry.opcode = TransferRequest::WRITE; entry.length = kDataLength; entry.source = (uint8_t *)(addr); entry.target_id = segment_id; entry.target_offset = remote_base; - ret = xport->submitTransfer(batch_id, {entry}); + ret = engine->submitTransfer(batch_id, {entry}); LOG_ASSERT(!ret); bool completed = false; TransferStatus status; while (!completed) { - int ret = xport->getTransferStatus(batch_id, 0, status); + int ret = engine->getTransferStatus(batch_id, 0, status); ASSERT_EQ(ret, 0); LOG_ASSERT(status.s != TransferStatusEnum::FAILED); if (status.s == TransferStatusEnum::COMPLETED) completed = true; } - ret = xport->freeBatchID(batch_id); + ret = engine->freeBatchID(batch_id); ASSERT_EQ(ret, 0); } @@ -136,16 +128,12 @@ TEST_F(TCPTransportTest, WriteAndReadtest) { const size_t kDataLength = 4096000; void *addr = nullptr; const size_t ram_buffer_size = 1ull << 30; - auto metadata_client = std::make_shared("127.0.0.1:2379"); - LOG_ASSERT(metadata_client); - - auto engine = std::make_unique(metadata_client); - + auto engine = std::make_unique(); auto hostname_port = parseHostNameWithPort("127.0.0.2:12345"); - engine->init("127.0.0.2:12345", hostname_port.first.c_str(), - hostname_port.second); + engine->init("127.0.0.1:2379", "127.0.0.2:12345", + hostname_port.first.c_str(), hostname_port.second); Transport *xport = nullptr; - xport = engine->installOrGetTransport("tcp", nullptr); + xport = engine->installTransport("tcp", nullptr); LOG_ASSERT(xport != nullptr); addr = allocateMemoryPool(ram_buffer_size, 0, false); @@ -155,10 +143,10 @@ TEST_F(TCPTransportTest, WriteAndReadtest) { *((char *)(addr) + offset) = 'a' + lrand48() % 26; auto segment_id = engine->openSegment("127.0.0.2:12345"); - auto segment_desc = xport->meta()->getSegmentDescByID(segment_id); + auto segment_desc = engine->getMetadata()->getSegmentDescByID(segment_id); uint64_t remote_base = (uint64_t)segment_desc->buffers[0].addr; { - auto batch_id = xport->allocateBatchID(1); + auto batch_id = engine->allocateBatchID(1); int ret = 0; TransferRequest entry; entry.opcode = TransferRequest::WRITE; @@ -166,22 +154,22 @@ TEST_F(TCPTransportTest, WriteAndReadtest) { entry.source = (uint8_t *)(addr); entry.target_id = segment_id; entry.target_offset = remote_base; - ret = xport->submitTransfer(batch_id, {entry}); + ret = engine->submitTransfer(batch_id, {entry}); LOG_ASSERT(!ret); bool completed = false; TransferStatus status; while (!completed) { - int ret = xport->getTransferStatus(batch_id, 0, status); + int ret = engine->getTransferStatus(batch_id, 0, status); ASSERT_EQ(ret, 0); LOG_ASSERT(status.s != TransferStatusEnum::FAILED); if (status.s == TransferStatusEnum::COMPLETED) completed = true; } - ret = xport->freeBatchID(batch_id); + ret = engine->freeBatchID(batch_id); ASSERT_EQ(ret, 0); } { - auto batch_id = xport->allocateBatchID(1); + auto batch_id = engine->allocateBatchID(1); int ret = 0; TransferRequest entry; @@ -190,17 +178,17 @@ TEST_F(TCPTransportTest, WriteAndReadtest) { entry.source = (uint8_t *)(addr) + kDataLength; entry.target_id = segment_id; entry.target_offset = remote_base; - ret = xport->submitTransfer(batch_id, {entry}); + ret = engine->submitTransfer(batch_id, {entry}); LOG_ASSERT(!ret); bool completed = false; TransferStatus status; while (!completed) { - int ret = xport->getTransferStatus(batch_id, 0, status); + int ret = engine->getTransferStatus(batch_id, 0, status); LOG_ASSERT(!ret); if (status.s == TransferStatusEnum::COMPLETED) completed = true; LOG_ASSERT(status.s != TransferStatusEnum::FAILED); } - ret = xport->freeBatchID(batch_id); + ret = engine->freeBatchID(batch_id); LOG_ASSERT(!ret); } LOG_ASSERT(0 == memcmp((uint8_t *)(addr), (uint8_t *)(addr) + kDataLength, @@ -211,16 +199,12 @@ TEST_F(TCPTransportTest, WriteAndRead2test) { const size_t kDataLength = 4096000; void *addr = nullptr; const size_t ram_buffer_size = 1ull << 30; - auto metadata_client = std::make_shared("127.0.0.1:2379"); - LOG_ASSERT(metadata_client); - - auto engine = std::make_unique(metadata_client); - + auto engine = std::make_unique(); auto hostname_port = parseHostNameWithPort("127.0.0.2:12345"); - engine->init("127.0.0.2:12345", hostname_port.first.c_str(), - hostname_port.second); + engine->init("127.0.0.1:2379", "127.0.0.2:12345", + hostname_port.first.c_str(), hostname_port.second); Transport *xport = nullptr; - xport = engine->installOrGetTransport("tcp", nullptr); + xport = engine->installTransport("tcp", nullptr); LOG_ASSERT(xport != nullptr); addr = allocateMemoryPool(ram_buffer_size, 0, false); @@ -230,11 +214,11 @@ TEST_F(TCPTransportTest, WriteAndRead2test) { *((char *)(addr) + offset) = 'a' + lrand48() % 26; auto segment_id = engine->openSegment("127.0.0.2:12345"); - auto segment_desc = xport->meta()->getSegmentDescByID(segment_id); + auto segment_desc = engine->getMetadata()->getSegmentDescByID(segment_id); uint64_t remote_base = (uint64_t)segment_desc->buffers[0].addr; { - auto batch_id = xport->allocateBatchID(1); + auto batch_id = engine->allocateBatchID(1); int ret = 0; TransferRequest entry; entry.opcode = TransferRequest::WRITE; @@ -242,22 +226,22 @@ TEST_F(TCPTransportTest, WriteAndRead2test) { entry.source = (uint8_t *)(addr); entry.target_id = segment_id; entry.target_offset = remote_base; - ret = xport->submitTransfer(batch_id, {entry}); + ret = engine->submitTransfer(batch_id, {entry}); LOG_ASSERT(!ret); bool completed = false; TransferStatus status; while (!completed) { - int ret = xport->getTransferStatus(batch_id, 0, status); + int ret = engine->getTransferStatus(batch_id, 0, status); ASSERT_EQ(ret, 0); LOG_ASSERT(status.s != TransferStatusEnum::FAILED); if (status.s == TransferStatusEnum::COMPLETED) completed = true; } - ret = xport->freeBatchID(batch_id); + ret = engine->freeBatchID(batch_id); ASSERT_EQ(ret, 0); } { - auto batch_id = xport->allocateBatchID(1); + auto batch_id = engine->allocateBatchID(1); int ret = 0; TransferRequest entry; entry.opcode = TransferRequest::READ; @@ -265,17 +249,17 @@ TEST_F(TCPTransportTest, WriteAndRead2test) { entry.source = (uint8_t *)(addr) + kDataLength; entry.target_id = segment_id; entry.target_offset = remote_base; - ret = xport->submitTransfer(batch_id, {entry}); + ret = engine->submitTransfer(batch_id, {entry}); LOG_ASSERT(!ret); bool completed = false; TransferStatus status; while (!completed) { - int ret = xport->getTransferStatus(batch_id, 0, status); + int ret = engine->getTransferStatus(batch_id, 0, status); LOG_ASSERT(!ret); if (status.s == TransferStatusEnum::COMPLETED) completed = true; LOG_ASSERT(status.s != TransferStatusEnum::FAILED); } - ret = xport->freeBatchID(batch_id); + ret = engine->freeBatchID(batch_id); LOG_ASSERT(!ret); } LOG_ASSERT(0 == memcmp((uint8_t *)(addr), (uint8_t *)(addr) + kDataLength, @@ -284,7 +268,7 @@ TEST_F(TCPTransportTest, WriteAndRead2test) { for (size_t offset = 0; offset < kDataLength; ++offset) *((char *)(addr) + offset) = 'a' + lrand48() % 26; { - auto batch_id = xport->allocateBatchID(1); + auto batch_id = engine->allocateBatchID(1); int ret = 0; TransferRequest entry; entry.opcode = TransferRequest::WRITE; @@ -292,22 +276,22 @@ TEST_F(TCPTransportTest, WriteAndRead2test) { entry.source = (uint8_t *)(addr); entry.target_id = segment_id; entry.target_offset = remote_base; - ret = xport->submitTransfer(batch_id, {entry}); + ret = engine->submitTransfer(batch_id, {entry}); LOG_ASSERT(!ret); bool completed = false; TransferStatus status; while (!completed) { - int ret = xport->getTransferStatus(batch_id, 0, status); + int ret = engine->getTransferStatus(batch_id, 0, status); ASSERT_EQ(ret, 0); LOG_ASSERT(status.s != TransferStatusEnum::FAILED); if (status.s == TransferStatusEnum::COMPLETED) completed = true; } - ret = xport->freeBatchID(batch_id); + ret = engine->freeBatchID(batch_id); ASSERT_EQ(ret, 0); } { - auto batch_id = xport->allocateBatchID(1); + auto batch_id = engine->allocateBatchID(1); int ret = 0; TransferRequest entry; entry.opcode = TransferRequest::READ; @@ -315,17 +299,17 @@ TEST_F(TCPTransportTest, WriteAndRead2test) { entry.source = (uint8_t *)(addr) + kDataLength; entry.target_id = segment_id; entry.target_offset = remote_base; - ret = xport->submitTransfer(batch_id, {entry}); + ret = engine->submitTransfer(batch_id, {entry}); LOG_ASSERT(!ret); bool completed = false; TransferStatus status; while (!completed) { - int ret = xport->getTransferStatus(batch_id, 0, status); + int ret = engine->getTransferStatus(batch_id, 0, status); LOG_ASSERT(!ret); if (status.s == TransferStatusEnum::COMPLETED) completed = true; LOG_ASSERT(status.s != TransferStatusEnum::FAILED); } - ret = xport->freeBatchID(batch_id); + ret = engine->freeBatchID(batch_id); LOG_ASSERT(!ret); } LOG_ASSERT(0 == memcmp((uint8_t *)(addr), (uint8_t *)(addr) + kDataLength, diff --git a/mooncake-transfer-engine/tests/topology_test.cpp b/mooncake-transfer-engine/tests/topology_test.cpp index ec5c342..4e3aaf3 100644 --- a/mooncake-transfer-engine/tests/topology_test.cpp +++ b/mooncake-transfer-engine/tests/topology_test.cpp @@ -7,11 +7,13 @@ TEST(ToplogyTest, GetTopologyMatrix) { - std::string topo = mooncake::discoverTopologyMatrix(); - LOG(INFO) << topo; - mooncake::TransferMetadata::PriorityMatrix matrix; - std::vector rnic_list; - mooncake::TransferMetadata::parseNicPriorityMatrix(topo, matrix, rnic_list); + mooncake::Topology topology; + topology.discover(); + std::string json_str = topology.toString(); + LOG(INFO) << json_str; + topology.clear(); + topology.parse(json_str); + ASSERT_EQ(topology.toString(), json_str); } int main(int argc, char **argv)