diff --git a/tutel/examples/helloworld_ddp.py b/tutel/examples/helloworld_ddp.py index 3ac8ad7f..86e80488 100755 --- a/tutel/examples/helloworld_ddp.py +++ b/tutel/examples/helloworld_ddp.py @@ -60,7 +60,6 @@ class ExampleModel(torch.nn.Module): def __init__(self): super().__init__() - self._ddp_params_and_buffers_to_ignore = list() self._moe_layer = tutel_moe.moe_layer( gate_type = {'type': 'top', 'k': top_value, 'fp32_gate': args.fp32_gate}, @@ -81,15 +80,20 @@ def forward(self, input): result = F.log_softmax(torch.sum(result, dim=2), dim=1) return result + # Important setting 1: skip handling expert parameters by Pytorch DDP def add_param_to_skip_allreduce(self, param_name): + if not hasattr(self, '_ddp_params_and_buffers_to_ignore'): + self._ddp_params_and_buffers_to_ignore = list() self._ddp_params_and_buffers_to_ignore.append(param_name) model = ExampleModel().to(device) +# Important setting 2: iterate all expert paramter object and move them into the array of setting 1 for name, param in model.named_parameters(): if hasattr(param, 'skip_allreduce'): model.add_param_to_skip_allreduce(name) + if torch.distributed.is_initialized(): model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])