diff --git a/apex/contrib/test/optimizers/test_distributed_fused_lamb.py b/apex/contrib/test/optimizers/test_distributed_fused_lamb.py index f38f371b7..e969e5d53 100644 --- a/apex/contrib/test/optimizers/test_distributed_fused_lamb.py +++ b/apex/contrib/test/optimizers/test_distributed_fused_lamb.py @@ -26,7 +26,7 @@ def forward(self, input_tensor, gt): return loss # A test for distributed fused Lamb optimizer: run several iterations and see if loss decreases -# There are two instances of the same test because based on `world_size` the optimizer decides what collectives operation to use. +# There are two instances of the same test because based on `world_size` the optimizer decides what collectives operation to use. # If torch.distributed.get_world_size() == torch.cuda.device_count() it uses only `all_gather`. # If torch.distributed.get_world_size() < torch.cuda.device_count() it uses both `all_gather` and `reduce_scatter`. class NcclDistributedFusedLAMB(NcclDistributedTestBase): @@ -35,17 +35,28 @@ def world_size(self) -> int: return torch.cuda.device_count() @common_utils.parametrize("no_copy", [False, True]) - @common_utils.parametrize("opt_kwargs", [ - dict(overlap_reductions=True, dwu_num_blocks=2, dwu_num_chunks=2, - fused_norm=False, fuse_scale=False, clip_after_ar=True, - full_ar=False), - dict(overlap_reductions=False, dwu_num_blocks=1, dwu_num_chunks=1, - fused_norm=True, fuse_scale=True, clip_after_ar=False), - ]) - def test_distributed_fused_lamb(self, no_copy, opt_kwargs): - if no_copy and 'no_copy' not in inspect.getfullargspec(torch.distributed.reduce_scatter).args: + @common_utils.parametrize( + "overlap_reductions,dwu_num_blocks,dwu_num_chunks,fused_norm,fuse_scale,clip_after_ar,full_ar", + ( + (True, 2, 2, False, False, True, False), + (False, 1, 1, True, True, False, False), + ), + ) + def test_distributed_fused_lamb( + self, + no_copy, + overlap_reductions, + dwu_num_blocks, + dwu_num_chunks, + fused_norm, + fuse_scale, + clip_after_ar, + full_ar, + ): + supports_no_copy = 'no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args + if no_copy and not supports_no_copy: self.skipTest("does not support no_copy") - if no_copy and 'no_copy' not in inspect.getfullargspec(torch.distributed.all_gather).args: + if no_copy and not supports_no_copy: self.skipTest("does not support no_copy") assert torch.distributed.is_initialized() @@ -66,25 +77,29 @@ def test_distributed_fused_lamb(self, no_copy, opt_kwargs): {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} ] - if 'full_ar' not in opt_kwargs: - opt_kwargs['full_ar'] = gpu_count == torch.cuda.device_count() - - # Aidyn-A: not sure what parameters are the best for testing purposes, - # setting up whatever I think appropriate. + # Aidyn-A: not sure what parameters are the best for testing purposes, + # setting up whatever I think appropriate. optimizer = DistributedFusedLAMB( - optimizer_grouped_parameters, - lr=0.1, - betas=(0.9, 0.9), - eps=1e-6, - max_grad_norm=1.0, - dwu_group_size=gpu_count, - dwu_num_rs_pg=1, - dwu_num_ar_pg=1, - dwu_num_ag_pg=1, - use_nvlamb=False, - set_param_views_to_flat_buffer=False, - e5m2_allgather=False, - **opt_kwargs + optimizer_grouped_parameters, + lr=0.1, + betas=(0.9, 0.9), + eps=1e-6, + max_grad_norm=1.0, + dwu_group_size=gpu_count, + dwu_num_rs_pg=1, + dwu_num_ar_pg=1, + dwu_num_ag_pg=1, + use_nvlamb=False, + set_param_views_to_flat_buffer=False, + e5m2_allgather=False, + overlap_reductions=overlap_reductions, + dwu_num_blocks=dwu_num_blocks, + dwu_num_chunks=dwu_num_chunks, + fused_norm=fused_norm, + fuse_scale=fuse_scale, + clip_after_ar=clip_after_ar, + full_ar=full_ar, + **({'no_copy': no_copy} if supports_no_copy else {}) ) optimizer.set_global_scale(init_scale)