Skip to content

Commit

Permalink
refactor(api): replace addr+len pairs with MRs
Browse files Browse the repository at this point in the history
Immediately on any top-level NCCL call, construct an immutable
nccl_ofi_mr_input_t on the stack, then pass that to communicator regmr
implementations. Add a flags argument to internal regmr functions such
that the input can be inspected and may add FI_MR_DMABUF if the input
arguments correspond to a file descriptor.

Implement top-level nccl_net_ofi_regMr in terms of
nccl_net_ofi_regMrDmaBuf, simply forwarding arguments alongside an
invalid file descriptor (-1) and a zero offset.

DMA-BUF remains unsupported as of this commit, but only due to not
advertising support back to NCCL/nccom.

Signed-off-by: Nicholas Sielicki <nslick@amazon.com>
  • Loading branch information
aws-nslick committed Sep 22, 2024
1 parent 448b311 commit a0c1dfe
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 88 deletions.
8 changes: 4 additions & 4 deletions include/nccl_ofi.h
Original file line number Diff line number Diff line change
Expand Up @@ -396,8 +396,8 @@ struct nccl_net_ofi_send_comm {
* @return 0 on success
* non-zero on error
*/
int (*regMr)(nccl_net_ofi_send_comm_t *send_comm, void *data, size_t size, int type,
void **mhandle);
int (*regMr)(nccl_net_ofi_send_comm_t *send_comm, const nccl_ofi_mr_input_t *in, int type,
void **mhandle);

/*
* @brief Deregister memory region on send communicator (both Host and CUDA)
Expand Down Expand Up @@ -429,8 +429,8 @@ struct nccl_net_ofi_recv_comm {
* @return 0 on success
* non-zero on error
*/
int (*regMr)(nccl_net_ofi_recv_comm_t *recv_comm, void *data, size_t size, int type,
void **mhandle);
int (*regMr)(nccl_net_ofi_recv_comm_t *send_comm, const nccl_ofi_mr_input_t *in, int type,
void **mhandle);

/*
* @brief Deregister memory region on recv communicator (both Host and CUDA)
Expand Down
47 changes: 36 additions & 11 deletions src/nccl_ofi_api.c
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,20 @@ ncclResult_t nccl_net_ofi_regMr_v7(void *comm, void *data, int size, int type,
ncclResult_t nccl_net_ofi_regMr(void *comm, void *data, size_t size, int type,
void **mhandle)
{
int ret = 0;
return nccl_net_ofi_regMrDmaBuf(comm,
data,
size,
type,
0, /* default value, no offset. */
-1, /* default value, invalid file descriptor. */
mhandle);
}

