From 7495480921c0cfde441f2b589bd3793508f1ae03 Mon Sep 17 00:00:00 2001 From: Seth Zegelstein Date: Fri, 10 Oct 2025 21:50:48 +0000 Subject: [PATCH 1/4] transport: Deprecate enforce_cst CUDA 11.3 released cuFlushGPUDirectRDMAWrites API which takes the place of the host transport enforce_cst api. NVSHMEM no longer supports CUDA 11, so these legacy API's can be removed. Signed-off-by: Seth Zegelstein --- src/host/proxy/proxy.cpp | 50 +++--------- .../internal/host_transport/transport.h | 1 - src/modules/transport/ibdevx/ibdevx.cpp | 41 ---------- src/modules/transport/ibgda/ibgda.cpp | 1 - src/modules/transport/ibrc/ibrc.cpp | 1 - src/modules/transport/libfabric/libfabric.cpp | 78 ------------------- src/modules/transport/libfabric/libfabric.h | 5 -- src/modules/transport/ucx/ucx.cpp | 62 --------------- 8 files changed, 11 insertions(+), 228 deletions(-) diff --git a/src/host/proxy/proxy.cpp b/src/host/proxy/proxy.cpp index 74270bf..e4925bf 100644 --- a/src/host/proxy/proxy.cpp +++ b/src/host/proxy/proxy.cpp @@ -686,51 +686,23 @@ int process_channel_amo(proxy_state_t *state, proxy_channel_t *ch, int *is_proce } void enforce_cst(proxy_state_t *proxy_state) { -#if defined(NVSHMEM_X86_64) - nvshmemi_state_t *state = proxy_state->nvshmemi_state; -#endif - int status = 0; if (nvshmemi_options.BYPASS_FLUSH) return; - if (proxy_state->is_consistency_api_supported) { - if (CU_FLUSH_GPU_DIRECT_RDMA_WRITES_TO_OWNER > proxy_state->gdr_device_native_ordering && - CUPFN(nvshmemi_cuda_syms, cuFlushGPUDirectRDMAWrites)) { - status = - CUPFN(nvshmemi_cuda_syms, - cuFlushGPUDirectRDMAWrites(CU_FLUSH_GPU_DIRECT_RDMA_WRITES_TARGET_CURRENT_CTX, - CU_FLUSH_GPU_DIRECT_RDMA_WRITES_TO_OWNER)); - /** We would want to use cudaFlushGPUDirectRDMAWritesToAllDevices when we enable - consistent access of data on any GPU (and not just self GPU) with - wait_until, quiet, barrier, etc. **/ - if (status != CUDA_SUCCESS) { - NVSHMEMI_ERROR_EXIT("cuFlushGPUDirectRDMAWrites() failed in the proxy thread \n"); - } - } - return; - } -#if defined(NVSHMEM_PPC64LE) - status = cudaEventRecord(proxy_state->cuev, proxy_state->stream); - if (unlikely(status != CUDA_SUCCESS)) { - NVSHMEMI_ERROR_EXIT("cuEventRecord() failed in the proxy thread \n"); - } -#elif defined(NVSHMEM_X86_64) - for (int i = 0; i < state->num_initialized_transports; i++) { - if (!((state->transport_bitmap) & (1 << i))) continue; - struct nvshmem_transport *tcurr = state->transports[i]; - if (!tcurr->host_ops.enforce_cst) continue; - - // assuming the transport is connected - IB RC - if (tcurr->attr & NVSHMEM_TRANSPORT_ATTR_CONNECTED) { - status = tcurr->host_ops.enforce_cst(tcurr); - if (status) { - NVSHMEMI_ERROR_PRINT("aborting due to error in progress_cst \n"); - exit(-1); - } + if (CU_FLUSH_GPU_DIRECT_RDMA_WRITES_TO_OWNER > proxy_state->gdr_device_native_ordering && + CUPFN(nvshmemi_cuda_syms, cuFlushGPUDirectRDMAWrites)) { + status = + CUPFN(nvshmemi_cuda_syms, + cuFlushGPUDirectRDMAWrites(CU_FLUSH_GPU_DIRECT_RDMA_WRITES_TARGET_CURRENT_CTX, + CU_FLUSH_GPU_DIRECT_RDMA_WRITES_TO_OWNER)); + /** We would want to use cudaFlushGPUDirectRDMAWritesToAllDevices when we enable + consistent access of data on any GPU (and not just self GPU) with + wait_until, quiet, barrier, etc. **/ + if (status != CUDA_SUCCESS) { + NVSHMEMI_ERROR_EXIT("cuFlushGPUDirectRDMAWrites() failed in the proxy thread \n"); } } -#endif } inline void quiet_ack_channels(proxy_state_t *proxy_state) { diff --git a/src/include/internal/host_transport/transport.h b/src/include/internal/host_transport/transport.h index f3fc7c1..f36b995 100644 --- a/src/include/internal/host_transport/transport.h +++ b/src/include/internal/host_transport/transport.h @@ -148,7 +148,6 @@ struct nvshmem_transport_host_ops { fence_handle fence; quiet_handle quiet; put_signal_handle put_signal; - int (*enforce_cst)(struct nvshmem_transport *transport); int (*enforce_cst_at_target)(struct nvshmem_transport *transport); int (*add_device_remote_mem_handles)(struct nvshmem_transport *transport, int transport_stride, nvshmem_mem_handle_t *mem_handles, uint64_t heap_offset, diff --git a/src/modules/transport/ibdevx/ibdevx.cpp b/src/modules/transport/ibdevx/ibdevx.cpp index 4a76fde..02e217d 100644 --- a/src/modules/transport/ibdevx/ibdevx.cpp +++ b/src/modules/transport/ibdevx/ibdevx.cpp @@ -1440,46 +1440,6 @@ int nvshmemt_ibdevx_amo(struct nvshmem_transport *tcurr, int pe, void *curetptr, return status; } -int nvshmemt_ibdevx_enforce_cst_at_target(struct nvshmem_transport *tcurr) { - nvshmemt_ib_common_state_t ibdevx_state = (nvshmemt_ib_common_state_t)tcurr->state; - struct ibdevx_ep *ep = (struct ibdevx_ep *)ibdevx_state->cst_ep; - struct ibdevx_rw_wqe *wqe; - - int status = 0; - - uintptr_t wqe_bb_idx_64 = ep->wqe_bb_idx; - uint32_t wqe_bb_idx_32 = ep->wqe_bb_idx; - size_t wqe_size; - - wqe = (struct ibdevx_rw_wqe *)((char *)ep->wq_buf + - ((wqe_bb_idx_64 % get_ibdevx_qp_depth(ibdevx_state)) - << NVSHMEMT_IBDEVX_WQE_BB_SHIFT)); - wqe_size = sizeof(struct ibdevx_rw_wqe); - memset(wqe, 0, sizeof(struct ibdevx_rw_wqe)); - - wqe->ctrl.fm_ce_se = MLX5_WQE_CTRL_CQ_UPDATE; - wqe->ctrl.qpn_ds = - htobe32((uint32_t)(wqe_size / NVSHMEMT_IBDEVX_MLX5_SEND_WQE_DS) | ep->qpid << 8); - wqe->ctrl.opmod_idx_opcode = htobe32(MLX5_OPCODE_RDMA_READ | (wqe_bb_idx_32 << 8)); - - wqe->raddr.raddr = htobe64((uintptr_t)local_dummy_mr.mr->addr); - wqe->raddr.rkey = htobe32(local_dummy_mr.rkey); - - wqe->data.data_seg.byte_count = htobe32((uint32_t)4); - wqe->data.data_seg.lkey = htobe32(local_dummy_mr.lkey); - wqe->data.data_seg.addr = htobe64((uintptr_t)local_dummy_mr.mr->addr); - - assert(wqe_size <= MLX5_SEND_WQE_BB); - ep->wqe_bb_idx++; - nvshmemt_ibdevx_post_send(ep, (void *)wqe, 1); - - status = nvshmemt_ib_common_check_poll_avail(tcurr, ep, NVSHMEMT_IB_COMMON_WAIT_ALL); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "check_poll failed \n"); - -out: - return status; -} - // Using common fence and quiet functions from transport_ib_common int nvshmemt_ibdevx_ep_create(struct ibdevx_ep **ep, int devid, nvshmem_transport_t t, @@ -1932,7 +1892,6 @@ int nvshmemt_init(nvshmem_transport_t *t, struct nvshmemi_cuda_fn_table *table, transport->host_ops.finalize = nvshmemt_ibdevx_finalize; transport->host_ops.show_info = nvshmemt_ibdevx_show_info; transport->host_ops.progress = nvshmemt_ibdevx_progress; - transport->host_ops.enforce_cst = nvshmemt_ibdevx_enforce_cst_at_target; transport->host_ops.put_signal = nvshmemt_put_signal; transport->attr = NVSHMEM_TRANSPORT_ATTR_CONNECTED; diff --git a/src/modules/transport/ibgda/ibgda.cpp b/src/modules/transport/ibgda/ibgda.cpp index 4c69cce..15eb8d0 100644 --- a/src/modules/transport/ibgda/ibgda.cpp +++ b/src/modules/transport/ibgda/ibgda.cpp @@ -4915,7 +4915,6 @@ int nvshmemt_init(nvshmem_transport_t *t, struct nvshmemi_cuda_fn_table *table, transport->host_ops.amo = NULL; transport->host_ops.fence = NULL; transport->host_ops.quiet = NULL; - transport->host_ops.enforce_cst = NULL; transport->host_ops.add_device_remote_mem_handles = nvshmemt_ibgda_add_device_remote_mem_handles; transport->host_ops.put_signal = NULL; diff --git a/src/modules/transport/ibrc/ibrc.cpp b/src/modules/transport/ibrc/ibrc.cpp index bcea10f..9cb02a3 100644 --- a/src/modules/transport/ibrc/ibrc.cpp +++ b/src/modules/transport/ibrc/ibrc.cpp @@ -1802,7 +1802,6 @@ int nvshmemt_init(nvshmem_transport_t *t, struct nvshmemi_cuda_fn_table *table, transport->host_ops.progress = nvshmemt_ibrc_progress; transport->host_ops.put_signal = nvshmemt_put_signal; - transport->host_ops.enforce_cst = nvshmemt_ibrc_enforce_cst_at_target; #if !defined(NVSHMEM_PPC64LE) && !defined(NVSHMEM_AARCH64) if (!use_gdrcopy) #endif diff --git a/src/modules/transport/libfabric/libfabric.cpp b/src/modules/transport/libfabric/libfabric.cpp index 927c079..001a1ad 100644 --- a/src/modules/transport/libfabric/libfabric.cpp +++ b/src/modules/transport/libfabric/libfabric.cpp @@ -1063,70 +1063,6 @@ int nvshmemt_put_signal_unordered(struct nvshmem_transport *tcurr, int pe, rma_v return status; } -static int nvshmemt_libfabric_enforce_cst(struct nvshmem_transport *tcurr) { - nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)tcurr->state; - uint64_t num_retries = 0; - int status; - int target_ep; - int mype = tcurr->my_pe; - -#ifdef NVSHMEM_USE_GDRCOPY - if (use_gdrcopy) { - if (libfabric_state->provider != NVSHMEMT_LIBFABRIC_PROVIDER_SLINGSHOT) { - int temp; - nvshmemt_libfabric_memhandle_info_t *mem_handle_info; - - mem_handle_info = - (nvshmemt_libfabric_memhandle_info_t *)nvshmemt_mem_handle_cache_get_by_idx( - libfabric_state->cache, 0); - if (!mem_handle_info) { - goto skip; - } - gdrcopy_ftable.copy_from_mapping(mem_handle_info->mh, &temp, mem_handle_info->cpu_ptr, - sizeof(int)); - } - } - -skip: -#endif - - target_ep = mype * NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS + NVSHMEMT_LIBFABRIC_PROXY_EP_IDX; - do { - struct fi_msg_rma msg; - struct iovec l_iov; - struct fi_rma_iov r_iov; - void *desc = libfabric_state->local_mr_desc[NVSHMEMT_LIBFABRIC_PROXY_EP_IDX]; - uint64_t flags = 0; - - memset(&msg, 0, sizeof(struct fi_msg_rma)); - memset(&l_iov, 0, sizeof(struct iovec)); - memset(&r_iov, 0, sizeof(struct fi_rma_iov)); - - l_iov.iov_base = libfabric_state->local_mem_ptr; - l_iov.iov_len = 8; - - r_iov.addr = 0; // Zero offset - r_iov.len = 8; - r_iov.key = libfabric_state->local_mr_key[NVSHMEMT_LIBFABRIC_PROXY_EP_IDX]; - - msg.msg_iov = &l_iov; - msg.desc = &desc; - msg.iov_count = 1; - msg.rma_iov = &r_iov; - msg.rma_iov_count = 1; - msg.context = NULL; - msg.data = 0; - - if (libfabric_state->prov_info->caps & FI_FENCE) flags |= FI_FENCE; - - status = - fi_readmsg(libfabric_state->eps[NVSHMEMT_LIBFABRIC_PROXY_EP_IDX].endpoint, &msg, flags); - } while (try_again(tcurr, &status, &num_retries)); - - libfabric_state->eps[target_ep].submitted_ops++; - return status; -} - static int nvshmemt_libfabric_release_mem_handle(nvshmem_mem_handle_t *mem_handle, nvshmem_transport_t t) { nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)t->state; @@ -1168,9 +1104,6 @@ static int nvshmemt_libfabric_release_mem_handle(nvshmem_mem_handle_t *mem_handl max_reg = 1; for (int i = 0; i < max_reg; i++) { - if (libfabric_state->local_mr[i] == fabric_handle->hdls[i].mr) - libfabric_state->local_mr[i] = NULL; - int status = fi_close(&fabric_handle->hdls[i].mr->fid); if (status) { NVSHMEMI_WARN_PRINT("Error releasing mem handle idx %d (%d): %s\n", i, status, @@ -1350,15 +1283,6 @@ static int nvshmemt_libfabric_get_mem_handle(nvshmem_mem_handle_t *mem_handle, v } while (curr_ptr < (char *)buf + length); } - if (libfabric_state->local_mr[0] == NULL && !local_only) { - for (int i = 0; i < NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; i++) { - libfabric_state->local_mr[i] = fabric_handle->hdls[i].mr; - libfabric_state->local_mr_key[i] = fabric_handle->hdls[i].key; - libfabric_state->local_mr_desc[i] = fabric_handle->hdls[i].local_desc; - } - libfabric_state->local_mem_ptr = buf; - } - out: if (status) { if (handle_info) { @@ -2092,8 +2016,6 @@ int nvshmemt_init(nvshmem_transport_t *t, struct nvshmemi_cuda_fn_table *table, transport->host_ops.finalize = nvshmemt_libfabric_finalize; transport->host_ops.show_info = nvshmemt_libfabric_show_info; transport->host_ops.progress = nvshmemt_libfabric_progress; - transport->host_ops.enforce_cst = nvshmemt_libfabric_enforce_cst; - transport->attr = NVSHMEM_TRANSPORT_ATTR_CONNECTED; transport->is_successfully_initialized = true; diff --git a/src/modules/transport/libfabric/libfabric.h b/src/modules/transport/libfabric/libfabric.h index f2b5931..17cb7c7 100644 --- a/src/modules/transport/libfabric/libfabric.h +++ b/src/modules/transport/libfabric/libfabric.h @@ -278,11 +278,6 @@ typedef struct { struct fid_domain *domain; struct fid_av *addresses[NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS]; nvshmemt_libfabric_endpoint_t *eps; - /* local_mr is used only for consistency ops. */ - struct fid_mr *local_mr[2]; - uint64_t local_mr_key[2]; - void *local_mr_desc[2]; - void *local_mem_ptr; nvshmemt_libfabric_domain_name_t *domain_names; int num_domains; nvshmemt_libfabric_provider provider; diff --git a/src/modules/transport/ucx/ucx.cpp b/src/modules/transport/ucx/ucx.cpp index 271ed69..4959d0b 100644 --- a/src/modules/transport/ucx/ucx.cpp +++ b/src/modules/transport/ucx/ucx.cpp @@ -1180,67 +1180,6 @@ int nvshmemt_ucx_finalize(nvshmem_transport_t transport) { return 0; } -int nvshmemt_ucx_enforce_cst_at_target(struct nvshmem_transport *tcurr) { - transport_ucx_state_t *ucx_state = (transport_ucx_state_t *)tcurr->state; - nvshmemt_ucx_mem_handle_info_t *mem_handle_info; - - mem_handle_info = - (nvshmemt_ucx_mem_handle_info_t *)nvshmemt_mem_handle_cache_get_by_idx(ucx_state->cache, 0); - - if (!mem_handle_info) return 0; -#ifdef NVSHMEM_USE_GDRCOPY - if (use_gdrcopy) { - int temp; - gdrcopy_ftable.copy_from_mapping(mem_handle_info->mh, &temp, mem_handle_info->cpu_ptr, - sizeof(int)); - return 0; - } -#endif - int mype = tcurr->my_pe; - int ep_index = (ucx_state->ep_count * mype + ucx_state->proxy_ep_idx); - ucp_ep_h ep = ucx_state->endpoints[ep_index]; - ucp_request_param_t param; - ucs_status_ptr_t ucs_ptr_rc = NULL; - ucs_status_t ucs_rc; - nvshmemt_ucx_mem_handle_t *mem_handle; - ucp_rkey_h rkey; - int local_int; - - mem_handle = mem_handle_info->mem_handle; - if (unlikely(mem_handle->ep_rkey_host == NULL)) { - ucs_rc = ucp_ep_rkey_unpack(ep, mem_handle->rkey_packed_buf, &mem_handle->ep_rkey_host); - if (ucs_rc != UCS_OK) { - NVSHMEMI_ERROR_EXIT("Unable to unpack rkey in UCS transport! Exiting.\n"); - } - } - rkey = mem_handle->ep_rkey_host; - - param.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK; - param.cb.send = nvshmemt_ucx_send_request_cb; - - ucs_ptr_rc = - ucp_get_nbx(ep, &local_int, sizeof(int), (uint64_t)mem_handle_info->ptr, rkey, ¶m); - - /* Wait for completion of get. */ - if (ucs_ptr_rc != NULL) { - if (UCS_PTR_IS_ERR(ucs_ptr_rc)) { - NVSHMEMI_ERROR_PRINT("UCX CST request completed with error.\n"); - return NVSHMEMX_ERROR_INTERNAL; - } else { - do { - ucs_rc = ucp_request_check_status(ucs_ptr_rc); - ucp_worker_progress(ucx_state->worker_context); - } while (ucs_rc == UCS_INPROGRESS); - if (ucs_rc != UCS_OK) { - NVSHMEMI_ERROR_PRINT("UCX CST request completed with error.\n"); - return NVSHMEMX_ERROR_INTERNAL; - } - } - } - - return 0; -} - int nvshmemt_ucx_show_info(struct nvshmem_transport *transport, int style) { NVSHMEMI_ERROR_PRINT("UCX show info not implemented"); return 0; @@ -1446,7 +1385,6 @@ int nvshmemt_init(nvshmem_transport_t *t, struct nvshmemi_cuda_fn_table *table, transport->host_ops.finalize = nvshmemt_ucx_finalize; transport->host_ops.show_info = nvshmemt_ucx_show_info; transport->host_ops.progress = nvshmemt_ucx_progress; - transport->host_ops.enforce_cst = nvshmemt_ucx_enforce_cst_at_target; transport->host_ops.enforce_cst_at_target = NULL; transport->host_ops.put_signal = nvshmemt_put_signal; transport->attr = NVSHMEM_TRANSPORT_ATTR_CONNECTED; From 9f4c4586b6a81098cf9fb50136ed54b8c6bf101e Mon Sep 17 00:00:00 2001 From: Seth Zegelstein Date: Mon, 20 Oct 2025 18:40:31 +0000 Subject: [PATCH 2/4] transport/libfabric: Rename is_proxy to qp_index The previous is_proxy variable equals qp_index. Change the name everywhere for consistency. Signed-off-by: Seth Zegelstein --- src/modules/transport/libfabric/libfabric.cpp | 58 +++++-------------- 1 file changed, 15 insertions(+), 43 deletions(-) diff --git a/src/modules/transport/libfabric/libfabric.cpp b/src/modules/transport/libfabric/libfabric.cpp index 001a1ad..1e57201 100644 --- a/src/modules/transport/libfabric/libfabric.cpp +++ b/src/modules/transport/libfabric/libfabric.cpp @@ -546,10 +546,9 @@ int nvshmemt_libfabric_put_signal_completion(nvshmem_transport_t transport, static int nvshmemt_libfabric_quiet(struct nvshmem_transport *tcurr, int pe, int qp_index) { nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)tcurr->state; nvshmemt_libfabric_endpoint_t *ep; - int is_proxy = qp_index != NVSHMEMX_QP_HOST; int status = 0; - if (is_proxy) { + if (qp_index) { ep = &libfabric_state->eps[NVSHMEMT_LIBFABRIC_PROXY_EP_IDX]; } else { ep = &libfabric_state->eps[NVSHMEMT_LIBFABRIC_HOST_EP_IDX]; @@ -604,7 +603,7 @@ static int nvshmemt_libfabric_show_info(struct nvshmem_transport *transport, int static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, rma_verb_t verb, rma_memdesc_t *remote, rma_memdesc_t *local, - rma_bytesdesc_t bytesdesc, int is_proxy, + rma_bytesdesc_t bytesdesc, int qp_index, uint32_t *imm_data) { nvshmemt_libfabric_mem_handle_ep_t *remote_handle, *local_handle; nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)tcurr->state; @@ -623,12 +622,7 @@ static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, memset(&p_op_msg, 0, sizeof(struct fi_msg_rma)); memset(&p_op_r_iov, 0, sizeof(struct fi_rma_iov)); - if (is_proxy) { - ep_idx = NVSHMEMT_LIBFABRIC_PROXY_EP_IDX; - } else { - ep_idx = NVSHMEMT_LIBFABRIC_HOST_EP_IDX; - } - + ep_idx = qp_index ? NVSHMEMT_LIBFABRIC_PROXY_EP_IDX : NVSHMEMT_LIBFABRIC_HOST_EP_IDX; ep = &libfabric_state->eps[ep_idx]; target_ep = pe * NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS + ep_idx; @@ -731,13 +725,13 @@ static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, static int nvshmemt_libfabric_rma(struct nvshmem_transport *tcurr, int pe, rma_verb_t verb, rma_memdesc_t *remote, rma_memdesc_t *local, - rma_bytesdesc_t bytesdesc, int is_proxy) { - return nvshmemt_libfabric_rma_impl(tcurr, pe, verb, remote, local, bytesdesc, is_proxy, NULL); + rma_bytesdesc_t bytesdesc, int qp_index) { + return nvshmemt_libfabric_rma_impl(tcurr, pe, verb, remote, local, bytesdesc, qp_index, NULL); } static int nvshmemt_libfabric_gdr_amo(struct nvshmem_transport *transport, int pe, void *curetptr, amo_verb_t verb, amo_memdesc_t *remote, - amo_bytesdesc_t bytesdesc, int is_proxy) { + amo_bytesdesc_t bytesdesc, int qp_index) { nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)transport->state; nvshmemt_libfabric_endpoint_t *ep; nvshmemt_libfabric_gdr_op_ctx_t *amo; @@ -745,12 +739,7 @@ static int nvshmemt_libfabric_gdr_amo(struct nvshmem_transport *transport, int p int target_ep, ep_idx; int status = 0; - if (is_proxy) { - ep_idx = NVSHMEMT_LIBFABRIC_PROXY_EP_IDX; - } else { - ep_idx = NVSHMEMT_LIBFABRIC_HOST_EP_IDX; - } - + ep_idx = qp_index ? NVSHMEMT_LIBFABRIC_PROXY_EP_IDX : NVSHMEMT_LIBFABRIC_HOST_EP_IDX; ep = &libfabric_state->eps[ep_idx]; target_ep = pe * NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS + ep_idx; @@ -793,7 +782,7 @@ static int nvshmemt_libfabric_gdr_amo(struct nvshmem_transport *transport, int p static int nvshmemt_libfabric_amo(struct nvshmem_transport *transport, int pe, void *curetptr, amo_verb_t verb, amo_memdesc_t *remote, amo_bytesdesc_t bytesdesc, - int is_proxy) { + int qp_index) { nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)transport->state; nvshmemt_libfabric_mem_handle_ep_t *remote_handle = NULL, *local_handle = NULL; nvshmemt_libfabric_endpoint_t *ep; @@ -815,12 +804,7 @@ static int nvshmemt_libfabric_amo(struct nvshmem_transport *transport, int pe, v memset(&fi_ret_iov, 0, sizeof(struct fi_ioc)); memset(&fi_remote_iov, 0, sizeof(struct fi_rma_ioc)); - if (is_proxy) { - ep_idx = NVSHMEMT_LIBFABRIC_PROXY_EP_IDX; - } else { - ep_idx = NVSHMEMT_LIBFABRIC_HOST_EP_IDX; - } - + ep_idx = qp_index ? NVSHMEMT_LIBFABRIC_PROXY_EP_IDX : NVSHMEMT_LIBFABRIC_HOST_EP_IDX; ep = &libfabric_state->eps[ep_idx]; target_ep = pe * NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS + ep_idx; @@ -945,7 +929,7 @@ static int nvshmemt_libfabric_amo(struct nvshmem_transport *transport, int pe, v static int nvshmemt_libfabric_gdr_signal(struct nvshmem_transport *transport, int pe, void *curetptr, amo_verb_t verb, amo_memdesc_t *remote, - amo_bytesdesc_t bytesdesc, int is_proxy, + amo_bytesdesc_t bytesdesc, int qp_index, uint32_t sequence_count, uint16_t num_writes) { nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)transport->state; nvshmemt_libfabric_endpoint_t *ep; @@ -955,12 +939,7 @@ static int nvshmemt_libfabric_gdr_signal(struct nvshmem_transport *transport, in int target_ep, ep_idx; int status = 0; - if (is_proxy) { - ep_idx = NVSHMEMT_LIBFABRIC_PROXY_EP_IDX; - } else { - ep_idx = NVSHMEMT_LIBFABRIC_HOST_EP_IDX; - } - + ep_idx = qp_index ? NVSHMEMT_LIBFABRIC_PROXY_EP_IDX : NVSHMEMT_LIBFABRIC_HOST_EP_IDX; ep = &libfabric_state->eps[ep_idx]; target_ep = pe * NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS + ep_idx; @@ -1006,18 +985,11 @@ int nvshmemt_put_signal_unordered(struct nvshmem_transport *tcurr, int pe, rma_v std::vector &write_local, std::vector &write_bytes_desc, amo_verb_t sig_verb, amo_memdesc_t *sig_target, - amo_bytesdesc_t sig_bytes_desc, int is_proxy) { + amo_bytesdesc_t sig_bytes_desc, int qp_index) { nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)tcurr->state; int status; uint32_t sequence_count = 0; - int ep_idx; - - if (is_proxy) { - ep_idx = NVSHMEMT_LIBFABRIC_PROXY_EP_IDX; - } else { - ep_idx = NVSHMEMT_LIBFABRIC_HOST_EP_IDX; - } - + int ep_idx = qp_index ? NVSHMEMT_LIBFABRIC_PROXY_EP_IDX : NVSHMEMT_LIBFABRIC_HOST_EP_IDX; nvshmemt_libfabric_endpoint_t &ep = libfabric_state->eps[ep_idx]; /* Get sequence number for this put-signal, with retry */ @@ -1042,7 +1014,7 @@ int nvshmemt_put_signal_unordered(struct nvshmem_transport *tcurr, int pe, rma_v for (size_t i = 0; i < write_remote.size(); i++) { status = nvshmemt_libfabric_rma_impl(tcurr, pe, write_verb, &write_remote[i], &write_local[i], - write_bytes_desc[i], is_proxy, &sequence_count); + write_bytes_desc[i], qp_index, &sequence_count); if (unlikely(status)) { NVSHMEMI_ERROR_PRINT( "Error in nvshmemt_put_signal_unordered, could not submit write #%lu\n", i); @@ -1052,7 +1024,7 @@ int nvshmemt_put_signal_unordered(struct nvshmem_transport *tcurr, int pe, rma_v assert(use_staged_atomics == true); status = nvshmemt_libfabric_gdr_signal(tcurr, pe, NULL, sig_verb, sig_target, sig_bytes_desc, - is_proxy, sequence_count, (uint16_t)write_remote.size()); + qp_index, sequence_count, (uint16_t)write_remote.size()); out: if (status) { NVSHMEMI_ERROR_PRINT( From af23d0bb1d1169b52e98f8bc6ee84d7605d4e456 Mon Sep 17 00:00:00 2001 From: Seth Zegelstein Date: Thu, 9 Oct 2025 05:29:48 +0000 Subject: [PATCH 3/4] transport/libfabric: Optimize Progress Attempt to request FI_PROGRESS_AUTO to see if the libfabric provider supports it, if it doesn't fall back to FI_PROGRESS_MANUAL. FI_PROGRESS_AUTO means that we do not need to call into the progress engine for submitted operations to complete. This means that we can remove the host endpoint from the progress call, and we only need to progress the host endpoint when user calls nvshmem_quiet() from the host. This allows us to set the threading model as FI_THREAD_COMPELTION because the host only progress the host EP, and the proxy only progresses the proxy EP, leading to compliance with FI_THREAD_COMPLETION. An edge case exists here where the user calls nvshmem_quiet() on the host QP_IDX from a GPU kernel, but this is illegial because the user shouldn't be calling QP API's on QP's not provided to them via the qp creation API's. This patch should offer a performance improvement because it reduces the number of EP's that are progressed in the critical path, and it allows the libfabric provider to reduce locking b/c of threading model FI_THREAD_COMPLETION. Signed-off-by: Seth Zegelstein --- src/modules/transport/libfabric/libfabric.cpp | 281 +++++++++++------- src/modules/transport/libfabric/libfabric.h | 1 + 2 files changed, 179 insertions(+), 103 deletions(-) diff --git a/src/modules/transport/libfabric/libfabric.cpp b/src/modules/transport/libfabric/libfabric.cpp index 1e57201..4b2952e 100644 --- a/src/modules/transport/libfabric/libfabric.cpp +++ b/src/modules/transport/libfabric/libfabric.cpp @@ -66,6 +66,8 @@ static bool use_gdrcopy = false; sizeof(nvshmemt_libfabric_gdr_op_ctx_t) - sizeof(struct fi_context2) - sizeof(fi_addr_t) static bool use_staged_atomics = false; +static bool use_auto_progress = false; + threadSafeOpQueue nvshmemtLibfabricOpQueue; std::mutex gdrRecvMutex; @@ -139,82 +141,114 @@ static int nvshmemt_libfabric_gdr_process_completion(nvshmem_transport_t transpo return status; } -static int nvshmemt_libfabric_progress(nvshmem_transport_t transport) { - nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)transport->state; +static int nvshmemt_libfabric_single_ep_progress(nvshmem_transport_t transport, + nvshmemt_libfabric_endpoint_t *ep) { + nvshmemt_libfabric_state_t *state = (nvshmemt_libfabric_state_t *)transport->state; + char buf[MAX_COMPLETIONS_PER_CQ_POLL * sizeof(struct fi_cq_data_entry)]; + fi_addr_t src_addr[MAX_COMPLETIONS_PER_CQ_POLL]; + fi_addr_t *addr; + ssize_t qstatus; + struct fi_cq_data_entry *entry; + uint64_t cnt; int status; - for (int i = 0; i < NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; i++) { - uint64_t cnt = fi_cntr_readerr(libfabric_state->eps[i].counter); - - if (cnt > 0) { - NVSHMEMI_WARN_PRINT("Nonzero error count progressing EP %d (%" PRIu64 ")\n", i, cnt); - - struct fi_cq_err_entry err; - memset(&err, 0, sizeof(struct fi_cq_err_entry)); - ssize_t nerr = fi_cq_readerr(libfabric_state->eps[i].cq, &err, 0); + cnt = fi_cntr_readerr(ep->counter); + if (cnt > 0) { + NVSHMEMI_WARN_PRINT("Nonzero error count progressing EP (%" PRIu64 ")\n", cnt); + struct fi_cq_err_entry err; + memset(&err, 0, sizeof(struct fi_cq_err_entry)); + ssize_t nerr = fi_cq_readerr(ep->cq, &err, 0); + + if (nerr > 0) { + char str[100] = "\0"; + const char *err_str = fi_cq_strerror(ep->cq, err.prov_errno, err.err_data, str, 100); + NVSHMEMI_WARN_PRINT( + "CQ reported error (%d): %s\n\tProvider error: %s\n\tSupplemental error " + "info: %s\n", + err.err, fi_strerror(err.err), err_str ? err_str : "none", + strlen(str) ? str : "none"); + } else if (nerr == -FI_EAGAIN) { + NVSHMEMI_WARN_PRINT("fi_cq_readerr returned -FI_EAGAIN\n"); + } else { + NVSHMEMI_WARN_PRINT("fi_cq_readerr returned %zd: %s\n", nerr, fi_strerror(-1 * nerr)); + } + return NVSHMEMX_ERROR_INTERNAL; + } - if (nerr > 0) { - char str[100] = "\0"; - const char *err_str = fi_cq_strerror(libfabric_state->eps[i].cq, err.prov_errno, - err.err_data, str, 100); - NVSHMEMI_WARN_PRINT( - "CQ reported error (%d): %s\n\tProvider error: %s\n\tSupplemental error " - "info: %s\n", - err.err, fi_strerror(err.err), err_str ? err_str : "none", - strlen(str) ? str : "none"); - } else if (nerr == -FI_EAGAIN) { - NVSHMEMI_WARN_PRINT("fi_cq_readerr returned -FI_EAGAIN\n"); + do { + qstatus = fi_cq_readfrom(ep->cq, buf, MAX_COMPLETIONS_PER_CQ_POLL, src_addr); + /* Note - EFA provider does not support selective completions */ + if (qstatus > 0) { + if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { + entry = (struct fi_cq_data_entry *)buf; + addr = src_addr; + for (int i = 0; i < qstatus; i++, entry++, addr++) { + status = nvshmemt_libfabric_gdr_process_completion(transport, ep, entry, addr); + if (status) return NVSHMEMX_ERROR_INTERNAL; + } } else { - NVSHMEMI_WARN_PRINT("fi_cq_readerr returned %zd: %s\n", nerr, - fi_strerror(-1 * nerr)); + NVSHMEMI_WARN_PRINT("Got %zd unexpected events on EP\n", qstatus); } - return NVSHMEMX_ERROR_INTERNAL; } + } while (qstatus > 0); + if (qstatus < 0 && qstatus != -FI_EAGAIN) { + NVSHMEMI_WARN_PRINT("Error progressing CQ (%zd): %s\n", qstatus, fi_strerror(qstatus * -1)); + return NVSHMEMX_ERROR_INTERNAL; + } - { - char buf[MAX_COMPLETIONS_PER_CQ_POLL * sizeof(struct fi_cq_data_entry)]; - fi_addr_t src_addr[MAX_COMPLETIONS_PER_CQ_POLL]; - ssize_t qstatus; - nvshmemt_libfabric_endpoint_t *ep = &libfabric_state->eps[i]; - do { - qstatus = fi_cq_readfrom(ep->cq, buf, MAX_COMPLETIONS_PER_CQ_POLL, src_addr); - /* Note - EFA provider does not support selective completions */ - if (qstatus > 0) { - if (libfabric_state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { - struct fi_cq_data_entry *entry = (struct fi_cq_data_entry *)buf; - fi_addr_t *addr = src_addr; - for (int i = 0; i < qstatus; i++, entry++, addr++) { - status = nvshmemt_libfabric_gdr_process_completion(transport, ep, entry, - addr); - if (status) return NVSHMEMX_ERROR_INTERNAL; - } - } else { - NVSHMEMI_WARN_PRINT("Got %zd unexpected events on EP %d\n", qstatus, i); - } - } - } while (qstatus > 0); - if (qstatus < 0 && qstatus != -FI_EAGAIN) { - NVSHMEMI_WARN_PRINT("Error progressing CQ (%zd): %s\n", qstatus, - fi_strerror(qstatus * -1)); - return NVSHMEMX_ERROR_INTERNAL; - } + return 0; +} + +static int nvshmemt_libfabric_auto_progress(nvshmem_transport_t transport, int qp_index) { + int status; + nvshmemt_libfabric_state_t *state = (nvshmemt_libfabric_state_t *)transport->state; + nvshmemt_libfabric_endpoint_t *ep; + + if (qp_index) + ep = &state->eps[NVSHMEMT_LIBFABRIC_PROXY_EP_IDX]; + else + ep = &state->eps[NVSHMEMT_LIBFABRIC_HOST_EP_IDX]; + + status = nvshmemt_libfabric_single_ep_progress(transport, ep); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Error in progress: %d.\n", status); + + if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { + if (gdrRecvMutex.try_lock()) { + status = nvshmemt_libfabric_gdr_process_amos(transport); + gdrRecvMutex.unlock(); } } - if (libfabric_state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { +out: + return status; +} + +static int nvshmemt_libfabric_auto_proxy_progress(nvshmem_transport_t transport) { + return nvshmemt_libfabric_auto_progress(transport, NVSHMEMT_LIBFABRIC_PROXY_EP_IDX); +} + +static int nvshmemt_libfabric_manual_progress(nvshmem_transport_t transport) { + nvshmemt_libfabric_state_t *state = (nvshmemt_libfabric_state_t *)transport->state; + int status; + for (int i = 0; i < NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; i++) { + status = nvshmemt_libfabric_single_ep_progress(transport, &state->eps[i]); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Error in progress: %d.\n", + status); + } + + if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { if (gdrRecvMutex.try_lock()) { status = nvshmemt_libfabric_gdr_process_amos(transport); gdrRecvMutex.unlock(); - if (status) { - return NVSHMEMX_ERROR_INTERNAL; - } } } - return 0; +out: + return status; } -static inline int try_again(nvshmem_transport_t transport, int *status, uint64_t *num_retries) { +static inline int try_again(nvshmem_transport_t transport, int *status, uint64_t *num_retries, + int qp_index) { if (likely(*status == 0)) { return 0; } @@ -227,7 +261,10 @@ static inline int try_again(nvshmem_transport_t transport, int *status, uint64_t return 0; } (*num_retries)++; - *status = nvshmemt_libfabric_progress(transport); + if (use_auto_progress) + *status = nvshmemt_libfabric_auto_progress(transport, qp_index); + else + *status = nvshmemt_libfabric_manual_progress(transport); } if (*status != 0) { @@ -251,7 +288,7 @@ int gdrcopy_amo_ack(nvshmem_transport_t transport, nvshmemt_libfabric_endpoint_t do { resp_op = (nvshmemt_libfabric_gdr_op_ctx_t *)nvshmemtLibfabricOpQueue.getNextSend(); status = resp_op == NULL ? -EAGAIN : 0; - } while (try_again(transport, &status, &num_retries)); + } while (try_again(transport, &status, &num_retries, ep->qp_index)); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable allocate buffer for atomic ack.\n"); @@ -264,7 +301,7 @@ int gdrcopy_amo_ack(nvshmem_transport_t transport, nvshmemt_libfabric_endpoint_t status = fi_writedata(ep->endpoint, resp_op, 0, fi_mr_desc(libfabric_state->mr), imm_data, dest_addr, (uint64_t)libfabric_state->remote_addr_staged_amo_ack[pe], libfabric_state->rkey_staged_amo_ack[pe], &resp_op->ofi_context); - } while (try_again(transport, &status, &num_retries)); + } while (try_again(transport, &status, &num_retries, ep->qp_index)); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to write atomic ack.\n"); ep->submitted_ops++; @@ -356,7 +393,7 @@ int perform_gdrcopy_amo(nvshmem_transport_t transport, nvshmemt_libfabric_gdr_op do { resp_op = (nvshmemt_libfabric_gdr_op_ctx_t *)nvshmemtLibfabricOpQueue.getNextSend(); status = resp_op == NULL ? -EAGAIN : 0; - } while (try_again(transport, &status, &num_retries)); + } while (try_again(transport, &status, &num_retries, op->ep->qp_index)); num_retries = 0; NVSHMEMI_NULL_ERROR_JMP(resp_op, status, NVSHMEMX_ERROR_INTERNAL, out, @@ -369,7 +406,7 @@ int perform_gdrcopy_amo(nvshmem_transport_t transport, nvshmemt_libfabric_gdr_op do { status = fi_send(op->ep->endpoint, (void *)resp_op, NVSHMEM_STAGED_AMO_WIREDATA_SIZE, fi_mr_desc(libfabric_state->mr), op->src_addr, &resp_op->ofi_context); - } while (try_again(transport, &status, &num_retries)); + } while (try_again(transport, &status, &num_retries, op->ep->qp_index)); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to respond to atomic request.\n"); op->ep->submitted_ops++; @@ -545,32 +582,40 @@ int nvshmemt_libfabric_put_signal_completion(nvshmem_transport_t transport, static int nvshmemt_libfabric_quiet(struct nvshmem_transport *tcurr, int pe, int qp_index) { nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)tcurr->state; + uint64_t completed; nvshmemt_libfabric_endpoint_t *ep; int status = 0; - if (qp_index) { + if (qp_index) ep = &libfabric_state->eps[NVSHMEMT_LIBFABRIC_PROXY_EP_IDX]; - } else { + else ep = &libfabric_state->eps[NVSHMEMT_LIBFABRIC_HOST_EP_IDX]; - } - if (likely(libfabric_state->prov_info->domain_attr->control_progress == FI_PROGRESS_MANUAL) || - (libfabric_state->prov_info->domain_attr->data_progress == FI_PROGRESS_MANUAL) || - (use_staged_atomics == true) -#ifdef NVSHMEM_USE_GDRCOPY - || (use_gdrcopy == true) -#endif - ) { - uint64_t submitted, completed; - for (;;) { - completed = fi_cntr_read(ep->counter); - submitted = ep->submitted_ops; - if (completed + ep->completed_staged_atomics == submitted) - break; - else { - if (nvshmemt_libfabric_progress(tcurr)) { - status = NVSHMEMX_ERROR_INTERNAL; + if (use_staged_atomics) { + if (use_auto_progress) { + for (;;) { + completed = fi_cntr_read(ep->counter); + if (completed + ep->completed_staged_atomics == ep->submitted_ops) { break; + } else { + status = nvshmemt_libfabric_auto_progress(tcurr, qp_index); + if (status) { + status = NVSHMEMX_ERROR_INTERNAL; + break; + } + } + } + } else { + for (;;) { + completed = fi_cntr_read(ep->counter); + if (completed + ep->completed_staged_atomics == ep->submitted_ops) + break; + else { + status = nvshmemt_libfabric_manual_progress(tcurr); + if (status) { + status = NVSHMEMX_ERROR_INTERNAL; + break; + } } } } @@ -631,7 +676,7 @@ static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, do { gdr_ctx = (nvshmemt_libfabric_gdr_op_ctx_t *)nvshmemtLibfabricOpQueue.getNextSend(); status = gdr_ctx == NULL ? -EAGAIN : 0; - } while (try_again(tcurr, &status, &num_retries)); + } while (try_again(tcurr, &status, &num_retries, qp_index)); NVSHMEMI_NULL_ERROR_JMP(gdr_ctx, status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to get context buffer for put request.\n"); context = &gdr_ctx->ofi_context; @@ -652,7 +697,7 @@ static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, status = fi_write(ep->endpoint, &p_buf->p_op.value, op_size, fi_mr_desc(libfabric_state->mr), target_ep, (uintptr_t)remote->ptr, remote_handle->key, context); - } while (try_again(tcurr, &status, &num_retries)); + } while (try_again(tcurr, &status, &num_retries, qp_index)); } else { p_op_msg.msg_iov = &p_op_l_iov; p_op_msg.desc = NULL; // Local buffer is on the stack @@ -676,7 +721,7 @@ static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, */ do { status = fi_writemsg(ep->endpoint, &p_op_msg, FI_INJECT); - } while (try_again(tcurr, &status, &num_retries)); + } while (try_again(tcurr, &status, &num_retries, qp_index)); } } else if (verb.desc == NVSHMEMI_OP_PUT) { uintptr_t remote_addr; @@ -693,7 +738,7 @@ static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, else status = fi_write(ep->endpoint, local->ptr, op_size, local_handle->local_desc, target_ep, remote_addr, remote_handle->key, context); - } while (try_again(tcurr, &status, &num_retries)); + } while (try_again(tcurr, &status, &num_retries, qp_index)); } else if (verb.desc == NVSHMEMI_OP_G || verb.desc == NVSHMEMI_OP_GET) { assert( !imm_data); // Write w/ imm not suppored with NVSHMEMI_OP_G/GET on Libfabric transport @@ -706,7 +751,7 @@ static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, do { status = fi_read(ep->endpoint, local->ptr, op_size, local_handle->local_desc, target_ep, remote_addr, remote_handle->key, context); - } while (try_again(tcurr, &status, &num_retries)); + } while (try_again(tcurr, &status, &num_retries, qp_index)); } else { NVSHMEMI_ERROR_JMP(status, NVSHMEMX_ERROR_INVALID_VALUE, out, "Invalid RMA operation specified.\n"); @@ -746,7 +791,7 @@ static int nvshmemt_libfabric_gdr_amo(struct nvshmem_transport *transport, int p do { amo = (nvshmemt_libfabric_gdr_op_ctx_t *)nvshmemtLibfabricOpQueue.getNextSend(); status = amo == NULL ? -EAGAIN : 0; - } while (try_again(transport, &status, &num_retries)); + } while (try_again(transport, &status, &num_retries, qp_index)); if (status) { NVSHMEMI_ERROR_PRINT("Unable to retrieve AMO operation."); @@ -767,7 +812,7 @@ static int nvshmemt_libfabric_gdr_amo(struct nvshmem_transport *transport, int p do { status = fi_send(ep->endpoint, (void *)amo, NVSHMEM_STAGED_AMO_WIREDATA_SIZE, fi_mr_desc(libfabric_state->mr), target_ep, &amo->ofi_context); - } while (try_again(transport, &status, &num_retries)); + } while (try_again(transport, &status, &num_retries, qp_index)); if (status) { NVSHMEMI_ERROR_PRINT("Received an error when trying to post an AMO operation.\n"); @@ -913,7 +958,7 @@ static int nvshmemt_libfabric_amo(struct nvshmem_transport *transport, int pe, v status = fi_fetch_atomicmsg(ep->endpoint, &amo_msg, &fi_ret_iov, &local_handle->local_desc, 1, FI_INJECT); } - } while (try_again(transport, &status, &num_retries)); + } while (try_again(transport, &status, &num_retries, qp_index)); if (status) goto out; // Status set by try_again @@ -948,7 +993,7 @@ static int nvshmemt_libfabric_gdr_signal(struct nvshmem_transport *transport, in do { context = (nvshmemt_libfabric_gdr_op_ctx_t *)nvshmemtLibfabricOpQueue.getNextSend(); status = context == NULL ? -EAGAIN : 0; - } while (try_again(transport, &status, &num_retries)); + } while (try_again(transport, &status, &num_retries, qp_index)); if (status) { NVSHMEMI_ERROR_PRINT("Unable to retrieve signal operation buffer."); @@ -967,7 +1012,7 @@ static int nvshmemt_libfabric_gdr_signal(struct nvshmem_transport *transport, in do { status = fi_send(ep->endpoint, (void *)signal, sizeof(nvshmemt_libfabric_gdr_signal_op_t), fi_mr_desc(libfabric_state->mr), target_ep, &context->ofi_context); - } while (try_again(transport, &status, &num_retries)); + } while (try_again(transport, &status, &num_retries, qp_index)); if (status) { NVSHMEMI_ERROR_PRINT("Received an error when trying to post a signal operation.\n"); @@ -1002,7 +1047,7 @@ int nvshmemt_put_signal_unordered(struct nvshmem_transport *tcurr, int pe, rma_v sequence_count = seq_num; status = 0; } - } while (try_again(tcurr, &status, &num_retries)); + } while (try_again(tcurr, &status, &num_retries, qp_index)); if (unlikely(status)) { NVSHMEMI_ERROR_PRINT("Error in nvshmemt_put_signal_unordered while waiting for category\n"); @@ -1476,6 +1521,7 @@ static int nvshmemt_libfabric_connect_endpoints(nvshmem_transport_t t, int *sele "Unable to allocate array of endpoint names."); for (int i = 0; i < NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; i++) { + state->eps[i].qp_index = i; status = fi_endpoint(state->domain, state->prov_info, &state->eps[i].endpoint, NULL); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to allocate endpoint: %d: %s\n", status, @@ -1803,7 +1849,8 @@ static int nvshmemt_libfabric_finalize(nvshmem_transport_t transport) { return 0; } -static int nvshmemi_libfabric_init_state(nvshmem_transport_t t, nvshmemt_libfabric_state_t *state) { +static int nvshmemi_libfabric_init_state(nvshmem_transport_t t, nvshmemt_libfabric_state_t *state, + struct nvshmemi_options_s *options) { struct fi_info info; struct fi_tx_attr tx_attr; struct fi_rx_attr rx_attr; @@ -1854,20 +1901,44 @@ static int nvshmemi_libfabric_init_state(nvshmem_transport_t t, nvshmemt_libfabr info.mode |= FI_CONTEXT2; } - /* Be thread safe at the level of the endpoint completion context. */ - domain_attr.threading = FI_THREAD_SAFE; - + ep_attr.type = FI_EP_RDM; /* Reliable datagrams */ /* Require completion RMA completion at target for correctness of quiet */ info.tx_attr->op_flags = FI_DELIVERY_COMPLETE; - ep_attr.type = FI_EP_RDM; // Reliable datagrams + /* nvshmemt_libfabric_auto_progress relaxes threading requirement */ + domain_attr.threading = FI_THREAD_COMPLETION; + info.domain_attr->data_progress = FI_PROGRESS_AUTO; status = fi_getinfo(FI_VERSION(NVSHMEMT_LIBFABRIC_MAJ_VER, NVSHMEMT_LIBFABRIC_MIN_VER), NULL, NULL, 0, &info, &returned_fabrics); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "No providers matched fi_getinfo query: %d: %s\n", status, - fi_strerror(status * -1)); + /* + * 1. Ensure that at least one fabric was returned + * 2. Make sure returned fabric matches the name of selected provider + * + * This has an assumption that the provided fabric option + * options.LIBFABRIC_PROVIDER will be a substr of the returned fabric + * name + */ + if (!status && strstr(returned_fabrics->fabric_attr->name, options->LIBFABRIC_PROVIDER)) { + use_auto_progress = true; + } else { + fi_freeinfo(returned_fabrics); + + /* + * Fallback to FI_PROGRESS_MANUAL path + * nvshmemt_libfabric_slow_progress requires FI_THREAD_SAFE + */ + domain_attr.threading = FI_THREAD_SAFE; + info.domain_attr->data_progress = FI_PROGRESS_MANUAL; + status = fi_getinfo(FI_VERSION(NVSHMEMT_LIBFABRIC_MAJ_VER, NVSHMEMT_LIBFABRIC_MIN_VER), + NULL, NULL, 0, &info, &returned_fabrics); + + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "No providers matched fi_getinfo query: %d: %s\n", status, + fi_strerror(status * -1)); + } + state->all_prov_info = returned_fabrics; for (current_fabric = returned_fabrics; current_fabric != NULL; current_fabric = current_fabric->next) { @@ -1987,7 +2058,6 @@ int nvshmemt_init(nvshmem_transport_t *t, struct nvshmemi_cuda_fn_table *table, transport->host_ops.quiet = nvshmemt_libfabric_quiet; transport->host_ops.finalize = nvshmemt_libfabric_finalize; transport->host_ops.show_info = nvshmemt_libfabric_show_info; - transport->host_ops.progress = nvshmemt_libfabric_progress; transport->attr = NVSHMEM_TRANSPORT_ATTR_CONNECTED; transport->is_successfully_initialized = true; @@ -2103,12 +2173,17 @@ int nvshmemt_init(nvshmem_transport_t *t, struct nvshmemi_cuda_fn_table *table, #undef NVSHMEMI_SET_ENV_VAR /* Prepare fabric state information. */ - status = nvshmemi_libfabric_init_state(transport, libfabric_state); + status = nvshmemi_libfabric_init_state(transport, libfabric_state, &options); if (status) { NVSHMEMI_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out_clean, "Failed to initialize the libfabric state.\n"); } + if (use_auto_progress) + transport->host_ops.progress = nvshmemt_libfabric_auto_proxy_progress; + else + transport->host_ops.progress = nvshmemt_libfabric_manual_progress; + *t = transport; out: if (status) { diff --git a/src/modules/transport/libfabric/libfabric.h b/src/modules/transport/libfabric/libfabric.h index 17cb7c7..8fb47de 100644 --- a/src/modules/transport/libfabric/libfabric.h +++ b/src/modules/transport/libfabric/libfabric.h @@ -184,6 +184,7 @@ typedef struct { nvshmemt_libfabric_endpoint_seq_counter_t put_signal_seq_counter; std::unordered_map> *proxy_put_signal_comp_map; + int qp_index; } nvshmemt_libfabric_endpoint_t; typedef enum { From 0e456ef22152102e39421ae25c2116d52b03cb09 Mon Sep 17 00:00:00 2001 From: Seth Zegelstein Date: Thu, 9 Oct 2025 17:33:11 +0000 Subject: [PATCH 4/4] transport/libfabric: Implement multi-rail This change implements multi-rail support for the libfabric host proxy transport. The transport changes from having 1 domain with 2 EP's to having 1 host domain on NIC 1 and one proxy domain per NIC. Splitting the host EP and proxy EP into seperate domains was done for simplicity of the code. Every domain resource (including AV) was bound on a 1-1 basis per EP so this change should be a functional no-op. In the future when one implements the QP API on the libfabric host proxy transport, N EP's per domain can be easily extended on this. This code uses a round robin based load balancer to assign messages to NIC's. One NIC will be used for the entire operation call into the libfabric transport (including put-signal), but not including messages that are segmented due to size or MR boundaries. The number of NIC's (domains) per PE are limited by the size of the struct nvshmemt_libfabric_mem_handle_t. A new env variable NVSHMEM_LIBFABRIC_MAX_NIC_PER_PE controls the max number of NIC's per PE. Thank you Justin for contributing an initial implementation of multi-rail which I built on top of. Co-authored-by: Justin Chui Signed-off-by: Seth Zegelstein --- src/modules/transport/common/env_defs.h | 3 + src/modules/transport/libfabric/libfabric.cpp | 945 +++++++++--------- src/modules/transport/libfabric/libfabric.h | 45 +- 3 files changed, 516 insertions(+), 477 deletions(-) diff --git a/src/modules/transport/common/env_defs.h b/src/modules/transport/common/env_defs.h index aafc312..00d593c 100644 --- a/src/modules/transport/common/env_defs.h +++ b/src/modules/transport/common/env_defs.h @@ -98,6 +98,9 @@ NVSHMEMI_ENV_DEF(DISABLE_LOCAL_ONLY_PROXY, bool, false, NVSHMEMI_ENV_CAT_TRANSPO NVSHMEMI_ENV_DEF(LIBFABRIC_PROVIDER, string, "cxi", NVSHMEMI_ENV_CAT_TRANSPORT, "Set the feature set provider for the libfabric transport: cxi, efa, verbs") +NVSHMEMI_ENV_DEF(LIBFABRIC_MAX_NIC_PER_PE, int, 16, NVSHMEMI_ENV_CAT_TRANSPORT, + "Set the maximum number of NIC's per PE to use for libfabric provider") + #if defined(NVSHMEM_IBGDA_SUPPORT) || defined(NVSHMEM_ENV_ALL) /** GPU-initiated communication **/ NVSHMEMI_ENV_DEF(IBGDA_ENABLE_MULTI_PORT, bool, false, NVSHMEMI_ENV_CAT_TRANSPORT, diff --git a/src/modules/transport/libfabric/libfabric.cpp b/src/modules/transport/libfabric/libfabric.cpp index 4b2952e..eb5cd43 100644 --- a/src/modules/transport/libfabric/libfabric.cpp +++ b/src/modules/transport/libfabric/libfabric.cpp @@ -65,13 +65,13 @@ static bool use_gdrcopy = false; #define NVSHMEM_STAGED_AMO_WIREDATA_SIZE \ sizeof(nvshmemt_libfabric_gdr_op_ctx_t) - sizeof(struct fi_context2) - sizeof(fi_addr_t) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) + static bool use_staged_atomics = false; static bool use_auto_progress = false; - -threadSafeOpQueue nvshmemtLibfabricOpQueue; std::mutex gdrRecvMutex; -int nvshmemt_libfabric_gdr_process_amos(nvshmem_transport_t transport); +int nvshmemt_libfabric_gdr_process_amos(nvshmem_transport_t transport, int qp_index); int nvshmemt_libfabric_put_signal_completion(nvshmem_transport_t transport, nvshmemt_libfabric_endpoint_t *ep, struct fi_cq_data_entry *entry, fi_addr_t *addr); @@ -81,6 +81,27 @@ static nvshmemt_libfabric_imm_cq_data_hdr_t nvshmemt_get_write_with_imm_hdr(uint NVSHMEM_STAGED_AMO_PUT_SIGNAL_SEQ_CNTR_BIT_SHIFT); } +static inline nvshmemt_libfabric_endpoint_t *nvshmemt_libfabric_get_next_ep( + nvshmemt_libfabric_state_t *state, int qp_index) { + int selected_ep; + + if (qp_index == NVSHMEMX_QP_HOST) { + selected_ep = 0; + } else { + /* + * Return the current EP, and increment the next EP in round robin fashion + * between 1 and state->num_selected_domains - 1. state->cur_proxy_ep_index + * is initialized to 1. This round-robin goes through the proxy EP's and + * ignores the host EP. + */ + selected_ep = state->cur_proxy_ep_index; + state->cur_proxy_ep_index = (state->cur_proxy_ep_index + 1) % state->num_selected_domains; + if (!state->cur_proxy_ep_index) state->cur_proxy_ep_index = 1; + } + + return state->eps[selected_ep]; +} + static void nvshmemt_libfabric_put_signal_ack_completion(nvshmemt_libfabric_endpoint_t *ep, struct fi_cq_data_entry *entry) { uint32_t seq_num = entry->data & NVSHMEM_STAGED_AMO_PUT_SIGNAL_SEQ_CNTR_BIT_MASK; @@ -98,6 +119,7 @@ static int nvshmemt_libfabric_gdr_process_completion(nvshmem_transport_t transpo fi_addr_t *addr) { int status = 0; nvshmemt_libfabric_gdr_op_ctx_t *op; + nvshmemt_libfabric_state_t *state = (nvshmemt_libfabric_state_t *)transport->state; /* Write w/imm doesn't have op->op_context, must be checked first */ if (entry->flags & FI_REMOTE_CQ_DATA) { @@ -122,16 +144,16 @@ static int nvshmemt_libfabric_gdr_process_completion(nvshmem_transport_t transpo op->src_addr = *addr; if (entry->flags & FI_SEND) { - nvshmemtLibfabricOpQueue.putToSend(op); + state->op_queue[ep->domain_index]->putToSend(op); } else if (entry->flags & FI_RMA) { /* inlined p ops or atomic responses */ - nvshmemtLibfabricOpQueue.putToSend(op); + state->op_queue[ep->domain_index]->putToSend(op); } else if (op->type == NVSHMEMT_LIBFABRIC_MATCH) { /* Must happen after entry->flags & FI_SEND to avoid send completions */ status = nvshmemt_libfabric_put_signal_completion(transport, ep, entry, addr); } else if (entry->flags & FI_RECV) { op->ep = ep; - nvshmemtLibfabricOpQueue.putToRecv(op); + state->op_queue[ep->domain_index]->putToRecv(op); } else { NVSHMEMI_ERROR_JMP(status, NVSHMEMX_ERROR_INVALID_VALUE, out, "Found an invalid message type in an ep completion.\n"); @@ -202,19 +224,24 @@ static int nvshmemt_libfabric_single_ep_progress(nvshmem_transport_t transport, static int nvshmemt_libfabric_auto_progress(nvshmem_transport_t transport, int qp_index) { int status; nvshmemt_libfabric_state_t *state = (nvshmemt_libfabric_state_t *)transport->state; - nvshmemt_libfabric_endpoint_t *ep; + int end_iter; - if (qp_index) - ep = &state->eps[NVSHMEMT_LIBFABRIC_PROXY_EP_IDX]; - else - ep = &state->eps[NVSHMEMT_LIBFABRIC_HOST_EP_IDX]; + if (qp_index == NVSHMEMX_QP_HOST) { + end_iter = NVSHMEMX_QP_HOST + 1; + } else { + end_iter = state->eps.size(); + qp_index = NVSHMEMT_LIBFABRIC_PROXY_EP_IDX; + } - status = nvshmemt_libfabric_single_ep_progress(transport, ep); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Error in progress: %d.\n", status); + for (int i = qp_index; i < end_iter; i++) { + status = nvshmemt_libfabric_single_ep_progress(transport, state->eps[i]); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Error in progress: %d.\n", + status); + } if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { if (gdrRecvMutex.try_lock()) { - status = nvshmemt_libfabric_gdr_process_amos(transport); + status = nvshmemt_libfabric_gdr_process_amos(transport, qp_index); gdrRecvMutex.unlock(); } } @@ -230,15 +257,15 @@ static int nvshmemt_libfabric_auto_proxy_progress(nvshmem_transport_t transport) static int nvshmemt_libfabric_manual_progress(nvshmem_transport_t transport) { nvshmemt_libfabric_state_t *state = (nvshmemt_libfabric_state_t *)transport->state; int status; - for (int i = 0; i < NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; i++) { - status = nvshmemt_libfabric_single_ep_progress(transport, &state->eps[i]); + for (size_t i = 0; i < state->eps.size(); i++) { + status = nvshmemt_libfabric_single_ep_progress(transport, state->eps[i]); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Error in progress: %d.\n", status); } if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { if (gdrRecvMutex.try_lock()) { - status = nvshmemt_libfabric_gdr_process_amos(transport); + status = nvshmemt_libfabric_gdr_process_amos(transport, NVSHMEMX_QP_ALL); gdrRecvMutex.unlock(); } } @@ -284,11 +311,13 @@ int gdrcopy_amo_ack(nvshmem_transport_t transport, nvshmemt_libfabric_endpoint_t uint64_t num_retries = 0; int status; uint64_t imm_data = 0; + uint64_t rkey_index = pe * libfabric_state->num_selected_domains + ep->domain_index; do { - resp_op = (nvshmemt_libfabric_gdr_op_ctx_t *)nvshmemtLibfabricOpQueue.getNextSend(); + resp_op = (nvshmemt_libfabric_gdr_op_ctx_t *)libfabric_state->op_queue[ep->domain_index] + ->getNextSend(); status = resp_op == NULL ? -EAGAIN : 0; - } while (try_again(transport, &status, &num_retries, ep->qp_index)); + } while (try_again(transport, &status, &num_retries, ep->domain_index)); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable allocate buffer for atomic ack.\n"); @@ -298,10 +327,11 @@ int gdrcopy_amo_ack(nvshmem_transport_t transport, nvshmemt_libfabric_endpoint_t << NVSHMEM_STAGED_AMO_PUT_SIGNAL_SEQ_CNTR_BIT_SHIFT) | sequence_count; do { - status = fi_writedata(ep->endpoint, resp_op, 0, fi_mr_desc(libfabric_state->mr), imm_data, - dest_addr, (uint64_t)libfabric_state->remote_addr_staged_amo_ack[pe], - libfabric_state->rkey_staged_amo_ack[pe], &resp_op->ofi_context); - } while (try_again(transport, &status, &num_retries, ep->qp_index)); + status = fi_writedata( + ep->endpoint, resp_op, 0, fi_mr_desc(libfabric_state->mr[ep->domain_index]), imm_data, + dest_addr, (uint64_t)libfabric_state->remote_addr_staged_amo_ack[pe], + libfabric_state->rkey_staged_amo_ack[rkey_index], &resp_op->ofi_context); + } while (try_again(transport, &status, &num_retries, ep->domain_index)); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to write atomic ack.\n"); ep->submitted_ops++; @@ -391,9 +421,11 @@ int perform_gdrcopy_amo(nvshmem_transport_t transport, nvshmemt_libfabric_gdr_op if (received_op->op > NVSHMEMI_AMO_END_OF_NONFETCH) { do { - resp_op = (nvshmemt_libfabric_gdr_op_ctx_t *)nvshmemtLibfabricOpQueue.getNextSend(); + resp_op = + (nvshmemt_libfabric_gdr_op_ctx_t *)libfabric_state->op_queue[op->ep->domain_index] + ->getNextSend(); status = resp_op == NULL ? -EAGAIN : 0; - } while (try_again(transport, &status, &num_retries, op->ep->qp_index)); + } while (try_again(transport, &status, &num_retries, op->ep->domain_index)); num_retries = 0; NVSHMEMI_NULL_ERROR_JMP(resp_op, status, NVSHMEMX_ERROR_INTERNAL, out, @@ -405,8 +437,9 @@ int perform_gdrcopy_amo(nvshmem_transport_t transport, nvshmemt_libfabric_gdr_op do { status = fi_send(op->ep->endpoint, (void *)resp_op, NVSHMEM_STAGED_AMO_WIREDATA_SIZE, - fi_mr_desc(libfabric_state->mr), op->src_addr, &resp_op->ofi_context); - } while (try_again(transport, &status, &num_retries, op->ep->qp_index)); + fi_mr_desc(libfabric_state->mr[op->ep->domain_index]), op->src_addr, + &resp_op->ofi_context); + } while (try_again(transport, &status, &num_retries, op->ep->domain_index)); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to respond to atomic request.\n"); op->ep->submitted_ops++; @@ -464,27 +497,44 @@ int nvshmemt_libfabric_gdr_process_ack(nvshmem_transport_t transport, return 0; } -int nvshmemt_libfabric_gdr_process_amos(nvshmem_transport_t transport) { +int nvshmemt_libfabric_gdr_process_amos(nvshmem_transport_t transport, int qp_index) { nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)transport->state; nvshmemt_libfabric_gdr_op_ctx_t *op; int status = 0; + int end_iter; - op = (nvshmemt_libfabric_gdr_op_ctx_t *)nvshmemtLibfabricOpQueue.getNextRecv(); - while (op) { - if (op->type == NVSHMEMT_LIBFABRIC_SEND) { - status = nvshmemt_libfabric_gdr_process_amo(transport, op); - } else { - status = nvshmemt_libfabric_gdr_process_ack(transport, op); - } + if (qp_index == NVSHMEMX_QP_HOST) { + end_iter = NVSHMEMX_QP_HOST + 1; + } else if (qp_index == NVSHMEMX_QP_ALL) { + qp_index = 0; + end_iter = libfabric_state->eps.size(); + } else { + end_iter = libfabric_state->eps.size(); + qp_index = NVSHMEMT_LIBFABRIC_PROXY_EP_IDX; + } - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to process atomic.\n"); - status = fi_recv(op->ep->endpoint, (void *)op, NVSHMEM_STAGED_AMO_WIREDATA_SIZE, - fi_mr_desc(libfabric_state->mr), FI_ADDR_UNSPEC, &op->ofi_context); - if (status) { - NVSHMEMI_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to re-post recv.\n"); + for (int i = qp_index; i < end_iter; i++) { + op = (nvshmemt_libfabric_gdr_op_ctx_t *)libfabric_state->op_queue[i]->getNextRecv(); + while (op) { + if (op->type == NVSHMEMT_LIBFABRIC_SEND) { + status = nvshmemt_libfabric_gdr_process_amo(transport, op); + } else { + status = nvshmemt_libfabric_gdr_process_ack(transport, op); + } + + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "Unable to process atomic.\n"); + status = fi_recv(op->ep->endpoint, (void *)op, NVSHMEM_STAGED_AMO_WIREDATA_SIZE, + fi_mr_desc(libfabric_state->mr[op->ep->domain_index]), FI_ADDR_UNSPEC, + &op->ofi_context); + if (status) { + NVSHMEMI_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "Unable to re-post recv.\n"); + } + op = (nvshmemt_libfabric_gdr_op_ctx_t *)libfabric_state->op_queue[i]->getNextRecv(); } - op = (nvshmemt_libfabric_gdr_op_ctx_t *)nvshmemtLibfabricOpQueue.getNextRecv(); } + out: return status; } @@ -573,7 +623,8 @@ int nvshmemt_libfabric_put_signal_completion(nvshmem_transport_t transport, "Error in put_signal_completion gdrcopy signaling operation.\n"); ep->proxy_put_signal_comp_map->erase(iter); status = fi_recv(ep->endpoint, op, NVSHMEM_STAGED_AMO_WIREDATA_SIZE, - fi_mr_desc(libfabric_state->mr), FI_ADDR_UNSPEC, &op->ofi_context); + fi_mr_desc(libfabric_state->mr[ep->domain_index]), FI_ADDR_UNSPEC, + &op->ofi_context); } out: @@ -581,53 +632,66 @@ int nvshmemt_libfabric_put_signal_completion(nvshmem_transport_t transport, } static int nvshmemt_libfabric_quiet(struct nvshmem_transport *tcurr, int pe, int qp_index) { - nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)tcurr->state; + nvshmemt_libfabric_state_t *state = (nvshmemt_libfabric_state_t *)tcurr->state; uint64_t completed; - nvshmemt_libfabric_endpoint_t *ep; + bool all_nics_quieted; int status = 0; + int end_iter; - if (qp_index) - ep = &libfabric_state->eps[NVSHMEMT_LIBFABRIC_PROXY_EP_IDX]; - else - ep = &libfabric_state->eps[NVSHMEMT_LIBFABRIC_HOST_EP_IDX]; + if (qp_index == NVSHMEMX_QP_HOST) { + end_iter = NVSHMEMX_QP_HOST + 1; + } else { + end_iter = state->eps.size(); + qp_index = NVSHMEMT_LIBFABRIC_PROXY_EP_IDX; + } if (use_staged_atomics) { if (use_auto_progress) { for (;;) { - completed = fi_cntr_read(ep->counter); - if (completed + ep->completed_staged_atomics == ep->submitted_ops) { - break; - } else { - status = nvshmemt_libfabric_auto_progress(tcurr, qp_index); - if (status) { - status = NVSHMEMX_ERROR_INTERNAL; - break; + all_nics_quieted = true; + for (int i = qp_index; i < end_iter; i++) { + completed = fi_cntr_read(state->eps[i]->counter) + + state->eps[i]->completed_staged_atomics; + if (state->eps[i]->submitted_ops != completed) { + all_nics_quieted = false; + if (nvshmemt_libfabric_auto_progress(tcurr, qp_index)) { + status = NVSHMEMX_ERROR_INTERNAL; + break; + } } } + + if (status || all_nics_quieted) break; } } else { for (;;) { - completed = fi_cntr_read(ep->counter); - if (completed + ep->completed_staged_atomics == ep->submitted_ops) + all_nics_quieted = true; + for (int i = qp_index; i < end_iter; i++) { + completed = fi_cntr_read(state->eps[i]->counter) + + state->eps[i]->completed_staged_atomics; + if (state->eps[i]->submitted_ops != completed) all_nics_quieted = false; + } + if (all_nics_quieted) break; + + /* FI_PROGRESS_MANUAL requires progress on every endpoint */ + if (nvshmemt_libfabric_manual_progress(tcurr)) { + status = NVSHMEMX_ERROR_INTERNAL; break; - else { - status = nvshmemt_libfabric_manual_progress(tcurr); - if (status) { - status = NVSHMEMX_ERROR_INTERNAL; - break; - } } } } } else { - status = fi_cntr_wait(ep->counter, ep->submitted_ops, NVSHMEMT_LIBFABRIC_QUIET_TIMEOUT_MS); - if (status) { - /* note - Status is negative for this function in error cases but - * fi_strerror only accepts positive values. - */ - NVSHMEMI_ERROR_PRINT("Error in quiet operation (%d): %s.\n", status, - fi_strerror(status * -1)); - status = NVSHMEMX_ERROR_INTERNAL; + for (int i = qp_index; i < end_iter; i++) { + status = fi_cntr_wait(state->eps[i]->counter, state->eps[i]->submitted_ops, + NVSHMEMT_LIBFABRIC_QUIET_TIMEOUT_MS); + if (status) { + /* note - Status is negative for this function in error cases but + * fi_strerror only accepts positive values. + */ + NVSHMEMI_ERROR_PRINT("Error in quiet operation (%d): %s.\n", status, + fi_strerror(status * -1)); + status = NVSHMEMX_ERROR_INTERNAL; + } } } @@ -648,33 +712,33 @@ static int nvshmemt_libfabric_show_info(struct nvshmem_transport *transport, int static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, rma_verb_t verb, rma_memdesc_t *remote, rma_memdesc_t *local, - rma_bytesdesc_t bytesdesc, int qp_index, - uint32_t *imm_data) { + rma_bytesdesc_t bytesdesc, int qp_index, uint32_t *imm_data, + nvshmemt_libfabric_endpoint_t *ep) { nvshmemt_libfabric_mem_handle_ep_t *remote_handle, *local_handle; nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)tcurr->state; struct iovec p_op_l_iov; struct fi_msg_rma p_op_msg; struct fi_rma_iov p_op_r_iov; - nvshmemt_libfabric_endpoint_t *ep; size_t op_size; uint64_t num_retries = 0; int status = 0; int target_ep; - int ep_idx = 0; void *context = NULL; memset(&p_op_l_iov, 0, sizeof(struct iovec)); memset(&p_op_msg, 0, sizeof(struct fi_msg_rma)); memset(&p_op_r_iov, 0, sizeof(struct fi_rma_iov)); - ep_idx = qp_index ? NVSHMEMT_LIBFABRIC_PROXY_EP_IDX : NVSHMEMT_LIBFABRIC_HOST_EP_IDX; - ep = &libfabric_state->eps[ep_idx]; - target_ep = pe * NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS + ep_idx; + /* put_signal passes in EP to ensure that both operations go through same EP */ + if (!ep) ep = nvshmemt_libfabric_get_next_ep(libfabric_state, qp_index); + + target_ep = pe * libfabric_state->num_selected_domains + ep->domain_index; if (libfabric_state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { nvshmemt_libfabric_gdr_op_ctx_t *gdr_ctx; do { - gdr_ctx = (nvshmemt_libfabric_gdr_op_ctx_t *)nvshmemtLibfabricOpQueue.getNextSend(); + gdr_ctx = (nvshmemt_libfabric_gdr_op_ctx_t *)libfabric_state->op_queue[ep->domain_index] + ->getNextSend(); status = gdr_ctx == NULL ? -EAGAIN : 0; } while (try_again(tcurr, &status, &num_retries, qp_index)); NVSHMEMI_NULL_ERROR_JMP(gdr_ctx, status, NVSHMEMX_ERROR_INTERNAL, out, @@ -682,8 +746,9 @@ static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, context = &gdr_ctx->ofi_context; } - remote_handle = &((nvshmemt_libfabric_mem_handle_t *)remote->handle)->hdls[ep_idx]; - local_handle = &((nvshmemt_libfabric_mem_handle_t *)local->handle)->hdls[ep_idx]; + remote_handle = &((nvshmemt_libfabric_mem_handle_t *)remote->handle)->hdls[ep->domain_index]; + local_handle = &((nvshmemt_libfabric_mem_handle_t *)local->handle)->hdls[ep->domain_index]; + op_size = bytesdesc.elembytes * bytesdesc.nelems; if (verb.desc == NVSHMEMI_OP_P) { @@ -695,7 +760,7 @@ static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, do { p_buf->p_op.value = *(uint64_t *)local->ptr; status = fi_write(ep->endpoint, &p_buf->p_op.value, op_size, - fi_mr_desc(libfabric_state->mr), target_ep, + fi_mr_desc(libfabric_state->mr[ep->domain_index]), target_ep, (uintptr_t)remote->ptr, remote_handle->key, context); } while (try_again(tcurr, &status, &num_retries, qp_index)); } else { @@ -709,7 +774,8 @@ static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, p_op_l_iov.iov_base = local->ptr; p_op_l_iov.iov_len = op_size; - if (libfabric_state->prov_info->domain_attr->mr_mode & FI_MR_VIRT_ADDR) + if (libfabric_state->prov_infos[ep->domain_index]->domain_attr->mr_mode & + FI_MR_VIRT_ADDR) p_op_r_iov.addr = (uintptr_t)remote->ptr; else p_op_r_iov.addr = (uintptr_t)remote->offset; @@ -725,7 +791,7 @@ static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, } } else if (verb.desc == NVSHMEMI_OP_PUT) { uintptr_t remote_addr; - if (libfabric_state->prov_info->domain_attr->mr_mode & FI_MR_VIRT_ADDR) + if (libfabric_state->prov_infos[ep->domain_index]->domain_attr->mr_mode & FI_MR_VIRT_ADDR) remote_addr = (uintptr_t)remote->ptr; else remote_addr = (uintptr_t)remote->offset; @@ -743,7 +809,7 @@ static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, assert( !imm_data); // Write w/ imm not suppored with NVSHMEMI_OP_G/GET on Libfabric transport uintptr_t remote_addr; - if (libfabric_state->prov_info->domain_attr->mr_mode & FI_MR_VIRT_ADDR) + if (libfabric_state->prov_infos[ep->domain_index]->domain_attr->mr_mode & FI_MR_VIRT_ADDR) remote_addr = (uintptr_t)remote->ptr; else remote_addr = (uintptr_t)remote->offset; @@ -771,7 +837,8 @@ static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, static int nvshmemt_libfabric_rma(struct nvshmem_transport *tcurr, int pe, rma_verb_t verb, rma_memdesc_t *remote, rma_memdesc_t *local, rma_bytesdesc_t bytesdesc, int qp_index) { - return nvshmemt_libfabric_rma_impl(tcurr, pe, verb, remote, local, bytesdesc, qp_index, NULL); + return nvshmemt_libfabric_rma_impl(tcurr, pe, verb, remote, local, bytesdesc, qp_index, NULL, + NULL); } static int nvshmemt_libfabric_gdr_amo(struct nvshmem_transport *transport, int pe, void *curetptr, @@ -781,15 +848,15 @@ static int nvshmemt_libfabric_gdr_amo(struct nvshmem_transport *transport, int p nvshmemt_libfabric_endpoint_t *ep; nvshmemt_libfabric_gdr_op_ctx_t *amo; uint64_t num_retries = 0; - int target_ep, ep_idx; + int target_ep; int status = 0; - ep_idx = qp_index ? NVSHMEMT_LIBFABRIC_PROXY_EP_IDX : NVSHMEMT_LIBFABRIC_HOST_EP_IDX; - ep = &libfabric_state->eps[ep_idx]; - target_ep = pe * NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS + ep_idx; + ep = nvshmemt_libfabric_get_next_ep(libfabric_state, qp_index); + target_ep = pe * libfabric_state->num_selected_domains + ep->domain_index; do { - amo = (nvshmemt_libfabric_gdr_op_ctx_t *)nvshmemtLibfabricOpQueue.getNextSend(); + amo = (nvshmemt_libfabric_gdr_op_ctx_t *)libfabric_state->op_queue[ep->domain_index] + ->getNextSend(); status = amo == NULL ? -EAGAIN : 0; } while (try_again(transport, &status, &num_retries, qp_index)); @@ -811,7 +878,8 @@ static int nvshmemt_libfabric_gdr_amo(struct nvshmem_transport *transport, int p num_retries = 0; do { status = fi_send(ep->endpoint, (void *)amo, NVSHMEM_STAGED_AMO_WIREDATA_SIZE, - fi_mr_desc(libfabric_state->mr), target_ep, &amo->ofi_context); + fi_mr_desc(libfabric_state->mr[ep->domain_index]), target_ep, + &amo->ofi_context); } while (try_again(transport, &status, &num_retries, qp_index)); if (status) { @@ -841,7 +909,6 @@ static int nvshmemt_libfabric_amo(struct nvshmem_transport *transport, int pe, v uint64_t num_retries = 0; int target_ep; int status = 0; - int ep_idx; memset(&amo_msg, 0, sizeof(struct fi_msg_atomic)); memset(&fi_local_iov, 0, sizeof(struct fi_ioc)); @@ -849,14 +916,14 @@ static int nvshmemt_libfabric_amo(struct nvshmem_transport *transport, int pe, v memset(&fi_ret_iov, 0, sizeof(struct fi_ioc)); memset(&fi_remote_iov, 0, sizeof(struct fi_rma_ioc)); - ep_idx = qp_index ? NVSHMEMT_LIBFABRIC_PROXY_EP_IDX : NVSHMEMT_LIBFABRIC_HOST_EP_IDX; - ep = &libfabric_state->eps[ep_idx]; - target_ep = pe * NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS + ep_idx; + ep = nvshmemt_libfabric_get_next_ep(libfabric_state, qp_index); + target_ep = pe * libfabric_state->num_selected_domains + ep->domain_index; remote_handle = - &((nvshmemt_libfabric_mem_handle_t *)remote->remote_memdesc.handle)->hdls[ep_idx]; + &((nvshmemt_libfabric_mem_handle_t *)remote->remote_memdesc.handle)->hdls[ep->domain_index]; if (verb.desc > NVSHMEMI_AMO_END_OF_NONFETCH) { - local_handle = &((nvshmemt_libfabric_mem_handle_t *)remote->ret_handle)->hdls[ep_idx]; + local_handle = + &((nvshmemt_libfabric_mem_handle_t *)remote->ret_handle)->hdls[ep->domain_index]; } if (bytesdesc.elembytes == 8) { @@ -923,7 +990,7 @@ static int nvshmemt_libfabric_amo(struct nvshmem_transport *transport, int pe, v amo_msg.addr = target_ep; - if (libfabric_state->prov_info->domain_attr->mr_mode & FI_MR_VIRT_ADDR) + if (libfabric_state->prov_infos[ep->domain_index]->domain_attr->mr_mode & FI_MR_VIRT_ADDR) fi_remote_iov.addr = (uintptr_t)remote->remote_memdesc.ptr; else fi_remote_iov.addr = (uintptr_t)remote->remote_memdesc.offset; @@ -975,23 +1042,23 @@ static int nvshmemt_libfabric_amo(struct nvshmem_transport *transport, int pe, v static int nvshmemt_libfabric_gdr_signal(struct nvshmem_transport *transport, int pe, void *curetptr, amo_verb_t verb, amo_memdesc_t *remote, amo_bytesdesc_t bytesdesc, int qp_index, - uint32_t sequence_count, uint16_t num_writes) { + uint32_t sequence_count, uint16_t num_writes, + nvshmemt_libfabric_endpoint_t *ep) { nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)transport->state; - nvshmemt_libfabric_endpoint_t *ep; + nvshmemt_libfabric_gdr_op_ctx_t *context; nvshmemt_libfabric_gdr_signal_op_t *signal; uint64_t num_retries = 0; - int target_ep, ep_idx; + int target_ep; int status = 0; - ep_idx = qp_index ? NVSHMEMT_LIBFABRIC_PROXY_EP_IDX : NVSHMEMT_LIBFABRIC_HOST_EP_IDX; - ep = &libfabric_state->eps[ep_idx]; - target_ep = pe * NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS + ep_idx; + target_ep = pe * libfabric_state->num_selected_domains + ep->domain_index; static_assert(sizeof(nvshmemt_libfabric_gdr_op_ctx) >= sizeof(nvshmemt_libfabric_gdr_signal_op_t)); do { - context = (nvshmemt_libfabric_gdr_op_ctx_t *)nvshmemtLibfabricOpQueue.getNextSend(); + context = (nvshmemt_libfabric_gdr_op_ctx_t *)libfabric_state->op_queue[ep->domain_index] + ->getNextSend(); status = context == NULL ? -EAGAIN : 0; } while (try_again(transport, &status, &num_retries, qp_index)); @@ -1011,7 +1078,8 @@ static int nvshmemt_libfabric_gdr_signal(struct nvshmem_transport *transport, in num_retries = 0; do { status = fi_send(ep->endpoint, (void *)signal, sizeof(nvshmemt_libfabric_gdr_signal_op_t), - fi_mr_desc(libfabric_state->mr), target_ep, &context->ofi_context); + fi_mr_desc(libfabric_state->mr[ep->domain_index]), target_ep, + &context->ofi_context); } while (try_again(transport, &status, &num_retries, qp_index)); if (status) { @@ -1031,16 +1099,15 @@ int nvshmemt_put_signal_unordered(struct nvshmem_transport *tcurr, int pe, rma_v std::vector &write_bytes_desc, amo_verb_t sig_verb, amo_memdesc_t *sig_target, amo_bytesdesc_t sig_bytes_desc, int qp_index) { - nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)tcurr->state; + nvshmemt_libfabric_state_t *state = (nvshmemt_libfabric_state_t *)tcurr->state; int status; uint32_t sequence_count = 0; - int ep_idx = qp_index ? NVSHMEMT_LIBFABRIC_PROXY_EP_IDX : NVSHMEMT_LIBFABRIC_HOST_EP_IDX; - nvshmemt_libfabric_endpoint_t &ep = libfabric_state->eps[ep_idx]; + nvshmemt_libfabric_endpoint_t *ep = nvshmemt_libfabric_get_next_ep(state, qp_index); /* Get sequence number for this put-signal, with retry */ uint64_t num_retries = 0; do { - int32_t seq_num = ep.put_signal_seq_counter.next_seq_num(); + int32_t seq_num = ep->put_signal_seq_counter.next_seq_num(); if (seq_num < 0) { status = -EAGAIN; } else { @@ -1059,7 +1126,7 @@ int nvshmemt_put_signal_unordered(struct nvshmem_transport *tcurr, int pe, rma_v for (size_t i = 0; i < write_remote.size(); i++) { status = nvshmemt_libfabric_rma_impl(tcurr, pe, write_verb, &write_remote[i], &write_local[i], - write_bytes_desc[i], qp_index, &sequence_count); + write_bytes_desc[i], qp_index, &sequence_count, ep); if (unlikely(status)) { NVSHMEMI_ERROR_PRINT( "Error in nvshmemt_put_signal_unordered, could not submit write #%lu\n", i); @@ -1068,8 +1135,9 @@ int nvshmemt_put_signal_unordered(struct nvshmem_transport *tcurr, int pe, rma_v } assert(use_staged_atomics == true); - status = nvshmemt_libfabric_gdr_signal(tcurr, pe, NULL, sig_verb, sig_target, sig_bytes_desc, - qp_index, sequence_count, (uint16_t)write_remote.size()); + status = + nvshmemt_libfabric_gdr_signal(tcurr, pe, NULL, sig_verb, sig_target, sig_bytes_desc, + qp_index, sequence_count, (uint16_t)write_remote.size(), ep); out: if (status) { NVSHMEMI_ERROR_PRINT( @@ -1085,7 +1153,7 @@ static int nvshmemt_libfabric_release_mem_handle(nvshmem_mem_handle_t *mem_handl nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)t->state; nvshmemt_libfabric_mem_handle_t *fabric_handle; void *curr_ptr; - int max_reg, status = 0; + int status = 0; assert(mem_handle != NULL); fabric_handle = (nvshmemt_libfabric_mem_handle_t *)mem_handle; @@ -1115,15 +1183,10 @@ static int nvshmemt_libfabric_release_mem_handle(nvshmem_mem_handle_t *mem_handl } } - if (libfabric_state->prov_info->domain_attr->mr_mode & FI_MR_ENDPOINT) - max_reg = NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; - else - max_reg = 1; - - for (int i = 0; i < max_reg; i++) { + for (size_t i = 0; i < libfabric_state->domains.size(); i++) { int status = fi_close(&fabric_handle->hdls[i].mr->fid); if (status) { - NVSHMEMI_WARN_PRINT("Error releasing mem handle idx %d (%d): %s\n", i, status, + NVSHMEMI_WARN_PRINT("Error releasing mem handle idx %zu (%d): %s\n", i, status, fi_strerror(status * -1)); } } @@ -1132,6 +1195,7 @@ static int nvshmemt_libfabric_release_mem_handle(nvshmem_mem_handle_t *mem_handl return status; } +static_assert(sizeof(nvshmemt_libfabric_mem_handle_t) < sizeof(nvshmem_mem_handle_t)); static int nvshmemt_libfabric_get_mem_handle(nvshmem_mem_handle_t *mem_handle, void *buf, size_t length, nvshmem_transport_t t, bool local_only) { @@ -1167,6 +1231,7 @@ static int nvshmemt_libfabric_get_mem_handle(nvshmem_mem_handle_t *mem_handle, v assert(mem_handle != NULL); fabric_handle = (nvshmemt_libfabric_mem_handle_t *)mem_handle; + fabric_handle->buf = buf; status = cudaPointerGetAttributes(&attr, buf); if (status != cudaSuccess) { NVSHMEMI_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, @@ -1197,40 +1262,15 @@ static int nvshmemt_libfabric_get_mem_handle(nvshmem_mem_handle_t *mem_handle, v mr_attr.iface = FI_HMEM_SYSTEM; } - fabric_handle->buf = buf; - if (libfabric_state->prov_info->domain_attr->mr_mode & FI_MR_ENDPOINT) { - for (int i = 0; i < NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; i++) { - status = - fi_mr_regattr(libfabric_state->domain, &mr_attr, 0, &fabric_handle->hdls[i].mr); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Error registering memory region: %s\n", - fi_strerror(status * -1)); - - status = - fi_mr_bind(fabric_handle->hdls[i].mr, &libfabric_state->eps[i].endpoint->fid, 0); - - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Error binding MR to EP %d: %s\n", i, fi_strerror(status * -1)); - - status = fi_mr_enable(fabric_handle->hdls[i].mr); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Error enabling MR: %s\n", - fi_strerror(status * -1)); - - fabric_handle->hdls[i].key = fi_mr_key(fabric_handle->hdls[i].mr); - fabric_handle->hdls[i].local_desc = fi_mr_desc(fabric_handle->hdls[i].mr); - } - } else { + for (size_t i = 0; i < libfabric_state->domains.size(); i++) { struct fid_mr *mr; - - status = fi_mr_regattr(libfabric_state->domain, &mr_attr, 0, &mr); + status = fi_mr_regattr(libfabric_state->domains[i], &mr_attr, 0, &mr); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Error registering memory region: %s\n", fi_strerror(status * -1)); - for (int i = 0; i < NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; i++) { - fabric_handle->hdls[i].mr = mr; - fabric_handle->hdls[i].key = fi_mr_key(mr); - fabric_handle->hdls[i].local_desc = fi_mr_desc(mr); - } + fabric_handle->hdls[i].mr = mr; + fabric_handle->hdls[i].key = fi_mr_key(mr); + fabric_handle->hdls[i].local_desc = fi_mr_desc(mr); } if (!local_only && libfabric_state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { @@ -1370,174 +1410,186 @@ static int nvshmemt_libfabric_connect_endpoints(nvshmem_transport_t t, int *sele nvshmemt_libfabric_state_t *state = (nvshmemt_libfabric_state_t *)t->state; nvshmemt_libfabric_ep_name_t *all_ep_names = NULL; nvshmemt_libfabric_ep_name_t *local_ep_names = NULL; - struct fi_info *current_fabric; + struct fi_info *current_info; + struct fid_fabric *fabric; + struct fid_domain *domain; + struct fid_av *address; + struct fid_mr *mr; struct fi_av_attr av_attr; struct fi_cq_attr cq_attr; struct fi_cntr_attr cntr_attr; size_t ep_namelen = NVSHMEMT_LIBFABRIC_EP_LEN; int status = 0; int total_num_eps; - size_t num_recvs_per_pe = 0; + size_t num_recvs_per_ep = 0; int n_pes = t->n_pes; - - if (state->eps) { - NVSHMEMI_WARN_PRINT( - "Device already selected. libfabric only supports one NIC per PE and doesn't support " - "additional QPs.\n"); - goto out_already_connected; + size_t num_sends; + size_t num_recvs; + size_t elem_size; + uint64_t flags; + state->num_selected_devs = MIN(num_selected_devs, state->max_nic_per_pe); + + if (state->eps.size()) { + NVSHMEMI_ERROR_PRINT("PE has previously called connect_endpoints()\n"); + return NVSHMEMX_ERROR_INTERNAL; } - state->eps = (nvshmemt_libfabric_endpoint_t *)calloc(NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS, - sizeof(nvshmemt_libfabric_endpoint_t)); - NVSHMEMI_NULL_ERROR_JMP(state->eps, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, - "Unable to allocate EPs."); - - current_fabric = state->all_prov_info; - do { - if (!strncmp(current_fabric->nic->device_attr->name, - state->domain_names[selected_dev_ids[0]].name, - NVSHMEMT_LIBFABRIC_DOMAIN_LEN)) { - break; - } - current_fabric = current_fabric->next; - } while (current_fabric != NULL); - NVSHMEMI_NULL_ERROR_JMP(current_fabric, status, NVSHMEMX_ERROR_INTERNAL, out, - "Unable to find the selected fabric.\n"); - - state->prov_info = fi_dupinfo(current_fabric); - - if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA && - strcmp(state->prov_info->fabric_attr->name, "efa-direct")) + if (state->num_selected_devs > NVSHMEMT_LIBFABRIC_MAX_NIC_PER_PE) { + state->num_selected_devs = NVSHMEMT_LIBFABRIC_MAX_NIC_PER_PE; NVSHMEMI_WARN_PRINT( - "Libfabric transport is using efa fabric instead of efa-direct, " - "use libfabric v2.1.0 or newer for improved performance\n"); - - status = fi_fabric(state->prov_info->fabric_attr, &state->fabric, NULL); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Failed to allocate fabric: %d: %s\n", status, fi_strerror(status * -1)); - - status = fi_domain(state->fabric, state->prov_info, &state->domain, NULL); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Failed to allocate domain: %d: %s\n", status, fi_strerror(status * -1)); - - if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { - state->num_sends = current_fabric->tx_attr->size * NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; - state->num_recvs = current_fabric->rx_attr->size * NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; - size_t elem_size = sizeof(nvshmemt_libfabric_gdr_op_ctx_t); - - num_recvs_per_pe = state->num_recvs / NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; - - state->recv_buf = calloc(state->num_sends + state->num_recvs, elem_size); - NVSHMEMI_NULL_ERROR_JMP(state->recv_buf, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, - "Unable to allocate EFA msg buffer.\n"); - state->send_buf = (char *)state->recv_buf + (elem_size * state->num_recvs); - - status = fi_mr_reg(state->domain, state->recv_buf, - (state->num_sends + state->num_recvs) * elem_size, FI_SEND | FI_RECV, 0, - 0, 0, &state->mr, NULL); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Failed to register EFA msg buffer: %d: %s\n", status, - fi_strerror(status * -1)); - - nvshmemtLibfabricOpQueue.putToSendBulk((char *)state->send_buf, elem_size, - state->num_sends); - } - - t->max_op_len = state->prov_info->ep_attr->max_msg_size; - av_attr.type = FI_AV_TABLE; - av_attr.rx_ctx_bits = 0; - av_attr.count = NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS * n_pes; - av_attr.ep_per_node = NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; - av_attr.name = NULL; - av_attr.map_addr = NULL; - av_attr.flags = 0; - - /* Note - This is needed because EFA will only bind AVs to EPs on a 1:1 basis. - * If EFA ever lifts this requirement, we can reduce the number of AVs required. - */ - for (int i = 0; i < NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; i++) { - status = fi_av_open(state->domain, &av_attr, &state->addresses[i], NULL); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Failed to allocate address vector: %d: %s\n", status, - fi_strerror(status * -1)); + "PE selected %d devices, but the libfabric transport only supports a max of %d " + "devices. Continuing using %d devices.\n", + state->num_selected_devs, NVSHMEMT_LIBFABRIC_MAX_NIC_PER_PE, + NVSHMEMT_LIBFABRIC_MAX_NIC_PER_PE); } + state->num_selected_domains = state->num_selected_devs + 1; - INFO(state->log_level, "Selected provider %s, fabric %s, nic %s, hmem %s", - state->prov_info->fabric_attr->prov_name, state->prov_info->fabric_attr->name, - state->prov_info->nic->device_attr->name, state->prov_info->caps & FI_HMEM ? "yes" : "no"); - - assert(state->eps); + /* Initialize configuration which only need to be set once */ + t->max_op_len = UINT64_MAX; /* Set as sential value */ + state->cur_proxy_ep_index = 1; memset(&cq_attr, 0, sizeof(struct fi_cq_attr)); - memset(&cntr_attr, 0, sizeof(struct fi_cntr_attr)); - - state->prov_info->ep_attr->tx_ctx_cnt = 0; - state->prov_info->caps = FI_RMA; - if ((state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_SLINGSHOT) || - (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_VERBS)) { - state->prov_info->caps |= FI_ATOMIC; - } else { - state->prov_info->caps |= FI_MSG; - state->prov_info->caps |= FI_SOURCE; - } - state->prov_info->tx_attr->op_flags = 0; - state->prov_info->tx_attr->mode = 0; - state->prov_info->rx_attr->mode = 0; - - if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { - state->prov_info->mode = FI_CONTEXT2; - } else { - state->prov_info->mode = 0; - } - - state->prov_info->tx_attr->op_flags = FI_DELIVERY_COMPLETE; - - cntr_attr.events = FI_CNTR_EVENTS_COMP; - cntr_attr.wait_obj = FI_WAIT_UNSPEC; - cntr_attr.wait_set = NULL; - cntr_attr.flags = 0; - if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_SLINGSHOT) { cq_attr.size = 16; /* CQ is only used to capture error events */ cq_attr.format = FI_CQ_FORMAT_UNSPEC; cq_attr.wait_obj = FI_WAIT_NONE; - } - - if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { + } else if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { cq_attr.format = FI_CQ_FORMAT_DATA; cq_attr.wait_obj = FI_WAIT_NONE; cq_attr.size = 32768; } - local_ep_names = (nvshmemt_libfabric_ep_name_t *)calloc(NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS, + memset(&av_attr, 0, sizeof(struct fi_av_attr)); + av_attr.type = FI_AV_TABLE; + av_attr.count = state->num_selected_domains * n_pes; + + memset(&cntr_attr, 0, sizeof(struct fi_cntr_attr)); + cntr_attr.events = FI_CNTR_EVENTS_COMP; + cntr_attr.wait_obj = FI_WAIT_UNSPEC; + + /* Find fabric info for each selected device */ + for (int dev_idx = 0; dev_idx < state->num_selected_devs; dev_idx++) { + current_info = state->all_prov_info; + do { + if (!strncmp(current_info->nic->device_attr->name, + state->domain_names[selected_dev_ids[dev_idx]].name, + NVSHMEMT_LIBFABRIC_DOMAIN_LEN)) { + break; + } + current_info = current_info->next; + } while (current_info != NULL); + NVSHMEMI_NULL_ERROR_JMP(current_info, status, NVSHMEMX_ERROR_INTERNAL, out, + "Unable to find fabric for device %d.\n", dev_idx); + + /* + * Create two domains (host/proxy domain) for the first NIC. + */ + if (state->prov_infos.size() == 0) state->prov_infos.push_back(current_info); + + state->prov_infos.push_back(current_info); + + if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA && + strcmp(current_info->fabric_attr->name, "efa-direct")) + NVSHMEMI_WARN_PRINT( + "Libfabric transport is using efa fabric instead of efa-direct, " + "use libfabric v2.1.0 or newer for improved performance\n"); + } + + /* Allocate out of band AV name exchange buffers */ + local_ep_names = (nvshmemt_libfabric_ep_name_t *)calloc(state->num_selected_domains, sizeof(nvshmemt_libfabric_ep_name_t)); NVSHMEMI_NULL_ERROR_JMP(local_ep_names, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, "Unable to allocate array of endpoint names."); - total_num_eps = NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS * n_pes; + total_num_eps = n_pes * state->num_selected_domains; all_ep_names = (nvshmemt_libfabric_ep_name_t *)calloc(total_num_eps, sizeof(nvshmemt_libfabric_ep_name_t)); NVSHMEMI_NULL_ERROR_JMP(all_ep_names, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, "Unable to allocate array of endpoint names."); - for (int i = 0; i < NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; i++) { - state->eps[i].qp_index = i; - status = fi_endpoint(state->domain, state->prov_info, &state->eps[i].endpoint, NULL); + /* Create Resources For Each Selected Device */ + for (size_t i = 0; i < state->prov_infos.size(); i++) { + INFO(state->log_level, + "Selected provider %s, fabric %s, nic %s, hmem %s multi-rail %zu/%d\n", + state->prov_infos[i]->fabric_attr->prov_name, state->prov_infos[i]->fabric_attr->name, + state->prov_infos[i]->nic->device_attr->name, + state->prov_infos[i]->caps & FI_HMEM ? "yes" : "no", i + 1, num_selected_devs); + + if (state->prov_infos[i]->ep_attr->max_msg_size < t->max_op_len) + t->max_op_len = state->prov_infos[i]->ep_attr->max_msg_size; + + status = fi_fabric(state->prov_infos[i]->fabric_attr, &fabric, NULL); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Unable to allocate endpoint: %d: %s\n", status, + "Failed to allocate fabric: %d: %s\n", status, fi_strerror(status * -1)); + state->fabrics.push_back(fabric); + + status = fi_domain(fabric, state->prov_infos[i], &domain, NULL); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "Failed to allocate domain: %d: %s\n", status, + fi_strerror(status * -1)); + state->domains.push_back(domain); + + if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { + num_sends = state->prov_infos[i]->tx_attr->size; + num_recvs = state->prov_infos[i]->rx_attr->size; + elem_size = sizeof(nvshmemt_libfabric_gdr_op_ctx_t); + num_recvs_per_ep = num_recvs; + + state->recv_buf.push_back(calloc(num_sends + num_recvs, elem_size)); + NVSHMEMI_NULL_ERROR_JMP(state->recv_buf[i], status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, + "Unable to allocate EFA msg buffer.\n"); + state->send_buf.push_back((char *)state->recv_buf[i] + (elem_size * num_recvs)); + + status = fi_mr_reg(domain, state->recv_buf[i], (num_sends + num_recvs) * elem_size, + FI_SEND | FI_RECV | FI_WRITE, 0, 0, 0, &mr, NULL); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "Failed to register EFA msg buffer: %d: %s\n", status, + fi_strerror(status * -1)); + state->mr.push_back(mr); + + state->op_queue.push_back(new threadSafeOpQueue); + state->op_queue[i]->putToSendBulk((char *)state->send_buf[i], elem_size, num_sends); + } + + status = fi_av_open(domain, &av_attr, &address, NULL); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "Failed to allocate address vector: %d: %s\n", status, + fi_strerror(status * -1)); + state->addresses.push_back(address); + + /* Create nvshmemt_libfabric_endpoint_t resources */ + state->eps.push_back( + (nvshmemt_libfabric_endpoint_t *)calloc(1, sizeof(nvshmemt_libfabric_endpoint_t))); + NVSHMEMI_NULL_ERROR_JMP(state->eps[i], status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, + "Unable to alloc libfabric_tx_progress_group struct.\n"); + state->eps[i]->domain_index = i; /* Initialize per-endpoint proxy_put_signal_comp_map */ - state->eps[i].proxy_put_signal_comp_map = + state->eps[i]->proxy_put_signal_comp_map = new std::unordered_map>(); - state->eps[i].put_signal_seq_counter.reset(); - state->eps[i].completed_staged_atomics = 0; + state->eps[i]->put_signal_seq_counter.reset(); + state->eps[i]->completed_staged_atomics = 0; + + status = fi_cq_open(domain, &cq_attr, &state->eps[i]->cq, NULL); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "Unable to open completion queue for endpoint: %d: %s\n", status, + fi_strerror(status * -1)); + status = fi_cntr_open(domain, &cntr_attr, &state->eps[i]->counter, NULL); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "Unable to open counter for endpoint: %d: %s\n", status, + fi_strerror(status * -1)); + + status = fi_endpoint(domain, state->prov_infos[i], &state->eps[i]->endpoint, NULL); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "Unable to allocate endpoint: %d: %s\n", status, + fi_strerror(status * -1)); /* FI_OPT_CUDA_API_PERMITTED was introduced in libfabric 1.18.0 */ if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { bool prohibit_cuda_api = false; - status = fi_setopt(&state->eps[i].endpoint->fid, FI_OPT_ENDPOINT, + status = fi_setopt(&state->eps[i]->endpoint->fid, FI_OPT_ENDPOINT, FI_OPT_CUDA_API_PERMITTED, &prohibit_cuda_api, sizeof(bool)); if (status == -FI_ENOPROTOOPT) { NVSHMEMI_WARN_PRINT( @@ -1551,112 +1603,90 @@ static int nvshmemt_libfabric_connect_endpoints(nvshmem_transport_t t, int *sele } } - status = fi_cq_open(state->domain, &cq_attr, &state->eps[i].cq, NULL); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Unable to open completion queue for endpoint: %d: %s\n", status, - fi_strerror(status * -1)); - - status = fi_cntr_open(state->domain, &cntr_attr, &state->eps[i].counter, NULL); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Unable to open counter for endpoint: %d: %s\n", status, - fi_strerror(status * -1)); - - status = fi_ep_bind(state->eps[i].endpoint, &state->addresses[i]->fid, 0); + /* Bind Resources To EP */ + status = fi_ep_bind(state->eps[i]->endpoint, &state->addresses[i]->fid, 0); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to bind endpoint to address vector: %d: %s\n", status, fi_strerror(status * -1)); - if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_VERBS) { - status = fi_ep_bind(state->eps[i].endpoint, &state->eps[i].cq->fid, - FI_SELECTIVE_COMPLETION | FI_TRANSMIT | FI_RECV); - } else if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_SLINGSHOT) { - status = fi_ep_bind(state->eps[i].endpoint, &state->eps[i].cq->fid, - FI_SELECTIVE_COMPLETION | FI_TRANSMIT); - } else if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { + if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_VERBS) + flags = FI_SELECTIVE_COMPLETION | FI_TRANSMIT | FI_RECV; + else if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_SLINGSHOT) + flags = FI_SELECTIVE_COMPLETION | FI_TRANSMIT; + else if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) /* EFA is documented as not supporting FI_SELECTIVE_COMPLETION */ - status = - fi_ep_bind(state->eps[i].endpoint, &state->eps[i].cq->fid, FI_TRANSMIT | FI_RECV); - } else { + flags = FI_TRANSMIT | FI_RECV; + else { NVSHMEMI_ERROR_PRINT( "Invalid provider identified. This should be impossible. " "Possible memory corruption in the state pointer?"); status = NVSHMEMX_ERROR_INTERNAL; goto out; } + + status = fi_ep_bind(state->eps[i]->endpoint, &state->eps[i]->cq->fid, flags); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to bind endpoint to completion queue: %d: %s\n", status, fi_strerror(status * -1)); -#ifdef NVSHMEM_USE_GDRCOPY - if (use_gdrcopy) { - status = fi_ep_bind(state->eps[i].endpoint, &state->eps[i].counter->fid, - FI_READ | FI_WRITE | FI_SEND); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Unable to bind endpoint to completion counter: %d: %s\n", status, - fi_strerror(status * -1)); - } else -#endif - { - int flags = FI_READ | FI_WRITE; - if (use_staged_atomics) { - flags |= FI_SEND; - } - status = fi_ep_bind(state->eps[i].endpoint, &state->eps[i].counter->fid, flags); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Unable to bind endpoint to completion counter: %d: %s\n", status, - fi_strerror(status * -1)); - } + flags = FI_READ | FI_WRITE; + if (use_staged_atomics) flags |= FI_SEND; + status = fi_ep_bind(state->eps[i]->endpoint, &state->eps[i]->counter->fid, flags); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "Unable to bind endpoint to completion counter: %d: %s\n", status, + fi_strerror(status * -1)); - status = fi_enable(state->eps[i].endpoint); + status = fi_enable(state->eps[i]->endpoint); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to enable endpoint: %d: %s\n", status, fi_strerror(status * -1)); if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { - for (size_t j = 0; j < num_recvs_per_pe; j++) { - nvshmemt_libfabric_gdr_op_ctx_t *op; - op = (nvshmemt_libfabric_gdr_op_ctx_t *)state->recv_buf; - op = op + ((num_recvs_per_pe * i) + j); + nvshmemt_libfabric_gdr_op_ctx_t *op; + op = (nvshmemt_libfabric_gdr_op_ctx_t *)state->recv_buf[i]; + for (size_t j = 0; j < num_recvs_per_ep; j++, op++) { assert(op != NULL); - status = fi_recv(state->eps[i].endpoint, op, NVSHMEM_STAGED_AMO_WIREDATA_SIZE, - fi_mr_desc(state->mr), FI_ADDR_UNSPEC, &op->ofi_context); + status = fi_recv(state->eps[i]->endpoint, op, NVSHMEM_STAGED_AMO_WIREDATA_SIZE, + fi_mr_desc(state->mr[i]), FI_ADDR_UNSPEC, &op->ofi_context); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "Unable to post recv to ep. Error: %d: %s\n", status, + fi_strerror(status * -1)); } - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Unable to post recv to ep. Error: %d: %s\n", status, - fi_strerror(status * -1)); } - status = fi_getname(&state->eps[i].endpoint->fid, local_ep_names[i].name, &ep_namelen); + status = fi_getname(&state->eps[i]->endpoint->fid, local_ep_names[i].name, &ep_namelen); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to get name for endpoint: %d: %s\n", status, fi_strerror(status * -1)); - if (ep_namelen > NVSHMEMT_LIBFABRIC_EP_LEN) { + if (ep_namelen > NVSHMEMT_LIBFABRIC_EP_LEN) NVSHMEMI_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Name of EP is too long."); - } } + /* Perform out of band address exchange */ status = t->boot_handle->allgather( local_ep_names, all_ep_names, - NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS * sizeof(nvshmemt_libfabric_ep_name_t), t->boot_handle); + state->num_selected_domains * sizeof(nvshmemt_libfabric_ep_name_t), t->boot_handle); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Failed to gather endpoint names.\n"); /* We need to insert one at a time since each buffer is larger than the address. */ - for (int j = 0; j < NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; j++) { - for (int i = 0; i < total_num_eps; i++) { - status = fi_av_insert(state->addresses[j], &all_ep_names[i], 1, NULL, 0, NULL); + for (int i = 0; i < state->num_selected_domains; i++) { + for (int j = 0; j < total_num_eps; j++) { + status = fi_av_insert(state->addresses[i], &all_ep_names[j], 1, NULL, 0, NULL); if (status < 1) { NVSHMEMI_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to insert ep names in address vector: %d: %s\n", status, fi_strerror(status * -1)); } - status = NVSHMEMX_SUCCESS; } } + /* Out of bounds exchange a pre-registered write w/imm target for staged_amo acks */ if (use_staged_atomics) { state->remote_addr_staged_amo_ack = (void **)calloc(sizeof(void *), t->n_pes); + state->rkey_staged_amo_ack = + (uint64_t *)calloc(sizeof(uint64_t), t->n_pes * state->num_selected_domains); NVSHMEMI_NULL_ERROR_JMP(state->remote_addr_staged_amo_ack, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, "Unable to allocate remote address array for staged atomic ack.\n"); @@ -1665,13 +1695,15 @@ static int nvshmemt_libfabric_connect_endpoints(nvshmem_transport_t t, int *sele NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to allocate CUDA memory for staged atomic ack.\n"); - status = fi_mr_reg(state->domain, state->remote_addr_staged_amo_ack[t->my_pe], sizeof(int), - FI_REMOTE_WRITE, 0, 0, 0, &state->mr_staged_amo_ack, NULL); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Failed to register EFA msg buffer: %d: %s\n", status, - fi_strerror(status * -1)); - state->rkey_staged_amo_ack = (uint64_t *)calloc(sizeof(uint64_t), t->n_pes); - state->rkey_staged_amo_ack[t->my_pe] = fi_mr_key(state->mr_staged_amo_ack); + for (size_t i = 0; i < state->domains.size(); i++) { + status = fi_mr_reg(state->domains[i], state->remote_addr_staged_amo_ack[t->my_pe], + sizeof(int), FI_REMOTE_WRITE, 0, 0, 0, &mr, NULL); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "Failed to register EFA msg buffer: %d: %s\n", status, + fi_strerror(status * -1)); + state->rkey_staged_amo_ack[t->my_pe * state->num_selected_domains + i] = fi_mr_key(mr); + state->mr_staged_amo_ack.push_back(mr); + } status = t->boot_handle->allgather(&state->remote_addr_staged_amo_ack[t->my_pe], state->remote_addr_staged_amo_ack, sizeof(void *), @@ -1679,9 +1711,10 @@ static int nvshmemt_libfabric_connect_endpoints(nvshmem_transport_t t, int *sele NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Failed to gather remote addresses.\n"); - status = - t->boot_handle->allgather(&state->rkey_staged_amo_ack[t->my_pe], - state->rkey_staged_amo_ack, sizeof(uint64_t), t->boot_handle); + status = t->boot_handle->allgather( + &state->rkey_staged_amo_ack[t->my_pe * state->num_selected_domains], + state->rkey_staged_amo_ack, sizeof(uint64_t) * state->num_selected_domains, + t->boot_handle); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Failed to gather remote keys.\n"); } @@ -1694,30 +1727,28 @@ static int nvshmemt_libfabric_connect_endpoints(nvshmem_transport_t t, int *sele free(state->remote_addr_staged_amo_ack); } if (state->rkey_staged_amo_ack) free(state->rkey_staged_amo_ack); - if (state->mr_staged_amo_ack) fi_close(&state->mr_staged_amo_ack->fid); - if (state->eps) { - for (int i = 0; i < NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; i++) { - if (state->eps[i].proxy_put_signal_comp_map) - delete state->eps[i].proxy_put_signal_comp_map; - if (state->eps[i].endpoint) { - fi_close(&state->eps[i].endpoint->fid); - state->eps[i].endpoint = NULL; - } - if (state->eps[i].cq) { - fi_close(&state->eps[i].cq->fid); - state->eps[i].cq = NULL; - } - if (state->eps[i].counter) { - fi_close(&state->eps[i].counter->fid); - state->eps[i].counter = NULL; - } + for (size_t i = 0; i < state->mr_staged_amo_ack.size(); i++) + fi_close(&state->mr_staged_amo_ack[i]->fid); + for (size_t i = 0; i < state->eps.size(); i++) { + if (state->eps[i]->proxy_put_signal_comp_map) + delete state->eps[i]->proxy_put_signal_comp_map; + if (state->eps[i]->endpoint) { + fi_close(&state->eps[i]->endpoint->fid); + state->eps[i]->endpoint = NULL; } - free(state->eps); - state->eps = NULL; + if (state->eps[i]->cq) { + fi_close(&state->eps[i]->cq->fid); + state->eps[i]->cq = NULL; + } + if (state->eps[i]->counter) { + fi_close(&state->eps[i]->counter->fid); + state->eps[i]->counter = NULL; + } + free(state->eps[i]); + state->eps[i] = NULL; } } -out_already_connected: free(local_ep_names); free(all_ep_names); @@ -1725,12 +1756,12 @@ static int nvshmemt_libfabric_connect_endpoints(nvshmem_transport_t t, int *sele } static int nvshmemt_libfabric_finalize(nvshmem_transport_t transport) { - nvshmemt_libfabric_state_t *libfabric_state; + nvshmemt_libfabric_state_t *state; int status; assert(transport); - libfabric_state = (nvshmemt_libfabric_state_t *)transport->state; + state = (nvshmemt_libfabric_state_t *)transport->state; if (transport->device_pci_paths) { for (int i = 0; i < transport->n_devices; i++) { @@ -1742,19 +1773,19 @@ static int nvshmemt_libfabric_finalize(nvshmem_transport_t transport) { size_t mem_handle_cache_size; nvshmemt_libfabric_memhandle_info_t *handle_info = NULL, *previous_handle_info = NULL; - if (libfabric_state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { - mem_handle_cache_size = nvshmemt_mem_handle_cache_get_size(libfabric_state->cache); + if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { + mem_handle_cache_size = nvshmemt_mem_handle_cache_get_size(state->cache); for (size_t i = 0; i < mem_handle_cache_size; i++) { handle_info = (nvshmemt_libfabric_memhandle_info_t *)nvshmemt_mem_handle_cache_get_by_idx( - libfabric_state->cache, i); + state->cache, i); if (handle_info && handle_info != previous_handle_info) { free(handle_info); } previous_handle_info = handle_info; } - nvshmemt_mem_handle_cache_fini(libfabric_state->cache); + nvshmemt_mem_handle_cache_fini(state->cache); #ifdef NVSHMEM_USE_GDRCOPY if (use_gdrcopy) { nvshmemt_gdrcopy_ftable_fini(&gdrcopy_ftable, &gdr_desc, &gdrcopy_handle); @@ -1762,88 +1793,88 @@ static int nvshmemt_libfabric_finalize(nvshmem_transport_t transport) { #endif } - if (libfabric_state->prov_info) { - fi_freeinfo(libfabric_state->prov_info); - } - - if (libfabric_state->eps) { - for (int i = 0; i < NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; i++) { - if (libfabric_state->eps[i].proxy_put_signal_comp_map) - delete libfabric_state->eps[i].proxy_put_signal_comp_map; - if (libfabric_state->eps[i].endpoint) { - status = fi_close(&libfabric_state->eps[i].endpoint->fid); - if (status) { - NVSHMEMI_WARN_PRINT("Unable to close fabric endpoint.: %d: %s\n", status, - fi_strerror(status * -1)); - } + /* + * Since fi_dupinfo() is not called, we don't need to clean + * we do not need to clean prov_infos + */ + if (state->all_prov_info) fi_freeinfo(state->all_prov_info); + + for (size_t i = 0; i < state->eps.size(); i++) { + if (state->eps[i]->proxy_put_signal_comp_map) + delete state->eps[i]->proxy_put_signal_comp_map; + if (state->eps[i]->endpoint) { + status = fi_close(&state->eps[i]->endpoint->fid); + if (status) { + NVSHMEMI_WARN_PRINT("Unable to close fabric endpoint.: %d: %s\n", status, + fi_strerror(status * -1)); } - if (libfabric_state->eps[i].cq) { - status = fi_close(&libfabric_state->eps[i].cq->fid); - if (status) { - NVSHMEMI_WARN_PRINT("Unable to close fabric cq: %d: %s\n", status, - fi_strerror(status * -1)); - } + } + if (state->eps[i]->cq) { + status = fi_close(&state->eps[i]->cq->fid); + if (status) { + NVSHMEMI_WARN_PRINT("Unable to close fabric cq: %d: %s\n", status, + fi_strerror(status * -1)); } - if (libfabric_state->eps[i].counter) { - status = fi_close(&libfabric_state->eps[i].counter->fid); - if (status) { - NVSHMEMI_WARN_PRINT("Unable to close fabric counter: %d: %s\n", status, - fi_strerror(status * -1)); - } + } + if (state->eps[i]->counter) { + status = fi_close(&state->eps[i]->counter->fid); + if (status) { + NVSHMEMI_WARN_PRINT("Unable to close fabric counter: %d: %s\n", status, + fi_strerror(status * -1)); } } - free(libfabric_state->eps); + free(state->eps[i]); } - if (libfabric_state->remote_addr_staged_amo_ack) { - if (libfabric_state->remote_addr_staged_amo_ack[transport->my_pe]) - cudaFree(libfabric_state->remote_addr_staged_amo_ack[transport->my_pe]); - free(libfabric_state->remote_addr_staged_amo_ack); + if (state->remote_addr_staged_amo_ack) { + if (state->remote_addr_staged_amo_ack[transport->my_pe]) + cudaFree(state->remote_addr_staged_amo_ack[transport->my_pe]); + free(state->remote_addr_staged_amo_ack); } - if (libfabric_state->rkey_staged_amo_ack) free(libfabric_state->rkey_staged_amo_ack); - if (libfabric_state->mr_staged_amo_ack) { - status = fi_close(&libfabric_state->mr_staged_amo_ack->fid); + if (state->rkey_staged_amo_ack) free(state->rkey_staged_amo_ack); + for (size_t i = 0; i < state->mr_staged_amo_ack.size(); i++) { + status = fi_close(&state->mr_staged_amo_ack[i]->fid); if (status) { NVSHMEMI_WARN_PRINT("Unable to close staged atomic ack MR: %d: %s\n", status, fi_strerror(status * -1)); } } - if (libfabric_state->mr) { - status = fi_close(&libfabric_state->mr->fid); + for (size_t i = 0; i < state->mr.size(); i++) { + status = fi_close(&state->mr[i]->fid); if (status) { NVSHMEMI_WARN_PRINT("Unable to close fabric MR: %d: %s\n", status, fi_strerror(status * -1)); } } - if (libfabric_state->recv_buf) free(libfabric_state->recv_buf); - for (int i = 0; i < NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; i++) { - status = fi_close(&libfabric_state->addresses[i]->fid); + for (size_t i = 0; i < state->recv_buf.size(); i++) free(state->recv_buf[i]); + + for (size_t i = 0; i < state->addresses.size(); i++) { + status = fi_close(&state->addresses[i]->fid); if (status) { NVSHMEMI_WARN_PRINT("Unable to close fabric address vector: %d: %s\n", status, fi_strerror(status * -1)); } } - if (libfabric_state->domain) { - status = fi_close(&libfabric_state->domain->fid); + for (size_t i = 0; i < state->domains.size(); i++) { + status = fi_close(&state->domains[i]->fid); if (status) { NVSHMEMI_WARN_PRINT("Unable to close fabric domain: %d: %s\n", status, fi_strerror(status * -1)); } } - if (libfabric_state->fabric) { - status = fi_close(&libfabric_state->fabric->fid); + for (size_t i = 0; i < state->fabrics.size(); i++) { + status = fi_close(&state->fabrics[i]->fid); if (status) { NVSHMEMI_WARN_PRINT("Unable to close fabric: %d: %s\n", status, fi_strerror(status * -1)); } } - free(libfabric_state); - + free(state); free(transport); return 0; @@ -1851,7 +1882,7 @@ static int nvshmemt_libfabric_finalize(nvshmem_transport_t transport) { static int nvshmemi_libfabric_init_state(nvshmem_transport_t t, nvshmemt_libfabric_state_t *state, struct nvshmemi_options_s *options) { - struct fi_info info; + struct fi_info hints; struct fi_tx_attr tx_attr; struct fi_rx_attr rx_attr; struct fi_ep_attr ep_attr; @@ -1859,58 +1890,58 @@ static int nvshmemi_libfabric_init_state(nvshmem_transport_t t, nvshmemt_libfabr struct fi_fabric_attr fabric_attr; struct fid_nic nic; struct fi_av_attr av_attr; - struct fi_info *returned_fabrics, *current_fabric; + struct fi_info *all_infos, *current_info; int num_fabrics_returned = 0; int status = 0; memset(&ep_attr, 0, sizeof(struct fi_ep_attr)); memset(&av_attr, 0, sizeof(struct fi_av_attr)); - memset(&info, 0, sizeof(struct fi_info)); + memset(&hints, 0, sizeof(struct fi_info)); memset(&tx_attr, 0, sizeof(struct fi_tx_attr)); memset(&rx_attr, 0, sizeof(struct fi_rx_attr)); memset(&domain_attr, 0, sizeof(struct fi_domain_attr)); memset(&fabric_attr, 0, sizeof(struct fi_fabric_attr)); memset(&nic, 0, sizeof(struct fid_nic)); - info.tx_attr = &tx_attr; - info.rx_attr = &rx_attr; - info.ep_attr = &ep_attr; - info.domain_attr = &domain_attr; - info.fabric_attr = &fabric_attr; - info.nic = &nic; + hints.tx_attr = &tx_attr; + hints.rx_attr = &rx_attr; + hints.ep_attr = &ep_attr; + hints.domain_attr = &domain_attr; + hints.fabric_attr = &fabric_attr; + hints.nic = &nic; - info.addr_format = FI_FORMAT_UNSPEC; - info.caps = FI_RMA | FI_HMEM; + hints.addr_format = FI_FORMAT_UNSPEC; + hints.caps = FI_RMA | FI_HMEM; if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_VERBS) { - info.caps |= FI_ATOMIC; + hints.caps |= FI_ATOMIC; domain_attr.mr_mode = FI_MR_VIRT_ADDR | FI_MR_ALLOCATED | FI_MR_PROV_KEY; } else if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_SLINGSHOT) { /* TODO: Use FI_FENCE to optimize put_with_signal */ - info.caps |= FI_FENCE | FI_ATOMIC; + hints.caps |= FI_FENCE | FI_ATOMIC; domain_attr.mr_mode = FI_MR_ENDPOINT | FI_MR_ALLOCATED | FI_MR_PROV_KEY; } else if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { domain_attr.mr_mode = FI_MR_LOCAL | FI_MR_VIRT_ADDR | FI_MR_ALLOCATED | FI_MR_PROV_KEY | FI_MR_HMEM; - info.caps |= FI_MSG; - info.caps |= FI_SOURCE; + hints.caps |= FI_MSG; + hints.caps |= FI_SOURCE; } if (use_staged_atomics) { - info.mode |= FI_CONTEXT2; + hints.mode |= FI_CONTEXT2; } ep_attr.type = FI_EP_RDM; /* Reliable datagrams */ /* Require completion RMA completion at target for correctness of quiet */ - info.tx_attr->op_flags = FI_DELIVERY_COMPLETE; + hints.tx_attr->op_flags = FI_DELIVERY_COMPLETE; /* nvshmemt_libfabric_auto_progress relaxes threading requirement */ domain_attr.threading = FI_THREAD_COMPLETION; - info.domain_attr->data_progress = FI_PROGRESS_AUTO; + hints.domain_attr->data_progress = FI_PROGRESS_AUTO; status = fi_getinfo(FI_VERSION(NVSHMEMT_LIBFABRIC_MAJ_VER, NVSHMEMT_LIBFABRIC_MIN_VER), NULL, - NULL, 0, &info, &returned_fabrics); + NULL, 0, &hints, &all_infos); /* * 1. Ensure that at least one fabric was returned @@ -1920,28 +1951,27 @@ static int nvshmemi_libfabric_init_state(nvshmem_transport_t t, nvshmemt_libfabr * options.LIBFABRIC_PROVIDER will be a substr of the returned fabric * name */ - if (!status && strstr(returned_fabrics->fabric_attr->name, options->LIBFABRIC_PROVIDER)) { + if (!status && strstr(all_infos->fabric_attr->name, options->LIBFABRIC_PROVIDER)) { use_auto_progress = true; } else { - fi_freeinfo(returned_fabrics); + fi_freeinfo(all_infos); /* * Fallback to FI_PROGRESS_MANUAL path * nvshmemt_libfabric_slow_progress requires FI_THREAD_SAFE */ domain_attr.threading = FI_THREAD_SAFE; - info.domain_attr->data_progress = FI_PROGRESS_MANUAL; + hints.domain_attr->data_progress = FI_PROGRESS_MANUAL; status = fi_getinfo(FI_VERSION(NVSHMEMT_LIBFABRIC_MAJ_VER, NVSHMEMT_LIBFABRIC_MIN_VER), - NULL, NULL, 0, &info, &returned_fabrics); + NULL, NULL, 0, &hints, &all_infos); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "No providers matched fi_getinfo query: %d: %s\n", status, fi_strerror(status * -1)); } - state->all_prov_info = returned_fabrics; - for (current_fabric = returned_fabrics; current_fabric != NULL; - current_fabric = current_fabric->next) { + state->all_prov_info = all_infos; + for (current_info = all_infos; current_info != NULL; current_info = current_info->next) { num_fabrics_returned++; } @@ -1952,53 +1982,51 @@ static int nvshmemi_libfabric_init_state(nvshmem_transport_t t, nvshmemt_libfabr /* Only select unique devices. */ state->num_domains = 0; - for (current_fabric = returned_fabrics; current_fabric != NULL; - current_fabric = current_fabric->next) { - if (!current_fabric->nic) { + for (current_info = all_infos; current_info != NULL; current_info = current_info->next) { + if (!current_info->nic) { INFO(state->log_level, "Interface did not return NIC structure to fi_getinfo. Skipping.\n"); continue; } - if (!current_fabric->tx_attr) { + if (!current_info->tx_attr) { INFO(state->log_level, "Interface did not return TX_ATTR structure to fi_getinfo. Skipping.\n"); continue; } TRACE(state->log_level, "fi_getinfo returned provider %s, fabric %s, nic %s", - current_fabric->fabric_attr->prov_name, current_fabric->fabric_attr->name, - current_fabric->nic->device_attr->name); + current_info->fabric_attr->prov_name, current_info->fabric_attr->name, + current_info->nic->device_attr->name); if (state->provider != NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { - if (current_fabric->tx_attr->inject_size < NVSHMEMT_LIBFABRIC_INJECT_BYTES) { + if (current_info->tx_attr->inject_size < NVSHMEMT_LIBFABRIC_INJECT_BYTES) { INFO(state->log_level, "Disabling interface due to insufficient inject data size. reported %lu, " "expected " "%u", - current_fabric->tx_attr->inject_size, NVSHMEMT_LIBFABRIC_INJECT_BYTES); + current_info->tx_attr->inject_size, NVSHMEMT_LIBFABRIC_INJECT_BYTES); continue; } } - if ((current_fabric->domain_attr->mr_mode & FI_MR_PROV_KEY) == 0) { + if ((current_info->domain_attr->mr_mode & FI_MR_PROV_KEY) == 0) { INFO(state->log_level, "Disabling interface due to FI_MR_PROV_KEY support"); continue; } for (int i = 0; i <= state->num_domains; i++) { - if (!strncmp(current_fabric->nic->device_attr->name, state->domain_names[i].name, + if (!strncmp(current_info->nic->device_attr->name, state->domain_names[i].name, NVSHMEMT_LIBFABRIC_DOMAIN_LEN)) { break; } else if (i == state->num_domains) { - size_t name_len = strlen(current_fabric->nic->device_attr->name); + size_t name_len = strlen(current_info->nic->device_attr->name); if (name_len >= NVSHMEMT_LIBFABRIC_DOMAIN_LEN) { NVSHMEMI_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to copy domain name for libfabric transport."); } (void)strncpy(state->domain_names[state->num_domains].name, - current_fabric->nic->device_attr->name, - NVSHMEMT_LIBFABRIC_DOMAIN_LEN); + current_info->nic->device_attr->name, NVSHMEMT_LIBFABRIC_DOMAIN_LEN); state->num_domains++; break; } @@ -2020,8 +2048,6 @@ static int nvshmemi_libfabric_init_state(nvshmem_transport_t t, nvshmemt_libfabr nvshmemt_libfabric_finalize(t); } - free(info.fabric_attr->name); - return status; } @@ -2070,6 +2096,7 @@ int nvshmemt_init(nvshmem_transport_t *t, struct nvshmemi_cuda_fn_table *table, "Unable to initialize env options."); libfabric_state->log_level = nvshmemt_common_get_log_level(&options); + libfabric_state->max_nic_per_pe = options.LIBFABRIC_MAX_NIC_PER_PE; if (strcmp(options.LIBFABRIC_PROVIDER, "verbs") == 0) { libfabric_state->provider = NVSHMEMT_LIBFABRIC_PROVIDER_VERBS; diff --git a/src/modules/transport/libfabric/libfabric.h b/src/modules/transport/libfabric/libfabric.h index 8fb47de..c37729e 100644 --- a/src/modules/transport/libfabric/libfabric.h +++ b/src/modules/transport/libfabric/libfabric.h @@ -30,12 +30,9 @@ #define NVSHMEMT_LIBFABRIC_DOMAIN_LEN 32 #define NVSHMEMT_LIBFABRIC_PROVIDER_LEN 32 #define NVSHMEMT_LIBFABRIC_EP_LEN 128 - -/* one EP for all proxy ops, one for host ops */ -#define NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS 2 +/* Constrainted by memhandle size */ +#define NVSHMEMT_LIBFABRIC_MAX_NIC_PER_PE 16 #define NVSHMEMT_LIBFABRIC_PROXY_EP_IDX 1 -#define NVSHMEMT_LIBFABRIC_HOST_EP_IDX 0 - #define NVSHMEMT_LIBFABRIC_QUIET_TIMEOUT_MS 20 /* Maximum size of inject data. Currently @@ -184,7 +181,7 @@ typedef struct { nvshmemt_libfabric_endpoint_seq_counter_t put_signal_seq_counter; std::unordered_map> *proxy_put_signal_comp_map; - int qp_index; + int domain_index; } nvshmemt_libfabric_endpoint_t; typedef enum { @@ -212,6 +209,10 @@ class threadSafeOpQueue { std::deque recv; public: + threadSafeOpQueue() = default; + threadSafeOpQueue(const threadSafeOpQueue &) = delete; + threadSafeOpQueue &operator=(const threadSafeOpQueue &) = delete; + void *getNextSend() { void *elem; send_mutex.lock(); @@ -273,26 +274,34 @@ class threadSafeOpQueue { }; typedef struct { - struct fi_info *prov_info; struct fi_info *all_prov_info; - struct fid_fabric *fabric; - struct fid_domain *domain; - struct fid_av *addresses[NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS]; - nvshmemt_libfabric_endpoint_t *eps; + std::vector prov_infos; + std::vector fabrics; + std::vector domains; + std::vector addresses; + std::vector eps; + nvshmemt_libfabric_domain_name_t *domain_names; int num_domains; nvshmemt_libfabric_provider provider; int log_level; struct nvshmemi_cuda_fn_table *table; - size_t num_sends; - void *send_buf; - size_t num_recvs; - void *recv_buf; - struct fid_mr *mr; struct transport_mem_handle_info_cache *cache; + + /* Required for multi-rail */ + int max_nic_per_pe; + int num_selected_devs; + int num_selected_domains; + int cur_proxy_ep_index; + + /* Required for staged_amo */ + std::vector op_queue; + std::vector mr; + std::vector send_buf; + std::vector recv_buf; + std::vector mr_staged_amo_ack; void **remote_addr_staged_amo_ack; uint64_t *rkey_staged_amo_ack; - struct fid_mr *mr_staged_amo_ack; } nvshmemt_libfabric_state_t; typedef enum { @@ -319,7 +328,7 @@ typedef struct { typedef struct { void *buf; - nvshmemt_libfabric_mem_handle_ep_t hdls[2]; + nvshmemt_libfabric_mem_handle_ep_t hdls[1 + NVSHMEMT_LIBFABRIC_MAX_NIC_PER_PE]; } nvshmemt_libfabric_mem_handle_t; typedef struct nvshmemt_libfabric_gdr_send_p_op {