Skip to content

Commit

Permalink
Add Multiplexed-round-robin scheduler
Browse files Browse the repository at this point in the history
This commit modifies the scheduler algorithm to round-robin the
payload_msg on each QP or betwen pairs of QP's or triplets of
QP's or quadruplets of QP's based on the min_stripe_size.
  • Loading branch information
arunkarthik-akkart committed Sep 20, 2024
1 parent 159bfed commit a9cc48c
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 307 deletions.
7 changes: 4 additions & 3 deletions include/nccl_ofi_param.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,10 @@ OFI_NCCL_PARAM_INT(disable_native_rdma_check, "DISABLE_NATIVE_RDMA_CHECK", 0);
OFI_NCCL_PARAM_INT(disable_gdr_required_check, "DISABLE_GDR_REQUIRED_CHECK", 0);

/*
* Maximum size of a message in bytes before message is multiplexed
*/
OFI_NCCL_PARAM_INT(round_robin_threshold, "ROUND_ROBIN_THRESHOLD", (256 * 1024));
* Minimum data size in bytes before the message is to be striped across multiple
* rails.
*/
OFI_NCCL_PARAM_INT(min_stripe_size, "MIN_STRIPE_SIZE", (210 * 1024));

/*
* Minimum bounce buffers posted per endpoint. The plugin will attempt to post
Expand Down
17 changes: 9 additions & 8 deletions include/nccl_ofi_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ typedef struct nccl_net_ofi_threshold_scheduler {
unsigned int rr_counter;
/* Lock for round robin counter */
pthread_mutex_t rr_lock;
/* Maximum size of a message in bytes before message is
/* Minimum size of the message in bytes before message is
* multiplexed */
size_t rr_threshold;
size_t min_stripe_size;
} nccl_net_ofi_threshold_scheduler_t;

/*
Expand All @@ -109,16 +109,15 @@ void nccl_net_ofi_release_schedule(nccl_net_ofi_scheduler_t *scheduler,
*
* @param num_rails
* Number of rails
* @param rr_threshold
* Maximum size of a message in bytes before message is multiplexed
*
* @param min_stripe_size
* Minimum size of a message in bytes before message is multiplexed
* @return Scheduler, on success
* NULL, on error
* @return 0, on success
* non-zero, on error
*/
int nccl_net_ofi_threshold_scheduler_init(int num_rails,
size_t rr_threshold,
size_t min_stripe_size,
nccl_net_ofi_scheduler_t **scheduler);