ncclResult_t nccl_net_ofi_regMrDmaBuf(void* comm, void* data, size_t size,
int type, uint64_t offset,
int fd, void** mhandle)
{
int ret;
/* Retrieve and validate comm */
nccl_net_ofi_comm_t *base_comm =
(nccl_net_ofi_comm_t *)comm;
Expand All @@ -318,16 +330,37 @@ ncclResult_t nccl_net_ofi_regMr(void *comm, void *data, size_t size, int type,
return ncclInternalError;
}

const nccl_ofi_mr_input_t mr_input_dmabuf = {
#if HAVE_DECL_FI_MR_DMABUF
.type = NCCL_OFI_MR_INPUT_DMABUF,
.fi_mr_dmabuf.fd = fd,
.fi_mr_dmabuf.offset = offset,
.fi_mr_dmabuf.len = size,
.fi_mr_dmabuf.base_addr = data
#endif
};

const nccl_ofi_mr_input_t mr_input_iovec = {
.type = NCCL_OFI_MR_INPUT_IOVEC,
.iovec.iov_base = data,
.iovec.iov_len = size
};

const nccl_ofi_mr_input_t *mr_input =
(fd == -1)
? &mr_input_iovec
: &mr_input_dmabuf;

switch (base_comm->type) {
case NCCL_NET_OFI_SEND_COMM:;
nccl_net_ofi_send_comm_t *send_comm =
(nccl_net_ofi_send_comm_t *)base_comm;
ret = send_comm->regMr(send_comm, data, size, type, mhandle);
ret = send_comm->regMr(send_comm, mr_input, type, mhandle);
break;
case NCCL_NET_OFI_RECV_COMM:;
nccl_net_ofi_recv_comm_t *recv_comm =
(nccl_net_ofi_recv_comm_t *)base_comm;
ret = recv_comm->regMr(recv_comm, data, size, type, mhandle);
ret = recv_comm->regMr(recv_comm, mr_input, type, mhandle);
break;
default:
NCCL_OFI_WARN("Unexpected communicator type. Communicator type: %d",
Expand Down Expand Up @@ -373,14 +406,6 @@ ncclResult_t nccl_net_ofi_deregMr(void *comm, void *mhandle)
}


ncclResult_t nccl_net_ofi_regMrDmaBuf(void* comm, void* data, size_t size,
int type, uint64_t offset,
int fd, void** mhandle)
{
return nccl_net_ofi_retval_translate(-ENOTSUP);
}


/*
* @brief Non-blocking accept which returns rComm as NULL
* with an expectation that it will be called again until
Expand Down
79 changes: 40 additions & 39 deletions src/nccl_ofi_rdma.c
Original file line number Diff line number Diff line change
Expand Up @@ -385,24 +385,28 @@ static int write_topo_file(nccl_ofi_topo_t *topo)
* non-zero on error
*/
static int set_mr_req_attr(nccl_ofi_idpool_t *key_pool, int dev_id,
void *data, size_t size, int type,
struct fi_mr_attr *mr_attr, struct iovec *iov)
const nccl_ofi_mr_input_t *in, uint64_t *flags,
int type, struct fi_mr_attr *mr_attr)
{
int ret = 0;

/* Populate IOV vector for memory registration */
iov->iov_base = data;
iov->iov_len = size;

/* Initialize MR attributes */
mr_attr->mr_iov = iov;
mr_attr->iov_count = 1;
mr_attr->access = FI_SEND | FI_RECV;

/* Add FI_WRITE (source of fi_write) and FI_REMOTE_WRITE (target of fi_write)
for RDMA send/recv buffers */
mr_attr->access |= (FI_WRITE | FI_REMOTE_WRITE);

if (in->type == NCCL_OFI_MR_INPUT_IOVEC) {
mr_attr->mr_iov = (struct iovec*)in;
mr_attr->iov_count = 1;
}
#if HAVE_DECL_FI_MR_DMABUF
else if (in->type == NCCL_OFI_MR_INPUT_DMABUF) {
*flags |= FI_MR_DMABUF;
mr_attr->dmabuf = (struct fi_mr_dmabuf*)in;
mr_attr->iov_count = 1;
}
#endif

switch (type) {
case NCCL_PTR_HOST:
mr_attr->access |= FI_READ;
Expand All @@ -414,7 +418,9 @@ static int set_mr_req_attr(nccl_ofi_idpool_t *key_pool, int dev_id,
mr_attr->iface = FI_HMEM_CUDA;

/* Get CUDA device ID */
ret = nccl_net_ofi_get_cuda_device_for_addr(data, &mr_attr->device.cuda);
ret = nccl_net_ofi_get_cuda_device_for_addr(
(void*)nccl_ofi_mr_input_base(in),
&mr_attr->device.cuda);
if (OFI_UNLIKELY(ret != 0)) {
goto exit;
}
Expand Down Expand Up @@ -455,11 +461,11 @@ static int set_mr_req_attr(nccl_ofi_idpool_t *key_pool, int dev_id,
static int register_rail_mr_buffer(struct fid_domain *domain,
struct fid_ep *ep, int dev_id,
int type, struct fi_mr_attr *mr_attr,
struct fid_mr **mr_handle)
uint64_t flags, struct fid_mr **mr_handle)
{
int ret = 0;

ret = fi_mr_regattr(domain, mr_attr, 0, mr_handle);
ret = fi_mr_regattr(domain, mr_attr, flags, mr_handle);
if (OFI_UNLIKELY(ret != 0)) {
NCCL_OFI_WARN("Unable to register memory (type = %d) for device %d. RC: %d, Error: %s",
type, dev_id, ret, fi_strerror(-ret));
Expand Down Expand Up @@ -2590,18 +2596,16 @@ static int dereg_mr_ep(nccl_net_ofi_rdma_mr_handle_t *mr_handle,
}

static inline int reg_mr_on_device(nccl_net_ofi_rdma_ep_t *ep,
void *data,
size_t size,
const nccl_ofi_mr_input_t *in,
int type,
nccl_net_ofi_rdma_mr_handle_t **mhandle)
{
int ret = 0;
nccl_net_ofi_rdma_mr_handle_t *ret_handle = NULL;
*mhandle = NULL;

struct iovec iov = {};
struct fid_domain *domain;
struct fi_mr_attr mr_attr = {};
uint64_t regattr_flags = 0;

/* Retrieve and validate device */
nccl_net_ofi_rdma_device_t *device =
Expand All @@ -2621,7 +2625,7 @@ static inline int reg_mr_on_device(nccl_net_ofi_rdma_ep_t *ep,
}

/* Create memory registration request */
ret = set_mr_req_attr(key_pool, dev_id, data, size, type, &mr_attr, &iov);
ret = set_mr_req_attr(key_pool, dev_id, in, &regattr_flags, type, &mr_attr);
if (OFI_UNLIKELY(ret != 0)) {
NCCL_OFI_WARN("Could not set registration request attributes, dev: %d",
dev_id);
Expand All @@ -2631,7 +2635,7 @@ static inline int reg_mr_on_device(nccl_net_ofi_rdma_ep_t *ep,
}

ret = register_rail_mr_buffer(ep->control_rail.domain, ep->control_rail.ofi_ep,
-1, type, &mr_attr,
-1, type, &mr_attr, regattr_flags,
&ret_handle->control_mr);
if (OFI_UNLIKELY(ret != 0)) {
free(ret_handle);
Expand All @@ -2646,7 +2650,7 @@ static inline int reg_mr_on_device(nccl_net_ofi_rdma_ep_t *ep,
domain = get_domain_from_endpoint(ep, rail_id);

ret = register_rail_mr_buffer(domain, rail->ofi_ep,
dev_id, type, &mr_attr,
dev_id, type, &mr_attr, regattr_flags,
&ret_handle->mr[rail_id]);
if (OFI_UNLIKELY(ret != 0)) {
if (dereg_mr_ep(ret_handle, key_pool, NULL) != 0) {
Expand Down Expand Up @@ -2678,8 +2682,7 @@ static inline int reg_mr_on_device(nccl_net_ofi_rdma_ep_t *ep,
* @return Memory registration handle
*/
static int reg_mr_ep(nccl_net_ofi_rdma_ep_t *ep,
void *data,
size_t size,
const nccl_ofi_mr_input_t *in,
int type,
nccl_ofi_mr_cache_t *mr_cache,
nccl_net_ofi_rdma_mr_handle_t **mhandle)
Expand All @@ -2695,19 +2698,14 @@ static int reg_mr_ep(nccl_net_ofi_rdma_ep_t *ep,
assert(device != NULL);

nccl_ofi_idpool_t *key_pool = &device->key_pool;

// XXX: removed in next commit.
nccl_ofi_mr_input_t mr_input = {};
if (mr_cache) {
// XXX: removed in next commit.
nccl_ofi_mr_input_fill(data, size, &mr_input);
/*
* MR cache is locked between lookup and insert, to be sure we
* insert a missing entry
*/
nccl_net_ofi_mutex_lock(&mr_cache->lock);
ret_handle = (nccl_net_ofi_rdma_mr_handle_t *)
nccl_ofi_mr_cache_lookup_entry(mr_cache, &mr_input);
nccl_ofi_mr_cache_lookup_entry(mr_cache, in);

if (ret_handle) {
/* Cache hit */
Expand All @@ -2716,14 +2714,14 @@ static int reg_mr_ep(nccl_net_ofi_rdma_ep_t *ep,
/* Cache miss */
}

ret = reg_mr_on_device(ep, data, size, type, &ret_handle);
ret = reg_mr_on_device(ep, in, type, &ret_handle);
if (OFI_UNLIKELY(ret != 0)) {
goto exit;
}

if (mr_cache) {
ret = nccl_ofi_mr_cache_insert_entry(mr_cache,
&mr_input,
in,
ret_handle);
if (OFI_UNLIKELY(ret != 0)) {
if (dereg_mr_ep(ret_handle, key_pool, NULL) != 0) {
Expand Down Expand Up @@ -2805,34 +2803,37 @@ static int reg_internal_mr_ep(nccl_net_ofi_rdma_ep_t *ep, void *data,
assert(NCCL_OFI_IS_PTR_ALIGNED(data, system_page_size));
assert(NCCL_OFI_IS_ALIGNED(size, system_page_size));

return reg_mr_ep(ep, data, size, type, NULL, mhandle);
nccl_ofi_mr_input_t in = {};
nccl_ofi_mr_input_fill(data, size, &in);

return reg_mr_ep(ep, &in, type, NULL, mhandle);
}

static int reg_mr_send_comm(nccl_net_ofi_send_comm_t *send_comm, void *data,
size_t size, int type, void **mhandle)
static int reg_mr_send_comm(nccl_net_ofi_send_comm_t *send_comm,
const nccl_ofi_mr_input_t *in,
int type, void **mhandle)
{
nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)send_comm->base.ep;
nccl_net_ofi_rdma_device_t *device = (nccl_net_ofi_rdma_device_t *)ep->base.device;
assert(device != NULL);

return reg_mr_ep(ep,
data,
size,
in,
type,
device->base.mr_cache,
(nccl_net_ofi_rdma_mr_handle_t **)mhandle);
}

static int reg_mr_recv_comm(nccl_net_ofi_recv_comm_t *recv_comm, void *data,
size_t size, int type, void **mhandle)
static int reg_mr_recv_comm(nccl_net_ofi_recv_comm_t *recv_comm,
const nccl_ofi_mr_input_t *in,
int type, void **mhandle)
{
nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)recv_comm->base.ep;
nccl_net_ofi_rdma_device_t *device = (nccl_net_ofi_rdma_device_t *)ep->base.device;
assert(device != NULL);

return reg_mr_ep(ep,
data,
size,
in,
type,
device->base.mr_cache,
(nccl_net_ofi_rdma_mr_handle_t **)mhandle);
Expand Down
Loading

0 comments on commit a0c1dfe

Please sign in to comment.