diff --git a/src/host/proxy/proxy.cpp b/src/host/proxy/proxy.cpp index 74270bf..e4925bf 100644 --- a/src/host/proxy/proxy.cpp +++ b/src/host/proxy/proxy.cpp @@ -686,51 +686,23 @@ int process_channel_amo(proxy_state_t *state, proxy_channel_t *ch, int *is_proce } void enforce_cst(proxy_state_t *proxy_state) { -#if defined(NVSHMEM_X86_64) - nvshmemi_state_t *state = proxy_state->nvshmemi_state; -#endif - int status = 0; if (nvshmemi_options.BYPASS_FLUSH) return; - if (proxy_state->is_consistency_api_supported) { - if (CU_FLUSH_GPU_DIRECT_RDMA_WRITES_TO_OWNER > proxy_state->gdr_device_native_ordering && - CUPFN(nvshmemi_cuda_syms, cuFlushGPUDirectRDMAWrites)) { - status = - CUPFN(nvshmemi_cuda_syms, - cuFlushGPUDirectRDMAWrites(CU_FLUSH_GPU_DIRECT_RDMA_WRITES_TARGET_CURRENT_CTX, - CU_FLUSH_GPU_DIRECT_RDMA_WRITES_TO_OWNER)); - /** We would want to use cudaFlushGPUDirectRDMAWritesToAllDevices when we enable - consistent access of data on any GPU (and not just self GPU) with - wait_until, quiet, barrier, etc. **/ - if (status != CUDA_SUCCESS) { - NVSHMEMI_ERROR_EXIT("cuFlushGPUDirectRDMAWrites() failed in the proxy thread \n"); - } - } - return; - } -#if defined(NVSHMEM_PPC64LE) - status = cudaEventRecord(proxy_state->cuev, proxy_state->stream); - if (unlikely(status != CUDA_SUCCESS)) { - NVSHMEMI_ERROR_EXIT("cuEventRecord() failed in the proxy thread \n"); - } -#elif defined(NVSHMEM_X86_64) - for (int i = 0; i < state->num_initialized_transports; i++) { - if (!((state->transport_bitmap) & (1 << i))) continue; - struct nvshmem_transport *tcurr = state->transports[i]; - if (!tcurr->host_ops.enforce_cst) continue; - - // assuming the transport is connected - IB RC - if (tcurr->attr & NVSHMEM_TRANSPORT_ATTR_CONNECTED) { - status = tcurr->host_ops.enforce_cst(tcurr); - if (status) { - NVSHMEMI_ERROR_PRINT("aborting due to error in progress_cst \n"); - exit(-1); - } + if (CU_FLUSH_GPU_DIRECT_RDMA_WRITES_TO_OWNER > proxy_state->gdr_device_native_ordering && + CUPFN(nvshmemi_cuda_syms, cuFlushGPUDirectRDMAWrites)) { + status = + CUPFN(nvshmemi_cuda_syms, + cuFlushGPUDirectRDMAWrites(CU_FLUSH_GPU_DIRECT_RDMA_WRITES_TARGET_CURRENT_CTX, + CU_FLUSH_GPU_DIRECT_RDMA_WRITES_TO_OWNER)); + /** We would want to use cudaFlushGPUDirectRDMAWritesToAllDevices when we enable + consistent access of data on any GPU (and not just self GPU) with + wait_until, quiet, barrier, etc. **/ + if (status != CUDA_SUCCESS) { + NVSHMEMI_ERROR_EXIT("cuFlushGPUDirectRDMAWrites() failed in the proxy thread \n"); } } -#endif } inline void quiet_ack_channels(proxy_state_t *proxy_state) { diff --git a/src/include/internal/host_transport/transport.h b/src/include/internal/host_transport/transport.h index f3fc7c1..f36b995 100644 --- a/src/include/internal/host_transport/transport.h +++ b/src/include/internal/host_transport/transport.h @@ -148,7 +148,6 @@ struct nvshmem_transport_host_ops { fence_handle fence; quiet_handle quiet; put_signal_handle put_signal; - int (*enforce_cst)(struct nvshmem_transport *transport); int (*enforce_cst_at_target)(struct nvshmem_transport *transport); int (*add_device_remote_mem_handles)(struct nvshmem_transport *transport, int transport_stride, nvshmem_mem_handle_t *mem_handles, uint64_t heap_offset, diff --git a/src/modules/transport/common/env_defs.h b/src/modules/transport/common/env_defs.h index aafc312..00d593c 100644 --- a/src/modules/transport/common/env_defs.h +++ b/src/modules/transport/common/env_defs.h @@ -98,6 +98,9 @@ NVSHMEMI_ENV_DEF(DISABLE_LOCAL_ONLY_PROXY, bool, false, NVSHMEMI_ENV_CAT_TRANSPO NVSHMEMI_ENV_DEF(LIBFABRIC_PROVIDER, string, "cxi", NVSHMEMI_ENV_CAT_TRANSPORT, "Set the feature set provider for the libfabric transport: cxi, efa, verbs") +NVSHMEMI_ENV_DEF(LIBFABRIC_MAX_NIC_PER_PE, int, 16, NVSHMEMI_ENV_CAT_TRANSPORT, + "Set the maximum number of NIC's per PE to use for libfabric provider") + #if defined(NVSHMEM_IBGDA_SUPPORT) || defined(NVSHMEM_ENV_ALL) /** GPU-initiated communication **/ NVSHMEMI_ENV_DEF(IBGDA_ENABLE_MULTI_PORT, bool, false, NVSHMEMI_ENV_CAT_TRANSPORT, diff --git a/src/modules/transport/ibdevx/ibdevx.cpp b/src/modules/transport/ibdevx/ibdevx.cpp index 4a76fde..02e217d 100644 --- a/src/modules/transport/ibdevx/ibdevx.cpp +++ b/src/modules/transport/ibdevx/ibdevx.cpp @@ -1440,46 +1440,6 @@ int nvshmemt_ibdevx_amo(struct nvshmem_transport *tcurr, int pe, void *curetptr, return status; } -int nvshmemt_ibdevx_enforce_cst_at_target(struct nvshmem_transport *tcurr) { - nvshmemt_ib_common_state_t ibdevx_state = (nvshmemt_ib_common_state_t)tcurr->state; - struct ibdevx_ep *ep = (struct ibdevx_ep *)ibdevx_state->cst_ep; - struct ibdevx_rw_wqe *wqe; - - int status = 0; - - uintptr_t wqe_bb_idx_64 = ep->wqe_bb_idx; - uint32_t wqe_bb_idx_32 = ep->wqe_bb_idx; - size_t wqe_size; - - wqe = (struct ibdevx_rw_wqe *)((char *)ep->wq_buf + - ((wqe_bb_idx_64 % get_ibdevx_qp_depth(ibdevx_state)) - << NVSHMEMT_IBDEVX_WQE_BB_SHIFT)); - wqe_size = sizeof(struct ibdevx_rw_wqe); - memset(wqe, 0, sizeof(struct ibdevx_rw_wqe)); - - wqe->ctrl.fm_ce_se = MLX5_WQE_CTRL_CQ_UPDATE; - wqe->ctrl.qpn_ds = - htobe32((uint32_t)(wqe_size / NVSHMEMT_IBDEVX_MLX5_SEND_WQE_DS) | ep->qpid << 8); - wqe->ctrl.opmod_idx_opcode = htobe32(MLX5_OPCODE_RDMA_READ | (wqe_bb_idx_32 << 8)); - - wqe->raddr.raddr = htobe64((uintptr_t)local_dummy_mr.mr->addr); - wqe->raddr.rkey = htobe32(local_dummy_mr.rkey); - - wqe->data.data_seg.byte_count = htobe32((uint32_t)4); - wqe->data.data_seg.lkey = htobe32(local_dummy_mr.lkey); - wqe->data.data_seg.addr = htobe64((uintptr_t)local_dummy_mr.mr->addr); - - assert(wqe_size <= MLX5_SEND_WQE_BB); - ep->wqe_bb_idx++; - nvshmemt_ibdevx_post_send(ep, (void *)wqe, 1); - - status = nvshmemt_ib_common_check_poll_avail(tcurr, ep, NVSHMEMT_IB_COMMON_WAIT_ALL); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "check_poll failed \n"); - -out: - return status; -} - // Using common fence and quiet functions from transport_ib_common int nvshmemt_ibdevx_ep_create(struct ibdevx_ep **ep, int devid, nvshmem_transport_t t, @@ -1932,7 +1892,6 @@ int nvshmemt_init(nvshmem_transport_t *t, struct nvshmemi_cuda_fn_table *table, transport->host_ops.finalize = nvshmemt_ibdevx_finalize; transport->host_ops.show_info = nvshmemt_ibdevx_show_info; transport->host_ops.progress = nvshmemt_ibdevx_progress; - transport->host_ops.enforce_cst = nvshmemt_ibdevx_enforce_cst_at_target; transport->host_ops.put_signal = nvshmemt_put_signal; transport->attr = NVSHMEM_TRANSPORT_ATTR_CONNECTED; diff --git a/src/modules/transport/ibgda/ibgda.cpp b/src/modules/transport/ibgda/ibgda.cpp index 4c69cce..15eb8d0 100644 --- a/src/modules/transport/ibgda/ibgda.cpp +++ b/src/modules/transport/ibgda/ibgda.cpp @@ -4915,7 +4915,6 @@ int nvshmemt_init(nvshmem_transport_t *t, struct nvshmemi_cuda_fn_table *table, transport->host_ops.amo = NULL; transport->host_ops.fence = NULL; transport->host_ops.quiet = NULL; - transport->host_ops.enforce_cst = NULL; transport->host_ops.add_device_remote_mem_handles = nvshmemt_ibgda_add_device_remote_mem_handles; transport->host_ops.put_signal = NULL; diff --git a/src/modules/transport/ibrc/ibrc.cpp b/src/modules/transport/ibrc/ibrc.cpp index bcea10f..9cb02a3 100644 --- a/src/modules/transport/ibrc/ibrc.cpp +++ b/src/modules/transport/ibrc/ibrc.cpp @@ -1802,7 +1802,6 @@ int nvshmemt_init(nvshmem_transport_t *t, struct nvshmemi_cuda_fn_table *table, transport->host_ops.progress = nvshmemt_ibrc_progress; transport->host_ops.put_signal = nvshmemt_put_signal; - transport->host_ops.enforce_cst = nvshmemt_ibrc_enforce_cst_at_target; #if !defined(NVSHMEM_PPC64LE) && !defined(NVSHMEM_AARCH64) if (!use_gdrcopy) #endif diff --git a/src/modules/transport/libfabric/libfabric.cpp b/src/modules/transport/libfabric/libfabric.cpp index 927c079..eb5cd43 100644 --- a/src/modules/transport/libfabric/libfabric.cpp +++ b/src/modules/transport/libfabric/libfabric.cpp @@ -65,11 +65,13 @@ static bool use_gdrcopy = false; #define NVSHMEM_STAGED_AMO_WIREDATA_SIZE \ sizeof(nvshmemt_libfabric_gdr_op_ctx_t) - sizeof(struct fi_context2) - sizeof(fi_addr_t) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) + static bool use_staged_atomics = false; -threadSafeOpQueue nvshmemtLibfabricOpQueue; +static bool use_auto_progress = false; std::mutex gdrRecvMutex; -int nvshmemt_libfabric_gdr_process_amos(nvshmem_transport_t transport); +int nvshmemt_libfabric_gdr_process_amos(nvshmem_transport_t transport, int qp_index); int nvshmemt_libfabric_put_signal_completion(nvshmem_transport_t transport, nvshmemt_libfabric_endpoint_t *ep, struct fi_cq_data_entry *entry, fi_addr_t *addr); @@ -79,6 +81,27 @@ static nvshmemt_libfabric_imm_cq_data_hdr_t nvshmemt_get_write_with_imm_hdr(uint NVSHMEM_STAGED_AMO_PUT_SIGNAL_SEQ_CNTR_BIT_SHIFT); } +static inline nvshmemt_libfabric_endpoint_t *nvshmemt_libfabric_get_next_ep( + nvshmemt_libfabric_state_t *state, int qp_index) { + int selected_ep; + + if (qp_index == NVSHMEMX_QP_HOST) { + selected_ep = 0; + } else { + /* + * Return the current EP, and increment the next EP in round robin fashion + * between 1 and state->num_selected_domains - 1. state->cur_proxy_ep_index + * is initialized to 1. This round-robin goes through the proxy EP's and + * ignores the host EP. + */ + selected_ep = state->cur_proxy_ep_index; + state->cur_proxy_ep_index = (state->cur_proxy_ep_index + 1) % state->num_selected_domains; + if (!state->cur_proxy_ep_index) state->cur_proxy_ep_index = 1; + } + + return state->eps[selected_ep]; +} + static void nvshmemt_libfabric_put_signal_ack_completion(nvshmemt_libfabric_endpoint_t *ep, struct fi_cq_data_entry *entry) { uint32_t seq_num = entry->data & NVSHMEM_STAGED_AMO_PUT_SIGNAL_SEQ_CNTR_BIT_MASK; @@ -96,6 +119,7 @@ static int nvshmemt_libfabric_gdr_process_completion(nvshmem_transport_t transpo fi_addr_t *addr) { int status = 0; nvshmemt_libfabric_gdr_op_ctx_t *op; + nvshmemt_libfabric_state_t *state = (nvshmemt_libfabric_state_t *)transport->state; /* Write w/imm doesn't have op->op_context, must be checked first */ if (entry->flags & FI_REMOTE_CQ_DATA) { @@ -120,16 +144,16 @@ static int nvshmemt_libfabric_gdr_process_completion(nvshmem_transport_t transpo op->src_addr = *addr; if (entry->flags & FI_SEND) { - nvshmemtLibfabricOpQueue.putToSend(op); + state->op_queue[ep->domain_index]->putToSend(op); } else if (entry->flags & FI_RMA) { /* inlined p ops or atomic responses */ - nvshmemtLibfabricOpQueue.putToSend(op); + state->op_queue[ep->domain_index]->putToSend(op); } else if (op->type == NVSHMEMT_LIBFABRIC_MATCH) { /* Must happen after entry->flags & FI_SEND to avoid send completions */ status = nvshmemt_libfabric_put_signal_completion(transport, ep, entry, addr); } else if (entry->flags & FI_RECV) { op->ep = ep; - nvshmemtLibfabricOpQueue.putToRecv(op); + state->op_queue[ep->domain_index]->putToRecv(op); } else { NVSHMEMI_ERROR_JMP(status, NVSHMEMX_ERROR_INVALID_VALUE, out, "Found an invalid message type in an ep completion.\n"); @@ -139,82 +163,119 @@ static int nvshmemt_libfabric_gdr_process_completion(nvshmem_transport_t transpo return status; } -static int nvshmemt_libfabric_progress(nvshmem_transport_t transport) { - nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)transport->state; +static int nvshmemt_libfabric_single_ep_progress(nvshmem_transport_t transport, + nvshmemt_libfabric_endpoint_t *ep) { + nvshmemt_libfabric_state_t *state = (nvshmemt_libfabric_state_t *)transport->state; + char buf[MAX_COMPLETIONS_PER_CQ_POLL * sizeof(struct fi_cq_data_entry)]; + fi_addr_t src_addr[MAX_COMPLETIONS_PER_CQ_POLL]; + fi_addr_t *addr; + ssize_t qstatus; + struct fi_cq_data_entry *entry; + uint64_t cnt; int status; - for (int i = 0; i < NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; i++) { - uint64_t cnt = fi_cntr_readerr(libfabric_state->eps[i].counter); - - if (cnt > 0) { - NVSHMEMI_WARN_PRINT("Nonzero error count progressing EP %d (%" PRIu64 ")\n", i, cnt); - - struct fi_cq_err_entry err; - memset(&err, 0, sizeof(struct fi_cq_err_entry)); - ssize_t nerr = fi_cq_readerr(libfabric_state->eps[i].cq, &err, 0); + cnt = fi_cntr_readerr(ep->counter); + if (cnt > 0) { + NVSHMEMI_WARN_PRINT("Nonzero error count progressing EP (%" PRIu64 ")\n", cnt); + struct fi_cq_err_entry err; + memset(&err, 0, sizeof(struct fi_cq_err_entry)); + ssize_t nerr = fi_cq_readerr(ep->cq, &err, 0); + + if (nerr > 0) { + char str[100] = "\0"; + const char *err_str = fi_cq_strerror(ep->cq, err.prov_errno, err.err_data, str, 100); + NVSHMEMI_WARN_PRINT( + "CQ reported error (%d): %s\n\tProvider error: %s\n\tSupplemental error " + "info: %s\n", + err.err, fi_strerror(err.err), err_str ? err_str : "none", + strlen(str) ? str : "none"); + } else if (nerr == -FI_EAGAIN) { + NVSHMEMI_WARN_PRINT("fi_cq_readerr returned -FI_EAGAIN\n"); + } else { + NVSHMEMI_WARN_PRINT("fi_cq_readerr returned %zd: %s\n", nerr, fi_strerror(-1 * nerr)); + } + return NVSHMEMX_ERROR_INTERNAL; + } - if (nerr > 0) { - char str[100] = "\0"; - const char *err_str = fi_cq_strerror(libfabric_state->eps[i].cq, err.prov_errno, - err.err_data, str, 100); - NVSHMEMI_WARN_PRINT( - "CQ reported error (%d): %s\n\tProvider error: %s\n\tSupplemental error " - "info: %s\n", - err.err, fi_strerror(err.err), err_str ? err_str : "none", - strlen(str) ? str : "none"); - } else if (nerr == -FI_EAGAIN) { - NVSHMEMI_WARN_PRINT("fi_cq_readerr returned -FI_EAGAIN\n"); + do { + qstatus = fi_cq_readfrom(ep->cq, buf, MAX_COMPLETIONS_PER_CQ_POLL, src_addr); + /* Note - EFA provider does not support selective completions */ + if (qstatus > 0) { + if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { + entry = (struct fi_cq_data_entry *)buf; + addr = src_addr; + for (int i = 0; i < qstatus; i++, entry++, addr++) { + status = nvshmemt_libfabric_gdr_process_completion(transport, ep, entry, addr); + if (status) return NVSHMEMX_ERROR_INTERNAL; + } } else { - NVSHMEMI_WARN_PRINT("fi_cq_readerr returned %zd: %s\n", nerr, - fi_strerror(-1 * nerr)); + NVSHMEMI_WARN_PRINT("Got %zd unexpected events on EP\n", qstatus); } - return NVSHMEMX_ERROR_INTERNAL; } + } while (qstatus > 0); + if (qstatus < 0 && qstatus != -FI_EAGAIN) { + NVSHMEMI_WARN_PRINT("Error progressing CQ (%zd): %s\n", qstatus, fi_strerror(qstatus * -1)); + return NVSHMEMX_ERROR_INTERNAL; + } - { - char buf[MAX_COMPLETIONS_PER_CQ_POLL * sizeof(struct fi_cq_data_entry)]; - fi_addr_t src_addr[MAX_COMPLETIONS_PER_CQ_POLL]; - ssize_t qstatus; - nvshmemt_libfabric_endpoint_t *ep = &libfabric_state->eps[i]; - do { - qstatus = fi_cq_readfrom(ep->cq, buf, MAX_COMPLETIONS_PER_CQ_POLL, src_addr); - /* Note - EFA provider does not support selective completions */ - if (qstatus > 0) { - if (libfabric_state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { - struct fi_cq_data_entry *entry = (struct fi_cq_data_entry *)buf; - fi_addr_t *addr = src_addr; - for (int i = 0; i < qstatus; i++, entry++, addr++) { - status = nvshmemt_libfabric_gdr_process_completion(transport, ep, entry, - addr); - if (status) return NVSHMEMX_ERROR_INTERNAL; - } - } else { - NVSHMEMI_WARN_PRINT("Got %zd unexpected events on EP %d\n", qstatus, i); - } - } - } while (qstatus > 0); - if (qstatus < 0 && qstatus != -FI_EAGAIN) { - NVSHMEMI_WARN_PRINT("Error progressing CQ (%zd): %s\n", qstatus, - fi_strerror(qstatus * -1)); - return NVSHMEMX_ERROR_INTERNAL; - } + return 0; +} + +static int nvshmemt_libfabric_auto_progress(nvshmem_transport_t transport, int qp_index) { + int status; + nvshmemt_libfabric_state_t *state = (nvshmemt_libfabric_state_t *)transport->state; + int end_iter; + + if (qp_index == NVSHMEMX_QP_HOST) { + end_iter = NVSHMEMX_QP_HOST + 1; + } else { + end_iter = state->eps.size(); + qp_index = NVSHMEMT_LIBFABRIC_PROXY_EP_IDX; + } + + for (int i = qp_index; i < end_iter; i++) { + status = nvshmemt_libfabric_single_ep_progress(transport, state->eps[i]); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Error in progress: %d.\n", + status); + } + + if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { + if (gdrRecvMutex.try_lock()) { + status = nvshmemt_libfabric_gdr_process_amos(transport, qp_index); + gdrRecvMutex.unlock(); } } - if (libfabric_state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { +out: + return status; +} + +static int nvshmemt_libfabric_auto_proxy_progress(nvshmem_transport_t transport) { + return nvshmemt_libfabric_auto_progress(transport, NVSHMEMT_LIBFABRIC_PROXY_EP_IDX); +} + +static int nvshmemt_libfabric_manual_progress(nvshmem_transport_t transport) { + nvshmemt_libfabric_state_t *state = (nvshmemt_libfabric_state_t *)transport->state; + int status; + for (size_t i = 0; i < state->eps.size(); i++) { + status = nvshmemt_libfabric_single_ep_progress(transport, state->eps[i]); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Error in progress: %d.\n", + status); + } + + if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { if (gdrRecvMutex.try_lock()) { - status = nvshmemt_libfabric_gdr_process_amos(transport); + status = nvshmemt_libfabric_gdr_process_amos(transport, NVSHMEMX_QP_ALL); gdrRecvMutex.unlock(); - if (status) { - return NVSHMEMX_ERROR_INTERNAL; - } } } - return 0; +out: + return status; } -static inline int try_again(nvshmem_transport_t transport, int *status, uint64_t *num_retries) { +static inline int try_again(nvshmem_transport_t transport, int *status, uint64_t *num_retries, + int qp_index) { if (likely(*status == 0)) { return 0; } @@ -227,7 +288,10 @@ static inline int try_again(nvshmem_transport_t transport, int *status, uint64_t return 0; } (*num_retries)++; - *status = nvshmemt_libfabric_progress(transport); + if (use_auto_progress) + *status = nvshmemt_libfabric_auto_progress(transport, qp_index); + else + *status = nvshmemt_libfabric_manual_progress(transport); } if (*status != 0) { @@ -247,11 +311,13 @@ int gdrcopy_amo_ack(nvshmem_transport_t transport, nvshmemt_libfabric_endpoint_t uint64_t num_retries = 0; int status; uint64_t imm_data = 0; + uint64_t rkey_index = pe * libfabric_state->num_selected_domains + ep->domain_index; do { - resp_op = (nvshmemt_libfabric_gdr_op_ctx_t *)nvshmemtLibfabricOpQueue.getNextSend(); + resp_op = (nvshmemt_libfabric_gdr_op_ctx_t *)libfabric_state->op_queue[ep->domain_index] + ->getNextSend(); status = resp_op == NULL ? -EAGAIN : 0; - } while (try_again(transport, &status, &num_retries)); + } while (try_again(transport, &status, &num_retries, ep->domain_index)); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable allocate buffer for atomic ack.\n"); @@ -261,10 +327,11 @@ int gdrcopy_amo_ack(nvshmem_transport_t transport, nvshmemt_libfabric_endpoint_t << NVSHMEM_STAGED_AMO_PUT_SIGNAL_SEQ_CNTR_BIT_SHIFT) | sequence_count; do { - status = fi_writedata(ep->endpoint, resp_op, 0, fi_mr_desc(libfabric_state->mr), imm_data, - dest_addr, (uint64_t)libfabric_state->remote_addr_staged_amo_ack[pe], - libfabric_state->rkey_staged_amo_ack[pe], &resp_op->ofi_context); - } while (try_again(transport, &status, &num_retries)); + status = fi_writedata( + ep->endpoint, resp_op, 0, fi_mr_desc(libfabric_state->mr[ep->domain_index]), imm_data, + dest_addr, (uint64_t)libfabric_state->remote_addr_staged_amo_ack[pe], + libfabric_state->rkey_staged_amo_ack[rkey_index], &resp_op->ofi_context); + } while (try_again(transport, &status, &num_retries, ep->domain_index)); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to write atomic ack.\n"); ep->submitted_ops++; @@ -354,9 +421,11 @@ int perform_gdrcopy_amo(nvshmem_transport_t transport, nvshmemt_libfabric_gdr_op if (received_op->op > NVSHMEMI_AMO_END_OF_NONFETCH) { do { - resp_op = (nvshmemt_libfabric_gdr_op_ctx_t *)nvshmemtLibfabricOpQueue.getNextSend(); + resp_op = + (nvshmemt_libfabric_gdr_op_ctx_t *)libfabric_state->op_queue[op->ep->domain_index] + ->getNextSend(); status = resp_op == NULL ? -EAGAIN : 0; - } while (try_again(transport, &status, &num_retries)); + } while (try_again(transport, &status, &num_retries, op->ep->domain_index)); num_retries = 0; NVSHMEMI_NULL_ERROR_JMP(resp_op, status, NVSHMEMX_ERROR_INTERNAL, out, @@ -368,8 +437,9 @@ int perform_gdrcopy_amo(nvshmem_transport_t transport, nvshmemt_libfabric_gdr_op do { status = fi_send(op->ep->endpoint, (void *)resp_op, NVSHMEM_STAGED_AMO_WIREDATA_SIZE, - fi_mr_desc(libfabric_state->mr), op->src_addr, &resp_op->ofi_context); - } while (try_again(transport, &status, &num_retries)); + fi_mr_desc(libfabric_state->mr[op->ep->domain_index]), op->src_addr, + &resp_op->ofi_context); + } while (try_again(transport, &status, &num_retries, op->ep->domain_index)); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to respond to atomic request.\n"); op->ep->submitted_ops++; @@ -427,27 +497,44 @@ int nvshmemt_libfabric_gdr_process_ack(nvshmem_transport_t transport, return 0; } -int nvshmemt_libfabric_gdr_process_amos(nvshmem_transport_t transport) { +int nvshmemt_libfabric_gdr_process_amos(nvshmem_transport_t transport, int qp_index) { nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)transport->state; nvshmemt_libfabric_gdr_op_ctx_t *op; int status = 0; + int end_iter; - op = (nvshmemt_libfabric_gdr_op_ctx_t *)nvshmemtLibfabricOpQueue.getNextRecv(); - while (op) { - if (op->type == NVSHMEMT_LIBFABRIC_SEND) { - status = nvshmemt_libfabric_gdr_process_amo(transport, op); - } else { - status = nvshmemt_libfabric_gdr_process_ack(transport, op); - } + if (qp_index == NVSHMEMX_QP_HOST) { + end_iter = NVSHMEMX_QP_HOST + 1; + } else if (qp_index == NVSHMEMX_QP_ALL) { + qp_index = 0; + end_iter = libfabric_state->eps.size(); + } else { + end_iter = libfabric_state->eps.size(); + qp_index = NVSHMEMT_LIBFABRIC_PROXY_EP_IDX; + } - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to process atomic.\n"); - status = fi_recv(op->ep->endpoint, (void *)op, NVSHMEM_STAGED_AMO_WIREDATA_SIZE, - fi_mr_desc(libfabric_state->mr), FI_ADDR_UNSPEC, &op->ofi_context); - if (status) { - NVSHMEMI_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to re-post recv.\n"); + for (int i = qp_index; i < end_iter; i++) { + op = (nvshmemt_libfabric_gdr_op_ctx_t *)libfabric_state->op_queue[i]->getNextRecv(); + while (op) { + if (op->type == NVSHMEMT_LIBFABRIC_SEND) { + status = nvshmemt_libfabric_gdr_process_amo(transport, op); + } else { + status = nvshmemt_libfabric_gdr_process_ack(transport, op); + } + + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "Unable to process atomic.\n"); + status = fi_recv(op->ep->endpoint, (void *)op, NVSHMEM_STAGED_AMO_WIREDATA_SIZE, + fi_mr_desc(libfabric_state->mr[op->ep->domain_index]), FI_ADDR_UNSPEC, + &op->ofi_context); + if (status) { + NVSHMEMI_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "Unable to re-post recv.\n"); + } + op = (nvshmemt_libfabric_gdr_op_ctx_t *)libfabric_state->op_queue[i]->getNextRecv(); } - op = (nvshmemt_libfabric_gdr_op_ctx_t *)nvshmemtLibfabricOpQueue.getNextRecv(); } + out: return status; } @@ -536,7 +623,8 @@ int nvshmemt_libfabric_put_signal_completion(nvshmem_transport_t transport, "Error in put_signal_completion gdrcopy signaling operation.\n"); ep->proxy_put_signal_comp_map->erase(iter); status = fi_recv(ep->endpoint, op, NVSHMEM_STAGED_AMO_WIREDATA_SIZE, - fi_mr_desc(libfabric_state->mr), FI_ADDR_UNSPEC, &op->ofi_context); + fi_mr_desc(libfabric_state->mr[ep->domain_index]), FI_ADDR_UNSPEC, + &op->ofi_context); } out: @@ -544,46 +632,66 @@ int nvshmemt_libfabric_put_signal_completion(nvshmem_transport_t transport, } static int nvshmemt_libfabric_quiet(struct nvshmem_transport *tcurr, int pe, int qp_index) { - nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)tcurr->state; - nvshmemt_libfabric_endpoint_t *ep; - int is_proxy = qp_index != NVSHMEMX_QP_HOST; + nvshmemt_libfabric_state_t *state = (nvshmemt_libfabric_state_t *)tcurr->state; + uint64_t completed; + bool all_nics_quieted; int status = 0; + int end_iter; - if (is_proxy) { - ep = &libfabric_state->eps[NVSHMEMT_LIBFABRIC_PROXY_EP_IDX]; + if (qp_index == NVSHMEMX_QP_HOST) { + end_iter = NVSHMEMX_QP_HOST + 1; } else { - ep = &libfabric_state->eps[NVSHMEMT_LIBFABRIC_HOST_EP_IDX]; + end_iter = state->eps.size(); + qp_index = NVSHMEMT_LIBFABRIC_PROXY_EP_IDX; } - if (likely(libfabric_state->prov_info->domain_attr->control_progress == FI_PROGRESS_MANUAL) || - (libfabric_state->prov_info->domain_attr->data_progress == FI_PROGRESS_MANUAL) || - (use_staged_atomics == true) -#ifdef NVSHMEM_USE_GDRCOPY - || (use_gdrcopy == true) -#endif - ) { - uint64_t submitted, completed; - for (;;) { - completed = fi_cntr_read(ep->counter); - submitted = ep->submitted_ops; - if (completed + ep->completed_staged_atomics == submitted) - break; - else { - if (nvshmemt_libfabric_progress(tcurr)) { + if (use_staged_atomics) { + if (use_auto_progress) { + for (;;) { + all_nics_quieted = true; + for (int i = qp_index; i < end_iter; i++) { + completed = fi_cntr_read(state->eps[i]->counter) + + state->eps[i]->completed_staged_atomics; + if (state->eps[i]->submitted_ops != completed) { + all_nics_quieted = false; + if (nvshmemt_libfabric_auto_progress(tcurr, qp_index)) { + status = NVSHMEMX_ERROR_INTERNAL; + break; + } + } + } + + if (status || all_nics_quieted) break; + } + } else { + for (;;) { + all_nics_quieted = true; + for (int i = qp_index; i < end_iter; i++) { + completed = fi_cntr_read(state->eps[i]->counter) + + state->eps[i]->completed_staged_atomics; + if (state->eps[i]->submitted_ops != completed) all_nics_quieted = false; + } + if (all_nics_quieted) break; + + /* FI_PROGRESS_MANUAL requires progress on every endpoint */ + if (nvshmemt_libfabric_manual_progress(tcurr)) { status = NVSHMEMX_ERROR_INTERNAL; break; } } } } else { - status = fi_cntr_wait(ep->counter, ep->submitted_ops, NVSHMEMT_LIBFABRIC_QUIET_TIMEOUT_MS); - if (status) { - /* note - Status is negative for this function in error cases but - * fi_strerror only accepts positive values. - */ - NVSHMEMI_ERROR_PRINT("Error in quiet operation (%d): %s.\n", status, - fi_strerror(status * -1)); - status = NVSHMEMX_ERROR_INTERNAL; + for (int i = qp_index; i < end_iter; i++) { + status = fi_cntr_wait(state->eps[i]->counter, state->eps[i]->submitted_ops, + NVSHMEMT_LIBFABRIC_QUIET_TIMEOUT_MS); + if (status) { + /* note - Status is negative for this function in error cases but + * fi_strerror only accepts positive values. + */ + NVSHMEMI_ERROR_PRINT("Error in quiet operation (%d): %s.\n", status, + fi_strerror(status * -1)); + status = NVSHMEMX_ERROR_INTERNAL; + } } } @@ -604,47 +712,43 @@ static int nvshmemt_libfabric_show_info(struct nvshmem_transport *transport, int static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, rma_verb_t verb, rma_memdesc_t *remote, rma_memdesc_t *local, - rma_bytesdesc_t bytesdesc, int is_proxy, - uint32_t *imm_data) { + rma_bytesdesc_t bytesdesc, int qp_index, uint32_t *imm_data, + nvshmemt_libfabric_endpoint_t *ep) { nvshmemt_libfabric_mem_handle_ep_t *remote_handle, *local_handle; nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)tcurr->state; struct iovec p_op_l_iov; struct fi_msg_rma p_op_msg; struct fi_rma_iov p_op_r_iov; - nvshmemt_libfabric_endpoint_t *ep; size_t op_size; uint64_t num_retries = 0; int status = 0; int target_ep; - int ep_idx = 0; void *context = NULL; memset(&p_op_l_iov, 0, sizeof(struct iovec)); memset(&p_op_msg, 0, sizeof(struct fi_msg_rma)); memset(&p_op_r_iov, 0, sizeof(struct fi_rma_iov)); - if (is_proxy) { - ep_idx = NVSHMEMT_LIBFABRIC_PROXY_EP_IDX; - } else { - ep_idx = NVSHMEMT_LIBFABRIC_HOST_EP_IDX; - } + /* put_signal passes in EP to ensure that both operations go through same EP */ + if (!ep) ep = nvshmemt_libfabric_get_next_ep(libfabric_state, qp_index); - ep = &libfabric_state->eps[ep_idx]; - target_ep = pe * NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS + ep_idx; + target_ep = pe * libfabric_state->num_selected_domains + ep->domain_index; if (libfabric_state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { nvshmemt_libfabric_gdr_op_ctx_t *gdr_ctx; do { - gdr_ctx = (nvshmemt_libfabric_gdr_op_ctx_t *)nvshmemtLibfabricOpQueue.getNextSend(); + gdr_ctx = (nvshmemt_libfabric_gdr_op_ctx_t *)libfabric_state->op_queue[ep->domain_index] + ->getNextSend(); status = gdr_ctx == NULL ? -EAGAIN : 0; - } while (try_again(tcurr, &status, &num_retries)); + } while (try_again(tcurr, &status, &num_retries, qp_index)); NVSHMEMI_NULL_ERROR_JMP(gdr_ctx, status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to get context buffer for put request.\n"); context = &gdr_ctx->ofi_context; } - remote_handle = &((nvshmemt_libfabric_mem_handle_t *)remote->handle)->hdls[ep_idx]; - local_handle = &((nvshmemt_libfabric_mem_handle_t *)local->handle)->hdls[ep_idx]; + remote_handle = &((nvshmemt_libfabric_mem_handle_t *)remote->handle)->hdls[ep->domain_index]; + local_handle = &((nvshmemt_libfabric_mem_handle_t *)local->handle)->hdls[ep->domain_index]; + op_size = bytesdesc.elembytes * bytesdesc.nelems; if (verb.desc == NVSHMEMI_OP_P) { @@ -656,9 +760,9 @@ static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, do { p_buf->p_op.value = *(uint64_t *)local->ptr; status = fi_write(ep->endpoint, &p_buf->p_op.value, op_size, - fi_mr_desc(libfabric_state->mr), target_ep, + fi_mr_desc(libfabric_state->mr[ep->domain_index]), target_ep, (uintptr_t)remote->ptr, remote_handle->key, context); - } while (try_again(tcurr, &status, &num_retries)); + } while (try_again(tcurr, &status, &num_retries, qp_index)); } else { p_op_msg.msg_iov = &p_op_l_iov; p_op_msg.desc = NULL; // Local buffer is on the stack @@ -670,7 +774,8 @@ static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, p_op_l_iov.iov_base = local->ptr; p_op_l_iov.iov_len = op_size; - if (libfabric_state->prov_info->domain_attr->mr_mode & FI_MR_VIRT_ADDR) + if (libfabric_state->prov_infos[ep->domain_index]->domain_attr->mr_mode & + FI_MR_VIRT_ADDR) p_op_r_iov.addr = (uintptr_t)remote->ptr; else p_op_r_iov.addr = (uintptr_t)remote->offset; @@ -682,11 +787,11 @@ static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, */ do { status = fi_writemsg(ep->endpoint, &p_op_msg, FI_INJECT); - } while (try_again(tcurr, &status, &num_retries)); + } while (try_again(tcurr, &status, &num_retries, qp_index)); } } else if (verb.desc == NVSHMEMI_OP_PUT) { uintptr_t remote_addr; - if (libfabric_state->prov_info->domain_attr->mr_mode & FI_MR_VIRT_ADDR) + if (libfabric_state->prov_infos[ep->domain_index]->domain_attr->mr_mode & FI_MR_VIRT_ADDR) remote_addr = (uintptr_t)remote->ptr; else remote_addr = (uintptr_t)remote->offset; @@ -699,12 +804,12 @@ static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, else status = fi_write(ep->endpoint, local->ptr, op_size, local_handle->local_desc, target_ep, remote_addr, remote_handle->key, context); - } while (try_again(tcurr, &status, &num_retries)); + } while (try_again(tcurr, &status, &num_retries, qp_index)); } else if (verb.desc == NVSHMEMI_OP_G || verb.desc == NVSHMEMI_OP_GET) { assert( !imm_data); // Write w/ imm not suppored with NVSHMEMI_OP_G/GET on Libfabric transport uintptr_t remote_addr; - if (libfabric_state->prov_info->domain_attr->mr_mode & FI_MR_VIRT_ADDR) + if (libfabric_state->prov_infos[ep->domain_index]->domain_attr->mr_mode & FI_MR_VIRT_ADDR) remote_addr = (uintptr_t)remote->ptr; else remote_addr = (uintptr_t)remote->offset; @@ -712,7 +817,7 @@ static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, do { status = fi_read(ep->endpoint, local->ptr, op_size, local_handle->local_desc, target_ep, remote_addr, remote_handle->key, context); - } while (try_again(tcurr, &status, &num_retries)); + } while (try_again(tcurr, &status, &num_retries, qp_index)); } else { NVSHMEMI_ERROR_JMP(status, NVSHMEMX_ERROR_INVALID_VALUE, out, "Invalid RMA operation specified.\n"); @@ -731,33 +836,29 @@ static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, static int nvshmemt_libfabric_rma(struct nvshmem_transport *tcurr, int pe, rma_verb_t verb, rma_memdesc_t *remote, rma_memdesc_t *local, - rma_bytesdesc_t bytesdesc, int is_proxy) { - return nvshmemt_libfabric_rma_impl(tcurr, pe, verb, remote, local, bytesdesc, is_proxy, NULL); + rma_bytesdesc_t bytesdesc, int qp_index) { + return nvshmemt_libfabric_rma_impl(tcurr, pe, verb, remote, local, bytesdesc, qp_index, NULL, + NULL); } static int nvshmemt_libfabric_gdr_amo(struct nvshmem_transport *transport, int pe, void *curetptr, amo_verb_t verb, amo_memdesc_t *remote, - amo_bytesdesc_t bytesdesc, int is_proxy) { + amo_bytesdesc_t bytesdesc, int qp_index) { nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)transport->state; nvshmemt_libfabric_endpoint_t *ep; nvshmemt_libfabric_gdr_op_ctx_t *amo; uint64_t num_retries = 0; - int target_ep, ep_idx; + int target_ep; int status = 0; - if (is_proxy) { - ep_idx = NVSHMEMT_LIBFABRIC_PROXY_EP_IDX; - } else { - ep_idx = NVSHMEMT_LIBFABRIC_HOST_EP_IDX; - } - - ep = &libfabric_state->eps[ep_idx]; - target_ep = pe * NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS + ep_idx; + ep = nvshmemt_libfabric_get_next_ep(libfabric_state, qp_index); + target_ep = pe * libfabric_state->num_selected_domains + ep->domain_index; do { - amo = (nvshmemt_libfabric_gdr_op_ctx_t *)nvshmemtLibfabricOpQueue.getNextSend(); + amo = (nvshmemt_libfabric_gdr_op_ctx_t *)libfabric_state->op_queue[ep->domain_index] + ->getNextSend(); status = amo == NULL ? -EAGAIN : 0; - } while (try_again(transport, &status, &num_retries)); + } while (try_again(transport, &status, &num_retries, qp_index)); if (status) { NVSHMEMI_ERROR_PRINT("Unable to retrieve AMO operation."); @@ -777,8 +878,9 @@ static int nvshmemt_libfabric_gdr_amo(struct nvshmem_transport *transport, int p num_retries = 0; do { status = fi_send(ep->endpoint, (void *)amo, NVSHMEM_STAGED_AMO_WIREDATA_SIZE, - fi_mr_desc(libfabric_state->mr), target_ep, &amo->ofi_context); - } while (try_again(transport, &status, &num_retries)); + fi_mr_desc(libfabric_state->mr[ep->domain_index]), target_ep, + &amo->ofi_context); + } while (try_again(transport, &status, &num_retries, qp_index)); if (status) { NVSHMEMI_ERROR_PRINT("Received an error when trying to post an AMO operation.\n"); @@ -793,7 +895,7 @@ static int nvshmemt_libfabric_gdr_amo(struct nvshmem_transport *transport, int p static int nvshmemt_libfabric_amo(struct nvshmem_transport *transport, int pe, void *curetptr, amo_verb_t verb, amo_memdesc_t *remote, amo_bytesdesc_t bytesdesc, - int is_proxy) { + int qp_index) { nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)transport->state; nvshmemt_libfabric_mem_handle_ep_t *remote_handle = NULL, *local_handle = NULL; nvshmemt_libfabric_endpoint_t *ep; @@ -807,7 +909,6 @@ static int nvshmemt_libfabric_amo(struct nvshmem_transport *transport, int pe, v uint64_t num_retries = 0; int target_ep; int status = 0; - int ep_idx; memset(&amo_msg, 0, sizeof(struct fi_msg_atomic)); memset(&fi_local_iov, 0, sizeof(struct fi_ioc)); @@ -815,19 +916,14 @@ static int nvshmemt_libfabric_amo(struct nvshmem_transport *transport, int pe, v memset(&fi_ret_iov, 0, sizeof(struct fi_ioc)); memset(&fi_remote_iov, 0, sizeof(struct fi_rma_ioc)); - if (is_proxy) { - ep_idx = NVSHMEMT_LIBFABRIC_PROXY_EP_IDX; - } else { - ep_idx = NVSHMEMT_LIBFABRIC_HOST_EP_IDX; - } - - ep = &libfabric_state->eps[ep_idx]; - target_ep = pe * NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS + ep_idx; + ep = nvshmemt_libfabric_get_next_ep(libfabric_state, qp_index); + target_ep = pe * libfabric_state->num_selected_domains + ep->domain_index; remote_handle = - &((nvshmemt_libfabric_mem_handle_t *)remote->remote_memdesc.handle)->hdls[ep_idx]; + &((nvshmemt_libfabric_mem_handle_t *)remote->remote_memdesc.handle)->hdls[ep->domain_index]; if (verb.desc > NVSHMEMI_AMO_END_OF_NONFETCH) { - local_handle = &((nvshmemt_libfabric_mem_handle_t *)remote->ret_handle)->hdls[ep_idx]; + local_handle = + &((nvshmemt_libfabric_mem_handle_t *)remote->ret_handle)->hdls[ep->domain_index]; } if (bytesdesc.elembytes == 8) { @@ -894,7 +990,7 @@ static int nvshmemt_libfabric_amo(struct nvshmem_transport *transport, int pe, v amo_msg.addr = target_ep; - if (libfabric_state->prov_info->domain_attr->mr_mode & FI_MR_VIRT_ADDR) + if (libfabric_state->prov_infos[ep->domain_index]->domain_attr->mr_mode & FI_MR_VIRT_ADDR) fi_remote_iov.addr = (uintptr_t)remote->remote_memdesc.ptr; else fi_remote_iov.addr = (uintptr_t)remote->remote_memdesc.offset; @@ -929,7 +1025,7 @@ static int nvshmemt_libfabric_amo(struct nvshmem_transport *transport, int pe, v status = fi_fetch_atomicmsg(ep->endpoint, &amo_msg, &fi_ret_iov, &local_handle->local_desc, 1, FI_INJECT); } - } while (try_again(transport, &status, &num_retries)); + } while (try_again(transport, &status, &num_retries, qp_index)); if (status) goto out; // Status set by try_again @@ -945,31 +1041,26 @@ static int nvshmemt_libfabric_amo(struct nvshmem_transport *transport, int pe, v static int nvshmemt_libfabric_gdr_signal(struct nvshmem_transport *transport, int pe, void *curetptr, amo_verb_t verb, amo_memdesc_t *remote, - amo_bytesdesc_t bytesdesc, int is_proxy, - uint32_t sequence_count, uint16_t num_writes) { + amo_bytesdesc_t bytesdesc, int qp_index, + uint32_t sequence_count, uint16_t num_writes, + nvshmemt_libfabric_endpoint_t *ep) { nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)transport->state; - nvshmemt_libfabric_endpoint_t *ep; + nvshmemt_libfabric_gdr_op_ctx_t *context; nvshmemt_libfabric_gdr_signal_op_t *signal; uint64_t num_retries = 0; - int target_ep, ep_idx; + int target_ep; int status = 0; - if (is_proxy) { - ep_idx = NVSHMEMT_LIBFABRIC_PROXY_EP_IDX; - } else { - ep_idx = NVSHMEMT_LIBFABRIC_HOST_EP_IDX; - } - - ep = &libfabric_state->eps[ep_idx]; - target_ep = pe * NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS + ep_idx; + target_ep = pe * libfabric_state->num_selected_domains + ep->domain_index; static_assert(sizeof(nvshmemt_libfabric_gdr_op_ctx) >= sizeof(nvshmemt_libfabric_gdr_signal_op_t)); do { - context = (nvshmemt_libfabric_gdr_op_ctx_t *)nvshmemtLibfabricOpQueue.getNextSend(); + context = (nvshmemt_libfabric_gdr_op_ctx_t *)libfabric_state->op_queue[ep->domain_index] + ->getNextSend(); status = context == NULL ? -EAGAIN : 0; - } while (try_again(transport, &status, &num_retries)); + } while (try_again(transport, &status, &num_retries, qp_index)); if (status) { NVSHMEMI_ERROR_PRINT("Unable to retrieve signal operation buffer."); @@ -987,8 +1078,9 @@ static int nvshmemt_libfabric_gdr_signal(struct nvshmem_transport *transport, in num_retries = 0; do { status = fi_send(ep->endpoint, (void *)signal, sizeof(nvshmemt_libfabric_gdr_signal_op_t), - fi_mr_desc(libfabric_state->mr), target_ep, &context->ofi_context); - } while (try_again(transport, &status, &num_retries)); + fi_mr_desc(libfabric_state->mr[ep->domain_index]), target_ep, + &context->ofi_context); + } while (try_again(transport, &status, &num_retries, qp_index)); if (status) { NVSHMEMI_ERROR_PRINT("Received an error when trying to post a signal operation.\n"); @@ -1006,31 +1098,23 @@ int nvshmemt_put_signal_unordered(struct nvshmem_transport *tcurr, int pe, rma_v std::vector &write_local, std::vector &write_bytes_desc, amo_verb_t sig_verb, amo_memdesc_t *sig_target, - amo_bytesdesc_t sig_bytes_desc, int is_proxy) { - nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)tcurr->state; + amo_bytesdesc_t sig_bytes_desc, int qp_index) { + nvshmemt_libfabric_state_t *state = (nvshmemt_libfabric_state_t *)tcurr->state; int status; uint32_t sequence_count = 0; - int ep_idx; - - if (is_proxy) { - ep_idx = NVSHMEMT_LIBFABRIC_PROXY_EP_IDX; - } else { - ep_idx = NVSHMEMT_LIBFABRIC_HOST_EP_IDX; - } - - nvshmemt_libfabric_endpoint_t &ep = libfabric_state->eps[ep_idx]; + nvshmemt_libfabric_endpoint_t *ep = nvshmemt_libfabric_get_next_ep(state, qp_index); /* Get sequence number for this put-signal, with retry */ uint64_t num_retries = 0; do { - int32_t seq_num = ep.put_signal_seq_counter.next_seq_num(); + int32_t seq_num = ep->put_signal_seq_counter.next_seq_num(); if (seq_num < 0) { status = -EAGAIN; } else { sequence_count = seq_num; status = 0; } - } while (try_again(tcurr, &status, &num_retries)); + } while (try_again(tcurr, &status, &num_retries, qp_index)); if (unlikely(status)) { NVSHMEMI_ERROR_PRINT("Error in nvshmemt_put_signal_unordered while waiting for category\n"); @@ -1042,7 +1126,7 @@ int nvshmemt_put_signal_unordered(struct nvshmem_transport *tcurr, int pe, rma_v for (size_t i = 0; i < write_remote.size(); i++) { status = nvshmemt_libfabric_rma_impl(tcurr, pe, write_verb, &write_remote[i], &write_local[i], - write_bytes_desc[i], is_proxy, &sequence_count); + write_bytes_desc[i], qp_index, &sequence_count, ep); if (unlikely(status)) { NVSHMEMI_ERROR_PRINT( "Error in nvshmemt_put_signal_unordered, could not submit write #%lu\n", i); @@ -1051,8 +1135,9 @@ int nvshmemt_put_signal_unordered(struct nvshmem_transport *tcurr, int pe, rma_v } assert(use_staged_atomics == true); - status = nvshmemt_libfabric_gdr_signal(tcurr, pe, NULL, sig_verb, sig_target, sig_bytes_desc, - is_proxy, sequence_count, (uint16_t)write_remote.size()); + status = + nvshmemt_libfabric_gdr_signal(tcurr, pe, NULL, sig_verb, sig_target, sig_bytes_desc, + qp_index, sequence_count, (uint16_t)write_remote.size(), ep); out: if (status) { NVSHMEMI_ERROR_PRINT( @@ -1063,76 +1148,12 @@ int nvshmemt_put_signal_unordered(struct nvshmem_transport *tcurr, int pe, rma_v return status; } -static int nvshmemt_libfabric_enforce_cst(struct nvshmem_transport *tcurr) { - nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)tcurr->state; - uint64_t num_retries = 0; - int status; - int target_ep; - int mype = tcurr->my_pe; - -#ifdef NVSHMEM_USE_GDRCOPY - if (use_gdrcopy) { - if (libfabric_state->provider != NVSHMEMT_LIBFABRIC_PROVIDER_SLINGSHOT) { - int temp; - nvshmemt_libfabric_memhandle_info_t *mem_handle_info; - - mem_handle_info = - (nvshmemt_libfabric_memhandle_info_t *)nvshmemt_mem_handle_cache_get_by_idx( - libfabric_state->cache, 0); - if (!mem_handle_info) { - goto skip; - } - gdrcopy_ftable.copy_from_mapping(mem_handle_info->mh, &temp, mem_handle_info->cpu_ptr, - sizeof(int)); - } - } - -skip: -#endif - - target_ep = mype * NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS + NVSHMEMT_LIBFABRIC_PROXY_EP_IDX; - do { - struct fi_msg_rma msg; - struct iovec l_iov; - struct fi_rma_iov r_iov; - void *desc = libfabric_state->local_mr_desc[NVSHMEMT_LIBFABRIC_PROXY_EP_IDX]; - uint64_t flags = 0; - - memset(&msg, 0, sizeof(struct fi_msg_rma)); - memset(&l_iov, 0, sizeof(struct iovec)); - memset(&r_iov, 0, sizeof(struct fi_rma_iov)); - - l_iov.iov_base = libfabric_state->local_mem_ptr; - l_iov.iov_len = 8; - - r_iov.addr = 0; // Zero offset - r_iov.len = 8; - r_iov.key = libfabric_state->local_mr_key[NVSHMEMT_LIBFABRIC_PROXY_EP_IDX]; - - msg.msg_iov = &l_iov; - msg.desc = &desc; - msg.iov_count = 1; - msg.rma_iov = &r_iov; - msg.rma_iov_count = 1; - msg.context = NULL; - msg.data = 0; - - if (libfabric_state->prov_info->caps & FI_FENCE) flags |= FI_FENCE; - - status = - fi_readmsg(libfabric_state->eps[NVSHMEMT_LIBFABRIC_PROXY_EP_IDX].endpoint, &msg, flags); - } while (try_again(tcurr, &status, &num_retries)); - - libfabric_state->eps[target_ep].submitted_ops++; - return status; -} - static int nvshmemt_libfabric_release_mem_handle(nvshmem_mem_handle_t *mem_handle, nvshmem_transport_t t) { nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)t->state; nvshmemt_libfabric_mem_handle_t *fabric_handle; void *curr_ptr; - int max_reg, status = 0; + int status = 0; assert(mem_handle != NULL); fabric_handle = (nvshmemt_libfabric_mem_handle_t *)mem_handle; @@ -1162,18 +1183,10 @@ static int nvshmemt_libfabric_release_mem_handle(nvshmem_mem_handle_t *mem_handl } } - if (libfabric_state->prov_info->domain_attr->mr_mode & FI_MR_ENDPOINT) - max_reg = NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; - else - max_reg = 1; - - for (int i = 0; i < max_reg; i++) { - if (libfabric_state->local_mr[i] == fabric_handle->hdls[i].mr) - libfabric_state->local_mr[i] = NULL; - + for (size_t i = 0; i < libfabric_state->domains.size(); i++) { int status = fi_close(&fabric_handle->hdls[i].mr->fid); if (status) { - NVSHMEMI_WARN_PRINT("Error releasing mem handle idx %d (%d): %s\n", i, status, + NVSHMEMI_WARN_PRINT("Error releasing mem handle idx %zu (%d): %s\n", i, status, fi_strerror(status * -1)); } } @@ -1182,6 +1195,7 @@ static int nvshmemt_libfabric_release_mem_handle(nvshmem_mem_handle_t *mem_handl return status; } +static_assert(sizeof(nvshmemt_libfabric_mem_handle_t) < sizeof(nvshmem_mem_handle_t)); static int nvshmemt_libfabric_get_mem_handle(nvshmem_mem_handle_t *mem_handle, void *buf, size_t length, nvshmem_transport_t t, bool local_only) { @@ -1217,6 +1231,7 @@ static int nvshmemt_libfabric_get_mem_handle(nvshmem_mem_handle_t *mem_handle, v assert(mem_handle != NULL); fabric_handle = (nvshmemt_libfabric_mem_handle_t *)mem_handle; + fabric_handle->buf = buf; status = cudaPointerGetAttributes(&attr, buf); if (status != cudaSuccess) { NVSHMEMI_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, @@ -1247,40 +1262,15 @@ static int nvshmemt_libfabric_get_mem_handle(nvshmem_mem_handle_t *mem_handle, v mr_attr.iface = FI_HMEM_SYSTEM; } - fabric_handle->buf = buf; - if (libfabric_state->prov_info->domain_attr->mr_mode & FI_MR_ENDPOINT) { - for (int i = 0; i < NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; i++) { - status = - fi_mr_regattr(libfabric_state->domain, &mr_attr, 0, &fabric_handle->hdls[i].mr); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Error registering memory region: %s\n", - fi_strerror(status * -1)); - - status = - fi_mr_bind(fabric_handle->hdls[i].mr, &libfabric_state->eps[i].endpoint->fid, 0); - - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Error binding MR to EP %d: %s\n", i, fi_strerror(status * -1)); - - status = fi_mr_enable(fabric_handle->hdls[i].mr); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Error enabling MR: %s\n", - fi_strerror(status * -1)); - - fabric_handle->hdls[i].key = fi_mr_key(fabric_handle->hdls[i].mr); - fabric_handle->hdls[i].local_desc = fi_mr_desc(fabric_handle->hdls[i].mr); - } - } else { + for (size_t i = 0; i < libfabric_state->domains.size(); i++) { struct fid_mr *mr; - - status = fi_mr_regattr(libfabric_state->domain, &mr_attr, 0, &mr); + status = fi_mr_regattr(libfabric_state->domains[i], &mr_attr, 0, &mr); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Error registering memory region: %s\n", fi_strerror(status * -1)); - for (int i = 0; i < NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; i++) { - fabric_handle->hdls[i].mr = mr; - fabric_handle->hdls[i].key = fi_mr_key(mr); - fabric_handle->hdls[i].local_desc = fi_mr_desc(mr); - } + fabric_handle->hdls[i].mr = mr; + fabric_handle->hdls[i].key = fi_mr_key(mr); + fabric_handle->hdls[i].local_desc = fi_mr_desc(mr); } if (!local_only && libfabric_state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { @@ -1350,15 +1340,6 @@ static int nvshmemt_libfabric_get_mem_handle(nvshmem_mem_handle_t *mem_handle, v } while (curr_ptr < (char *)buf + length); } - if (libfabric_state->local_mr[0] == NULL && !local_only) { - for (int i = 0; i < NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; i++) { - libfabric_state->local_mr[i] = fabric_handle->hdls[i].mr; - libfabric_state->local_mr_key[i] = fabric_handle->hdls[i].key; - libfabric_state->local_mr_desc[i] = fabric_handle->hdls[i].local_desc; - } - libfabric_state->local_mem_ptr = buf; - } - out: if (status) { if (handle_info) { @@ -1429,173 +1410,186 @@ static int nvshmemt_libfabric_connect_endpoints(nvshmem_transport_t t, int *sele nvshmemt_libfabric_state_t *state = (nvshmemt_libfabric_state_t *)t->state; nvshmemt_libfabric_ep_name_t *all_ep_names = NULL; nvshmemt_libfabric_ep_name_t *local_ep_names = NULL; - struct fi_info *current_fabric; + struct fi_info *current_info; + struct fid_fabric *fabric; + struct fid_domain *domain; + struct fid_av *address; + struct fid_mr *mr; struct fi_av_attr av_attr; struct fi_cq_attr cq_attr; struct fi_cntr_attr cntr_attr; size_t ep_namelen = NVSHMEMT_LIBFABRIC_EP_LEN; int status = 0; int total_num_eps; - size_t num_recvs_per_pe = 0; + size_t num_recvs_per_ep = 0; int n_pes = t->n_pes; - - if (state->eps) { - NVSHMEMI_WARN_PRINT( - "Device already selected. libfabric only supports one NIC per PE and doesn't support " - "additional QPs.\n"); - goto out_already_connected; + size_t num_sends; + size_t num_recvs; + size_t elem_size; + uint64_t flags; + state->num_selected_devs = MIN(num_selected_devs, state->max_nic_per_pe); + + if (state->eps.size()) { + NVSHMEMI_ERROR_PRINT("PE has previously called connect_endpoints()\n"); + return NVSHMEMX_ERROR_INTERNAL; } - state->eps = (nvshmemt_libfabric_endpoint_t *)calloc(NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS, - sizeof(nvshmemt_libfabric_endpoint_t)); - NVSHMEMI_NULL_ERROR_JMP(state->eps, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, - "Unable to allocate EPs."); - - current_fabric = state->all_prov_info; - do { - if (!strncmp(current_fabric->nic->device_attr->name, - state->domain_names[selected_dev_ids[0]].name, - NVSHMEMT_LIBFABRIC_DOMAIN_LEN)) { - break; - } - current_fabric = current_fabric->next; - } while (current_fabric != NULL); - NVSHMEMI_NULL_ERROR_JMP(current_fabric, status, NVSHMEMX_ERROR_INTERNAL, out, - "Unable to find the selected fabric.\n"); - - state->prov_info = fi_dupinfo(current_fabric); - - if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA && - strcmp(state->prov_info->fabric_attr->name, "efa-direct")) + if (state->num_selected_devs > NVSHMEMT_LIBFABRIC_MAX_NIC_PER_PE) { + state->num_selected_devs = NVSHMEMT_LIBFABRIC_MAX_NIC_PER_PE; NVSHMEMI_WARN_PRINT( - "Libfabric transport is using efa fabric instead of efa-direct, " - "use libfabric v2.1.0 or newer for improved performance\n"); - - status = fi_fabric(state->prov_info->fabric_attr, &state->fabric, NULL); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Failed to allocate fabric: %d: %s\n", status, fi_strerror(status * -1)); - - status = fi_domain(state->fabric, state->prov_info, &state->domain, NULL); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Failed to allocate domain: %d: %s\n", status, fi_strerror(status * -1)); - - if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { - state->num_sends = current_fabric->tx_attr->size * NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; - state->num_recvs = current_fabric->rx_attr->size * NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; - size_t elem_size = sizeof(nvshmemt_libfabric_gdr_op_ctx_t); - - num_recvs_per_pe = state->num_recvs / NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; - - state->recv_buf = calloc(state->num_sends + state->num_recvs, elem_size); - NVSHMEMI_NULL_ERROR_JMP(state->recv_buf, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, - "Unable to allocate EFA msg buffer.\n"); - state->send_buf = (char *)state->recv_buf + (elem_size * state->num_recvs); - - status = fi_mr_reg(state->domain, state->recv_buf, - (state->num_sends + state->num_recvs) * elem_size, FI_SEND | FI_RECV, 0, - 0, 0, &state->mr, NULL); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Failed to register EFA msg buffer: %d: %s\n", status, - fi_strerror(status * -1)); - - nvshmemtLibfabricOpQueue.putToSendBulk((char *)state->send_buf, elem_size, - state->num_sends); - } - - t->max_op_len = state->prov_info->ep_attr->max_msg_size; - av_attr.type = FI_AV_TABLE; - av_attr.rx_ctx_bits = 0; - av_attr.count = NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS * n_pes; - av_attr.ep_per_node = NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; - av_attr.name = NULL; - av_attr.map_addr = NULL; - av_attr.flags = 0; - - /* Note - This is needed because EFA will only bind AVs to EPs on a 1:1 basis. - * If EFA ever lifts this requirement, we can reduce the number of AVs required. - */ - for (int i = 0; i < NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; i++) { - status = fi_av_open(state->domain, &av_attr, &state->addresses[i], NULL); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Failed to allocate address vector: %d: %s\n", status, - fi_strerror(status * -1)); + "PE selected %d devices, but the libfabric transport only supports a max of %d " + "devices. Continuing using %d devices.\n", + state->num_selected_devs, NVSHMEMT_LIBFABRIC_MAX_NIC_PER_PE, + NVSHMEMT_LIBFABRIC_MAX_NIC_PER_PE); } + state->num_selected_domains = state->num_selected_devs + 1; - INFO(state->log_level, "Selected provider %s, fabric %s, nic %s, hmem %s", - state->prov_info->fabric_attr->prov_name, state->prov_info->fabric_attr->name, - state->prov_info->nic->device_attr->name, state->prov_info->caps & FI_HMEM ? "yes" : "no"); - - assert(state->eps); + /* Initialize configuration which only need to be set once */ + t->max_op_len = UINT64_MAX; /* Set as sential value */ + state->cur_proxy_ep_index = 1; memset(&cq_attr, 0, sizeof(struct fi_cq_attr)); - memset(&cntr_attr, 0, sizeof(struct fi_cntr_attr)); - - state->prov_info->ep_attr->tx_ctx_cnt = 0; - state->prov_info->caps = FI_RMA; - if ((state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_SLINGSHOT) || - (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_VERBS)) { - state->prov_info->caps |= FI_ATOMIC; - } else { - state->prov_info->caps |= FI_MSG; - state->prov_info->caps |= FI_SOURCE; - } - state->prov_info->tx_attr->op_flags = 0; - state->prov_info->tx_attr->mode = 0; - state->prov_info->rx_attr->mode = 0; - - if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { - state->prov_info->mode = FI_CONTEXT2; - } else { - state->prov_info->mode = 0; - } - - state->prov_info->tx_attr->op_flags = FI_DELIVERY_COMPLETE; - - cntr_attr.events = FI_CNTR_EVENTS_COMP; - cntr_attr.wait_obj = FI_WAIT_UNSPEC; - cntr_attr.wait_set = NULL; - cntr_attr.flags = 0; - if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_SLINGSHOT) { cq_attr.size = 16; /* CQ is only used to capture error events */ cq_attr.format = FI_CQ_FORMAT_UNSPEC; cq_attr.wait_obj = FI_WAIT_NONE; - } - - if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { + } else if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { cq_attr.format = FI_CQ_FORMAT_DATA; cq_attr.wait_obj = FI_WAIT_NONE; cq_attr.size = 32768; } - local_ep_names = (nvshmemt_libfabric_ep_name_t *)calloc(NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS, + memset(&av_attr, 0, sizeof(struct fi_av_attr)); + av_attr.type = FI_AV_TABLE; + av_attr.count = state->num_selected_domains * n_pes; + + memset(&cntr_attr, 0, sizeof(struct fi_cntr_attr)); + cntr_attr.events = FI_CNTR_EVENTS_COMP; + cntr_attr.wait_obj = FI_WAIT_UNSPEC; + + /* Find fabric info for each selected device */ + for (int dev_idx = 0; dev_idx < state->num_selected_devs; dev_idx++) { + current_info = state->all_prov_info; + do { + if (!strncmp(current_info->nic->device_attr->name, + state->domain_names[selected_dev_ids[dev_idx]].name, + NVSHMEMT_LIBFABRIC_DOMAIN_LEN)) { + break; + } + current_info = current_info->next; + } while (current_info != NULL); + NVSHMEMI_NULL_ERROR_JMP(current_info, status, NVSHMEMX_ERROR_INTERNAL, out, + "Unable to find fabric for device %d.\n", dev_idx); + + /* + * Create two domains (host/proxy domain) for the first NIC. + */ + if (state->prov_infos.size() == 0) state->prov_infos.push_back(current_info); + + state->prov_infos.push_back(current_info); + + if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA && + strcmp(current_info->fabric_attr->name, "efa-direct")) + NVSHMEMI_WARN_PRINT( + "Libfabric transport is using efa fabric instead of efa-direct, " + "use libfabric v2.1.0 or newer for improved performance\n"); + } + + /* Allocate out of band AV name exchange buffers */ + local_ep_names = (nvshmemt_libfabric_ep_name_t *)calloc(state->num_selected_domains, sizeof(nvshmemt_libfabric_ep_name_t)); NVSHMEMI_NULL_ERROR_JMP(local_ep_names, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, "Unable to allocate array of endpoint names."); - total_num_eps = NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS * n_pes; + total_num_eps = n_pes * state->num_selected_domains; all_ep_names = (nvshmemt_libfabric_ep_name_t *)calloc(total_num_eps, sizeof(nvshmemt_libfabric_ep_name_t)); NVSHMEMI_NULL_ERROR_JMP(all_ep_names, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, "Unable to allocate array of endpoint names."); - for (int i = 0; i < NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; i++) { - status = fi_endpoint(state->domain, state->prov_info, &state->eps[i].endpoint, NULL); + /* Create Resources For Each Selected Device */ + for (size_t i = 0; i < state->prov_infos.size(); i++) { + INFO(state->log_level, + "Selected provider %s, fabric %s, nic %s, hmem %s multi-rail %zu/%d\n", + state->prov_infos[i]->fabric_attr->prov_name, state->prov_infos[i]->fabric_attr->name, + state->prov_infos[i]->nic->device_attr->name, + state->prov_infos[i]->caps & FI_HMEM ? "yes" : "no", i + 1, num_selected_devs); + + if (state->prov_infos[i]->ep_attr->max_msg_size < t->max_op_len) + t->max_op_len = state->prov_infos[i]->ep_attr->max_msg_size; + + status = fi_fabric(state->prov_infos[i]->fabric_attr, &fabric, NULL); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Unable to allocate endpoint: %d: %s\n", status, + "Failed to allocate fabric: %d: %s\n", status, + fi_strerror(status * -1)); + state->fabrics.push_back(fabric); + + status = fi_domain(fabric, state->prov_infos[i], &domain, NULL); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "Failed to allocate domain: %d: %s\n", status, + fi_strerror(status * -1)); + state->domains.push_back(domain); + + if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { + num_sends = state->prov_infos[i]->tx_attr->size; + num_recvs = state->prov_infos[i]->rx_attr->size; + elem_size = sizeof(nvshmemt_libfabric_gdr_op_ctx_t); + num_recvs_per_ep = num_recvs; + + state->recv_buf.push_back(calloc(num_sends + num_recvs, elem_size)); + NVSHMEMI_NULL_ERROR_JMP(state->recv_buf[i], status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, + "Unable to allocate EFA msg buffer.\n"); + state->send_buf.push_back((char *)state->recv_buf[i] + (elem_size * num_recvs)); + + status = fi_mr_reg(domain, state->recv_buf[i], (num_sends + num_recvs) * elem_size, + FI_SEND | FI_RECV | FI_WRITE, 0, 0, 0, &mr, NULL); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "Failed to register EFA msg buffer: %d: %s\n", status, + fi_strerror(status * -1)); + state->mr.push_back(mr); + + state->op_queue.push_back(new threadSafeOpQueue); + state->op_queue[i]->putToSendBulk((char *)state->send_buf[i], elem_size, num_sends); + } + + status = fi_av_open(domain, &av_attr, &address, NULL); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "Failed to allocate address vector: %d: %s\n", status, fi_strerror(status * -1)); + state->addresses.push_back(address); + + /* Create nvshmemt_libfabric_endpoint_t resources */ + state->eps.push_back( + (nvshmemt_libfabric_endpoint_t *)calloc(1, sizeof(nvshmemt_libfabric_endpoint_t))); + NVSHMEMI_NULL_ERROR_JMP(state->eps[i], status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, + "Unable to alloc libfabric_tx_progress_group struct.\n"); + state->eps[i]->domain_index = i; /* Initialize per-endpoint proxy_put_signal_comp_map */ - state->eps[i].proxy_put_signal_comp_map = + state->eps[i]->proxy_put_signal_comp_map = new std::unordered_map>(); - state->eps[i].put_signal_seq_counter.reset(); - state->eps[i].completed_staged_atomics = 0; + state->eps[i]->put_signal_seq_counter.reset(); + state->eps[i]->completed_staged_atomics = 0; + + status = fi_cq_open(domain, &cq_attr, &state->eps[i]->cq, NULL); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "Unable to open completion queue for endpoint: %d: %s\n", status, + fi_strerror(status * -1)); + + status = fi_cntr_open(domain, &cntr_attr, &state->eps[i]->counter, NULL); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "Unable to open counter for endpoint: %d: %s\n", status, + fi_strerror(status * -1)); + status = fi_endpoint(domain, state->prov_infos[i], &state->eps[i]->endpoint, NULL); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "Unable to allocate endpoint: %d: %s\n", status, + fi_strerror(status * -1)); /* FI_OPT_CUDA_API_PERMITTED was introduced in libfabric 1.18.0 */ if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { bool prohibit_cuda_api = false; - status = fi_setopt(&state->eps[i].endpoint->fid, FI_OPT_ENDPOINT, + status = fi_setopt(&state->eps[i]->endpoint->fid, FI_OPT_ENDPOINT, FI_OPT_CUDA_API_PERMITTED, &prohibit_cuda_api, sizeof(bool)); if (status == -FI_ENOPROTOOPT) { NVSHMEMI_WARN_PRINT( @@ -1609,112 +1603,90 @@ static int nvshmemt_libfabric_connect_endpoints(nvshmem_transport_t t, int *sele } } - status = fi_cq_open(state->domain, &cq_attr, &state->eps[i].cq, NULL); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Unable to open completion queue for endpoint: %d: %s\n", status, - fi_strerror(status * -1)); - - status = fi_cntr_open(state->domain, &cntr_attr, &state->eps[i].counter, NULL); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Unable to open counter for endpoint: %d: %s\n", status, - fi_strerror(status * -1)); - - status = fi_ep_bind(state->eps[i].endpoint, &state->addresses[i]->fid, 0); + /* Bind Resources To EP */ + status = fi_ep_bind(state->eps[i]->endpoint, &state->addresses[i]->fid, 0); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to bind endpoint to address vector: %d: %s\n", status, fi_strerror(status * -1)); - if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_VERBS) { - status = fi_ep_bind(state->eps[i].endpoint, &state->eps[i].cq->fid, - FI_SELECTIVE_COMPLETION | FI_TRANSMIT | FI_RECV); - } else if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_SLINGSHOT) { - status = fi_ep_bind(state->eps[i].endpoint, &state->eps[i].cq->fid, - FI_SELECTIVE_COMPLETION | FI_TRANSMIT); - } else if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { + if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_VERBS) + flags = FI_SELECTIVE_COMPLETION | FI_TRANSMIT | FI_RECV; + else if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_SLINGSHOT) + flags = FI_SELECTIVE_COMPLETION | FI_TRANSMIT; + else if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) /* EFA is documented as not supporting FI_SELECTIVE_COMPLETION */ - status = - fi_ep_bind(state->eps[i].endpoint, &state->eps[i].cq->fid, FI_TRANSMIT | FI_RECV); - } else { + flags = FI_TRANSMIT | FI_RECV; + else { NVSHMEMI_ERROR_PRINT( "Invalid provider identified. This should be impossible. " "Possible memory corruption in the state pointer?"); status = NVSHMEMX_ERROR_INTERNAL; goto out; } + + status = fi_ep_bind(state->eps[i]->endpoint, &state->eps[i]->cq->fid, flags); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to bind endpoint to completion queue: %d: %s\n", status, fi_strerror(status * -1)); -#ifdef NVSHMEM_USE_GDRCOPY - if (use_gdrcopy) { - status = fi_ep_bind(state->eps[i].endpoint, &state->eps[i].counter->fid, - FI_READ | FI_WRITE | FI_SEND); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Unable to bind endpoint to completion counter: %d: %s\n", status, - fi_strerror(status * -1)); - } else -#endif - { - int flags = FI_READ | FI_WRITE; - if (use_staged_atomics) { - flags |= FI_SEND; - } - status = fi_ep_bind(state->eps[i].endpoint, &state->eps[i].counter->fid, flags); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Unable to bind endpoint to completion counter: %d: %s\n", status, - fi_strerror(status * -1)); - } + flags = FI_READ | FI_WRITE; + if (use_staged_atomics) flags |= FI_SEND; + status = fi_ep_bind(state->eps[i]->endpoint, &state->eps[i]->counter->fid, flags); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "Unable to bind endpoint to completion counter: %d: %s\n", status, + fi_strerror(status * -1)); - status = fi_enable(state->eps[i].endpoint); + status = fi_enable(state->eps[i]->endpoint); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to enable endpoint: %d: %s\n", status, fi_strerror(status * -1)); if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { - for (size_t j = 0; j < num_recvs_per_pe; j++) { - nvshmemt_libfabric_gdr_op_ctx_t *op; - op = (nvshmemt_libfabric_gdr_op_ctx_t *)state->recv_buf; - op = op + ((num_recvs_per_pe * i) + j); + nvshmemt_libfabric_gdr_op_ctx_t *op; + op = (nvshmemt_libfabric_gdr_op_ctx_t *)state->recv_buf[i]; + for (size_t j = 0; j < num_recvs_per_ep; j++, op++) { assert(op != NULL); - status = fi_recv(state->eps[i].endpoint, op, NVSHMEM_STAGED_AMO_WIREDATA_SIZE, - fi_mr_desc(state->mr), FI_ADDR_UNSPEC, &op->ofi_context); + status = fi_recv(state->eps[i]->endpoint, op, NVSHMEM_STAGED_AMO_WIREDATA_SIZE, + fi_mr_desc(state->mr[i]), FI_ADDR_UNSPEC, &op->ofi_context); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "Unable to post recv to ep. Error: %d: %s\n", status, + fi_strerror(status * -1)); } - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Unable to post recv to ep. Error: %d: %s\n", status, - fi_strerror(status * -1)); } - status = fi_getname(&state->eps[i].endpoint->fid, local_ep_names[i].name, &ep_namelen); + status = fi_getname(&state->eps[i]->endpoint->fid, local_ep_names[i].name, &ep_namelen); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to get name for endpoint: %d: %s\n", status, fi_strerror(status * -1)); - if (ep_namelen > NVSHMEMT_LIBFABRIC_EP_LEN) { + if (ep_namelen > NVSHMEMT_LIBFABRIC_EP_LEN) NVSHMEMI_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Name of EP is too long."); - } } + /* Perform out of band address exchange */ status = t->boot_handle->allgather( local_ep_names, all_ep_names, - NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS * sizeof(nvshmemt_libfabric_ep_name_t), t->boot_handle); + state->num_selected_domains * sizeof(nvshmemt_libfabric_ep_name_t), t->boot_handle); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Failed to gather endpoint names.\n"); /* We need to insert one at a time since each buffer is larger than the address. */ - for (int j = 0; j < NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; j++) { - for (int i = 0; i < total_num_eps; i++) { - status = fi_av_insert(state->addresses[j], &all_ep_names[i], 1, NULL, 0, NULL); + for (int i = 0; i < state->num_selected_domains; i++) { + for (int j = 0; j < total_num_eps; j++) { + status = fi_av_insert(state->addresses[i], &all_ep_names[j], 1, NULL, 0, NULL); if (status < 1) { NVSHMEMI_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to insert ep names in address vector: %d: %s\n", status, fi_strerror(status * -1)); } - status = NVSHMEMX_SUCCESS; } } + /* Out of bounds exchange a pre-registered write w/imm target for staged_amo acks */ if (use_staged_atomics) { state->remote_addr_staged_amo_ack = (void **)calloc(sizeof(void *), t->n_pes); + state->rkey_staged_amo_ack = + (uint64_t *)calloc(sizeof(uint64_t), t->n_pes * state->num_selected_domains); NVSHMEMI_NULL_ERROR_JMP(state->remote_addr_staged_amo_ack, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, "Unable to allocate remote address array for staged atomic ack.\n"); @@ -1723,13 +1695,15 @@ static int nvshmemt_libfabric_connect_endpoints(nvshmem_transport_t t, int *sele NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to allocate CUDA memory for staged atomic ack.\n"); - status = fi_mr_reg(state->domain, state->remote_addr_staged_amo_ack[t->my_pe], sizeof(int), - FI_REMOTE_WRITE, 0, 0, 0, &state->mr_staged_amo_ack, NULL); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Failed to register EFA msg buffer: %d: %s\n", status, - fi_strerror(status * -1)); - state->rkey_staged_amo_ack = (uint64_t *)calloc(sizeof(uint64_t), t->n_pes); - state->rkey_staged_amo_ack[t->my_pe] = fi_mr_key(state->mr_staged_amo_ack); + for (size_t i = 0; i < state->domains.size(); i++) { + status = fi_mr_reg(state->domains[i], state->remote_addr_staged_amo_ack[t->my_pe], + sizeof(int), FI_REMOTE_WRITE, 0, 0, 0, &mr, NULL); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "Failed to register EFA msg buffer: %d: %s\n", status, + fi_strerror(status * -1)); + state->rkey_staged_amo_ack[t->my_pe * state->num_selected_domains + i] = fi_mr_key(mr); + state->mr_staged_amo_ack.push_back(mr); + } status = t->boot_handle->allgather(&state->remote_addr_staged_amo_ack[t->my_pe], state->remote_addr_staged_amo_ack, sizeof(void *), @@ -1737,9 +1711,10 @@ static int nvshmemt_libfabric_connect_endpoints(nvshmem_transport_t t, int *sele NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Failed to gather remote addresses.\n"); - status = - t->boot_handle->allgather(&state->rkey_staged_amo_ack[t->my_pe], - state->rkey_staged_amo_ack, sizeof(uint64_t), t->boot_handle); + status = t->boot_handle->allgather( + &state->rkey_staged_amo_ack[t->my_pe * state->num_selected_domains], + state->rkey_staged_amo_ack, sizeof(uint64_t) * state->num_selected_domains, + t->boot_handle); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Failed to gather remote keys.\n"); } @@ -1752,30 +1727,28 @@ static int nvshmemt_libfabric_connect_endpoints(nvshmem_transport_t t, int *sele free(state->remote_addr_staged_amo_ack); } if (state->rkey_staged_amo_ack) free(state->rkey_staged_amo_ack); - if (state->mr_staged_amo_ack) fi_close(&state->mr_staged_amo_ack->fid); - if (state->eps) { - for (int i = 0; i < NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; i++) { - if (state->eps[i].proxy_put_signal_comp_map) - delete state->eps[i].proxy_put_signal_comp_map; - if (state->eps[i].endpoint) { - fi_close(&state->eps[i].endpoint->fid); - state->eps[i].endpoint = NULL; - } - if (state->eps[i].cq) { - fi_close(&state->eps[i].cq->fid); - state->eps[i].cq = NULL; - } - if (state->eps[i].counter) { - fi_close(&state->eps[i].counter->fid); - state->eps[i].counter = NULL; - } + for (size_t i = 0; i < state->mr_staged_amo_ack.size(); i++) + fi_close(&state->mr_staged_amo_ack[i]->fid); + for (size_t i = 0; i < state->eps.size(); i++) { + if (state->eps[i]->proxy_put_signal_comp_map) + delete state->eps[i]->proxy_put_signal_comp_map; + if (state->eps[i]->endpoint) { + fi_close(&state->eps[i]->endpoint->fid); + state->eps[i]->endpoint = NULL; + } + if (state->eps[i]->cq) { + fi_close(&state->eps[i]->cq->fid); + state->eps[i]->cq = NULL; } - free(state->eps); - state->eps = NULL; + if (state->eps[i]->counter) { + fi_close(&state->eps[i]->counter->fid); + state->eps[i]->counter = NULL; + } + free(state->eps[i]); + state->eps[i] = NULL; } } -out_already_connected: free(local_ep_names); free(all_ep_names); @@ -1783,12 +1756,12 @@ static int nvshmemt_libfabric_connect_endpoints(nvshmem_transport_t t, int *sele } static int nvshmemt_libfabric_finalize(nvshmem_transport_t transport) { - nvshmemt_libfabric_state_t *libfabric_state; + nvshmemt_libfabric_state_t *state; int status; assert(transport); - libfabric_state = (nvshmemt_libfabric_state_t *)transport->state; + state = (nvshmemt_libfabric_state_t *)transport->state; if (transport->device_pci_paths) { for (int i = 0; i < transport->n_devices; i++) { @@ -1800,19 +1773,19 @@ static int nvshmemt_libfabric_finalize(nvshmem_transport_t transport) { size_t mem_handle_cache_size; nvshmemt_libfabric_memhandle_info_t *handle_info = NULL, *previous_handle_info = NULL; - if (libfabric_state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { - mem_handle_cache_size = nvshmemt_mem_handle_cache_get_size(libfabric_state->cache); + if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { + mem_handle_cache_size = nvshmemt_mem_handle_cache_get_size(state->cache); for (size_t i = 0; i < mem_handle_cache_size; i++) { handle_info = (nvshmemt_libfabric_memhandle_info_t *)nvshmemt_mem_handle_cache_get_by_idx( - libfabric_state->cache, i); + state->cache, i); if (handle_info && handle_info != previous_handle_info) { free(handle_info); } previous_handle_info = handle_info; } - nvshmemt_mem_handle_cache_fini(libfabric_state->cache); + nvshmemt_mem_handle_cache_fini(state->cache); #ifdef NVSHMEM_USE_GDRCOPY if (use_gdrcopy) { nvshmemt_gdrcopy_ftable_fini(&gdrcopy_ftable, &gdr_desc, &gdrcopy_handle); @@ -1820,95 +1793,96 @@ static int nvshmemt_libfabric_finalize(nvshmem_transport_t transport) { #endif } - if (libfabric_state->prov_info) { - fi_freeinfo(libfabric_state->prov_info); - } + /* + * Since fi_dupinfo() is not called, we don't need to clean + * we do not need to clean prov_infos + */ + if (state->all_prov_info) fi_freeinfo(state->all_prov_info); - if (libfabric_state->eps) { - for (int i = 0; i < NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; i++) { - if (libfabric_state->eps[i].proxy_put_signal_comp_map) - delete libfabric_state->eps[i].proxy_put_signal_comp_map; - if (libfabric_state->eps[i].endpoint) { - status = fi_close(&libfabric_state->eps[i].endpoint->fid); - if (status) { - NVSHMEMI_WARN_PRINT("Unable to close fabric endpoint.: %d: %s\n", status, - fi_strerror(status * -1)); - } + for (size_t i = 0; i < state->eps.size(); i++) { + if (state->eps[i]->proxy_put_signal_comp_map) + delete state->eps[i]->proxy_put_signal_comp_map; + if (state->eps[i]->endpoint) { + status = fi_close(&state->eps[i]->endpoint->fid); + if (status) { + NVSHMEMI_WARN_PRINT("Unable to close fabric endpoint.: %d: %s\n", status, + fi_strerror(status * -1)); } - if (libfabric_state->eps[i].cq) { - status = fi_close(&libfabric_state->eps[i].cq->fid); - if (status) { - NVSHMEMI_WARN_PRINT("Unable to close fabric cq: %d: %s\n", status, - fi_strerror(status * -1)); - } + } + if (state->eps[i]->cq) { + status = fi_close(&state->eps[i]->cq->fid); + if (status) { + NVSHMEMI_WARN_PRINT("Unable to close fabric cq: %d: %s\n", status, + fi_strerror(status * -1)); } - if (libfabric_state->eps[i].counter) { - status = fi_close(&libfabric_state->eps[i].counter->fid); - if (status) { - NVSHMEMI_WARN_PRINT("Unable to close fabric counter: %d: %s\n", status, - fi_strerror(status * -1)); - } + } + if (state->eps[i]->counter) { + status = fi_close(&state->eps[i]->counter->fid); + if (status) { + NVSHMEMI_WARN_PRINT("Unable to close fabric counter: %d: %s\n", status, + fi_strerror(status * -1)); } } - free(libfabric_state->eps); + free(state->eps[i]); } - if (libfabric_state->remote_addr_staged_amo_ack) { - if (libfabric_state->remote_addr_staged_amo_ack[transport->my_pe]) - cudaFree(libfabric_state->remote_addr_staged_amo_ack[transport->my_pe]); - free(libfabric_state->remote_addr_staged_amo_ack); + if (state->remote_addr_staged_amo_ack) { + if (state->remote_addr_staged_amo_ack[transport->my_pe]) + cudaFree(state->remote_addr_staged_amo_ack[transport->my_pe]); + free(state->remote_addr_staged_amo_ack); } - if (libfabric_state->rkey_staged_amo_ack) free(libfabric_state->rkey_staged_amo_ack); - if (libfabric_state->mr_staged_amo_ack) { - status = fi_close(&libfabric_state->mr_staged_amo_ack->fid); + if (state->rkey_staged_amo_ack) free(state->rkey_staged_amo_ack); + for (size_t i = 0; i < state->mr_staged_amo_ack.size(); i++) { + status = fi_close(&state->mr_staged_amo_ack[i]->fid); if (status) { NVSHMEMI_WARN_PRINT("Unable to close staged atomic ack MR: %d: %s\n", status, fi_strerror(status * -1)); } } - if (libfabric_state->mr) { - status = fi_close(&libfabric_state->mr->fid); + for (size_t i = 0; i < state->mr.size(); i++) { + status = fi_close(&state->mr[i]->fid); if (status) { NVSHMEMI_WARN_PRINT("Unable to close fabric MR: %d: %s\n", status, fi_strerror(status * -1)); } } - if (libfabric_state->recv_buf) free(libfabric_state->recv_buf); - for (int i = 0; i < NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; i++) { - status = fi_close(&libfabric_state->addresses[i]->fid); + for (size_t i = 0; i < state->recv_buf.size(); i++) free(state->recv_buf[i]); + + for (size_t i = 0; i < state->addresses.size(); i++) { + status = fi_close(&state->addresses[i]->fid); if (status) { NVSHMEMI_WARN_PRINT("Unable to close fabric address vector: %d: %s\n", status, fi_strerror(status * -1)); } } - if (libfabric_state->domain) { - status = fi_close(&libfabric_state->domain->fid); + for (size_t i = 0; i < state->domains.size(); i++) { + status = fi_close(&state->domains[i]->fid); if (status) { NVSHMEMI_WARN_PRINT("Unable to close fabric domain: %d: %s\n", status, fi_strerror(status * -1)); } } - if (libfabric_state->fabric) { - status = fi_close(&libfabric_state->fabric->fid); + for (size_t i = 0; i < state->fabrics.size(); i++) { + status = fi_close(&state->fabrics[i]->fid); if (status) { NVSHMEMI_WARN_PRINT("Unable to close fabric: %d: %s\n", status, fi_strerror(status * -1)); } } - free(libfabric_state); - + free(state); free(transport); return 0; } -static int nvshmemi_libfabric_init_state(nvshmem_transport_t t, nvshmemt_libfabric_state_t *state) { - struct fi_info info; +static int nvshmemi_libfabric_init_state(nvshmem_transport_t t, nvshmemt_libfabric_state_t *state, + struct nvshmemi_options_s *options) { + struct fi_info hints; struct fi_tx_attr tx_attr; struct fi_rx_attr rx_attr; struct fi_ep_attr ep_attr; @@ -1916,65 +1890,88 @@ static int nvshmemi_libfabric_init_state(nvshmem_transport_t t, nvshmemt_libfabr struct fi_fabric_attr fabric_attr; struct fid_nic nic; struct fi_av_attr av_attr; - struct fi_info *returned_fabrics, *current_fabric; + struct fi_info *all_infos, *current_info; int num_fabrics_returned = 0; int status = 0; memset(&ep_attr, 0, sizeof(struct fi_ep_attr)); memset(&av_attr, 0, sizeof(struct fi_av_attr)); - memset(&info, 0, sizeof(struct fi_info)); + memset(&hints, 0, sizeof(struct fi_info)); memset(&tx_attr, 0, sizeof(struct fi_tx_attr)); memset(&rx_attr, 0, sizeof(struct fi_rx_attr)); memset(&domain_attr, 0, sizeof(struct fi_domain_attr)); memset(&fabric_attr, 0, sizeof(struct fi_fabric_attr)); memset(&nic, 0, sizeof(struct fid_nic)); - info.tx_attr = &tx_attr; - info.rx_attr = &rx_attr; - info.ep_attr = &ep_attr; - info.domain_attr = &domain_attr; - info.fabric_attr = &fabric_attr; - info.nic = &nic; + hints.tx_attr = &tx_attr; + hints.rx_attr = &rx_attr; + hints.ep_attr = &ep_attr; + hints.domain_attr = &domain_attr; + hints.fabric_attr = &fabric_attr; + hints.nic = &nic; - info.addr_format = FI_FORMAT_UNSPEC; - info.caps = FI_RMA | FI_HMEM; + hints.addr_format = FI_FORMAT_UNSPEC; + hints.caps = FI_RMA | FI_HMEM; if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_VERBS) { - info.caps |= FI_ATOMIC; + hints.caps |= FI_ATOMIC; domain_attr.mr_mode = FI_MR_VIRT_ADDR | FI_MR_ALLOCATED | FI_MR_PROV_KEY; } else if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_SLINGSHOT) { /* TODO: Use FI_FENCE to optimize put_with_signal */ - info.caps |= FI_FENCE | FI_ATOMIC; + hints.caps |= FI_FENCE | FI_ATOMIC; domain_attr.mr_mode = FI_MR_ENDPOINT | FI_MR_ALLOCATED | FI_MR_PROV_KEY; } else if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { domain_attr.mr_mode = FI_MR_LOCAL | FI_MR_VIRT_ADDR | FI_MR_ALLOCATED | FI_MR_PROV_KEY | FI_MR_HMEM; - info.caps |= FI_MSG; - info.caps |= FI_SOURCE; + hints.caps |= FI_MSG; + hints.caps |= FI_SOURCE; } if (use_staged_atomics) { - info.mode |= FI_CONTEXT2; + hints.mode |= FI_CONTEXT2; } - /* Be thread safe at the level of the endpoint completion context. */ - domain_attr.threading = FI_THREAD_SAFE; - + ep_attr.type = FI_EP_RDM; /* Reliable datagrams */ /* Require completion RMA completion at target for correctness of quiet */ - info.tx_attr->op_flags = FI_DELIVERY_COMPLETE; + hints.tx_attr->op_flags = FI_DELIVERY_COMPLETE; - ep_attr.type = FI_EP_RDM; // Reliable datagrams + /* nvshmemt_libfabric_auto_progress relaxes threading requirement */ + domain_attr.threading = FI_THREAD_COMPLETION; + hints.domain_attr->data_progress = FI_PROGRESS_AUTO; status = fi_getinfo(FI_VERSION(NVSHMEMT_LIBFABRIC_MAJ_VER, NVSHMEMT_LIBFABRIC_MIN_VER), NULL, - NULL, 0, &info, &returned_fabrics); + NULL, 0, &hints, &all_infos); + + /* + * 1. Ensure that at least one fabric was returned + * 2. Make sure returned fabric matches the name of selected provider + * + * This has an assumption that the provided fabric option + * options.LIBFABRIC_PROVIDER will be a substr of the returned fabric + * name + */ + if (!status && strstr(all_infos->fabric_attr->name, options->LIBFABRIC_PROVIDER)) { + use_auto_progress = true; + } else { + fi_freeinfo(all_infos); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "No providers matched fi_getinfo query: %d: %s\n", status, - fi_strerror(status * -1)); - state->all_prov_info = returned_fabrics; - for (current_fabric = returned_fabrics; current_fabric != NULL; - current_fabric = current_fabric->next) { + /* + * Fallback to FI_PROGRESS_MANUAL path + * nvshmemt_libfabric_slow_progress requires FI_THREAD_SAFE + */ + domain_attr.threading = FI_THREAD_SAFE; + hints.domain_attr->data_progress = FI_PROGRESS_MANUAL; + status = fi_getinfo(FI_VERSION(NVSHMEMT_LIBFABRIC_MAJ_VER, NVSHMEMT_LIBFABRIC_MIN_VER), + NULL, NULL, 0, &hints, &all_infos); + + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "No providers matched fi_getinfo query: %d: %s\n", status, + fi_strerror(status * -1)); + } + + state->all_prov_info = all_infos; + for (current_info = all_infos; current_info != NULL; current_info = current_info->next) { num_fabrics_returned++; } @@ -1985,53 +1982,51 @@ static int nvshmemi_libfabric_init_state(nvshmem_transport_t t, nvshmemt_libfabr /* Only select unique devices. */ state->num_domains = 0; - for (current_fabric = returned_fabrics; current_fabric != NULL; - current_fabric = current_fabric->next) { - if (!current_fabric->nic) { + for (current_info = all_infos; current_info != NULL; current_info = current_info->next) { + if (!current_info->nic) { INFO(state->log_level, "Interface did not return NIC structure to fi_getinfo. Skipping.\n"); continue; } - if (!current_fabric->tx_attr) { + if (!current_info->tx_attr) { INFO(state->log_level, "Interface did not return TX_ATTR structure to fi_getinfo. Skipping.\n"); continue; } TRACE(state->log_level, "fi_getinfo returned provider %s, fabric %s, nic %s", - current_fabric->fabric_attr->prov_name, current_fabric->fabric_attr->name, - current_fabric->nic->device_attr->name); + current_info->fabric_attr->prov_name, current_info->fabric_attr->name, + current_info->nic->device_attr->name); if (state->provider != NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { - if (current_fabric->tx_attr->inject_size < NVSHMEMT_LIBFABRIC_INJECT_BYTES) { + if (current_info->tx_attr->inject_size < NVSHMEMT_LIBFABRIC_INJECT_BYTES) { INFO(state->log_level, "Disabling interface due to insufficient inject data size. reported %lu, " "expected " "%u", - current_fabric->tx_attr->inject_size, NVSHMEMT_LIBFABRIC_INJECT_BYTES); + current_info->tx_attr->inject_size, NVSHMEMT_LIBFABRIC_INJECT_BYTES); continue; } } - if ((current_fabric->domain_attr->mr_mode & FI_MR_PROV_KEY) == 0) { + if ((current_info->domain_attr->mr_mode & FI_MR_PROV_KEY) == 0) { INFO(state->log_level, "Disabling interface due to FI_MR_PROV_KEY support"); continue; } for (int i = 0; i <= state->num_domains; i++) { - if (!strncmp(current_fabric->nic->device_attr->name, state->domain_names[i].name, + if (!strncmp(current_info->nic->device_attr->name, state->domain_names[i].name, NVSHMEMT_LIBFABRIC_DOMAIN_LEN)) { break; } else if (i == state->num_domains) { - size_t name_len = strlen(current_fabric->nic->device_attr->name); + size_t name_len = strlen(current_info->nic->device_attr->name); if (name_len >= NVSHMEMT_LIBFABRIC_DOMAIN_LEN) { NVSHMEMI_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to copy domain name for libfabric transport."); } (void)strncpy(state->domain_names[state->num_domains].name, - current_fabric->nic->device_attr->name, - NVSHMEMT_LIBFABRIC_DOMAIN_LEN); + current_info->nic->device_attr->name, NVSHMEMT_LIBFABRIC_DOMAIN_LEN); state->num_domains++; break; } @@ -2053,8 +2048,6 @@ static int nvshmemi_libfabric_init_state(nvshmem_transport_t t, nvshmemt_libfabr nvshmemt_libfabric_finalize(t); } - free(info.fabric_attr->name); - return status; } @@ -2091,9 +2084,6 @@ int nvshmemt_init(nvshmem_transport_t *t, struct nvshmemi_cuda_fn_table *table, transport->host_ops.quiet = nvshmemt_libfabric_quiet; transport->host_ops.finalize = nvshmemt_libfabric_finalize; transport->host_ops.show_info = nvshmemt_libfabric_show_info; - transport->host_ops.progress = nvshmemt_libfabric_progress; - transport->host_ops.enforce_cst = nvshmemt_libfabric_enforce_cst; - transport->attr = NVSHMEM_TRANSPORT_ATTR_CONNECTED; transport->is_successfully_initialized = true; @@ -2106,6 +2096,7 @@ int nvshmemt_init(nvshmem_transport_t *t, struct nvshmemi_cuda_fn_table *table, "Unable to initialize env options."); libfabric_state->log_level = nvshmemt_common_get_log_level(&options); + libfabric_state->max_nic_per_pe = options.LIBFABRIC_MAX_NIC_PER_PE; if (strcmp(options.LIBFABRIC_PROVIDER, "verbs") == 0) { libfabric_state->provider = NVSHMEMT_LIBFABRIC_PROVIDER_VERBS; @@ -2209,12 +2200,17 @@ int nvshmemt_init(nvshmem_transport_t *t, struct nvshmemi_cuda_fn_table *table, #undef NVSHMEMI_SET_ENV_VAR /* Prepare fabric state information. */ - status = nvshmemi_libfabric_init_state(transport, libfabric_state); + status = nvshmemi_libfabric_init_state(transport, libfabric_state, &options); if (status) { NVSHMEMI_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out_clean, "Failed to initialize the libfabric state.\n"); } + if (use_auto_progress) + transport->host_ops.progress = nvshmemt_libfabric_auto_proxy_progress; + else + transport->host_ops.progress = nvshmemt_libfabric_manual_progress; + *t = transport; out: if (status) { diff --git a/src/modules/transport/libfabric/libfabric.h b/src/modules/transport/libfabric/libfabric.h index f2b5931..c37729e 100644 --- a/src/modules/transport/libfabric/libfabric.h +++ b/src/modules/transport/libfabric/libfabric.h @@ -30,12 +30,9 @@ #define NVSHMEMT_LIBFABRIC_DOMAIN_LEN 32 #define NVSHMEMT_LIBFABRIC_PROVIDER_LEN 32 #define NVSHMEMT_LIBFABRIC_EP_LEN 128 - -/* one EP for all proxy ops, one for host ops */ -#define NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS 2 +/* Constrainted by memhandle size */ +#define NVSHMEMT_LIBFABRIC_MAX_NIC_PER_PE 16 #define NVSHMEMT_LIBFABRIC_PROXY_EP_IDX 1 -#define NVSHMEMT_LIBFABRIC_HOST_EP_IDX 0 - #define NVSHMEMT_LIBFABRIC_QUIET_TIMEOUT_MS 20 /* Maximum size of inject data. Currently @@ -184,6 +181,7 @@ typedef struct { nvshmemt_libfabric_endpoint_seq_counter_t put_signal_seq_counter; std::unordered_map> *proxy_put_signal_comp_map; + int domain_index; } nvshmemt_libfabric_endpoint_t; typedef enum { @@ -211,6 +209,10 @@ class threadSafeOpQueue { std::deque recv; public: + threadSafeOpQueue() = default; + threadSafeOpQueue(const threadSafeOpQueue &) = delete; + threadSafeOpQueue &operator=(const threadSafeOpQueue &) = delete; + void *getNextSend() { void *elem; send_mutex.lock(); @@ -272,31 +274,34 @@ class threadSafeOpQueue { }; typedef struct { - struct fi_info *prov_info; struct fi_info *all_prov_info; - struct fid_fabric *fabric; - struct fid_domain *domain; - struct fid_av *addresses[NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS]; - nvshmemt_libfabric_endpoint_t *eps; - /* local_mr is used only for consistency ops. */ - struct fid_mr *local_mr[2]; - uint64_t local_mr_key[2]; - void *local_mr_desc[2]; - void *local_mem_ptr; + std::vector prov_infos; + std::vector fabrics; + std::vector domains; + std::vector addresses; + std::vector eps; + nvshmemt_libfabric_domain_name_t *domain_names; int num_domains; nvshmemt_libfabric_provider provider; int log_level; struct nvshmemi_cuda_fn_table *table; - size_t num_sends; - void *send_buf; - size_t num_recvs; - void *recv_buf; - struct fid_mr *mr; struct transport_mem_handle_info_cache *cache; + + /* Required for multi-rail */ + int max_nic_per_pe; + int num_selected_devs; + int num_selected_domains; + int cur_proxy_ep_index; + + /* Required for staged_amo */ + std::vector op_queue; + std::vector mr; + std::vector send_buf; + std::vector recv_buf; + std::vector mr_staged_amo_ack; void **remote_addr_staged_amo_ack; uint64_t *rkey_staged_amo_ack; - struct fid_mr *mr_staged_amo_ack; } nvshmemt_libfabric_state_t; typedef enum { @@ -323,7 +328,7 @@ typedef struct { typedef struct { void *buf; - nvshmemt_libfabric_mem_handle_ep_t hdls[2]; + nvshmemt_libfabric_mem_handle_ep_t hdls[1 + NVSHMEMT_LIBFABRIC_MAX_NIC_PER_PE]; } nvshmemt_libfabric_mem_handle_t; typedef struct nvshmemt_libfabric_gdr_send_p_op { diff --git a/src/modules/transport/ucx/ucx.cpp b/src/modules/transport/ucx/ucx.cpp index 271ed69..4959d0b 100644 --- a/src/modules/transport/ucx/ucx.cpp +++ b/src/modules/transport/ucx/ucx.cpp @@ -1180,67 +1180,6 @@ int nvshmemt_ucx_finalize(nvshmem_transport_t transport) { return 0; } -int nvshmemt_ucx_enforce_cst_at_target(struct nvshmem_transport *tcurr) { - transport_ucx_state_t *ucx_state = (transport_ucx_state_t *)tcurr->state; - nvshmemt_ucx_mem_handle_info_t *mem_handle_info; - - mem_handle_info = - (nvshmemt_ucx_mem_handle_info_t *)nvshmemt_mem_handle_cache_get_by_idx(ucx_state->cache, 0); - - if (!mem_handle_info) return 0; -#ifdef NVSHMEM_USE_GDRCOPY - if (use_gdrcopy) { - int temp; - gdrcopy_ftable.copy_from_mapping(mem_handle_info->mh, &temp, mem_handle_info->cpu_ptr, - sizeof(int)); - return 0; - } -#endif - int mype = tcurr->my_pe; - int ep_index = (ucx_state->ep_count * mype + ucx_state->proxy_ep_idx); - ucp_ep_h ep = ucx_state->endpoints[ep_index]; - ucp_request_param_t param; - ucs_status_ptr_t ucs_ptr_rc = NULL; - ucs_status_t ucs_rc; - nvshmemt_ucx_mem_handle_t *mem_handle; - ucp_rkey_h rkey; - int local_int; - - mem_handle = mem_handle_info->mem_handle; - if (unlikely(mem_handle->ep_rkey_host == NULL)) { - ucs_rc = ucp_ep_rkey_unpack(ep, mem_handle->rkey_packed_buf, &mem_handle->ep_rkey_host); - if (ucs_rc != UCS_OK) { - NVSHMEMI_ERROR_EXIT("Unable to unpack rkey in UCS transport! Exiting.\n"); - } - } - rkey = mem_handle->ep_rkey_host; - - param.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK; - param.cb.send = nvshmemt_ucx_send_request_cb; - - ucs_ptr_rc = - ucp_get_nbx(ep, &local_int, sizeof(int), (uint64_t)mem_handle_info->ptr, rkey, ¶m); - - /* Wait for completion of get. */ - if (ucs_ptr_rc != NULL) { - if (UCS_PTR_IS_ERR(ucs_ptr_rc)) { - NVSHMEMI_ERROR_PRINT("UCX CST request completed with error.\n"); - return NVSHMEMX_ERROR_INTERNAL; - } else { - do { - ucs_rc = ucp_request_check_status(ucs_ptr_rc); - ucp_worker_progress(ucx_state->worker_context); - } while (ucs_rc == UCS_INPROGRESS); - if (ucs_rc != UCS_OK) { - NVSHMEMI_ERROR_PRINT("UCX CST request completed with error.\n"); - return NVSHMEMX_ERROR_INTERNAL; - } - } - } - - return 0; -} - int nvshmemt_ucx_show_info(struct nvshmem_transport *transport, int style) { NVSHMEMI_ERROR_PRINT("UCX show info not implemented"); return 0; @@ -1446,7 +1385,6 @@ int nvshmemt_init(nvshmem_transport_t *t, struct nvshmemi_cuda_fn_table *table, transport->host_ops.finalize = nvshmemt_ucx_finalize; transport->host_ops.show_info = nvshmemt_ucx_show_info; transport->host_ops.progress = nvshmemt_ucx_progress; - transport->host_ops.enforce_cst = nvshmemt_ucx_enforce_cst_at_target; transport->host_ops.enforce_cst_at_target = NULL; transport->host_ops.put_signal = nvshmemt_put_signal; transport->attr = NVSHMEM_TRANSPORT_ATTR_CONNECTED;