/*
Expand All @@ -127,9 +126,11 @@ int nccl_net_ofi_threshold_scheduler_init(int num_rails,
* A mininal stripe size `max_stripe_size' is calculated (multiple of
* `align') that is sufficient to assign the whole message. Rails are
* filled from low id to large id. The last rail may get assigned less
* data.
* data. The number of rails are calculated based on the ratio of
* (`data_size` / `min_stripe_size`)
*/
void nccl_net_ofi_set_multiplexing_schedule(size_t size,
int nccl_net_ofi_set_schedule_by_threshold(nccl_net_ofi_threshold_scheduler_t *scheduler,
size_t size,
int num_rails,
size_t align,
nccl_net_ofi_schedule_t *schedule);
Expand Down
8 changes: 4 additions & 4 deletions src/nccl_ofi_rdma.c
Original file line number Diff line number Diff line change
Expand Up @@ -6940,7 +6940,7 @@ nccl_net_ofi_rdma_device_release(nccl_net_ofi_device_t *base_device)
static nccl_net_ofi_rdma_device_t *
nccl_net_ofi_rdma_device_create(nccl_net_ofi_plugin_t *plugin,
int dev_id, struct fi_info *info_list,
nccl_ofi_topo_t *topo, size_t rr_threshold)
nccl_ofi_topo_t *topo, size_t min_strip_size)
{
int ret;

Expand Down Expand Up @@ -6983,7 +6983,7 @@ nccl_net_ofi_rdma_device_create(nccl_net_ofi_plugin_t *plugin,

/* Create scheduler */
ret = nccl_net_ofi_threshold_scheduler_init(length,
rr_threshold,
min_strip_size,
&device->scheduler);
if (ret != 0) {
goto error;
Expand Down Expand Up @@ -7196,7 +7196,7 @@ static inline int nccl_net_ofi_rdma_plugin_complete_init(nccl_net_ofi_plugin_t *
nccl_net_ofi_rdma_device_t *device =
nccl_net_ofi_rdma_device_create(&rdma_plugin->base, dev_id,
info_list, rdma_plugin->topo,
ofi_nccl_round_robin_threshold());
ofi_nccl_min_stripe_size());
if (device == NULL) {
NCCL_OFI_WARN("Device creation failed");
return -ENOMEM;
Expand Down Expand Up @@ -7318,7 +7318,7 @@ int nccl_net_ofi_rdma_init(const char *provider_filter,
}

if (ofi_nccl_eager_max_size() < 0 ||
ofi_nccl_eager_max_size() > ofi_nccl_round_robin_threshold()) {
ofi_nccl_eager_max_size() > ofi_nccl_min_stripe_size()) {
NCCL_OFI_WARN("Invalid value for EAGER_MAX_SIZE");
ret = ncclInvalidArgument;
goto error;
Expand Down
91 changes: 32 additions & 59 deletions src/nccl_ofi_scheduler.c
Original file line number Diff line number Diff line change
Expand Up @@ -22,80 +22,53 @@ static inline size_t sizeof_schedule(int num_rails)
+ num_rails * sizeof(nccl_net_ofi_xfer_info_t);
}

void nccl_net_ofi_set_multiplexing_schedule(size_t size, int num_rails,
int nccl_net_ofi_set_schedule_by_threshold(nccl_net_ofi_threshold_scheduler_t *scheduler,
size_t size,
int num_rails,
size_t align,
nccl_net_ofi_schedule_t *schedule)
{
int ret = 0;

/* Number of stripes is atleast 1 for zero-sized messages and at most equal to num of rails */
int num_stripes = (int) NCCL_OFI_MAX(1, NCCL_OFI_MIN(NCCL_OFI_DIV_CEIL(size, scheduler->min_stripe_size), num_rails));
if (OFI_UNLIKELY(num_rails == 0)) {
return -1;
}

assert(num_stripes <= num_rails);

int rail_id;
nccl_net_ofi_mutex_lock(&scheduler->rr_lock);

/* Retieve and increment multiplex-round-robin counter; wrap around if required */
rail_id = scheduler->rr_counter;
scheduler->rr_counter = (scheduler->rr_counter + num_stripes) % num_rails;

nccl_net_ofi_mutex_unlock(&scheduler->rr_lock);

/* Number of bytes left to assign */
size_t left = size;
/* Offset into message */
size_t offset = 0;
/* Maximum size of a stripe */
size_t max_stripe_size = 0;

schedule->num_xfer_infos = 0;
schedule->num_xfer_infos = num_stripes;

if (OFI_UNLIKELY(num_rails == 0)) return;

max_stripe_size = NCCL_OFI_DIV_CEIL(NCCL_OFI_DIV_CEIL(size, num_rails), align) * align;
max_stripe_size = NCCL_OFI_DIV_CEIL(NCCL_OFI_DIV_CEIL(size, num_stripes), align) * align;

/* Compute stripes and assign to rails */
for (int rail_id = 0; rail_id != num_rails && left > 0; ++rail_id) {
for (int stripe_idx = 0; stripe_idx < num_stripes; ++stripe_idx) {
size_t stripe_size = NCCL_OFI_MIN(left, max_stripe_size);

schedule->rail_xfer_infos[rail_id].rail_id = rail_id;
schedule->rail_xfer_infos[rail_id].offset = offset;
schedule->rail_xfer_infos[rail_id].msg_size = stripe_size;
schedule->rail_xfer_infos[stripe_idx].rail_id = rail_id;
schedule->rail_xfer_infos[stripe_idx].offset = offset;
schedule->rail_xfer_infos[stripe_idx].msg_size = stripe_size;

schedule->num_xfer_infos++;
offset += stripe_size;
left -= stripe_size;
}
}

/*
* @brief Assign message round-robin
*/
static inline int set_round_robin_schedule(nccl_net_ofi_threshold_scheduler_t *scheduler,
size_t size,
int num_rails,
nccl_net_ofi_schedule_t *schedule)
{
int rail_id;

nccl_net_ofi_mutex_lock(&scheduler->rr_lock);

/* Retieve and increment round-robin counter; wrap around if required */
rail_id = (scheduler->rr_counter)++;
scheduler->rr_counter = scheduler->rr_counter == num_rails ? 0 : scheduler->rr_counter;

nccl_net_ofi_mutex_unlock(&scheduler->rr_lock);

schedule->num_xfer_infos = 1;
schedule->rail_xfer_infos[0].rail_id = rail_id;
schedule->rail_xfer_infos[0].offset = 0;
schedule->rail_xfer_infos[0].msg_size = size;

return 0;
}

/*
* @brief Assign message round-robin or multiplex message depending on its size
*
* Messages larger than `threshold' are multiplexed. Smaller messages are assigned round-robin.
*/
static inline int set_schedule_by_threshold(nccl_net_ofi_threshold_scheduler_t *scheduler,
size_t size,
int num_rails,
size_t align,
nccl_net_ofi_schedule_t *schedule)
{
int ret = 0;
if (size > scheduler->rr_threshold) {
nccl_net_ofi_set_multiplexing_schedule(size, num_rails,
align, schedule);
} else {
ret = set_round_robin_schedule(scheduler, size, num_rails, schedule);
rail_id = (rail_id + 1) % num_rails;
}
return ret;
}
Expand Down Expand Up @@ -146,7 +119,7 @@ static nccl_net_ofi_schedule_t *get_threshold_schedule(nccl_net_ofi_scheduler_t
NCCL_OFI_WARN("Failed to allocate schedule");
return NULL;
}
ret = set_schedule_by_threshold(scheduler, size, num_rails, align,
ret = nccl_net_ofi_set_schedule_by_threshold(scheduler, size, num_rails, align,
schedule);
if (OFI_UNLIKELY(ret)) {
nccl_net_ofi_release_schedule(scheduler_p, schedule);
Expand Down Expand Up @@ -238,7 +211,7 @@ int scheduler_init(int num_rails, nccl_net_ofi_scheduler_t *scheduler)
}

int nccl_net_ofi_threshold_scheduler_init(int num_rails,
size_t rr_threshold,
size_t min_stripe_size,
nccl_net_ofi_scheduler_t **scheduler_p)
{
int ret = 0;
Expand All @@ -261,7 +234,7 @@ int nccl_net_ofi_threshold_scheduler_init(int num_rails,
scheduler->base.get_schedule = get_threshold_schedule;
scheduler->base.fini = threshold_scheduler_fini;
scheduler->rr_counter = 0;
scheduler->rr_threshold = rr_threshold;
scheduler->min_stripe_size = min_stripe_size;

ret = nccl_net_ofi_mutex_init(&scheduler->rr_lock, NULL);
if (ret) {
Expand Down
Loading

0 comments on commit a9cc48c

Please sign in to comment.