Skip to content

Commit

Permalink
add more comment in helloworld_ddp example (#205)
Browse files Browse the repository at this point in the history
  • Loading branch information
msftsw authored May 15, 2023
1 parent d61df8d commit 9016428
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion tutel/examples/helloworld_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@
class ExampleModel(torch.nn.Module):
def __init__(self):
super().__init__()
self._ddp_params_and_buffers_to_ignore = list()

self._moe_layer = tutel_moe.moe_layer(
gate_type = {'type': 'top', 'k': top_value, 'fp32_gate': args.fp32_gate},
Expand All @@ -81,15 +80,20 @@ def forward(self, input):
result = F.log_softmax(torch.sum(result, dim=2), dim=1)
return result

# Important setting 1: skip handling expert parameters by Pytorch DDP
def add_param_to_skip_allreduce(self, param_name):
if not hasattr(self, '_ddp_params_and_buffers_to_ignore'):
self._ddp_params_and_buffers_to_ignore = list()
self._ddp_params_and_buffers_to_ignore.append(param_name)


model = ExampleModel().to(device)

# Important setting 2: iterate all expert paramter object and move them into the array of setting 1
for name, param in model.named_parameters():
if hasattr(param, 'skip_allreduce'):
model.add_param_to_skip_allreduce(name)

if torch.distributed.is_initialized():
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])

Expand Down

0 comments on commit 9016428

Please sign in to comment.