Skip to content

Commit

Permalink
support handling multi-gate options (#71)
Browse files Browse the repository at this point in the history
  • Loading branch information
msftsw authored Dec 27, 2021
1 parent e7d165f commit ea17ea6
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 29 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ Usage of MOELayer:
```
* Usage of MOELayer Args:
gate_type : dict-type gate description, e.g. {'type': 'top', 'k': 2, ..}, or {'type': 'megatron'}
gate_type : dict-type gate description, e.g. {'type': 'top', 'k': 2, ..}, or {'type': 'megatron'},
or a list of dict-type gate descriptions, e.g. [{'type': 'top', 'k', 2}, {'type': 'top', 'k', 2}],
the value of k in top-gating can be also negative, like -2, which indicates one GPU will hold 1/(-k) parameters of an expert
model_dim : the number of channels for MOE's input tensor
experts : a dict-type config for builtin expert network, or a torch.nn.Module-type custom expert network
scan_expert_func : allow users to specify a lambda function to iterate each experts param, e.g. `scan_expert_func = lambda name, param: setattr(param, 'expert', True)`
Expand Down
56 changes: 28 additions & 28 deletions tutel/impls/moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def apply_on_expert_fn(self, input, expert_fn, group, sharded_count):
return result_output, l_loss


class MegatronLMGate():
class MegatronLMGate(torch.nn.Module):
"""Megatron-LM Tensor Parallel over MoE Gate Type
"""

Expand All @@ -157,6 +157,9 @@ def __init__(
**kwargs,
):
self.l_zero = None
self._modules = dict()
self._parameters = dict()
self._buffers = dict()

def named_parameters(self):
return []
Expand All @@ -173,15 +176,6 @@ def apply_on_expert_fn(self, input, expert_fn, group, sharded_count):

class MOELayer(torch.nn.Module):
"""Tutel optimized MOELayer
Args:
gate_type : dict-type gate description, e.g. {'type': 'top', 'k': 2, ..}, or {'type': 'megatron'}
model_dim : the number of channels for MOE's input tensor
experts : a dict-type config for builtin expert network, or a torch.nn.Module-type custom expert network
scan_expert_func : allow users to specify a lambda function to iterate each experts param, e.g. `scan_expert_func = lambda name, param: setattr(param, 'expert', True)`
result_func : allow users to specify a lambda function to format the MoE output and aux_loss, e.g. `result_func = lambda output: (output, output.l_aux)`
group : specify the explicit communication group of all_to_all
seeds : a tuple containing a tripple of int to specify manual seed of (shared params, local params, others params after MoE's)
"""

def __init__(self, gate_type, model_dim: int, experts = None, scan_expert_func = None, result_func = None, group: Optional[Any] = None, seeds = None, **kwargs):
Expand Down Expand Up @@ -342,22 +336,28 @@ def to(self, *args, **kwargs):
logging.warning(f"gate_type value `{gate_type}` in tutel.moe_layer has been deprecated, please use gate_type = {{'type': 'top', 'k': {top_k}}} instead.")
gate_type = {'type': 'top', 'k': top_k}

if gate_type['type'] == 'top':
if seeds is not None and seeds[0] is not None:
torch.manual_seed(seeds[0])

if "fp32_gate" in kwargs:
logging.warning(f'`fp32_gate` option in tutel.moe_layer has been deprecated, please move this option to gate_type = {{.., "fp32_gate": {kwargs["fp32_gate"]}}} instead.')
gate_type["fp32_gate"] = kwargs["fp32_gate"]
if not isinstance(gate_type, list):
gate_type = [gate_type]

self.gates = []
for gi, single_gate_type in enumerate(gate_type):
if single_gate_type['type'] == 'top':
if seeds is not None and seeds[0] is not None:
torch.manual_seed(seeds[0] + gi)
if "fp32_gate" in kwargs:
logging.warning(f'`fp32_gate` option in tutel.moe_layer has been deprecated, please move this option to gate_type = {{.., "fp32_gate": {kwargs["fp32_gate"]}}} instead.')
single_gate_type["fp32_gate"] = kwargs["fp32_gate"]

self.gates += [TopKGate(model_dim=model_dim, top_k=single_gate_type['k'], num_global_experts=self.num_global_experts, **single_gate_type)]
elif single_gate_type['type'] == 'megatron':
self.gates += [MegatronLMGate(**single_gate_type)]
assert isinstance(experts, dict), "Gate type `megatron` requires dict-type expert description."
assert self.num_local_experts == 1, "Gate type `megatron` requires `count_per_node` == 1 in expert attributions."
assert experts['type'] == 'ffn', "Gate type `megatron` requires `type` == `ffn` in expert attributions."
else:
raise Exception("Unrecognized gate_type: %s" % single_gate_type)

self.gate = TopKGate(model_dim=model_dim, top_k=gate_type['k'], num_global_experts=self.num_global_experts, **gate_type)
elif gate_type['type'] == 'megatron':
self.gate = MegatronLMGate(**gate_type)
assert isinstance(experts, dict), "Gate type `megatron` requires dict-type expert description."
assert self.num_local_experts == 1, "Gate type `megatron` requires `count_per_node` == 1 in expert attributions."
assert experts['type'] == 'ffn', "Gate type `megatron` requires `type` == `ffn` in expert attributions."
else:
raise Exception("Unrecognized gate_type: %s" % gate_type)
self.gates = ModuleList(self.gates)

if seeds is not None and len(seeds) > 2 and seeds[2] is not None:
torch.manual_seed(seeds[2])
Expand All @@ -375,13 +375,13 @@ def expert_fn(dispatched_input):

def get_parameter_iterator(self, param_type):
if param_type == 'gate':
return self.gate.named_parameters()
return self.gates.named_parameters()
elif param_type == 'local_experts':
return self.experts.named_parameters()
else:
raise Exception("Specified parameter type is not recognized: %s. Valid `param_type` includes: gate, local_experts." % param_type)

def forward(self, input: Tensor, **kwargs: Any):
def forward(self, input: Tensor, gate_index=0, **kwargs: Any):
if self.skip_moe:
result_output = input
result_output.l_aux = None
Expand All @@ -404,7 +404,7 @@ def forward(self, input: Tensor, **kwargs: Any):
reshaped_input = pad_input

reshaped_input = reshaped_input.to(next(iter(self.experts.parameters())).dtype)
result_output, l_aux = self.gate.apply_on_expert_fn(reshaped_input, self.expert_fn, self.group, sharded_count=self.sharded_count)
result_output, l_aux = self.gates[gate_index].apply_on_expert_fn(reshaped_input, self.expert_fn, self.group, sharded_count=self.sharded_count)

result_output = result_output[:reshaped_input_samples, :]
result_output = result_output.view(original_shape).to(original_dtype)
Expand Down

0 comments on commit ea17ea6

Please sign in to comment.