From a9cc48c3e71776732618d6815f7301877e70f73e Mon Sep 17 00:00:00 2001 From: Arun Karthik Date: Tue, 10 Sep 2024 09:56:21 +0000 Subject: [PATCH] Add Multiplexed-round-robin scheduler This commit modifies the scheduler algorithm to round-robin the payload_msg on each QP or betwen pairs of QP's or triplets of QP's or quadruplets of QP's based on the min_stripe_size. --- include/nccl_ofi_param.h | 7 +- include/nccl_ofi_scheduler.h | 17 +- src/nccl_ofi_rdma.c | 8 +- src/nccl_ofi_scheduler.c | 91 ++++------ tests/unit/scheduler.c | 312 +++++++++-------------------------- 5 files changed, 128 insertions(+), 307 deletions(-) diff --git a/include/nccl_ofi_param.h b/include/nccl_ofi_param.h index 7ae619c35..0cc2155ad 100644 --- a/include/nccl_ofi_param.h +++ b/include/nccl_ofi_param.h @@ -179,9 +179,10 @@ OFI_NCCL_PARAM_INT(disable_native_rdma_check, "DISABLE_NATIVE_RDMA_CHECK", 0); OFI_NCCL_PARAM_INT(disable_gdr_required_check, "DISABLE_GDR_REQUIRED_CHECK", 0); /* - * Maximum size of a message in bytes before message is multiplexed - */ -OFI_NCCL_PARAM_INT(round_robin_threshold, "ROUND_ROBIN_THRESHOLD", (256 * 1024)); + * Minimum data size in bytes before the message is to be striped across multiple + * rails. +*/ +OFI_NCCL_PARAM_INT(min_stripe_size, "MIN_STRIPE_SIZE", (210 * 1024)); /* * Minimum bounce buffers posted per endpoint. The plugin will attempt to post diff --git a/include/nccl_ofi_scheduler.h b/include/nccl_ofi_scheduler.h index e7e7a828c..292dd85f9 100644 --- a/include/nccl_ofi_scheduler.h +++ b/include/nccl_ofi_scheduler.h @@ -93,9 +93,9 @@ typedef struct nccl_net_ofi_threshold_scheduler { unsigned int rr_counter; /* Lock for round robin counter */ pthread_mutex_t rr_lock; - /* Maximum size of a message in bytes before message is + /* Minimum size of the message in bytes before message is * multiplexed */ - size_t rr_threshold; + size_t min_stripe_size; } nccl_net_ofi_threshold_scheduler_t; /* @@ -109,16 +109,15 @@ void nccl_net_ofi_release_schedule(nccl_net_ofi_scheduler_t *scheduler, * * @param num_rails * Number of rails - * @param rr_threshold - * Maximum size of a message in bytes before message is multiplexed - * + * @param min_stripe_size + * Minimum size of a message in bytes before message is multiplexed * @return Scheduler, on success * NULL, on error * @return 0, on success * non-zero, on error */ int nccl_net_ofi_threshold_scheduler_init(int num_rails, - size_t rr_threshold, + size_t min_stripe_size, nccl_net_ofi_scheduler_t **scheduler); /* @@ -127,9 +126,11 @@ int nccl_net_ofi_threshold_scheduler_init(int num_rails, * A mininal stripe size `max_stripe_size' is calculated (multiple of * `align') that is sufficient to assign the whole message. Rails are * filled from low id to large id. The last rail may get assigned less - * data. + * data. The number of rails are calculated based on the ratio of + * (`data_size` / `min_stripe_size`) */ -void nccl_net_ofi_set_multiplexing_schedule(size_t size, +int nccl_net_ofi_set_schedule_by_threshold(nccl_net_ofi_threshold_scheduler_t *scheduler, + size_t size, int num_rails, size_t align, nccl_net_ofi_schedule_t *schedule); diff --git a/src/nccl_ofi_rdma.c b/src/nccl_ofi_rdma.c index ee61a885e..50b0845e1 100644 --- a/src/nccl_ofi_rdma.c +++ b/src/nccl_ofi_rdma.c @@ -6940,7 +6940,7 @@ nccl_net_ofi_rdma_device_release(nccl_net_ofi_device_t *base_device) static nccl_net_ofi_rdma_device_t * nccl_net_ofi_rdma_device_create(nccl_net_ofi_plugin_t *plugin, int dev_id, struct fi_info *info_list, - nccl_ofi_topo_t *topo, size_t rr_threshold) + nccl_ofi_topo_t *topo, size_t min_strip_size) { int ret; @@ -6983,7 +6983,7 @@ nccl_net_ofi_rdma_device_create(nccl_net_ofi_plugin_t *plugin, /* Create scheduler */ ret = nccl_net_ofi_threshold_scheduler_init(length, - rr_threshold, + min_strip_size, &device->scheduler); if (ret != 0) { goto error; @@ -7196,7 +7196,7 @@ static inline int nccl_net_ofi_rdma_plugin_complete_init(nccl_net_ofi_plugin_t * nccl_net_ofi_rdma_device_t *device = nccl_net_ofi_rdma_device_create(&rdma_plugin->base, dev_id, info_list, rdma_plugin->topo, - ofi_nccl_round_robin_threshold()); + ofi_nccl_min_stripe_size()); if (device == NULL) { NCCL_OFI_WARN("Device creation failed"); return -ENOMEM; @@ -7318,7 +7318,7 @@ int nccl_net_ofi_rdma_init(const char *provider_filter, } if (ofi_nccl_eager_max_size() < 0 || - ofi_nccl_eager_max_size() > ofi_nccl_round_robin_threshold()) { + ofi_nccl_eager_max_size() > ofi_nccl_min_stripe_size()) { NCCL_OFI_WARN("Invalid value for EAGER_MAX_SIZE"); ret = ncclInvalidArgument; goto error; diff --git a/src/nccl_ofi_scheduler.c b/src/nccl_ofi_scheduler.c index bac026e80..f158e0428 100644 --- a/src/nccl_ofi_scheduler.c +++ b/src/nccl_ofi_scheduler.c @@ -22,10 +22,31 @@ static inline size_t sizeof_schedule(int num_rails) + num_rails * sizeof(nccl_net_ofi_xfer_info_t); } -void nccl_net_ofi_set_multiplexing_schedule(size_t size, int num_rails, +int nccl_net_ofi_set_schedule_by_threshold(nccl_net_ofi_threshold_scheduler_t *scheduler, + size_t size, + int num_rails, size_t align, nccl_net_ofi_schedule_t *schedule) { + int ret = 0; + + /* Number of stripes is atleast 1 for zero-sized messages and at most equal to num of rails */ + int num_stripes = (int) NCCL_OFI_MAX(1, NCCL_OFI_MIN(NCCL_OFI_DIV_CEIL(size, scheduler->min_stripe_size), num_rails)); + if (OFI_UNLIKELY(num_rails == 0)) { + return -1; + } + + assert(num_stripes <= num_rails); + + int rail_id; + nccl_net_ofi_mutex_lock(&scheduler->rr_lock); + + /* Retieve and increment multiplex-round-robin counter; wrap around if required */ + rail_id = scheduler->rr_counter; + scheduler->rr_counter = (scheduler->rr_counter + num_stripes) % num_rails; + + nccl_net_ofi_mutex_unlock(&scheduler->rr_lock); + /* Number of bytes left to assign */ size_t left = size; /* Offset into message */ @@ -33,69 +54,21 @@ void nccl_net_ofi_set_multiplexing_schedule(size_t size, int num_rails, /* Maximum size of a stripe */ size_t max_stripe_size = 0; - schedule->num_xfer_infos = 0; + schedule->num_xfer_infos = num_stripes; - if (OFI_UNLIKELY(num_rails == 0)) return; - - max_stripe_size = NCCL_OFI_DIV_CEIL(NCCL_OFI_DIV_CEIL(size, num_rails), align) * align; + max_stripe_size = NCCL_OFI_DIV_CEIL(NCCL_OFI_DIV_CEIL(size, num_stripes), align) * align; /* Compute stripes and assign to rails */ - for (int rail_id = 0; rail_id != num_rails && left > 0; ++rail_id) { + for (int stripe_idx = 0; stripe_idx < num_stripes; ++stripe_idx) { size_t stripe_size = NCCL_OFI_MIN(left, max_stripe_size); - schedule->rail_xfer_infos[rail_id].rail_id = rail_id; - schedule->rail_xfer_infos[rail_id].offset = offset; - schedule->rail_xfer_infos[rail_id].msg_size = stripe_size; + schedule->rail_xfer_infos[stripe_idx].rail_id = rail_id; + schedule->rail_xfer_infos[stripe_idx].offset = offset; + schedule->rail_xfer_infos[stripe_idx].msg_size = stripe_size; - schedule->num_xfer_infos++; offset += stripe_size; left -= stripe_size; - } -} - -/* - * @brief Assign message round-robin - */ -static inline int set_round_robin_schedule(nccl_net_ofi_threshold_scheduler_t *scheduler, - size_t size, - int num_rails, - nccl_net_ofi_schedule_t *schedule) -{ - int rail_id; - - nccl_net_ofi_mutex_lock(&scheduler->rr_lock); - - /* Retieve and increment round-robin counter; wrap around if required */ - rail_id = (scheduler->rr_counter)++; - scheduler->rr_counter = scheduler->rr_counter == num_rails ? 0 : scheduler->rr_counter; - - nccl_net_ofi_mutex_unlock(&scheduler->rr_lock); - - schedule->num_xfer_infos = 1; - schedule->rail_xfer_infos[0].rail_id = rail_id; - schedule->rail_xfer_infos[0].offset = 0; - schedule->rail_xfer_infos[0].msg_size = size; - - return 0; -} - -/* - * @brief Assign message round-robin or multiplex message depending on its size - * - * Messages larger than `threshold' are multiplexed. Smaller messages are assigned round-robin. - */ -static inline int set_schedule_by_threshold(nccl_net_ofi_threshold_scheduler_t *scheduler, - size_t size, - int num_rails, - size_t align, - nccl_net_ofi_schedule_t *schedule) -{ - int ret = 0; - if (size > scheduler->rr_threshold) { - nccl_net_ofi_set_multiplexing_schedule(size, num_rails, - align, schedule); - } else { - ret = set_round_robin_schedule(scheduler, size, num_rails, schedule); + rail_id = (rail_id + 1) % num_rails; } return ret; } @@ -146,7 +119,7 @@ static nccl_net_ofi_schedule_t *get_threshold_schedule(nccl_net_ofi_scheduler_t NCCL_OFI_WARN("Failed to allocate schedule"); return NULL; } - ret = set_schedule_by_threshold(scheduler, size, num_rails, align, + ret = nccl_net_ofi_set_schedule_by_threshold(scheduler, size, num_rails, align, schedule); if (OFI_UNLIKELY(ret)) { nccl_net_ofi_release_schedule(scheduler_p, schedule); @@ -238,7 +211,7 @@ int scheduler_init(int num_rails, nccl_net_ofi_scheduler_t *scheduler) } int nccl_net_ofi_threshold_scheduler_init(int num_rails, - size_t rr_threshold, + size_t min_stripe_size, nccl_net_ofi_scheduler_t **scheduler_p) { int ret = 0; @@ -261,7 +234,7 @@ int nccl_net_ofi_threshold_scheduler_init(int num_rails, scheduler->base.get_schedule = get_threshold_schedule; scheduler->base.fini = threshold_scheduler_fini; scheduler->rr_counter = 0; - scheduler->rr_threshold = rr_threshold; + scheduler->min_stripe_size = min_stripe_size; ret = nccl_net_ofi_mutex_init(&scheduler->rr_lock, NULL); if (ret) { diff --git a/tests/unit/scheduler.c b/tests/unit/scheduler.c index 1d535423d..b811afca3 100644 --- a/tests/unit/scheduler.c +++ b/tests/unit/scheduler.c @@ -14,22 +14,6 @@ #include "test-common.h" #include "nccl_ofi_scheduler.h" -int create_multiplexed(size_t size, - int num_rails, - size_t align, - nccl_net_ofi_schedule_t **schedule_p) -{ - nccl_net_ofi_schedule_t *schedule = (nccl_net_ofi_schedule_t *)malloc( - sizeof(nccl_net_ofi_schedule_t) + num_rails * sizeof(nccl_net_ofi_xfer_info_t)); - if (!schedule) { - NCCL_OFI_WARN("Could not allocate schedule"); - return -ENOMEM; - } - nccl_net_ofi_set_multiplexing_schedule(size, num_rails, align, schedule); - *schedule_p = schedule; - return 0; -} - int verify_xfer_info(nccl_net_ofi_xfer_info_t *xfer, nccl_net_ofi_xfer_info_t *ref_xfer, int xfer_id) { int ret = ref_xfer->rail_id != xfer->rail_id @@ -72,250 +56,101 @@ int verify_schedule(nccl_net_ofi_schedule_t *schedule, nccl_net_ofi_schedule_t * return ret; } -int test_multiplexing_schedule() +int test_threshold_scheduler() { + int num_rails = 4; nccl_net_ofi_schedule_t *schedule = NULL; nccl_net_ofi_schedule_t *ref_schedule = (nccl_net_ofi_schedule_t *)malloc( - sizeof(nccl_net_ofi_schedule_t) + 3 * sizeof(nccl_net_ofi_xfer_info_t)); + sizeof(nccl_net_ofi_schedule_t) + num_rails * sizeof(nccl_net_ofi_xfer_info_t)); if (!ref_schedule) { NCCL_OFI_WARN("Could not allocate schedule"); return -ENOMEM; } - size_t size; - int num_rails; - size_t align; + size_t msg_size = 0; + size_t align = 128; int ret = 0; + size_t min_stripe_size = 4096; - size = 1; - num_rails = 0; - align = 1; - ret = create_multiplexed(size, num_rails, align, &schedule); - if (ret) { - NCCL_OFI_WARN("Failed to create multiplexed schedule"); - free(ref_schedule); - return ret; - } - ref_schedule->num_xfer_infos = 0; - ret = verify_schedule(schedule, ref_schedule); - if (ret) { - NCCL_OFI_WARN("Verification failed"); + nccl_net_ofi_scheduler_t *scheduler; + if (nccl_net_ofi_threshold_scheduler_init(num_rails, min_stripe_size, &scheduler)) { + NCCL_OFI_WARN("Failed to initialize threshold scheduler"); free(ref_schedule); return ret; } - free(schedule); /************************/ /* Test one rail */ /************************/ - /* No data */ - size = 0; - num_rails = 1; - align = 1; - ret = create_multiplexed(size, num_rails, align, &schedule); - if (ret) { - NCCL_OFI_WARN("Failed to create multiplexed schedule"); - free(ref_schedule); - return ret; - } - ref_schedule->num_xfer_infos = 0; - ret = verify_schedule(schedule, ref_schedule); - if (ret) { - NCCL_OFI_WARN("Verification failed"); - free(ref_schedule); - return ret; - } - free(schedule); - - /* Data size = align - 1 */ - size = 1; - num_rails = 1; - align = 2; - ret = create_multiplexed(size, num_rails, align, &schedule); - if (ret) { - NCCL_OFI_WARN("Failed to create multiplexed schedule"); - free(ref_schedule); - return ret; - } - ref_schedule->num_xfer_infos = 1; - ref_schedule->rail_xfer_infos[0].rail_id = 0; - ref_schedule->rail_xfer_infos[0].offset = 0; - ref_schedule->rail_xfer_infos[0].msg_size = size; - ret = verify_schedule(schedule, ref_schedule); - if (ret) { - NCCL_OFI_WARN("Verification failed"); - free(ref_schedule); - return ret; - } - free(schedule); - - /* Data size = align */ - size = 2; - num_rails = 1; - align = 2; - ret = create_multiplexed(size, num_rails, align, &schedule); - if (ret) { - NCCL_OFI_WARN("Failed to create multiplexed schedule"); + /* Verify that message with less than `min_stripe_size' bytes are assigned round-robin */ + schedule = scheduler->get_schedule(scheduler, min_stripe_size - 1, num_rails); + if (!schedule) { + NCCL_OFI_WARN("Failed to get schedule"); free(ref_schedule); return ret; } + msg_size = min_stripe_size - 1; ref_schedule->num_xfer_infos = 1; ref_schedule->rail_xfer_infos[0].rail_id = 0; ref_schedule->rail_xfer_infos[0].offset = 0; - ref_schedule->rail_xfer_infos[0].msg_size = size; + ref_schedule->rail_xfer_infos[0].msg_size = msg_size; ret = verify_schedule(schedule, ref_schedule); if (ret) { NCCL_OFI_WARN("Verification failed"); free(ref_schedule); return ret; } - free(schedule); + nccl_net_ofi_release_schedule(scheduler, schedule); - /* Data size = align + 1 */ - size = 3; - num_rails = 1; - align = 2; - ret = create_multiplexed(size, num_rails, align, &schedule); - if (ret) { - NCCL_OFI_WARN("Failed to create multiplexed schedule"); + schedule = scheduler->get_schedule(scheduler, min_stripe_size - 1, num_rails); + if (!schedule) { + NCCL_OFI_WARN("Failed to get schedule"); free(ref_schedule); - return ret; + return -1; } ref_schedule->num_xfer_infos = 1; - ref_schedule->rail_xfer_infos[0].rail_id = 0; + ref_schedule->rail_xfer_infos[0].rail_id = 1; ref_schedule->rail_xfer_infos[0].offset = 0; - ref_schedule->rail_xfer_infos[0].msg_size = size; + ref_schedule->rail_xfer_infos[0].msg_size = min_stripe_size - 1; ret = verify_schedule(schedule, ref_schedule); if (ret) { NCCL_OFI_WARN("Verification failed"); free(ref_schedule); return ret; } - free(schedule); + nccl_net_ofi_release_schedule(scheduler, schedule); /************************/ - /* Test three rail */ + /* Test two rails */ /************************/ - /* No data */ - size = 0; - num_rails = 3; - align = 1; - ret = create_multiplexed(size, num_rails, align, &schedule); - if (ret) { - NCCL_OFI_WARN("Failed to create multiplexed schedule"); - free(ref_schedule); - return ret; - } - ref_schedule->num_xfer_infos = 0; - ret = verify_schedule(schedule, ref_schedule); - if (ret) { - NCCL_OFI_WARN("Verification failed"); - free(ref_schedule); - return ret; - } - free(schedule); - - /* Data size = 4 * align - 1 */ - num_rails = 3; - align = 3; - size = 4 * align - 1; - ret = create_multiplexed(size, num_rails, align, &schedule); - if (ret) { - NCCL_OFI_WARN("Failed to create multiplexed schedule"); + /* Verify that messages with greater than the `min_stripe_size' but less than 2x `min_stripe_size` + * bytes are assigned 2 rail multiplexing */ + schedule = scheduler->get_schedule(scheduler, min_stripe_size + 1, num_rails); + msg_size = NCCL_OFI_DIV_CEIL(NCCL_OFI_DIV_CEIL(min_stripe_size + 1, 2), align) * align; + if (!schedule) { + NCCL_OFI_WARN("Failed to get schedule"); free(ref_schedule); - return ret; + return -1; } ref_schedule->num_xfer_infos = 2; - ref_schedule->rail_xfer_infos[0].rail_id = 0; + ref_schedule->rail_xfer_infos[0].rail_id = 2; ref_schedule->rail_xfer_infos[0].offset = 0; - ref_schedule->rail_xfer_infos[0].msg_size = 2 * align; - ref_schedule->rail_xfer_infos[1].rail_id = 1; - ref_schedule->rail_xfer_infos[1].offset = 2 * align; - ref_schedule->rail_xfer_infos[1].msg_size = 2 * align - 1; - ret = verify_schedule(schedule, ref_schedule); - if (ret) { - NCCL_OFI_WARN("Verification failed"); - free(ref_schedule); - return ret; - } - free(schedule); + ref_schedule->rail_xfer_infos[0].msg_size = msg_size; - /* Data size = 4 * align */ - num_rails = 3; - align = 3; - size = 4 * align; - ret = create_multiplexed(size, num_rails, align, &schedule); - if (ret) { - NCCL_OFI_WARN("Failed to create multiplexed schedule"); - free(ref_schedule); - return ret; - } - ref_schedule->num_xfer_infos = 2; - ref_schedule->rail_xfer_infos[0].rail_id = 0; - ref_schedule->rail_xfer_infos[0].offset = 0; - ref_schedule->rail_xfer_infos[0].msg_size = 2 * align; - ref_schedule->rail_xfer_infos[1].rail_id = 1; - ref_schedule->rail_xfer_infos[1].offset = 2 * align; - ref_schedule->rail_xfer_infos[1].msg_size = 2 * align; - ret = verify_schedule(schedule, ref_schedule); - if (ret) { - NCCL_OFI_WARN("Verification failed"); - free(ref_schedule); - return ret; - } - free(schedule); + ref_schedule->rail_xfer_infos[1].rail_id = 3; + ref_schedule->rail_xfer_infos[1].offset = msg_size; + ref_schedule->rail_xfer_infos[1].msg_size = min_stripe_size + 1 - msg_size; - /* Data size = 4 * align + 1 */ - num_rails = 3; - align = 3; - size = 4 * align + 1; - ret = create_multiplexed(size, num_rails, align, &schedule); - if (ret) { - NCCL_OFI_WARN("Failed to create multiplexed schedule"); - free(ref_schedule); - return ret; - } - ref_schedule->num_xfer_infos = 3; - ref_schedule->rail_xfer_infos[0].rail_id = 0; - ref_schedule->rail_xfer_infos[0].offset = 0; - ref_schedule->rail_xfer_infos[0].msg_size = 2 * align; - ref_schedule->rail_xfer_infos[1].rail_id = 1; - ref_schedule->rail_xfer_infos[1].offset = 2 * align; - ref_schedule->rail_xfer_infos[1].msg_size = 2 * align; - ref_schedule->rail_xfer_infos[2].rail_id = 2; - ref_schedule->rail_xfer_infos[2].offset = 4 * align; - ref_schedule->rail_xfer_infos[2].msg_size = 1; ret = verify_schedule(schedule, ref_schedule); if (ret) { NCCL_OFI_WARN("Verification failed"); free(ref_schedule); return ret; } - free(schedule); - - free(ref_schedule); - - return 0; -} - -int test_threshold_scheduler() -{ - nccl_net_ofi_schedule_t *schedule; - int num_rails = 2; - int ret = 0; - size_t rr_threshold = 8192; - nccl_net_ofi_schedule_t *ref_schedule = (nccl_net_ofi_schedule_t *)malloc( - sizeof(nccl_net_ofi_schedule_t) + num_rails * sizeof(nccl_net_ofi_xfer_info_t)); - nccl_net_ofi_scheduler_t *scheduler; - if (nccl_net_ofi_threshold_scheduler_init(num_rails, rr_threshold, &scheduler)) { - NCCL_OFI_WARN("Failed to initialize threshold scheduler"); - free(ref_schedule); - return -1; - } + nccl_net_ofi_release_schedule(scheduler, schedule); - /* Verify that message with more than `rr_threshold' bytes is multiplexed */ - schedule = scheduler->get_schedule(scheduler, rr_threshold + 1, num_rails); + schedule = scheduler->get_schedule(scheduler, min_stripe_size + 1, num_rails); if (!schedule) { NCCL_OFI_WARN("Failed to get schedule"); free(ref_schedule); @@ -324,29 +159,12 @@ int test_threshold_scheduler() ref_schedule->num_xfer_infos = 2; ref_schedule->rail_xfer_infos[0].rail_id = 0; ref_schedule->rail_xfer_infos[0].offset = 0; - ref_schedule->rail_xfer_infos[0].msg_size = rr_threshold / 2 + 128; + ref_schedule->rail_xfer_infos[0].msg_size = msg_size; + ref_schedule->rail_xfer_infos[1].rail_id = 1; - ref_schedule->rail_xfer_infos[1].offset = rr_threshold / 2 + 128; - ref_schedule->rail_xfer_infos[1].msg_size = rr_threshold / 2- 127; - ret = verify_schedule(schedule, ref_schedule); - if (ret) { - NCCL_OFI_WARN("Verification failed"); - free(ref_schedule); - return ret; - } - nccl_net_ofi_release_schedule(scheduler, schedule); + ref_schedule->rail_xfer_infos[1].offset = msg_size; + ref_schedule->rail_xfer_infos[1].msg_size = min_stripe_size + 1 - msg_size; - /* Verify that three messages with `rr_threshold' bytes are assigned round robin */ - schedule = scheduler->get_schedule(scheduler, rr_threshold, num_rails); - if (!schedule) { - NCCL_OFI_WARN("Failed to get schedule"); - free(ref_schedule); - return -1; - } - ref_schedule->num_xfer_infos = 1; - ref_schedule->rail_xfer_infos[0].rail_id = 0; - ref_schedule->rail_xfer_infos[0].offset = 0; - ref_schedule->rail_xfer_infos[0].msg_size = rr_threshold; ret = verify_schedule(schedule, ref_schedule); if (ret) { NCCL_OFI_WARN("Verification failed"); @@ -355,16 +173,29 @@ int test_threshold_scheduler() } nccl_net_ofi_release_schedule(scheduler, schedule); - schedule = scheduler->get_schedule(scheduler, rr_threshold, num_rails); + /************************/ + /* Test three rails */ + /************************/ + + /* Verify that messages with greater than the 2x `min_stripe_size' but less than or equal to + * 3x `min_stripe_size` bytes are assigned 3 rail multiplexing */ + schedule = scheduler->get_schedule(scheduler, (min_stripe_size * 2) + 1, num_rails); + msg_size = NCCL_OFI_DIV_CEIL(NCCL_OFI_DIV_CEIL((min_stripe_size * 2) + 1, 3), align) * align; if (!schedule) { NCCL_OFI_WARN("Failed to get schedule"); free(ref_schedule); return -1; } - ref_schedule->num_xfer_infos = 1; - ref_schedule->rail_xfer_infos[0].rail_id = 1; + ref_schedule->num_xfer_infos = 3; + ref_schedule->rail_xfer_infos[0].rail_id = 2; ref_schedule->rail_xfer_infos[0].offset = 0; - ref_schedule->rail_xfer_infos[0].msg_size = rr_threshold; + ref_schedule->rail_xfer_infos[0].msg_size = msg_size; + ref_schedule->rail_xfer_infos[1].rail_id = 3; + ref_schedule->rail_xfer_infos[1].offset = msg_size; + ref_schedule->rail_xfer_infos[1].msg_size = msg_size; + ref_schedule->rail_xfer_infos[2].rail_id = 0; + ref_schedule->rail_xfer_infos[2].offset = 2 * msg_size; + ref_schedule->rail_xfer_infos[2].msg_size = (min_stripe_size * 2) + 1 - (2 * msg_size); ret = verify_schedule(schedule, ref_schedule); if (ret) { NCCL_OFI_WARN("Verification failed"); @@ -373,16 +204,31 @@ int test_threshold_scheduler() } nccl_net_ofi_release_schedule(scheduler, schedule); - schedule = scheduler->get_schedule(scheduler, rr_threshold, num_rails); + /************************/ + /* Test four rails */ + /************************/ + + /* Verify that messages with greater than the 3x `min_stripe_size' are assigned 4 rail multiplexing */ + schedule = scheduler->get_schedule(scheduler, (min_stripe_size * 3) + 1, num_rails); + msg_size = NCCL_OFI_DIV_CEIL(NCCL_OFI_DIV_CEIL((min_stripe_size * 3) + 1, 4), align) * align; if (!schedule) { NCCL_OFI_WARN("Failed to get schedule"); free(ref_schedule); return -1; } - ref_schedule->num_xfer_infos = 1; - ref_schedule->rail_xfer_infos[0].rail_id = 0; + ref_schedule->num_xfer_infos = 4; + ref_schedule->rail_xfer_infos[0].rail_id = 1; ref_schedule->rail_xfer_infos[0].offset = 0; - ref_schedule->rail_xfer_infos[0].msg_size = rr_threshold; + ref_schedule->rail_xfer_infos[0].msg_size = msg_size; + ref_schedule->rail_xfer_infos[1].rail_id = 2; + ref_schedule->rail_xfer_infos[1].offset = msg_size; + ref_schedule->rail_xfer_infos[1].msg_size = msg_size; + ref_schedule->rail_xfer_infos[2].rail_id = 3; + ref_schedule->rail_xfer_infos[2].offset = 2 * msg_size; + ref_schedule->rail_xfer_infos[2].msg_size = msg_size; + ref_schedule->rail_xfer_infos[3].rail_id = 0; + ref_schedule->rail_xfer_infos[3].offset = 3 * msg_size; + ref_schedule->rail_xfer_infos[3].msg_size = (min_stripe_size * 3) + 1 - (3 * msg_size); ret = verify_schedule(schedule, ref_schedule); if (ret) { NCCL_OFI_WARN("Verification failed"); @@ -406,7 +252,7 @@ int main(int argc, char *argv[]) ofi_log_function = logger; system_page_size = 4096; - ret = test_multiplexing_schedule() || test_threshold_scheduler(); + ret = test_threshold_scheduler(); /** Success!? **/ return ret;