From 4c052654765f4bd43afcc4b766ccc3bd2731b4e8 Mon Sep 17 00:00:00 2001 From: sheetalarkadam <100380551+sheetalarkadam@users.noreply.github.com> Date: Wed, 22 Nov 2023 11:31:49 -0800 Subject: [PATCH] Sync changes in internal moe to public repo (#186) * update moe code with internal Microsoft code. Main changes include Extend Distributed Grid to enable expert relocation Extend Distributed Grid to enable expert relocation, Simplify interface. Use options dict for MixtureOfExperts and Layers, expert weight calculation optimization Expert weight calculation optimization. Avoid calculation on for padding used for unused expert capacity * sync tests from internal Microsoft repo. Added new tests for nccl_output_processor,nccl and sccl tests, compression and coverage. Added readme on running tests * restructure directory to install with setup.py and maintain same as internal --------- Co-authored-by: Sheetal Kadam --- README.md | 5 +- ort_moe/README.md | 15 + ort_moe/VERSION_NUMBER | 1 + ort_moe/{ => ort_moe}/__init__.py | 0 ort_moe/{ => ort_moe}/collectives.py | 64 ++- ort_moe/{ => ort_moe}/custom_ops.py | 24 +- ort_moe/{ => ort_moe}/experts.py | 12 +- ort_moe/{ => ort_moe}/gate_logs.py | 0 ort_moe/{ => ort_moe}/grids.py | 264 +++++++----- ort_moe/{ => ort_moe}/layers.py | 93 ++--- ort_moe/{ => ort_moe}/loss_functions.py | 0 ort_moe/{ => ort_moe}/moe.py | 111 +++-- ort_moe/ort_moe/moe_config.py | 171 ++++++++ ort_moe/{ => ort_moe}/topKgate.py | 206 ++++++---- ort_moe/{ => ort_moe}/utils.py | 10 +- ort_moe/setup.py | 42 ++ ort_moe/tests/README.md | 23 ++ ort_moe/tests/__init__.py | 2 + ort_moe/tests/nccl_output_processor.py | 31 ++ ort_moe/tests/pytest.ini | 4 + ort_moe/tests/run_all.sh | 18 + ort_moe/tests/test_compression.py | 49 +++ ort_moe/tests/test_coverage.py | 31 ++ ort_moe/tests/test_grid.py | 57 ++- ort_moe/tests/test_moe.py | 480 +++++++++++++--------- ort_moe/tests/test_nccl.py | 10 + ort_moe/tests/test_sccl_without_import.py | 10 + ort_moe/tests/test_top2gating.py | 99 ++--- ort_moe/tests/test_uni_image.py | 32 ++ 29 files changed, 1326 insertions(+), 538 deletions(-) create mode 100644 ort_moe/README.md create mode 100644 ort_moe/VERSION_NUMBER rename ort_moe/{ => ort_moe}/__init__.py (100%) rename ort_moe/{ => ort_moe}/collectives.py (66%) rename ort_moe/{ => ort_moe}/custom_ops.py (67%) rename ort_moe/{ => ort_moe}/experts.py (92%) rename ort_moe/{ => ort_moe}/gate_logs.py (100%) rename ort_moe/{ => ort_moe}/grids.py (85%) rename ort_moe/{ => ort_moe}/layers.py (77%) rename ort_moe/{ => ort_moe}/loss_functions.py (100%) rename ort_moe/{ => ort_moe}/moe.py (77%) create mode 100644 ort_moe/ort_moe/moe_config.py rename ort_moe/{ => ort_moe}/topKgate.py (80%) rename ort_moe/{ => ort_moe}/utils.py (98%) create mode 100644 ort_moe/setup.py create mode 100644 ort_moe/tests/README.md create mode 100644 ort_moe/tests/nccl_output_processor.py create mode 100644 ort_moe/tests/pytest.ini create mode 100755 ort_moe/tests/run_all.sh create mode 100644 ort_moe/tests/test_compression.py create mode 100644 ort_moe/tests/test_coverage.py create mode 100644 ort_moe/tests/test_nccl.py create mode 100644 ort_moe/tests/test_sccl_without_import.py create mode 100644 ort_moe/tests/test_uni_image.py diff --git a/README.md b/README.md index 2a4a9d02..c6ee8c10 100644 --- a/README.md +++ b/README.md @@ -84,8 +84,7 @@ Build MoE ```bash cd ort_moe -pip install build # Install PyPA build -python -m build +python setup.py install ``` ## Install for Inference @@ -290,4 +289,4 @@ Please refer to our [contributing guide](CONTRIBUTING.md) for more information o ## License -This project has an MIT license, as found in the [LICENSE](LICENSE) file. +This project has an MIT license, as found in the [LICENSE](LICENSE) file. \ No newline at end of file diff --git a/ort_moe/README.md b/ort_moe/README.md new file mode 100644 index 00000000..1b5e70f4 --- /dev/null +++ b/ort_moe/README.md @@ -0,0 +1,15 @@ +# Introduction +Mixture Of Experts (MoE) implementation in PyTorch +This repo contains following components +#moe_module +moe_module contains PyTorch implementation of MoE, Experts and Gates and Transformer layers. This module is used by 1P workloads to inject MoE layers in their model. We aim to implement moe_module such that it is amenable to variety of model distribution techniques for large scale (100B+ param) training. +#Proxy Models +We have implemented proxy of Gshard, Switch Transformer etc.. models using moe_module in moe_models.py. These models serves two purposes. First to simple standalone proxy model for approximate performance analysis and characterization of scaling efforts. Second is to evaluate flexibility of moe_module interface in variety of situations. We encourage contribution of new proxy models. +#Trainer +We have extended NLP trainer from Pytorch tutorial in baseline_nlp.py to run proxy models and collect performance data. We use WIkiText as a dataset, which is obviously not representative of real world 1P workload scenarios. The trainer allows us to experiment with variety of PyTorch packages/techniques such as DeepSpeed, Apex, torch DistributedDataParallel etc.. We welcome contributions to incorporate Pipeline Parallelism and Megatron-style training. +# ITP Scripts +We have scripts available in experiments/itp folder to easily launch jobs on ITP cluster. We have two ready made experiments available, Switch-CA and Switch-CB. They are two variants of Switch Transformer model scaled to 100B parameter size. + +# Updates +0.1.7 : Adapt new name - ort_moe + diff --git a/ort_moe/VERSION_NUMBER b/ort_moe/VERSION_NUMBER new file mode 100644 index 00000000..0ea3a944 --- /dev/null +++ b/ort_moe/VERSION_NUMBER @@ -0,0 +1 @@ +0.2.0 diff --git a/ort_moe/__init__.py b/ort_moe/ort_moe/__init__.py similarity index 100% rename from ort_moe/__init__.py rename to ort_moe/ort_moe/__init__.py diff --git a/ort_moe/collectives.py b/ort_moe/ort_moe/collectives.py similarity index 66% rename from ort_moe/collectives.py rename to ort_moe/ort_moe/collectives.py index ea8a24d9..5d30ebe0 100644 --- a/ort_moe/collectives.py +++ b/ort_moe/ort_moe/collectives.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- +import os import torch import torch.distributed as dist @@ -44,6 +45,21 @@ def backward(ctx, *grad_output): output = (grad_output_chunks[ctx.rank]).contiguous() return None, output, None +def compressed_all_to_all(output, input, group=None): + + world_size = dist.get_world_size(group) + rank = torch.distributed.get_rank(group) + + ts_in = torch.tensor_split(input, world_size) + compressed_a2a_input, _ = dg_compress(ts_in) + ts_out = torch.tensor_split(output, world_size) + + for i in range(world_size): + if i != rank: + torch.distributed.send(compressed_a2a_input[i], i) + torch.distributed.recv(compressed_a2a_input[i], i) + dg_decompress(compressed_a2a_input, ts_out) + # Based on https://github.com/pytorch/pytorch/pull/40762 class AllToAll(torch.autograd.Function): @staticmethod @@ -65,7 +81,10 @@ def forward(ctx, group, input, max_len_cpu): # type: ignore input = torch.nn.functional.pad(input, pad=(0,0,0,max_len_cpu-input.shape[-2]), mode = 'constant', value=0.0) input = input.contiguous() output = torch.empty_like(input) - dist.all_to_all_single(output, input, group=group) + if os.environ.get('ORT_MOE_COMPRESS_ALLToALL') is not None: + compressed_all_to_all(output, input, group=group) + else: + dist.all_to_all_single(output, input, group=group) return output @staticmethod @@ -102,3 +121,46 @@ def backward(ctx, *grad_output): dist.all_reduce(grad_output[0], group = ctx.group) ret_val = grad_output[0].contiguous() return ret_val, None + +_DIETGPU_SCRATCH_PAD = None +_DIETGPU_LIBRARY_LOADED = False +def dg_load_library(): + global _DIETGPU_LIBRARY_LOADED + if _DIETGPU_LIBRARY_LOADED is True: + return + + dg_lib = os.environ.get('DIETGPU_LIB_PATH') + torch.ops.load_library(dg_lib) + _DIETGPU_LIBRARY_LOADED = True + return + +def get_tensor_list_device(the_list): + if isinstance(the_list, torch.Tensor): + return the_list.device + + for l in the_list: + if isinstance(l, torch.Tensor): + return l.device + + return None + +def dg_get_scratch_pad(device): + global _DIETGPU_SCRATCH_PAD + if _DIETGPU_SCRATCH_PAD is None: + _DIETGPU_SCRATCH_PAD = torch.empty([64*1024*1024], dtype=torch.uint8, device=device) + return _DIETGPU_SCRATCH_PAD + +def dg_compress(input_list): + dg_load_library() + + output, output_size, _ = torch.ops.dietgpu.compress_data(True, input_list, dg_get_scratch_pad(get_tensor_list_device(input_list))) + compressed_output_list = [] + for size, t in zip(output_size, [*output]): + truncated_t = t.narrow(0, 0, size.item()).clone() + compressed_output_list.append(truncated_t) + + return compressed_output_list, output_size + +def dg_decompress(input_list, output_list): + dg_load_library() + torch.ops.dietgpu.decompress_data(True, input_list, output_list, dg_get_scratch_pad(get_tensor_list_device(input_list))) diff --git a/ort_moe/custom_ops.py b/ort_moe/ort_moe/custom_ops.py similarity index 67% rename from ort_moe/custom_ops.py rename to ort_moe/ort_moe/custom_ops.py index 94e260a1..a0f881e4 100644 --- a/ort_moe/custom_ops.py +++ b/ort_moe/ort_moe/custom_ops.py @@ -4,12 +4,13 @@ # -------------------------------------------------------------------------- import torch +from .moe_config import moe_config # The switch to decided whether to use torch.einsum (when this flag is true) or use rewrited-version. # switch can be bubbled up in future USE_EINSUM = True -def einsum(rule, a, b): +def om_einsum(rule, a, b): """ The rewrite of torch.einsum for some specific cases. The rewrites are on par or more performant upon the benchmark we tested @@ -44,3 +45,24 @@ def einsum(rule, a, b): return torch.bmm(a, b.transpose(1, 2)).reshape(s, m) else: return torch.einsum(rule, a, b) + +def om_cumsum(mask, dim, options = None): + """ + The rewrite of torch.cumsum to use tutel cumsum if desired. + Args: + tensor (torch.Tensor): the input tensor of cumsum op + dim (int): the dimension of cumsum op + options (moe_config): the options to decide whether to use tutel cumsum + """ + if mask.device.type == 'cpu' or options is None: + return torch.cumsum(mask, dim) - 1 + + moe_options = None + if isinstance(options, moe_config): moe_options = options + else: moe_options = moe_config(options) + + if moe_options.enable_tutel_cumsum(): + from tutel.jit_kernels.gating import fast_cumsum_sub_one + return fast_cumsum_sub_one(mask, dim) + + return torch.cumsum(mask, dim) - 1 diff --git a/ort_moe/experts.py b/ort_moe/ort_moe/experts.py similarity index 92% rename from ort_moe/experts.py rename to ort_moe/ort_moe/experts.py index 12ede479..89b8844b 100644 --- a/ort_moe/experts.py +++ b/ort_moe/ort_moe/experts.py @@ -22,18 +22,20 @@ class FFNExpert(nn.Module): """ def __init__(self, d_model, dim_feedforward, dgrid, activation_fn = nn.functional.relu, expert_dropout = 0.0): super().__init__() - self.mp_size = dgrid.get_expert_slicing_world_size() + self.mp_size = 1 + if dgrid is not None: + self.mp_size = dgrid.get_expert_slicing_world_size() self.linear1 = nn.Linear(d_model, dim_feedforward//self.mp_size, bias=False) self.linear2 = nn.Linear(dim_feedforward//self.mp_size, d_model, bias=False) self.activation_fn = activation_fn self.expert_dropout_rate = expert_dropout def forward(self, x: torch.tensor): - x = self.linear1(x) + x = self.linear1(x.float()) x = self.activation_fn(x) if self.expert_dropout_rate > 0: x = F.dropout(x, p=self.expert_dropout_rate, training=self.training) - x = self.linear2(x) + x = self.linear2(x.float()) return x class MergedFFNExpert(nn.Module): @@ -73,11 +75,11 @@ def forward(self, x: torch.tensor): x = x.transpose(0, 1) #gecm --> egcm input_shape = x.shape reshaped_x = x.reshape(input_shape[0], -1, input_shape[-1]) #egcm --> e,gxc,m - out1 = torch.bmm(reshaped_x, self.weight1) #e, gxc, f + out1 = torch.bmm(reshaped_x.float(), self.weight1) #e, gxc, f out1 = self.activation_fn(out1) if self.expert_dropout_rate > 0: out1 = F.dropout(out1, p=self.expert_dropout_rate, training=self.training) - out2 = torch.bmm(out1, self.weight2) #e, gxc, m + out2 = torch.bmm(out1.float(), self.weight2) #e, gxc, m out2 = out2.reshape(input_shape) out2 = out2.transpose(0, 1) #egcm --> gecm return out2 \ No newline at end of file diff --git a/ort_moe/gate_logs.py b/ort_moe/ort_moe/gate_logs.py similarity index 100% rename from ort_moe/gate_logs.py rename to ort_moe/ort_moe/gate_logs.py diff --git a/ort_moe/grids.py b/ort_moe/ort_moe/grids.py similarity index 85% rename from ort_moe/grids.py rename to ort_moe/ort_moe/grids.py index 260f6932..cf541da5 100644 --- a/ort_moe/grids.py +++ b/ort_moe/ort_moe/grids.py @@ -8,7 +8,75 @@ import torch import torch.distributed as dist -class DistributionGrid: + +class BaseGrid: + '''BaseGrid class provides an abstract class in the case that no distributed backend is available. + ''' + def __init__(self): + self._EXPERT_REPLICA_GROUP = None + self._EXPERT_PARALLEL_GROUP = None + self._EXPERT_SLICING_GROUP = None + self._export_relocation_map = {} + + def get_expert_replica_group(self): + return self._EXPERT_REPLICA_GROUP + + ''' + Interface to support Expert Parallelism. + + In this distribution technique, experts in each layer are evenly distributed among available ranks + while non-expert parameters are replicated on each rank. + ''' + def get_expert_parallel_group(self): + return self._EXPERT_PARALLEL_GROUP + + def get_expert_parallel_world_size(self): + if self.get_expert_parallel_group() is not None: + return dist.get_world_size(group=self.get_expert_parallel_group()) + return 1 + + def get_expert_parallel_rank(self): + if self.get_expert_parallel_group() is not None: + return dist.get_rank(group=self.get_expert_parallel_group()) + return 0 + + ''' + Interface to support Expert Slicing. + + In this distribution technique, experts in each layer are sliced on hidden dimension and + sharded across available ranks. The non-expert parameters are replicated on each rank. + ''' + def get_expert_slicing_group(self): + return self._EXPERT_SLICING_GROUP + + ''' + Simplified common interface for Expert Parallel and Expert Slicing + ''' + def get_expert_group(self): + if self._EXPERT_PARALLEL_GROUP is None: + return self._EXPERT_SLICING_GROUP + else: + return self._EXPERT_PARALLEL_GROUP + + def get_expert_world_size(self): + if self.get_expert_slicing_group() is not None: + return dist.get_world_size(group=self.get_expert_slicing_group()) + elif self.get_expert_parallel_group() is not None: + return dist.get_world_size(group=self.get_expert_parallel_group()) + return 1 + + def get_expert_rank(self): + if self.get_expert_parallel_group() is not None: + return dist.get_rank(group=self.get_expert_parallel_group()) + elif self.get_expert_slicing_group() is not None: + return dist.get_rank(group=self.get_expert_slicing_group()) + return 0 + + def get_mpi_group_for_expert_group(self): + return None + + +class DistributionGrid(BaseGrid): '''DistributionGrid provides simple interface to create and manage process groups for various distributed training configurations. @@ -17,11 +85,11 @@ class DistributionGrid: [2] Create expert parallel grid where experts are evenly distributed among available ranks. dgrid = DistributionGrid(expert_parallel_group_size = ) - + [3] Create expert slicing grid where each expert are evenly sharded among avaiable ranks. dgrid = DistributionGrid(expert_slicing_group_size = ) - [4] Create replicas of expert parallel or expert slicing distributions + [4] Create replicas of expert parallel or expert slicing distributions dgrid = DistributionGrid(expert_parallel_group_size (or expert_slicing_group_size) = ), expert_parallel_replica_group_size = ) @@ -32,23 +100,22 @@ class DistributionGrid: data_parallel_group_size: number of data parallel copies of the model expert_parallel_group_size: number of GPUs sharing experts of single layer expert_slicing_group_size: number of GPUs experts are sharded onto of single layer - expert_parallel_replica_group_size: number of data parallel copies of experts + expert_parallel_replica_group_size: number of data parallel copies of experts num_of_nodes_in_pipeline: number of nodes used in pipeline parallel mode num_of_pipeline_stage: number of pipeline stages options: Various grid options ''' - def __init__(self, data_parallel_group_size = None, expert_parallel_group_size = None, + def __init__(self, data_parallel_group_size = None, expert_parallel_group_size = None, expert_parallel_replica_group_size = None, expert_slicing_group_size = None, - num_of_nodes_in_pipeline = None, num_of_pipeline_stage = None, options = None): + num_of_nodes_in_pipeline = None, num_of_pipeline_stage = None, options = {}): #print("==> initialize dgrid") + super().__init__() self._DATA_PARALLEL_GROUP = None - self._EXPERT_PARALLEL_GROUP = None - self._EXPERT_REPLICA_GROUP = None - self._EXPERT_SLICING_GROUP = None # base rank for broadcasting initialized weights self._EXPERT_REPLICA_GROUP_BCAST_SRC_RANK = None self._ep_rank_list = None self._es_rank_list = None + self._dp_rank_list = None self._MPI_EP_GROUP = None self._MPI_ES_GROUP = None # Used to wrap MoE layer by FullyShardedDataParallel for the experts, which are not sharded @@ -70,13 +137,13 @@ def __init__(self, data_parallel_group_size = None, expert_parallel_group_size = self._options = options assert not(expert_slicing_group_size is not None and expert_parallel_group_size is not None), \ - "Cannot have both expert slicing and expert parallel" + "Cannot have both expert slicing and expert parallel" is_expert_slicing = expert_slicing_group_size is not None expert_group_size = expert_slicing_group_size if is_expert_slicing else expert_parallel_group_size # Standard data parallel distribution without any separate distribution for experts if data_parallel_group_size is not None: - + assert expert_parallel_group_size is None, \ "Standard Data Parallelism with Expert Parallelism is not supported" assert expert_slicing_group_size is None, \ @@ -98,7 +165,7 @@ def __init__(self, data_parallel_group_size = None, expert_parallel_group_size = "Expert Parallelism/Slicing with Expert Replicas is not supported" assert expert_replica_group_size is None, \ "Pipeline Parallelism with Expert Repliacs is not supported" - + self._initialize_pipeline_parallel(num_of_nodes_in_pipeline, num_of_pipeline_stage) return @@ -197,16 +264,16 @@ def _initialize_expert_parallel_or_slicing_group(self, ranks_count, is_expert_sl def _initialize_expert_parallel_or_expert_slicing_replica_groups(self, expert_ranks, dp_ranks, is_expert_slicing = False): ''' Initialize new process groups for exper parallel replicas. - Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we - want 2-data parallel copies of each experts. - + Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we + want 2-data parallel copies of each experts. + This function will create 8 expert replica groups, 2 expert parallel groups as: 8 expert replica groups: [g0, g8], [g1, g9], [g2, g10], [g3, g11], [g4, g12], [g5, g13], [g6, g14], [g7, g15] 2 expert parallel groups or expert slicing group: [g0, g1, g2, g3, g4, g5, g6, g7], [g8, g9, g10, g11, g12, g13, g14, g15] - Note that for efficiency of all2all, the caller should make sure adjacent ranks + Note that for efficiency of all2all, the caller should make sure adjacent ranks are on the same expert parallel group. Args: @@ -219,15 +286,16 @@ def _initialize_expert_parallel_or_expert_slicing_replica_groups(self, expert_ra assert dist.get_world_size() % dp_ranks == 0, "Only support the n-way expert parallel replicas when world size is dividable by n" assert dp_ranks > 1, "There should be at least 2 replicas" - row_major_list = True - if self._options is not None: - rank_schedule = self._options.get("rank_schedule") - if rank_schedule == "row_major": - row_major_list = True - elif rank_schedule == "column_major": - row_major_list = False - else: - assert False, "Selected grid rank schedule is not supported" + + rank_schedule = self._options.get("rank_schedule", "row_major") + if rank_schedule == "row_major": + row_major_list = True + elif rank_schedule == "column_major": + row_major_list = False + else: + assert False, "Selected grid rank schedule is not supported" + is_replica_in_same_node = self._options.get("is_replica_in_same_node", False) + assert row_major_list or not is_replica_in_same_node, "column_major with is_replica_in_same_node == True is not supported" def rearrange(rlist, eranks): rearranged =[] @@ -240,24 +308,28 @@ def rearrange(rlist, eranks): expert_rank_list = [] if row_major_list is True: - expert_rank_list = [rank_list[i:i+expert_ranks] for i in range(0, len(rank_list), expert_ranks)] - dp_rank_list = [rank_list[i::expert_ranks] for i in range(0, expert_ranks)] + if is_replica_in_same_node: + expert_rank_list = [rank_list[i::dp_ranks] for i in range(0, dp_ranks)] + self._dp_rank_list = [rank_list[i:i+dp_ranks] for i in range(0, len(rank_list), dp_ranks)] + else: + expert_rank_list = [rank_list[i:i+expert_ranks] for i in range(0, len(rank_list), expert_ranks)] + self._dp_rank_list = [rank_list[i::expert_ranks] for i in range(0, expert_ranks)] else: expert_rank_list = [rearranged_list[i:i+expert_ranks] for i in range(0, len(rearranged_list), expert_ranks)] - dp_rank_list = [rearranged_list[i::expert_ranks] for i in range(0, expert_ranks)] + self._dp_rank_list = [rearranged_list[i::expert_ranks] for i in range(0, expert_ranks)] # [1] Setup process group for expert parallel distribution local_rank = dist.get_rank() exp_group,_ = self._build_process_group(expert_rank_list, local_rank) if is_expert_slicing: - self._es_rank_list = expert_rank_list + self._es_rank_list = expert_rank_list self._EXPERT_SLICING_GROUP = exp_group else: - self._ep_rank_list = expert_rank_list + self._ep_rank_list = expert_rank_list self._EXPERT_PARALLEL_GROUP = exp_group - + # [2] Setup process group for expert parallel replicas - self._EXPERT_REPLICA_GROUP, self._EXPERT_REPLICA_GROUP_BCAST_SRC_RANK = self._build_process_group(dp_rank_list, local_rank) + self._EXPERT_REPLICA_GROUP, self._EXPERT_REPLICA_GROUP_BCAST_SRC_RANK = self._build_process_group(self._dp_rank_list, local_rank) self._DATA_PARALLEL_GROUP = dist.group.WORLD @@ -303,7 +375,7 @@ def _build_process_group(self, rank_lists, rank): pg = None min_rank = None # Each rank in main group need to go through each new_group() function even if it doesn't belong to that group - # https://pytorch.org/docs/stable/_modules/torch/distributed/distributed_c10d.html#new_group + # https://pytorch.org/docs/stable/_modules/torch/distributed/distributed_c10d.html#new_group for r in rank_lists: tmp = dist.new_group(r) if rank in r: @@ -311,26 +383,6 @@ def _build_process_group(self, rank_lists, rank): min_rank = min(r) return pg, min_rank - - ''' - Interface to support Expert Parallelism. - - In this distribution technique, experts in each layer are evenly distributed among available ranks - while non-expert parameters are replicated on each rank. - ''' - def get_expert_parallel_group(self): - return self._EXPERT_PARALLEL_GROUP - - def get_expert_parallel_world_size(self): - if self.get_expert_parallel_group() is not None: - return dist.get_world_size(group=self.get_expert_parallel_group()) - return 1 - - def get_expert_parallel_rank(self): - if self.get_expert_parallel_group() is not None: - return dist.get_rank(group=self.get_expert_parallel_group()) - return 0 - def get_moe_group(self): ''' This group is used by FullyShardedDataParallel (FSDP) to wrap MoE layers. @@ -348,15 +400,6 @@ def get_moe_group(self): self._MOE_GROUP,_ = self._build_process_group(moe_groups, local_rank) return self._MOE_GROUP - ''' - Interface to support Expert Slicing. - - In this distribution technique, experts in each layer are sliced on hidden dimension and - sharded across available ranks. The non-expert parameters are replicated on each rank. - ''' - def get_expert_slicing_group(self): - return self._EXPERT_SLICING_GROUP - def get_expert_slicing_world_size(self): if self.get_expert_slicing_group() is not None: return dist.get_world_size(group=self.get_expert_slicing_group()) @@ -374,22 +417,22 @@ def get_expert_slicing_rank(self): ''' def get_expert_parallel_replica_group(self): return None if self.get_expert_parallel_group() is None else self.get_expert_replica_group() - + def get_expert_parallel_replica_world_size(self): return 1 if self.get_expert_parallel_group() is None else self.get_expert_replica_world_size() - + + def expert_parallel_replica_group_member_rank_lists(self): + return self._dp_rank_list + def get_expert_parallel_replica_rank(self): return 0 if self.get_expert_parallel_group() is None else self.get_expert_replica_rank() - + def get_expert_parallel_replica_src_rank(self): return None if self.get_expert_parallel_group() is None else self.get_expert_replica_src_rank() - + def get_mpi_group_for_expert_parallel_group(self): return None if self.get_expert_parallel_group() is None else self.get_mpi_group_for_expert_group() - def get_expert_replica_group(self): - return self._EXPERT_REPLICA_GROUP - def get_expert_replica_world_size(self): if self.get_expert_replica_group() is not None: return dist.get_world_size(group=self.get_expert_replica_group()) @@ -403,30 +446,7 @@ def get_expert_replica_rank(self): def get_expert_replica_src_rank(self): return self._EXPERT_REPLICA_GROUP_BCAST_SRC_RANK - ''' - Simplified common interface for Expert Parallel and Expert Slicing - ''' - def get_expert_group(self): - if self._EXPERT_PARALLEL_GROUP is None: - return self._EXPERT_SLICING_GROUP - else: - return self._EXPERT_PARALLEL_GROUP - - def get_expert_world_size(self): - if self.get_expert_slicing_group() is not None: - return dist.get_world_size(group=self.get_expert_slicing_group()) - elif self.get_expert_parallel_group() is not None: - return dist.get_world_size(group=self.get_expert_parallel_group()) - return 1 - - def get_expert_rank(self): - if self.get_expert_parallel_group() is not None: - return dist.get_rank(group=self.get_expert_parallel_group()) - elif self.get_expert_slicing_group() is not None: - return dist.get_rank(group=self.get_expert_slicing_group()) - return 0 - - # Create MPI group for correspending expert parallel group or expert slicing gorup. + # Create MPI group for correspending expert parallel group or expert slicing gorup. def get_mpi_group_for_expert_group(self): is_expert_slicing = self._es_rank_list is not None if self._MPI_EP_GROUP is not None: @@ -437,7 +457,7 @@ def get_mpi_group_for_expert_group(self): return None expert_group = dist.group.WORLD for g in expert_rank_list: - tmp = MPI.COMM_WORLD.Create_group(MPI.COMM_WORLD.group.Incl(g)) + tmp = MPI.COMM_WORLD.Create_group(MPI.COMM_WORLD.group.Incl(g)) if MPI.COMM_WORLD.Get_rank() in g: expert_group = tmp if is_expert_slicing: @@ -485,7 +505,7 @@ def get_first_pipeline_stage_device(self): def get_last_pipeline_stage_device(self): # FIXME : Return torch device here ? return self._NUM_PIPELINE_STAGES - 1 - + ''' Helper routines to help map experts from one distribution scheme to another. ''' @@ -510,6 +530,11 @@ def map_expert_id_local_to_global(self, total_experts, eparm_id): ep_rank = self.get_expert_parallel_rank() # Calculate global id of the expert gid = eparm_id + int(total_experts/ep_world_size) * ep_rank + + # If gid was relocated then use the original global id + if gid in self._export_relocation_map: + gid = self._export_relocation_map[gid] + if self.get_expert_parallel_replica_group() is None: return gid else: @@ -523,7 +548,7 @@ def map_expert_id_global_to_local(self, total_experts, global_eparm_id): global_eparam_id : Global ID of the expert parameter that will be mapped to a local ID based on the grid's distribution strategy. Return Values: - nrank, nid : Local rank and ID of this global expert parameter will be mapped to + nrank, nid : Local rank and ID of this global expert parameter will be mapped to ''' assert self.get_expert_parallel_group() is not None, "Unsupported expert mapping configuration" assert self.get_expert_slicing_group() is None, "Unsupported expert mapping configuration" @@ -531,17 +556,58 @@ def map_expert_id_global_to_local(self, total_experts, global_eparm_id): assert self.get_num_of_pipeline_stages() == 0, "Unsupported expert mapping configuration" assert global_eparm_id < total_experts, 'Global expert id out of range' + # If global param id was relocated then use the original global id + if global_eparm_id in self._export_relocation_map: + global_eparm_id = self._export_relocation_map[global_eparm_id] + # Map global id to the grid - ep_world_size = self.get_expert_parallel_world_size() + ep_world_size = self.get_expert_parallel_world_size() assert total_experts % ep_world_size == 0, 'Experts can not be evenly divided among ranks' ep_rank_size = int(total_experts / ep_world_size) - nrank = math.floor(global_eparm_id / ep_rank_size) + nrank = math.floor(global_eparm_id / ep_rank_size) nid = global_eparm_id - ep_rank_size * nrank if self.get_expert_parallel_replica_group() is None: - return nrank, nid + return nrank, nid # expert parallel replica result = [(nrank, nid)] for i in range(1, self.get_expert_parallel_replica_world_size()): result += [(nrank + i * self.get_expert_parallel_world_size(), nid)] return result + + def exchange_expert_location(self, global_expert_id1, global_expert_id2): + ''' + Exchange expert location across ranks. This can be used to balance token routing. + Args: + global_expert_id1, global_expert_id2 : Two experts that will be exchanged across ranks. + ''' + if global_expert_id1 == global_expert_id2: + return + self._export_relocation_map[global_expert_id1] = global_expert_id2 + self._export_relocation_map[global_expert_id2] = global_expert_id1 + return + + def remove_expert_relocation(self, id): + ''' + Remove expert relocation entry from the relocation map. + Args: + id : Expert id that will be removed from the relocation map. + ''' + if id in self._export_relocation_map: + id2 = self._export_relocation_map[id] + del self._export_relocation_map[id] + del self._export_relocation_map[id2] + + def get_get_relocation_id(self, id): + ''' + If the expert is relocated then return its relocated expert id. + Args: + id : Expert id whose relocated id will be returned. + ''' + return self._export_relocation_map.get(id, None) + + def get_expert_relocation_map(self): + return self._export_relocation_map + + def set_expert_relocation_map(self, map): + self._export_relocation_map = map \ No newline at end of file diff --git a/ort_moe/layers.py b/ort_moe/ort_moe/layers.py similarity index 77% rename from ort_moe/layers.py rename to ort_moe/ort_moe/layers.py index f778756a..f5fa5a51 100644 --- a/ort_moe/layers.py +++ b/ort_moe/ort_moe/layers.py @@ -12,6 +12,7 @@ from ort_moe.topKgate import TopKGate from ort_moe.moe import MixtureOfExpertsFunc from ort_moe.utils import fsdp_wrap +from ort_moe.moe_config import moe_config class TransformerMoEEncoderLayer(nn.Module): r"""TransformerMoEEncoderLayer is made up of muti headded attention, and gated collection @@ -26,27 +27,23 @@ class TransformerMoEEncoderLayer(nn.Module): nexprts: the number of experts(default=64). balance_ratio: The scaling ratio for the loss_aux gate: the gating function (default=top2k). - fp16_mode : True if FP16 mode is enabled. Default is 'False' expertslist: List of experts of type nn.ModuleList. - merged_expert: Whether the experts are mergedFFN experts distribution_grid: DistributionGrid object providing interface to query torch.distributed process groups - use_fsdp : Use FullyShardedDataParallel to shard the layer. Default is 'False' - flatten_parameters : Flatten sharded paratmers when use_fsdp is True. Default is 'True' - apex_opt_level : Default 'None' - + options: See moe_config.py Examples:: >>> moe_layer = nn.TransformerMoEEncoderLayer(d_model=512, nhead=8) >>> src = torch.rand(10, 32, 512) >>> out = moe_layer(src) """ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", - nexperts=64, balance_ratio = 0.01, gate=None, fp16_mode=False, - expertslist=None, merged_expert=True, use_mpi4py=False, distribution_grid=None, - use_fsdp=False, flatten_parameters=True, apex_opt_level=None): + nexperts=64, balance_ratio = 0.01, gate=None, + expertslist=None, distribution_grid=None, + options={}): super(TransformerMoEEncoderLayer, self).__init__() + self.options = moe_config(options) if not gate: #default is top1 - gate = TopKGate(d_model, nexperts, balance_ratio=balance_ratio, fp16_mode = fp16_mode, k = 1, dgrid=distribution_grid) + gate = TopKGate(d_model, nexperts, balance_ratio=balance_ratio, k = 1, dgrid=distribution_grid, options = options) # attention self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) @@ -56,7 +53,7 @@ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation else: self.n_local_experts = nexperts - if not merged_expert: + if not self.options.enable_merged_experts(): if expertslist is None: experts = nn.ModuleList() @@ -71,14 +68,14 @@ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation # mixer of experts self.moe = MixtureOfExpertsFunc(gate, experts, is_encoder=True, - fp16_mode=fp16_mode, use_mpi4py=use_mpi4py, distribution_grid=distribution_grid) + distribution_grid=distribution_grid, options=options) - if use_fsdp is True: + if self.options.enable_fsdp_zero_optimization() is True: mp = False - if apex_opt_level == "O2": + if self.options.nvidia_apex_opt_level() == "O2": mp = True fsdp_config = dict(mixed_precision=mp, process_group=distribution_grid.get_moe_group()) - if flatten_parameters is False: + if self.options.fsdp_flatten_parameters() is False: fsdp_config['flatten_parameters'] = False self.moe = fsdp_wrap(self.moe, **fsdp_config) @@ -120,15 +117,10 @@ class LanguageExpertMoEEncoderLayer(nn.Module): nexprts: the number of experts(default=64). balance_ratio: The scaling ratio for the loss_aux gate: the gating function (default=top2k). - fp16_mode : True if FP16 mode is enabled. Default is 'False' expertslist: List of experts of type nn.ModuleList. nlang_experts: number of language experts - use_mpi4py: Whether use mpi4py library or nccl package distribution_grid: DistributionGrid object providing interface to query torch.distributed process groups - merged_expert: whether the MoE experts is MergedFFNExpert - use_fsdp : Use FullyShardedDataParallel to shard the layer - flatten_parameters : Flatten sharded paratmers when use_fsdp is True. Default is 'True' - apex_opt_level : Default 'None' + options: See moe_config.py Examples:: >>> moe_layer = nn.LanguageExpertMoEEncoderLayer(d_model=512, nhead=8, nlang_experts=4) @@ -136,18 +128,18 @@ class LanguageExpertMoEEncoderLayer(nn.Module): >>> out = moe_layer(src) """ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", - nexperts=64, balance_ratio = 0.01, gate=None, fp16_mode=False, - expertslist=None, nlang_experts=4, use_mpi4py=False, distribution_grid=None, merged_expert=True, - use_fsdp=False, flatten_parameters=True, apex_opt_level=None): + nexperts=64, balance_ratio = 0.01, gate=None, + expertslist=None, nlang_experts=4, distribution_grid=None, + options={}): super(LanguageExpertMoEEncoderLayer, self).__init__() self.experts = nn.ModuleDict() for i in range(nlang_experts): le = TransformerMoEEncoderLayer(d_model, nhead, dim_feedforward, dropout, nexperts = nexperts, balance_ratio = balance_ratio, - gate = gate, fp16_mode = fp16_mode, - expertslist = expertslist, use_mpi4py=use_mpi4py, + gate = gate, + expertslist = expertslist, distribution_grid=distribution_grid, - merged_expert=merged_expert, use_fsdp=use_fsdp + options = options ) self.experts[f"seq2seq{i}"] = le @@ -168,14 +160,9 @@ class TransformerMoEDecoderLayer(nn.Module): nexprts: the number of experts(default=64). balance_ratio: The scaling ratio for the loss_aux gate: the gating function (default=None. If none then top2 gating is used). - fp16_mode : True if FP16 mode is enabled. Default is 'False' expertslist: List of experts of type nn.ModuleList. - merged_expert: whether the MoE experts is MergedFFNExpert - use_mpi4py: Whether use mpi4py library or nccl package distribution_grid: DistributionGrid object providing interface to query torch.distributed process groups - use_fsdp : Use FullyShardedDataParallel to shard the layer - flatten_parameters : Flatten sharded paratmers when use_fsdp is True. Default is 'True' - apex_opt_level : Default 'None' + options: See moe_config.py Examples:: >>> moe_layer = nn.TransformerMoEDecoderLayer(d_model=512, nhead=8) @@ -184,10 +171,11 @@ class TransformerMoEDecoderLayer(nn.Module): >>> out = moe_layer(src, memory) """ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", - nexperts=64, balance_ratio = 0.01, gate=None, fp16_mode=False, - expertslist=None, merged_expert = True, use_mpi4py=False, distribution_grid=None, - use_fsdp=False, flatten_parameters=True, apex_opt_level=None): + nexperts=64, balance_ratio = 0.01, gate=None, + expertslist=None, distribution_grid=None, + options = {}): super(TransformerMoEDecoderLayer, self).__init__() + self.options = moe_config(options) # attention self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) @@ -199,7 +187,7 @@ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation else: self.n_local_experts = nexperts - if not merged_expert: + if not self.options.enable_merged_experts(): if expertslist is None: experts = nn.ModuleList() for i in range(self.n_local_experts): @@ -213,18 +201,19 @@ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation # gate if not gate: #default is top2 - gate = TopKGate(d_model, nexperts, balance_ratio=balance_ratio, fp16_mode = fp16_mode, dgrid=distribution_grid) + gate = TopKGate(d_model, nexperts, balance_ratio=balance_ratio, dgrid=distribution_grid, options = options) # mixer of experts self.moe = MixtureOfExpertsFunc(gate, experts, is_encoder=False, - fp16_mode = fp16_mode, use_mpi4py = use_mpi4py, distribution_grid = distribution_grid) + distribution_grid = distribution_grid, + options = options) - if use_fsdp is True: + if self.options.enable_fsdp_zero_optimization() is True: mp = False - if apex_opt_level == "O2": + if self.options.nvidia_apex_opt_level() == "O2": mp = True fsdp_config = dict(mixed_precision=mp, process_group=distribution_grid.get_moe_group()) - if flatten_parameters is False: + if self.options.fsdp_flatten_parameters() is False: fsdp_config['flatten_parameters'] = False self.moe = fsdp_wrap(self.moe, **fsdp_config) @@ -276,15 +265,10 @@ class LanguageExpertMoEDecoderLayer(nn.Module): nexprts: the number of experts(default=64). balance_ratio: The scaling ratio for the loss_aux gate: the gating function (default=top2k). - fp16_mode : True if FP16 mode is enabled. Default is 'False' expertslist: List of experts of type nn.ModuleList nlang_experts: number of language experts - use_mpi4py: Whether use mpi4py library or nccl package distribution_grid: DistributionGrid object providing interface to query torch.distributed process groups - merged_expert: whether the MoE experts is MergedFFNExpert - use_fsdp : Use FullyShardedDataParallel to shard the layer - flatten_parameters : Flatten sharded paratmers when use_fsdp is True. Default is 'True' - apex_opt_level : Default 'None' + options: See moe_config.py Examples:: @@ -293,17 +277,18 @@ class LanguageExpertMoEDecoderLayer(nn.Module): >>> out = moe_layer(src) """ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", - nexperts=64, balance_ratio = 0.01, gate=None, fp16_mode=False, - expertslist=None, nlang_experts=4, use_mpi4py=False, distribution_grid=None, merged_expert=True, - use_fsdp=False, flatten_parameters=True, apex_opt_level=None): + nexperts=64, balance_ratio = 0.01, gate=None, + expertslist=None, nlang_experts=4, distribution_grid=None, + options = {}): super(LanguageExpertMoEDecoderLayer, self).__init__() self.experts = nn.ModuleDict() for i in range(nlang_experts): le = TransformerMoEDecoderLayer(d_model, nhead, dim_feedforward, dropout, nexperts = nexperts, balance_ratio = balance_ratio, - gate = gate, fp16_mode = fp16_mode, - expertslist = expertslist, use_mpi4py=use_mpi4py, - distribution_grid=distribution_grid, merged_expert=merged_expert, use_fsdp=use_fsdp) + gate = gate, + expertslist = expertslist, + distribution_grid=distribution_grid, + options = options) self.experts[f"seq2seq{i}"] = le def forward(self, tgt, memory, lang_id=None): diff --git a/ort_moe/loss_functions.py b/ort_moe/ort_moe/loss_functions.py similarity index 100% rename from ort_moe/loss_functions.py rename to ort_moe/ort_moe/loss_functions.py diff --git a/ort_moe/moe.py b/ort_moe/ort_moe/moe.py similarity index 77% rename from ort_moe/moe.py rename to ort_moe/ort_moe/moe.py index 8f285d8a..7be16616 100644 --- a/ort_moe/moe.py +++ b/ort_moe/ort_moe/moe.py @@ -13,17 +13,18 @@ from typing import Any +import math import torch from torch import Tensor import torch.nn as nn import torch.distributed as dist +import torch.utils.checkpoint as checkpoint from . import experts -import math +from .moe_config import moe_config from .collectives import AllGather, AllToAll, AllReduce -from .custom_ops import einsum +from .custom_ops import om_einsum -def MixtureOfExpertsFunc(gate, experts_, distribution_grid, is_encoder = True, - fp16_mode = False, use_mpi4py = True): +def MixtureOfExpertsFunc(gate, experts_, distribution_grid, is_encoder = True, options = {}): r"""A factory function to call different MoE classes Args: gate: the gating function (required). @@ -31,15 +32,12 @@ def MixtureOfExpertsFunc(gate, experts_, distribution_grid, is_encoder = True, distribution_grid: DistributionGrid object providing interface to query torch.distributed process groups. It is a required keyword to remove the confusion for the usage of dgrid, for single GPU, instantiate an empty dgrid() is_encoder: Whether this MOE is in encoder layer. If false, it is decoder layer - fp16_mode: Whether the input to experts should be in fp16. If this is true, the input to alltoall is cast to fp16. - # NOTE: If this is false, the input to experts may still be casted to fp16 based on AMP setting, but input to alltoall is not casted - use_mpi4py: Use CPU MPI library or GPU NCCL MPI library - # NOTE: if set use_mpi4py to false, it introduces extra dtoh copy (hence GPU sync point). DON'T turn it off unless mpi4py is not available. + options: See moe_config.py """ if distribution_grid.get_expert_slicing_group() is not None: - return MixtureOfExpertsES(gate, experts_, distribution_grid, is_encoder, fp16_mode, use_mpi4py) + return MixtureOfExpertsES(gate, experts_, distribution_grid, is_encoder, options) else: - return MixtureOfExpertsEP(gate, experts_, distribution_grid, is_encoder, fp16_mode, use_mpi4py) + return MixtureOfExpertsEP(gate, experts_, distribution_grid, is_encoder, options) class MixtureOfExperts(nn.Module): r"""MixtureOfExperts module implements mixture of experts. @@ -49,10 +47,7 @@ class MixtureOfExperts(nn.Module): distribution_grid: DistributionGrid object providing interface to query torch.distributed process groups. It is a required keyword to remove the confusion for the usage of dgrid, for single GPU, instantiate an empty dgrid() is_encoder: Whether this MOE is in encoder layer. If false, it is decoder layer - fp16_mode: Whether the input to experts should be in fp16. If this is true, the input to alltoall is cast to fp16. - # NOTE: If this is false, the input to experts may still be casted to fp16 based on AMP setting, but input to alltoall is not casted - use_mpi4py: Use CPU MPI library or GPU NCCL MPI library - # NOTE: if set use_mpi4py to false, it introduces extra dtoh copy (hence GPU sync point). DON'T turn it off unless mpi4py is not available. + options: See moe_config.py """ #max_len dictionary for encoder and decoder. #Need two dicts because the max_len for encoder and decoder may be different @@ -77,32 +72,34 @@ def reset_moe_state(cls): cls.reset_moe_encoder_state() cls.reset_moe_decoder_state() - def __init__(self, gate, experts_, distribution_grid, is_encoder = True, fp16_mode = False, use_mpi4py = True): + def __init__(self, gate, experts_, distribution_grid, is_encoder = True, options = {}): super(MixtureOfExperts, self).__init__() assert distribution_grid != None + self._options = moe_config(options) self.is_mergedFFNExpert = isinstance(experts_, experts.MergedFFNExpert) self.gate = gate self.moe_experts = experts_ self.num_experts = self.moe_experts.local_num_experts if self.is_mergedFFNExpert else len(experts_) self.is_encoder = is_encoder - self.use_mpi4py = use_mpi4py self.expert_rank = distribution_grid.get_expert_rank() self.expert_group = distribution_grid.get_expert_group() self.expert_group_size = distribution_grid.get_expert_world_size() - if self.use_mpi4py: + if self._options.use_mpi_for_imbalanced_input(): self.mpi_expert_group = distribution_grid.get_mpi_group_for_expert_group() #tag the is_moe_param for the experts, later in the application people can extract expert specific parameters if needed for p in self.moe_experts.parameters(): p.is_moe_param = True + if self._options.enable_deepspeed_zero_optimization(): + # Allreduce and group_name attributes are for DS Zero 1,2,3 + p.allreduce = False + p.group_name = f"ep_size_{self.expert_group_size}" for p in self.gate.parameters(): p.is_gate_param = True - self.fp16_mode = fp16_mode - def get_max_len(self, tensor_len, max_len, device): r"""To obtain the maximum length of the tensor's specific dimension, store it in max_len dictionary. This is later used in alltoall to pad all tensors to the maximum length. @@ -117,7 +114,7 @@ def get_max_len(self, tensor_len, max_len, device): max_len["need_update"] = False else: max_len_tensor = tensor_len - if self.use_mpi4py: + if self._options.use_mpi_for_imbalanced_input(): from mpi4py import MPI max_len_tensor = self.mpi_expert_group.allreduce(max_len_tensor, MPI.MAX) max_len["max_len"] = max_len_tensor @@ -133,7 +130,7 @@ def forward(self, input): MixtureOfExpertsEP, instead, call the factory function MixtureOfExpertFunc") class MixtureOfExpertsES(MixtureOfExperts): - def __init__(self, gate, experts_, distribution_grid, is_encoder = True, fp16_mode = False, use_mpi4py = True): + def __init__(self, gate, experts_, distribution_grid, is_encoder = True, options = {}): r"""MixtureOfExpertsES module implements mixture of experts with expert slicing Args: gate: the gating function (required). @@ -141,12 +138,9 @@ def __init__(self, gate, experts_, distribution_grid, is_encoder = True, fp16_mo distribution_grid: DistributionGrid object providing interface to query torch.distributed process groups. It is a required keyword to remove the confusion for the usage of dgrid, for single GPU, instantiate an empty dgrid() is_encoder: Whether this MOE is in encoder layer. If false, it is decoder layer - fp16_mode: Whether the input to experts should be in fp16. If this is true, the input to alltoall is cast to fp16. - # NOTE: If this is false, the input to experts may still be casted to fp16 based on AMP setting, but input to alltoall is not casted - use_mpi4py: Use CPU MPI library or GPU NCCL MPI library - # NOTE: if set use_mpi4py to false, it introduces extra dtoh copy (hence GPU sync point). DON'T turn it off unless mpi4py is not available. + options: See moe_config.py """ - MixtureOfExperts.__init__(self, gate, experts_, distribution_grid, is_encoder, fp16_mode, use_mpi4py) + MixtureOfExperts.__init__(self, gate, experts_, distribution_grid, is_encoder, options) def forward(self, input, **kwargs) -> Tensor: assert len(input.shape) == 3, "input Tensor must have dimensions: (s)equence, (t)oken, (m)odel" @@ -178,7 +172,7 @@ def forward(self, input, **kwargs) -> Tensor: dispatched_input = reshaped_input[dispatch_mask % reshaped_input.shape[0]] #[sum*(E*C), M] - if self.fp16_mode: + if self._options.fp16_mode(): dispatched_input = dispatched_input.to(torch.float16) expert_outputs = [] @@ -212,15 +206,14 @@ def forward(self, input, **kwargs) -> Tensor: rerouted_output = rerouted_output.reshape(self.gate.k, reshaped_input.shape[0], reshaped_input.shape[1]) #reshaped to [K, S, M] #in general, combined_output = sum_i(combined[:, i]*rerouted_output[i,:,:]) - combined_output = einsum("ks,ksm->sm", combine_weights, rerouted_output.to(combine_weights)) + combined_output = om_einsum("ks,ksm->sm", combine_weights, rerouted_output.to(combine_weights)) combined_output = combined_output.reshape(self.expert_group_size, -1, d_model) local_combined_output = torch.narrow(combined_output[self.expert_rank], dim = 0, start = 0, length = c_cpu) return local_combined_output.reshape(input.shape).to(input) class MixtureOfExpertsEP(MixtureOfExperts): - def __init__(self, gate, experts_, is_encoder = True, fp16_mode = False, use_mpi4py = True, - distribution_grid=None): + def __init__(self, gate, experts_, distribution_grid = None, is_encoder = True, options = {}): r"""MixtureOfExpertsEP module implements mixture of experts with expert parallelsim Args: gate: the gating function (required). @@ -228,12 +221,8 @@ def __init__(self, gate, experts_, is_encoder = True, fp16_mode = False, use_mpi distribution_grid: DistributionGrid object providing interface to query torch.distributed process groups. It is a required keyword to remove the confusion for the usage of dgrid, for single GPU, instantiate an empty dgrid() is_encoder: Whether this MOE is in encoder layer. If false, it is decoder layer - fp16_mode: Whether the input to experts should be in fp16. If this is true, the input to alltoall is cast to fp16. - # NOTE: If this is false, the input to experts may still be casted to fp16 based on AMP setting, but input to alltoall is not casted - use_mpi4py: Use CPU MPI library or GPU NCCL MPI library - # NOTE: if set use_mpi4py to false, it introduces extra dtoh copy (hence GPU sync point). DON'T turn it off unless mpi4py is not available. """ - MixtureOfExperts.__init__(self, gate, experts_, is_encoder, fp16_mode, use_mpi4py, distribution_grid) + MixtureOfExperts.__init__(self, gate, experts_, distribution_grid, is_encoder, options) def forward(self, input:Tensor, **kwargs:Any) -> Tensor: assert len(input.shape) == 3, "input Tensor must have dimensions: (s)equence, (t)oken, (m)odel" @@ -242,6 +231,22 @@ def forward(self, input:Tensor, **kwargs:Any) -> Tensor: d_model = input.shape[2] # Reshape into S tokens by dropping sequence dimension. reshaped_input = input.reshape(-1, d_model) + + if kwargs.get('shuffle_group', None) is not None and self._options.enable_base_layer_shuffling(): + shuffle_group = kwargs['shuffle_group'] + else: + if kwargs.get('shuffle_group', None) is not None: + print("WARNING: When Base Layer Shuffling is False, shuffle_group should not be set.") + shuffle_group = self.expert_group + + if self._options.enable_base_layer_shuffling(): + S = reshaped_input.shape[0] + perm = torch.randperm(S) + reshaped_input = reshaped_input[perm] + reshaped_input = reshaped_input.reshape(self.expert_group_size, -1, d_model) # e x s/e x m + reshaped_input = AllToAll.apply(shuffle_group, reshaped_input, reshaped_input.shape[-2]) # e x s/e x m + reshaped_input = reshaped_input.reshape(-1, d_model) # s x m + # Nonpadding masks of the input tensor with original shape [s, t]. # In top1gating, only nonpadding tokens are dispatched to experts. # lid is the layer id for warning message. Default value -1 will not trigger the warning. @@ -251,7 +256,7 @@ def forward(self, input:Tensor, **kwargs:Any) -> Tensor: # index_select() is used to replace advance indexing because dispatch_mask has duplicate indices. # Backward pass is slow when advance indexing contains duplicate indices - if self.fp16_mode and dispatched_input.dtype is not torch.float16: + if self._options.fp16_mode() and dispatched_input.dtype is not torch.float16: dispatched_input = dispatched_input.to(torch.float16) c_cpu = dispatched_input.shape[1] if self.expert_group_size > 1: @@ -261,21 +266,35 @@ def forward(self, input:Tensor, **kwargs:Any) -> Tensor: #add the allreduce to get the max_len self.get_max_len(c_cpu, max_len, dispatched_input.get_device()) assert max_len["max_len"] >= c_cpu - dispatched_input = AllToAll.apply(self.expert_group, dispatched_input, max_len['max_len']) + dispatched_input = AllToAll.apply(shuffle_group, dispatched_input, max_len['max_len']) # Re-shape after all-to-all: ecm -> gecm dispatched_input = dispatched_input.reshape(self.expert_group_size, self.num_experts, -1, d_model) if not self.is_mergedFFNExpert: chunks = dispatched_input.chunk(self.num_experts, dim=1) expert_outputs = [] - for chunk, expert in zip(chunks, self.moe_experts): - expert_outputs += [expert(chunk)] + for chunk, expert, e in zip(chunks, self.moe_experts, range(self.num_experts)): + if self._options.checkpoint_experts() is True: + expert_outputs += [checkpoint.checkpoint(expert, chunk)] + elif self._options.enable_expert_weight_calculation_optimization() is True: + # Only process input chunk for the selected experts. The input chunk is padded upto the capacity factor + z = torch.min(dispatch_mask[e], dim=0) + truncated_chunk = None + if z.values == -1: + truncated_chunk = chunk[...,:z.indices,:] + eo = expert(truncated_chunk) + pad_size = math.ceil(capacity_fp) - truncated_chunk.size(dim=2) + expert_outputs +=[torch.nn.functional.pad(eo, (0,0,pad_size,0), 'constant', 0.0)] + else: + expert_outputs += [expert(chunk)] + else: + expert_outputs += [expert(chunk)] expert_output = torch.cat(expert_outputs, dim=1) else: expert_output = self.moe_experts(dispatched_input) if self.expert_group_size > 1: - expert_output = AllToAll.apply(self.expert_group, expert_output, max_len['max_len']) + expert_output = AllToAll.apply(shuffle_group, expert_output, max_len['max_len']) expert_output = torch.narrow(expert_output, dim = 2, start=0, length = c_cpu) # Re-shape back: gecm -> ecm expert_output = expert_output.reshape(self.expert_group_size * self.num_experts, -1, d_model) @@ -286,6 +305,16 @@ def forward(self, input:Tensor, **kwargs:Any) -> Tensor: rerouted_output = rerouted_output.reshape(self.gate.k, reshaped_input.shape[0], reshaped_input.shape[1]) #reshaped to [K, S, M] #in general, combined_output = sum_i(combined[:, i]*rerouted_output[i,:,:]) - combined_output = einsum("ks,ksm->sm", combine_weights, rerouted_output.to(combine_weights)) + combined_output = om_einsum("ks,ksm->sm", combine_weights, rerouted_output.to(combine_weights)) + + if self._options.enable_base_layer_shuffling(): + combined_output = combined_output.reshape(self.expert_group_size, -1, d_model) # e x s/e x m + combined_output = AllToAll.apply(shuffle_group, combined_output, combined_output.shape[-2]) # e x s/e x m + combined_output = combined_output.reshape(-1, d_model) # s x m + inverse_shuffle_output = torch.empty_like(combined_output) + inverse_shuffle_output[perm] = combined_output + combined_output = inverse_shuffle_output.reshape(input.shape).to(input) + else: + combined_output = combined_output.reshape(input.shape).to(input) return combined_output.reshape(input.shape).to(input) diff --git a/ort_moe/ort_moe/moe_config.py b/ort_moe/ort_moe/moe_config.py new file mode 100644 index 00000000..c00b1384 --- /dev/null +++ b/ort_moe/ort_moe/moe_config.py @@ -0,0 +1,171 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +""" + ORT MoE Configuration Options + ----------------------------- + { + ## Enable Verbose output + "verbose" : False|True, + + ## Options to configure DeepSpeed ZeRO optimizer + "deepspeed_zero_optimization": { + "stage" : 0|1|2|3 + }, + + ## Options to configure PyTorch(FSDP) ZeRO optimizer + "fsdp_zero_optimization": { + "stage" : 0|1|2|3, + "flatten_parameters" : True|False + }, + + ## merged_experts -- Enhance performance by using batched GEMM + "merged_experts": True|False, + + ## Option to select NVIDIA Apex Optimization level + "nvidia_apex_opt_level" : "O0"|"01"|"O2"|"O3", + + ## Option to enable FP16 processing inside an MoE Expert + "fp16_mode" : True|False, + + ## Options to enable support for imbalanced inputs. This feature + ## supports distributed training scenario where each rank could + ## have distinct input size (for example sequence length) that + ## could interfere with communication collectives that expect same + ## input tensor size on each rank. The support is enabled by default. + "imbalanced_input_support" : { + "enabled" : True|False, + + ## Additional all-reduce is used to handle imbalanced inputs. + ## Using MPI for this all-reduce elminates extra DtoH copy so + ## this is enabled by default. + "use_mpi" : True|False + }, + + ## Option to control checkpointing of experts using torch.checkpointing + ## API. Default is False. + "checkpoint_experts" : False|True + + ## Option to enable dynamic capacity feature to dynamically adjust capacity + ## factor during Expert Parallel mode. By default this feature is disabled. + "enable_dynamic_capacity" : False|True + + ## Option to enable Basa Layer Shuffling. Paper link: https://arxiv.org/abs/2103.16716 + "enable_base_layer_shuffling" : False|True + + ## Option to enable Tutel cumsum optimization + "enable_tutel_cumsum" : False|True + + ## Option to enable expert weight calculation optimization + "enable_expert_weight_calculation_optimization" : False|True + } +""" +class moe_config: + def __init__(self, options): + self._options = {} + if options is not None: + self._options = options + + def enable_verbose(self): + r"""enable_verbose + Returns true if verbose mode is on + """ + return self._options.get("verbose", False) + + def enable_deepspeed_zero_optimization(self): + r"""enable_deepspeed_zero_optimization: + Returns true if DeepSpeed ZeRO stage 1, 2 or 3 is selected. + """ + ds = self._options.get("deepspeed_zero_optimization", {}) + stage = ds.get("stage", 0) + if stage > 0: + return True + return False + + def enable_fsdp_zero_optimization(self): + r"""enable_fsdp_zero_optimization: + Returns true if FSDP ZeRO stage 1, 2 or 3 is selected. + """ + ds = self._options.get("fsdp_zero_optimization", {}) + stage = ds.get("stage", 0) + if stage > 0: + return True + return False + + def fsdp_flatten_parameters(self): + r"""fsdp_flatten_parameters: + Returns true if flatten parameters optimization is enabled in the FSDP. + """ + ds = self._options.get("fsdp_zero_optimization", {}) + stage = ds.get("flatten_parameters", True) + if stage > 0: + return True + return False + + def enable_merged_experts(self): + r"""enable_merged_experts: + Returns true if merged_experts optimization is enable. This optimization + used batched gemm to improve performance. + """ + return self._options.get("merged_experts", True) + + def nvidia_apex_opt_level(self): + r"""nvidia_apex_opt_level: + Return selected Nvidia Apex opt level + """ + return self._options.get("nvidia_apex_opt_level", None) + + def fp16_mode(self): + r"""fp16_mode + Return true if fp16_mode is enabled. In this mode Expert computations are + done in fp16 + """ + return self._options.get("fp16_mode", False) + + def support_imbalanced_input(self): + r"""support_imbalanced_input + Return true if support for imbalanced input is enabled + """ + ds = self._options.get("imbalanced_input_support", {}) + return ds.get("enabled", True) + + def use_mpi_for_imbalanced_input(self): + r"""use_mpi_for_imbalanced_input + Return true if use of MPI is enabled to support imbalanced inputs. + """ + ds = self._options.get("imbalanced_input_support", {}) + if ds.get("enabled", True) is True: + return ds.get("use_mpi", True) + return False + + def checkpoint_experts(self): + r"""checkpoint_experts + Return true if experts should be checkpointed using torch API. + """ + return self._options.get("checkpoint_experts", False) + + def enable_dynamic_capacity(self): + r"""enable_dynamic_capacity + Return true if capacity factor should be dynamically adjusted + """ + return self._options.get("enable_dynamic_capacity", False) + + def enable_base_layer_shuffling(self): + r"""enable_base_layer_shuffling + Return true if Base Layer Shuffling is enabled. + """ + return self._options.get("enable_base_layer_shuffling", False) + + def enable_tutel_cumsum(self): + r"""enable_tutel_cumsum + Return true if Tutel cumsum kernel is enabled + """ + return self._options.get("enable_tutel_cumsum", False) + + def enable_expert_weight_calculation_optimization(self): + r"""enable_expert_weight_calculation_optimization + Returntrue if expert weight calculation optimization is enabled + """ + return self._options.get("enable_expert_weight_calculation_optimization", False) diff --git a/ort_moe/topKgate.py b/ort_moe/ort_moe/topKgate.py similarity index 80% rename from ort_moe/topKgate.py rename to ort_moe/ort_moe/topKgate.py index 6a874c9f..37a24e2b 100644 --- a/ort_moe/topKgate.py +++ b/ort_moe/ort_moe/topKgate.py @@ -23,7 +23,9 @@ from .loss_functions import loss_functions from .gate_logs import gate_logs -from .custom_ops import einsum +from .custom_ops import om_einsum, om_cumsum +from torch.cuda.amp import autocast +from .moe_config import moe_config as moe_config uniform_map: Dict[torch.device, Callable] = {} gumbel_map: Dict[torch.device, Callable] = {} @@ -65,13 +67,19 @@ def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor: gumbel_map[device] = gumbel return gumbel(shape) -def top2gating(logits: torch.Tensor, capacity_factor: float, fp16_mode: bool=False, nonpadding: torch.Tensor=None, +def top2gating(logits: torch.Tensor, capacity_factor: float, nonpadding: torch.Tensor=None, logits_gumbel: float=0.0, token_drop_type: str='cut', second_place_loss_ratio: float=0.0, straight_through: bool=False, straight_through_temperature: float=1.0, balance_ratio={'load_balance': 0.01}, gate_log_req: dict={}, lid: int=-1, - tutel_cumsum_sub_one: callable=None)-> Tuple[Tensor, Tensor, Tensor, Tensor, float]: + options={})-> Tuple[Tensor, Tensor, Tensor, Tensor, float]: """Implements Top2Gating on logits.""" - if fp16_mode is True: + moe_options = None + if isinstance(options, moe_config): + moe_options = options + else: + moe_options = moe_config(options) + + if moe_options.fp16_mode() is True: logits = logits.to(torch.float32) gates = F.softmax(logits, dim=1) #dim: [bs, num_experts] if straight_through: @@ -95,10 +103,10 @@ def top2gating(logits: torch.Tensor, capacity_factor: float, fp16_mode: bool=Fal if straight_through: gates_st_nonpadding = gates_st if nonpadding is not None: - mask1 = einsum("s,se->se", nonpadding, mask1) - gates_nonpadding = einsum("s,se->se", nonpadding, gates_nonpadding) + mask1 = om_einsum("s,se->se", nonpadding, mask1) + gates_nonpadding = om_einsum("s,se->se", nonpadding, gates_nonpadding) if straight_through: - gates_st_nonpadding = einsum("s,se->se", nonpadding, gates_st_nonpadding) if straight_through_temperature != 1.0 else gates_nonpadding + gates_st_nonpadding = om_einsum("s,se->se", nonpadding, gates_st_nonpadding) if straight_through_temperature != 1.0 else gates_nonpadding gates_st_equals_gates = True if logits_gumbel > 0: @@ -123,10 +131,10 @@ def top2gating(logits: torch.Tensor, capacity_factor: float, fp16_mode: bool=Fal if straight_through: gates_st_without1_nonpadding = gates_st_without1 if nonpadding is not None: - mask2 = einsum("s,se->se", nonpadding, mask2) - gates_without1_nonpadding = einsum("s,se->se", nonpadding, gates_without1) + mask2 = om_einsum("s,se->se", nonpadding, mask2) + gates_without1_nonpadding = om_einsum("s,se->se", nonpadding, gates_without1) if straight_through: - gates_st_without1_nonpadding = einsum("s,se->se", nonpadding, gates_st_without1) if not gates_st_equals_gates else gates_without1_nonpadding + gates_st_without1_nonpadding = om_einsum("s,se->se", nonpadding, gates_st_without1) if not gates_st_equals_gates else gates_without1_nonpadding # Compute l_aux # the fraction of the router probability allocated for each expert @@ -160,12 +168,9 @@ def top2gating(logits: torch.Tensor, capacity_factor: float, fp16_mode: bool=Fal mask2 *= priority_mask # Compute locations in capacity buffer - if mask1.device.type == 'cpu' or tutel_cumsum_sub_one is None: - locations1 = torch.cumsum(mask1, dim=0) - 1 - locations2 = torch.cumsum(mask2, dim=0) - 1 - else: - locations1 = tutel_cumsum_sub_one(mask1, dim=0) - locations2 = tutel_cumsum_sub_one(mask2, dim=0) + locations1 = om_cumsum(mask1, dim=0, options=moe_options) + locations2 = om_cumsum(mask2, dim=0, options=moe_options) + # Update 2nd's location by accounting for locations of 1st locations2 += torch.sum(mask1, dim=0, keepdim=True) if token_drop_type == 'cut': @@ -218,7 +223,7 @@ def top2gating(logits: torch.Tensor, capacity_factor: float, fp16_mode: bool=Fal dispatch_mask[dispatch_indices2] = indices + num_tokens #indice + num_tokens* kth top, to make sure each element in the dispatch mask is unique dispatch_mask = dispatch_mask[0:-1].reshape(num_experts, -1) #discard the fake tokens - if fp16_mode is True: + if moe_options.fp16_mode() is True: gates12_s.to(torch.float16) return loss, gate_log, gates12_s, dispatch_mask, capacity_fp @@ -240,9 +245,6 @@ class TopKGate(torch.nn.Module): a scalar variable to control the cacacity of each expert in evaluation: capacity = num_tokens / number_of_experts * capacity factor k (int): TopK gating function. Currently only supports k = 1 or k = 2 - fp16_mode (bool): - a boolean variable to control whether fp16_mode is used in moe layer (e.g., by turning on AMP), - if so, we cast the inputs and weights in gating function to fp32 for model quality requirement switch_jitter (float): a small variable to controls the multiplicative jitter to the gate input: x *= uniform(1-epsilon, 1+epsilon) only applicable for top1gating @@ -260,8 +262,6 @@ class TopKGate(torch.nn.Module): whether to use Straight Through method to make the load_balance loss fully differentiable straight_through_temperature (float): temperature of softmax for straight_through - use_tutel_cumsum_sub_one (callable): - whether to use fast_cumsum_sub_one from tutel or not """ def __init__(self, model_dim: int, @@ -272,7 +272,6 @@ def __init__(self, capacity_factor: float=1.0, eval_capacity_factor: float=1.0, k: int=2, - fp16_mode: bool=False, switch_jitter: float=0.0, switch_dropout: float=0.0, logits_gumbel: float=0.0, @@ -281,9 +280,10 @@ def __init__(self, second_place_loss_ratio: float=0.0, straight_through: bool=False, straight_through_temperature: float=1.0, - use_tutel_cumsum_sub_one: bool=True, + options: dict={} ) -> None: super().__init__() + self.options = moe_config(options) self.is_expert_slicing = dgrid.get_expert_slicing_group() is not None self.dgrid = dgrid self.wg = torch.nn.Linear(model_dim, num_experts, bias=False) @@ -292,7 +292,6 @@ def __init__(self, self.capacity_factor = capacity_factor self.eval_capacity_factor = eval_capacity_factor self.k = k - self.fp16_mode = fp16_mode self.switch_jitter = switch_jitter self.switch_dropout = switch_dropout self.logits_gumbel = logits_gumbel @@ -307,15 +306,9 @@ def __init__(self, self.second_place_loss_ratio = second_place_loss_ratio self.straight_through = straight_through self.straight_through_temperature = straight_through_temperature + assert k == 1 or (k > 1 and not self.options.enable_dynamic_capacity()), "dynamic_capacity is only supported for k = 1" self.loss = None self.gate_log = None - if use_tutel_cumsum_sub_one: - # from https://github.com/microsoft/tutel (commit e51df1ca64be59eae3691bc1c64b20a201de1009) - # Please 'run pip install -r ./ requirements.txt' to install tutel - from tutel.jit_kernels.gating import fast_cumsum_sub_one - self.tutel_cumsum_sub_one = fast_cumsum_sub_one - else: - self.tutel_cumsum_sub_one = None def forward(self, input: torch.Tensor, nonpadding: torch.Tensor = None, lid: int=-1) -> Tuple[Tensor, Tensor, float]: # type: ignore """ @@ -330,49 +323,55 @@ def forward(self, input: torch.Tensor, nonpadding: torch.Tensor = None, lid: int """ assert self.k ==1 or self.k == 2, "k can only be 1 or 2" - if self.fp16_mode is True: - input = input.to(torch.float32) - self.wg = self.wg.to(torch.float32) - if self.training and self.k == 1: - if self.switch_jitter > 0: - input = multiplicative_jitter(input, device=input.device, epsilon=self.switch_jitter) - elif self.switch_dropout > 0: - input = F.dropout(input, p=self.switch_dropout, training=self.training) - logits = self.wg(input) #dim: [bxs, num_experts] - if self.k == 1: - self.loss, self.gate_log, gates1_s, dispatch_mask, retval = top1gating( - logits, - self.capacity_factor if self.training else self.eval_capacity_factor, - is_expert_slicing=self.is_expert_slicing, - fp16_mode=self.fp16_mode, - nonpadding=nonpadding, - logits_gumbel=self.logits_gumbel if self.training else 0, - token_drop_type=self.token_drop_type, - straight_through=self.straight_through, - straight_through_temperature=self.straight_through_temperature, - balance_ratio=self.balance_ratio, - gate_log_req=self.gate_log_req, - lid=lid, - tutel_cumsum_sub_one=self.tutel_cumsum_sub_one, - ) - return gates1_s, dispatch_mask, retval - else: - self.loss, self.gate_log, gates12_s, dispatch_mask, capacity_fp = top2gating( - logits, - self.capacity_factor if self.training else self.eval_capacity_factor, - fp16_mode=self.fp16_mode, - nonpadding=nonpadding, - logits_gumbel=self.logits_gumbel if self.training else 0, - token_drop_type=self.token_drop_type, - second_place_loss_ratio=self.second_place_loss_ratio, - straight_through=self.straight_through, - straight_through_temperature=self.straight_through_temperature, - balance_ratio=self.balance_ratio, - gate_log_req=self.gate_log_req, - lid=lid, - tutel_cumsum_sub_one=self.tutel_cumsum_sub_one, - ) - return gates12_s, dispatch_mask, capacity_fp + """ + In topokgate, we cast several tensors to float32 to ensure better quality. + However, if a module' forward (which contain topkgate) is wrapped by autocast(), + tensors may be casted to other types automatically for some computation. + For the quality reason, we disable autocast() for the topkgate + """ + with autocast(enabled=False): + if self.options.fp16_mode() is True: + input = input.to(torch.float32) + self.wg = self.wg.to(torch.float32) + if self.training and self.k == 1: + if self.switch_jitter > 0: + input = multiplicative_jitter(input, device=input.device, epsilon=self.switch_jitter) + elif self.switch_dropout > 0: + input = F.dropout(input, p=self.switch_dropout, training=self.training) + logits = self.wg(input) #dim: [bxs, num_experts] + if self.k == 1: + self.loss, self.gate_log, gates1_s, dispatch_mask, retval = top1gating( + logits, + self.capacity_factor if self.training else self.eval_capacity_factor, + is_expert_slicing=self.is_expert_slicing, + nonpadding=nonpadding, + logits_gumbel=self.logits_gumbel if self.training else 0, + token_drop_type=self.token_drop_type, + straight_through=self.straight_through, + straight_through_temperature=self.straight_through_temperature, + balance_ratio=self.balance_ratio, + gate_log_req=self.gate_log_req, + lid=lid, + options=self.options, + dgrid = self.dgrid + ) + return gates1_s, dispatch_mask, retval + else: + self.loss, self.gate_log, gates12_s, dispatch_mask, capacity_fp = top2gating( + logits, + self.capacity_factor if self.training else self.eval_capacity_factor, + nonpadding=nonpadding, + logits_gumbel=self.logits_gumbel if self.training else 0, + token_drop_type=self.token_drop_type, + second_place_loss_ratio=self.second_place_loss_ratio, + straight_through=self.straight_through, + straight_through_temperature=self.straight_through_temperature, + balance_ratio=self.balance_ratio, + gate_log_req=self.gate_log_req, + lid=lid, + options=self.options + ) + return gates12_s, dispatch_mask, capacity_fp def set_gate_metrics(self, balance_ratio=None, gate_log_req=None): if balance_ratio is not None: @@ -387,11 +386,36 @@ def fast_one_hot(indices: torch.Tensor, num_classes : int): ret = ret.scatter(-1, indices.unsqueeze(-1), 1) return ret -def top1gating(logits: torch.Tensor, capacity_factor: float, is_expert_slicing=False, fp16_mode: bool=False, nonpadding: torch.Tensor=None, +def update_mask_for_relocated_experts(dgrid, mask, options): + ''' + If experts are relocated then update the mask. Only works for top1 for now. + ''' + if dgrid is None: + return mask + + # mask dim is: [bs, num_experts] + for b in mask: + expert_index = torch.max(b, dim=0).indices + old_index = expert_index.item() + new_index = dgrid.get_get_relocation_id(old_index) + if new_index is not None: + # Set the mask at new index and reset the mask for the old index + if options.enable_verbose(): + print(f'Mask updated for Expert {old_index} relocated to {new_index}') + b[new_index] = 1 + b[old_index] = 0 + return mask + +def top1gating(logits: torch.Tensor, capacity_factor: float, is_expert_slicing=False, nonpadding: torch.Tensor=None, logits_gumbel: float=0.0, token_drop_type: str='cut', straight_through: bool=False, straight_through_temperature: float=1.0, balance_ratio={'load_balance': 0.01}, gate_log_req: dict={}, lid: int=-1, - tutel_cumsum_sub_one: callable=None)-> Tuple[Tensor, Tensor, Tensor, Tensor, float]: - if fp16_mode is True: + options={}, dgrid = None)-> Tuple[Tensor, Tensor, Tensor, Tensor, float]: + moe_options = None + if isinstance(options, moe_config): + moe_options = options + else: + moe_options = moe_config(options) + if moe_options.fp16_mode() is True: logits = logits.to(torch.float32) if logits_gumbel > 0: @@ -419,16 +443,20 @@ def top1gating(logits: torch.Tensor, capacity_factor: float, is_expert_slicing=F #create mask for 1st's expert per token indices_s = torch.argmax(logits_w_noise if logits_gumbel > 0 else gates, dim = 1) #dim: [bs], the index of the expert with highest softmax value mask1 = fast_one_hot(indices_s, num_classes = num_experts) #dim: [bs, num_experts]. 1 for the expert with highest softmax value + + # If experts are relocated then update the mask + mask1 = update_mask_for_relocated_experts(dgrid, mask1, moe_options) + if lid >= 0 and (torch.sum(mask1.float(), dim=0).int() == 0).any(): print(f"WARNING: top1gating: expert got too few examples in layer {lid}: {torch.sum(mask1.float(), dim=0).int().tolist()}") # mask using nonpadding (https://github.com/tensorflow/mesh/blob/a54f5cf75ef44d8a97190b3e5aaec176c138b3c0/mesh_tensorflow/transformer/moe.py#L1224) gates_nonpadding = gates if nonpadding is not None: - mask1 = einsum("s,se->se", nonpadding, mask1) - gates_nonpadding = einsum("s,se->se", nonpadding, gates_nonpadding) + mask1 = om_einsum("s,se->se", nonpadding, mask1) + gates_nonpadding = om_einsum("s,se->se", nonpadding, gates_nonpadding) if straight_through: - gates_st = einsum("s,se->se", nonpadding, gates_st) if not gates_st_equals_gates else gates_nonpadding + gates_st = om_einsum("s,se->se", nonpadding, gates_st) if not gates_st_equals_gates else gates_nonpadding #TODO: Need to add a unit test if is_expert_slicing: indices_s = torch.where(nonpadding > 0, indices_s, num_experts) # Assign token_i to "expert num_experts" (fake expert) when nonpadding[i] == 0 @@ -440,6 +468,11 @@ def top1gating(logits: torch.Tensor, capacity_factor: float, is_expert_slicing=F raw_mask1 = mask1.clone().detach() + if moe_options.enable_dynamic_capacity() and not is_expert_slicing: + expert_usage = torch.max(torch.sum(mask1, dim=0)) + capacity = min((expert_usage // 32+1) * 32, capacity) # padding for | 32 + capacity_fp = float(capacity) + if not is_expert_slicing and token_drop_type in ['random', 'routing_weight']: if token_drop_type == 'random': # randomly select masked tokens to fit in capacity buffer @@ -462,7 +495,7 @@ def top1gating(logits: torch.Tensor, capacity_factor: float, is_expert_slicing=F #TODO: 1. Need more tests for this functionality. Leave it as-is now to unblock CLIP training # 2. Need to add a unit test - if ((not is_expert_slicing) and token_drop_type != 'cut') or nonpadding is not None: # Could be remove in the next version + if nonpadding is not None: # Could be remove in the next version discard_tmp = num_tokens - expert_count.sum() count_discard = torch.tensor([discard_tmp], device=logits.device) expert_count = torch.cat((expert_count, count_discard)) @@ -482,10 +515,7 @@ def top1gating(logits: torch.Tensor, capacity_factor: float, is_expert_slicing=F indices_in_expert = torch.min(indices_repeat, dim=1).values else: #Compute locations in capacity buffer - if mask1.device.type == 'cpu' or tutel_cumsum_sub_one is None: - locations1 = torch.cumsum(mask1, dim=0) - 1 - else: - locations1 = tutel_cumsum_sub_one(mask1, dim=0) + locations1 = om_cumsum(mask1, dim=0, options=moe_options) if not is_expert_slicing and token_drop_type == 'cut': #Remove locations outside capacity from mask @@ -509,7 +539,7 @@ def top1gating(logits: torch.Tensor, capacity_factor: float, is_expert_slicing=F #Normalize gate probabilities/ep mask1_float = mask1.float() gates_nonpadding = gates_nonpadding.float() - gates1_s = einsum("se,se->s", gates_nonpadding, mask1_float).reshape(1, -1) #[topK, S] + gates1_s = om_einsum("se,se->s", gates_nonpadding, mask1_float).reshape(1, -1) #[topK, S] loss, gate_log = compute_gate_loss(balance_ratio, gate_log_req, logits=logits, gates=gates_nonpadding, gates_max=gates1_s, @@ -517,7 +547,7 @@ def top1gating(logits: torch.Tensor, capacity_factor: float, is_expert_slicing=F router_prob_fraction=me, token_dispatch_fraction=ce, nonpadding=nonpadding, num_experts=num_experts, num_nonpadding=num_nonpadding) - if fp16_mode is True: + if moe_options.fp16_mode() is True: gates1_s = gates1_s.to(torch.float16) if not is_expert_slicing: # Loss: loss_aux diff --git a/ort_moe/utils.py b/ort_moe/ort_moe/utils.py similarity index 98% rename from ort_moe/utils.py rename to ort_moe/ort_moe/utils.py index 6ea810bd..dfdaef8f 100644 --- a/ort_moe/utils.py +++ b/ort_moe/ort_moe/utils.py @@ -5,6 +5,7 @@ import copy import random +import numpy as np import torch import torch.distributed as dist from torch._C import default_generator @@ -114,6 +115,8 @@ def get_state_dict_partitions_for_saving(model, dgrid, total_num_experts): # global rank 0 saves skeleton (non-expert) parameters if dist.get_rank() == 0: partitions["skeleton"] = get_non_expert_parameters_state_dict(model) + if bool(dgrid.get_expert_relocation_map()): + partitions["skeleton"]["expert_relocation_map"] = dgrid.get_expert_relocation_map() # every node with expert parallel replica rank 0 saves its experts if dgrid.get_expert_parallel_replica_rank() == 0: @@ -464,12 +467,13 @@ class TemporaryRngState: the PyTorch RNG state for CPU and GPU (if cuda is initialized). It does not currently reset the numpy RNG state. ''' - def __init__(self, add_rank_to_seed=False): - self.seed = random.randrange(2**32) + def __init__(self, add_rank_to_seed=False, seed=None): + self.seed = seed if seed is not None else random.Random().randrange(2**32) if add_rank_to_seed: assert dist.is_initialized() self.seed += dist.get_rank() self.python_rng_state = random.getstate() + self.numpy_rng_state = np.random.get_state() self.torch_rng_state = torch.get_rng_state() if torch.cuda.is_initialized(): self.torch_rng_state_cuda = torch.cuda.get_rng_state() @@ -483,9 +487,11 @@ def __enter__(self): default_generator.manual_seed(self.seed + 1) if torch.cuda.is_initialized(): torch.cuda.manual_seed(self.seed + 2) # only set seed of default cuda device + np.random.seed(self.seed + 3) def __exit__(self, exc_type, exc_value, exc_traceback): random.setstate(self.python_rng_state) + np.random.set_state(self.numpy_rng_state) torch.set_rng_state(self.torch_rng_state) if torch.cuda.is_initialized(): torch.cuda.set_rng_state(self.torch_rng_state_cuda) diff --git a/ort_moe/setup.py b/ort_moe/setup.py new file mode 100644 index 00000000..4fa35201 --- /dev/null +++ b/ort_moe/setup.py @@ -0,0 +1,42 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import setuptools +import os +import sys +import datetime + +nightly_build = False +package_name = 'ort_moe' + +if '--nightly_build' in sys.argv: + package_name = 'ort_moe-nightly' + nightly_build = True + sys.argv.remove('--nightly_build') + +version_number = '' +with open('VERSION_NUMBER') as f: + version_number = f.readline().strip() + +if nightly_build: + #https://docs.microsoft.com/en-us/azure/devops/pipelines/build/variables + build_suffix = os.environ.get('BUILD_BUILDNUMBER') + if build_suffix is None: + #The following line is only for local testing + build_suffix = str(datetime.datetime.now().strftime("%Y%m%d%H%M")) + else: + build_suffix = build_suffix.replace('.','') + version_number = version_number + ".dev" + build_suffix +else: + build_suffix = str(datetime.datetime.now().strftime("%Y%m%d%H%M")) + version_number = version_number + ".dev" + build_suffix + + +if __name__ == '__main__': + setuptools.setup( + name=package_name, + version=version_number, + packages=['ort_moe'], + ) + diff --git a/ort_moe/tests/README.md b/ort_moe/tests/README.md new file mode 100644 index 00000000..f547fbba --- /dev/null +++ b/ort_moe/tests/README.md @@ -0,0 +1,23 @@ + +Ideally, prefer to run the test on the cluster +Run both of the following: +cd ../experiments/cluster/ + +python baseline_unittests_experiment.py #This runs UT of different model setting +python moe_functionaltests_experiment.py #This runs the UTs for MOE module and gating functions + + +Less ideally To run all the tests locally: +./run_all.sh + +This requires the following packages to be installed: +RUN pip install pytest +RUN pip install mpi4py +RUN pip install pytest-mpi + + +IF you want to get the coverage information of the UTs, you need the coverage package to be installed: +RUN pip install coverage + +After the install, if you got WARNING: "The scripts coverage, coverage-3.7 and coverage3 are installed in /home//.local/bin which is not on PATH". Consider adding this directory to PATH by: +RUN export PATH=$PATH:/home//.local/bin diff --git a/ort_moe/tests/__init__.py b/ort_moe/tests/__init__.py index 039f118d..846cb824 100644 --- a/ort_moe/tests/__init__.py +++ b/ort_moe/tests/__init__.py @@ -6,3 +6,5 @@ import pathlib import sys sys.path.insert(0, os.path.join(pathlib.Path(__file__).parent.absolute())) +sys.path.insert(0, os.path.join(pathlib.Path(__file__).parent.parent.absolute(), 'moe_module')) +sys.path.insert(0, os.path.join(pathlib.Path(__file__).parent.parent.absolute())) diff --git a/ort_moe/tests/nccl_output_processor.py b/ort_moe/tests/nccl_output_processor.py new file mode 100644 index 00000000..c9582602 --- /dev/null +++ b/ort_moe/tests/nccl_output_processor.py @@ -0,0 +1,31 @@ +import os,re,sys + +if len(sys.argv) != 3: + print("usage: python nccl_output_processor.py nccl_output sccl_output") +file1 = open(sys.argv[1],"r") +file2 = open(sys.argv[2],"r") +nccl_res = [] +for l in file1.readlines(): + g = re.match("\s+(\d+)\s+\d+\s+float\s+([0-9]*\.?[0-9]+)\s+[0-9]*\.?[0-9]+\s+[0-9]*\.?[0-9]+\s+(\d+e\+\d+)", l) + if g is not None: + nccl_res.append((float(g.group(1)), float(g.group(2)), float(g.group(3)))) +# the sccl output +sccl_res = [] +for l in file2.readlines(): + g = re.match("\s+(\d+)\s+\d+\s+float\s+([0-9]*\.?[0-9]+)\s+[0-9]*\.?[0-9]+\s+[0-9]*\.?[0-9]+\s+(\d+e\+\d+)", l) + if g is not None: + sccl_res.append((float(g.group(1)), float(g.group(2)), float(g.group(3)))) +counter = 0 +for a,b in zip(nccl_res, sccl_res): + if a[0] != b[0]: + print("Sizes didn't match in sccl/nccl comparison") + exit(-1) + # Make sure SCCL is not more than 10% slower than NCCL. Always skip the first size as it is unstable. + if b[1] > a[1]*1.05 and counter > 0: + print(f"Performance of sccl slowed down for size {a[0]}: nccl {a[1]} vs sccl {b[1]}") + exit(-1) + if a[2] > 0 or b[2] > 0: + print(f"Correctness did not pass for size {a[0]}: nccl {a[2]}, sccl {b[2]}") + exit(-1) + counter += 1 +print("All checks passed!") diff --git a/ort_moe/tests/pytest.ini b/ort_moe/tests/pytest.ini new file mode 100644 index 00000000..f2bb5354 --- /dev/null +++ b/ort_moe/tests/pytest.ini @@ -0,0 +1,4 @@ +[pytest] +markers = + with_ort: marks tests that can be run with ORTModule. + with_ort1: Fails when runs together with "with_ort" due to segmentation fault. Succeeds if runs as a separate CI pipeline stage. \ No newline at end of file diff --git a/ort_moe/tests/run_all.sh b/ort_moe/tests/run_all.sh new file mode 100755 index 00000000..8fb86d33 --- /dev/null +++ b/ort_moe/tests/run_all.sh @@ -0,0 +1,18 @@ +#! /bin/bash +set -e +# Any subsequent(*) commands which fail will cause the shell script to exit immediately + +echo "Running all tests" +echo "And print test coverage of the code files under folder of ../moe_module/ . If you don't need the coverage imformation, please replace \"coverage run --parallel-mode --source=../moe_module/\" with \"python\" to disable the coverage information collecting." + +mpirun -n 4 --allow-run-as-root coverage run --parallel-mode --source=../moe_module/ -m pytest --with-mpi test_top2gating.py +mpirun -n 4 --allow-run-as-root coverage run --parallel-mode --source=../moe_module/ -m pytest --with-mpi test_moe.py +mpirun -n 4 --allow-run-as-root coverage run --parallel-mode --source=../moe_module/ -m pytest --with-mpi test_grid.py + +echo "Combine the coverage tool output and print the report." +coverage combine +coverage report -m > coverage_log +python test_coverage.py + +python test_uni_image.py + diff --git a/ort_moe/tests/test_compression.py b/ort_moe/tests/test_compression.py new file mode 100644 index 00000000..79f6775b --- /dev/null +++ b/ort_moe/tests/test_compression.py @@ -0,0 +1,49 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from mpi4py import MPI +import torch.distributed as dist +import torch +import pytest +import os + +from ort_moe.collectives import compressed_all_to_all + +assert torch.cuda.is_available() + +BACKEND = dist.Backend.NCCL +os.environ["MASTER_ADDR"] = "localhost" +os.environ["MASTER_PORT"] = "29501" # torch 1.5 compatibility +comm = MPI.COMM_WORLD +rank = comm.Get_rank() +size = comm.Get_size() + +if not dist.is_initialized(): + dist.init_process_group(backend=BACKEND, rank=rank, world_size=size) + +def compare_tensor_lists(l1, l2): + for a, b in zip(l1, l2): + return torch.equal(a, b) + +def test_fake(device=rank): + return + +def disable_test_alltoall_compression(device = rank): + # generate input + a2a_input = torch.empty(4*2*8, dtype=torch.float16, device=torch.device(rank)) + a2a_input = torch.reshape(a2a_input,[4,2,8]) + a2a_input[0] = a2a_input[0].fill_(rank*2) + a2a_input[1] = a2a_input[1].fill_(rank*3) + a2a_input[2] = a2a_input[2].fill_(rank*4) + a2a_input[3] = a2a_input[3].fill_(rank*5) + + # original all2all + orig_a2a_output = torch.empty(a2a_input.size(), dtype=a2a_input.dtype, device=a2a_input.device) + torch.distributed.all_to_all_single(orig_a2a_output, a2a_input) + + # manual all2all + mc_a2a_output = torch.empty(a2a_input.size(), dtype=a2a_input.dtype, device=a2a_input.device) + compressed_all_to_all(mc_a2a_output, a2a_input) + + assert compare_tensor_lists([orig_a2a_output], [mc_a2a_output]) diff --git a/ort_moe/tests/test_coverage.py b/ort_moe/tests/test_coverage.py new file mode 100644 index 00000000..10018aaa --- /dev/null +++ b/ort_moe/tests/test_coverage.py @@ -0,0 +1,31 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import os + +expected_coverage = dict() +expected_coverage["collectives.py"] = 46 #The reported missing are only backward autograd functions, they are actually called but not recognized by the coverage. +# Compression tests are disabled currently. +expected_coverage["experts.py"] = 96 #Cannot have a reliable way to test the dropout +expected_coverage["gate_logs.py"] = 100 +expected_coverage["grids.py"] = 93 #The reported missing are for mpi4py and unimportant unexpected errors. +expected_coverage["loss_functions.py"] = 100 +#TODO: Recover the coverage. Tempoparily drop the test coverage from 96-93, since the "nonpadding" is not covered. It is not used in CLIP-H +expected_coverage["moe.py"] = 93 #Tested the uncovered lines are actually covered, the mpi4py part (line 87-89) cannot be tested due to CI does not work with dist in this case... +expected_coverage["topKgate.py"] = 88 #Work around, after modification for the expert_slicing and adding loss tests, it should be set to 94. +expected_coverage["utils.py"] = 58 + +succeeded = True +f = open("coverage_log", 'r') +for l in f: + l_list = l.split() + file_name = l_list[0].split('/')[-1] + if file_name in expected_coverage: + coverage = l_list[3].split("%")[0] + if int(coverage) < expected_coverage[file_name]: + print(f"{file_name} expected coverage is {expected_coverage[file_name]}, " + f"actual coverage is {coverage}") + succeeded = False +f.close() +assert succeeded diff --git a/ort_moe/tests/test_grid.py b/ort_moe/tests/test_grid.py index 4fbf5992..7f415361 100644 --- a/ort_moe/tests/test_grid.py +++ b/ort_moe/tests/test_grid.py @@ -133,6 +133,41 @@ def test_expert_mapping_3(device = rank): elif rank == 3: assert grank == 46 and replica_id == 1 +def test_expert_mapping_4(device = rank): + if dist.get_world_size() != 4: + return + + dgrid = DistributionGrid(expert_parallel_group_size = 4) + + grank = dgrid.map_expert_id_local_to_global(64, 14) + if rank == 0: + assert grank == 14 + elif rank == 1: + assert grank == 30 + elif rank == 2: + assert grank == 46 + elif rank == 3: + assert grank == 62 + + dgrid.exchange_expert_location(30, 30) + dgrid.exchange_expert_location(30, 46) + + nrank, nid = dgrid.map_expert_id_global_to_local(64, 30) + assert nrank == 2 and nid == 14 + + grank2 = dgrid.map_expert_id_local_to_global(64, 14) + if rank == 0: + assert grank2 == 14 + elif rank == 1: + assert grank2 == 46 + elif rank == 2: + assert grank2 == 30 + elif rank == 3: + assert grank2 == 62 + + dgrid.set_expert_relocation_map(dgrid.get_expert_relocation_map()) + dgrid.remove_expert_relocation(30) + def test_rank_schedule(device = rank): if dist.get_world_size() != 4: return @@ -152,7 +187,7 @@ def test_rank_schedule(device = rank): else: assert dgrid.get_expert_parallel_replica_rank() == 1 - options = {"rank_schedule" : "row_major" } + options = {"rank_schedule": "row_major"} dgridr = DistributionGrid(expert_parallel_group_size = 2, expert_parallel_replica_group_size = 2, options = options) # EP 0 -> [0,1] @@ -168,10 +203,10 @@ def test_rank_schedule(device = rank): else: assert dgridr.get_expert_parallel_replica_rank() == 1 - options ={ "rank_schedule" : "column_major" } + options = {"rank_schedule": "column_major"} dgridc = DistributionGrid(expert_parallel_group_size = 2, expert_parallel_replica_group_size = 2, options = options) - # EP 0 -> [0,2 + # EP 0 -> [0,2] # EP 1 -> [1,3] # ER 0 -> [0,1] # ER 1 -> [2,3] @@ -183,3 +218,19 @@ def test_rank_schedule(device = rank): assert dgridc.get_expert_parallel_replica_rank() == 0 else: assert dgridc.get_expert_parallel_replica_rank() == 1 + + options = {"rank_schedule": "row_major", "is_replica_in_same_node": True} + dgridr = DistributionGrid(expert_parallel_group_size = 2, expert_parallel_replica_group_size = 2, + options = options) + # EP 0 -> [0,2] + # EP 1 -> [1,3] + # ER 0 -> [0,1] + # ER 1 -> [2,3] + if rank < 2: + assert dgridr.get_expert_replica_src_rank() == 0 + else: + assert dgridr.get_expert_replica_src_rank() == 2 + if (rank % 2) == 0: + assert dgridr.get_expert_parallel_replica_rank() == 0 + else: + assert dgridr.get_expert_parallel_replica_rank() == 1 diff --git a/ort_moe/tests/test_moe.py b/ort_moe/tests/test_moe.py index 951d503d..c2eb93f4 100644 --- a/ort_moe/tests/test_moe.py +++ b/ort_moe/tests/test_moe.py @@ -15,7 +15,7 @@ from torch import nn -from ort_moe.custom_ops import einsum +from ort_moe.custom_ops import om_einsum from ort_moe.experts import FFNExpert, MergedFFNExpert from ort_moe.moe import MixtureOfExpertsFunc, MixtureOfExperts, AllToAll, MixtureOfExpertsES from ort_moe.utils import broadcast_parameters, moe_module_all_reduce, apex_amp_scale_check_overflow_override, is_moe_parameter @@ -32,10 +32,8 @@ from ort_moe.grids import DistributionGrid from ort_moe.loss_functions import loss_functions from ort_moe.gate_logs import gate_logs -from . import topKgate_old, moe_old from apex import amp as apex_amp -use_tutel = True # if USE_ORT env is set then run tests with ORTModule use_ort = os.getenv("USE_ORT", None) if use_ort: @@ -48,7 +46,6 @@ enable_custom_autograd_support() debug_options = DebugOptions(save_onnx=True, onnx_prefix='moe', log_level=LogLevel.INFO) ortmodule.ONNX_OPSET_VERSION=13 - use_tutel = False assert torch.cuda.is_available() @@ -99,7 +96,7 @@ def test_MixtureOfExperts(): num_local_experts = 4 num_experts = dist.get_world_size(dist.group.WORLD) * num_local_experts - gating_fn = TopKGate(model_dim, num_experts, k=1, dgrid=dgrid, use_tutel_cumsum_sub_one=use_tutel) # Top1 + gating_fn = TopKGate(model_dim, num_experts, k=1, dgrid=dgrid) # Top1 experts = torch.nn.ModuleList() for i in range(num_local_experts): experts.append(FFNExpert(model_dim, ff_dim, dgrid=dgrid)) @@ -249,7 +246,7 @@ def test_MixtureOfExperts_single_forward(device=rank): num_experts = dist.get_world_size(dist.group.WORLD) * num_local_experts input = torch.rand(4, 16, model_dim).to(device) dgrid = DistributionGrid(expert_parallel_group_size = dist.get_world_size()) - gate = TopKGate(model_dim, num_experts, k=2, dgrid=dgrid, use_tutel_cumsum_sub_one=use_tutel) + gate = TopKGate(model_dim, num_experts, k=2, dgrid=dgrid) experts = torch.nn.ModuleList() for i in range(num_local_experts): expert = torch.nn.Linear(model_dim, model_dim, bias=False) @@ -278,7 +275,7 @@ def test_MixtureOfExperts_multi_forward(device=rank): num_experts = dist.get_world_size(dist.group.WORLD) * num_local_experts input = torch.rand(4 * num_local_experts, 16, model_dim).to(device) dgrid = DistributionGrid(expert_parallel_group_size = dist.get_world_size()) - gate = TopKGate(model_dim, num_experts, k=2, dgrid=dgrid, use_tutel_cumsum_sub_one=use_tutel) + gate = TopKGate(model_dim, num_experts, k=2, dgrid=dgrid) experts = torch.nn.ModuleList() for i in range(num_local_experts): expert = torch.nn.Linear(model_dim, model_dim, bias=False) @@ -316,7 +313,7 @@ def test_expert_moe_encoder(): dim_feedforward = 256 num_local_experts = 4 nexperts = dist.get_world_size(dist.group.WORLD) * num_local_experts - gate = TopKGate(d_model, nexperts, k=1, dgrid=dgrid, use_tutel_cumsum_sub_one=use_tutel) # top-1 + gate = TopKGate(d_model, nexperts, k=1, dgrid=dgrid) # top-1 encoder = TransformerMoEEncoderLayer(d_model, nhead, dim_feedforward=dim_feedforward, nexperts=nexperts, gate=gate, distribution_grid = dgrid) if use_ort: @@ -342,7 +339,7 @@ def test_expert_moe_decoder(): num_local_experts = 4 nexperts = dist.get_world_size(dist.group.WORLD) * num_local_experts dgrid = DistributionGrid(expert_parallel_group_size = dist.get_world_size()) - gate = TopKGate(d_model, nexperts, k=1, dgrid=dgrid, use_tutel_cumsum_sub_one=use_tutel) # top-1 + gate = TopKGate(d_model, nexperts, k=1, dgrid=dgrid) # top-1 decoder = TransformerMoEDecoderLayer(d_model, nhead, dim_feedforward=dim_feedforward, nexperts=nexperts, gate=gate, distribution_grid = dgrid) if use_ort: @@ -427,6 +424,39 @@ def test_forward_routing_multi(device=rank): assert torch.allclose(input[:, i] * (expert + 1), output[:, i]) +@pytest.mark.mpi +def test_forward_routing_shuffle(device=rank): + dgrid = DistributionGrid(expert_parallel_group_size = dist.get_world_size()) + model_dim = 6 + num_local_experts = 4 + num_rank = dist.get_world_size(dist.group.WORLD) + num_experts = num_rank * num_local_experts + input = torch.ones(4 * num_local_experts, 32, model_dim).to(device) + gate = RoundRobinGate(model_dim, num_experts, dgrid) + experts = torch.nn.ModuleList() + for i in range(num_local_experts): + expert = torch.nn.Linear(model_dim, model_dim, bias=False) + # Use scaling matrix (each rank has a different scale) + scale = dist.get_rank() * num_local_experts + i + 1 + expert.weight = torch.nn.Parameter(torch.eye(model_dim) * scale) + experts.append(expert) + options = {"enable_base_layer_shuffling" : True} + moe = MixtureOfExpertsFunc(gate, experts, distribution_grid=dgrid, options=options).to(device) + if use_ort: + moe = ORTModule(moe) + rand_fixed_idx = num_rank // 2 - 1 + rank_lists = [list(range(0, rand_fixed_idx)), list(range(rand_fixed_idx, num_rank))] + for r in rank_lists: + tmp = dist.new_group(r) + if rank in r: + pg = tmp + output = moe(input, shuffle_group=pg) + if use_ort: + loss = output[0].sum() + loss.backward() + assert output.shape == input.shape + + @pytest.mark.mpi @pytest.mark.with_ort def test_backward(device=rank): @@ -435,7 +465,7 @@ def test_backward(device=rank): model_dim = 8 num_experts = dist.get_world_size(dist.group.WORLD) input = torch.randn(4, 16, model_dim).to(device) - gate = TopKGate(model_dim, num_experts, k=2, dgrid=dgrid, use_tutel_cumsum_sub_one=use_tutel) + gate = TopKGate(model_dim, num_experts, k=2, dgrid=dgrid) experts = torch.nn.ModuleList() expert = torch.nn.Linear(model_dim, model_dim, bias=False) experts.append(expert) @@ -521,7 +551,7 @@ def test_moe_reset_state2(device=rank): dgrid = DistributionGrid(expert_parallel_group_size = dist.get_world_size()) model_dim = 8 num_experts = dist.get_world_size(dist.group.WORLD) - gate = TopKGate(model_dim, num_experts, k=2, dgrid=dgrid, use_tutel_cumsum_sub_one=use_tutel) + gate = TopKGate(model_dim, num_experts, k=2, dgrid=dgrid) experts = torch.nn.ModuleList() expert = torch.nn.Linear(model_dim, model_dim, bias=False) experts.append(expert) @@ -586,7 +616,7 @@ def test_expert_moez_encoder(): dim_feedforward = 256 num_local_experts = 4 nexperts = dist.get_world_size(dist.group.WORLD) * num_local_experts - gate = TopKGate(d_model, nexperts, k=1, dgrid=dgrid, use_tutel_cumsum_sub_one=use_tutel) # top-1 + gate = TopKGate(d_model, nexperts, k=1, dgrid=dgrid) # top-1 shared_encoder = TransformerMoEEncoderLayer(d_model, nhead, dim_feedforward=dim_feedforward, nexperts=nexperts, gate=gate, distribution_grid=dgrid) for p in shared_encoder.named_parameters(): @@ -630,7 +660,7 @@ def test_expert_moez_decoder(): dim_feedforward = 256 num_local_experts = 4 nexperts = dist.get_world_size(dist.group.WORLD) * num_local_experts - gate = TopKGate(d_model, nexperts, k=1, dgrid=dgrid, use_tutel_cumsum_sub_one=use_tutel) # top-1 + gate = TopKGate(d_model, nexperts, k=1, dgrid=dgrid) # top-1 shared_decoder = TransformerMoEDecoderLayer(d_model, nhead, dim_feedforward=dim_feedforward, nexperts=nexperts, gate=gate, distribution_grid=dgrid) if use_ort: @@ -669,7 +699,7 @@ def test_moe_allreduce(device=rank): num_experts = dist.get_world_size(dist.group.WORLD) * num_local_experts d_model = 2 d_ff = 16 - gating_fn = TopKGate(d_model, num_experts, k=1, dgrid=dgrid, use_tutel_cumsum_sub_one=use_tutel) # Top1 + gating_fn = TopKGate(d_model, num_experts, k=1, dgrid=dgrid) # Top1 merged_experts = MergedFFNExpert(d_model, d_ff, num_local_experts, dgrid=dgrid) model = MixtureOfExpertsFunc(gating_fn, merged_experts, dgrid).to(device) if use_ort: @@ -708,7 +738,7 @@ def test_dp_group_allreduce(device=rank): d_model = 2 d_ff = 16 - gating_fn = TopKGate(d_model, num_experts, k=1, dgrid=dgrid, use_tutel_cumsum_sub_one=use_tutel) # Top1 + gating_fn = TopKGate(d_model, num_experts, k=1, dgrid=dgrid) # Top1 expert = FFNExpert(d_model, d_ff, dgrid=dgrid) experts = torch.nn.ModuleList() experts.append(expert) @@ -767,7 +797,7 @@ def test_dp_group_with_all2all(device = rank): d_model = 8 input = torch.randn(4, 16, d_model).to(device) - gating_fn = TopKGate(d_model, num_experts, k=2, dgrid=dgrid, use_tutel_cumsum_sub_one=use_tutel) # Top1 + gating_fn = TopKGate(d_model, num_experts, k=2, dgrid=dgrid) # Top1 experts = torch.nn.ModuleList() for _ in range(num_local_experts): expert = torch.nn.Linear(d_model, d_model, bias = False) @@ -837,7 +867,7 @@ def test_ep_group_all2all_forward(device = rank): d_model = 8 input = torch.randn(4, 16, d_model).to(device) - gating_fn = TopKGate(d_model, num_experts, k=2, dgrid=dgrid, use_tutel_cumsum_sub_one=use_tutel) # Top1 + gating_fn = TopKGate(d_model, num_experts, k=2, dgrid=dgrid) # Top1 experts = torch.nn.ModuleList() for _ in range(num_local_experts): expert = torch.nn.Linear(d_model, d_model, bias = False) @@ -909,7 +939,7 @@ def test_ep_group_all2all_backward(device=rank): model_dim = 8 input = (torch.ones(4, 16, model_dim)*5).to(device) - gate = TopKGate(model_dim, num_experts, k=2, dgrid=dgrid, use_tutel_cumsum_sub_one=use_tutel) + gate = TopKGate(model_dim, num_experts, k=2, dgrid=dgrid) # set gate weights always same, so we obtain deterministic mapping gate.wg.weight = torch.nn.Parameter(torch.tensor([[ 0.9311, 0.1706, 0.3681, 1.7191, 2.0357, 0.2269, -0.0920, -0.2983], [ 1.2228, 0.3268, -0.0645, -0.7305, -1.1829, -0.0968, 1.2634, 1.5229]])) @@ -961,7 +991,7 @@ def test_max_len_all_reduce(device = rank): d_model = 8 input = torch.randn(4, 4, d_model).to(device) - gating_fn = TopKGate(d_model, ep_ways*num_local_experts, k=1, capacity_factor=1.25, dgrid=dgrid, use_tutel_cumsum_sub_one=use_tutel) # Top1 + gating_fn = TopKGate(d_model, ep_ways*num_local_experts, k=1, capacity_factor=1.25, dgrid=dgrid) # Top1 experts = torch.nn.ModuleList() for _ in range(num_local_experts): expert = torch.nn.Linear(d_model, d_model, bias = False) @@ -1000,154 +1030,6 @@ def test_max_len_all_reduce(device = rank): dp_ways *= 2 - -#run 20 tests, each one with randomly generated model config -#in each test, run 10 iterations of fwd+bwd pass -def test_moe_memory_fix(device = rank): - import random - import topKgate_old - import moe_old - for i in range(20): - if dist.get_rank() == 0: - print("test ", i) - d_model = 768 - batch = random.randint(10, 15) - seq_length = random.randint(100, 200) - local_n_experts = 4 - glboal_n_experts = local_n_experts * dist.get_world_size() - #manually set the seeds so the parameters for old and new models are the same - #NOTE: To make above statment true, the parameter initalization order needs to be - #identical for old and new models. - - dgrid = DistributionGrid(expert_parallel_group_size = dist.get_world_size()) - - use_merged_ffn = i % 2 == 0 - torch.manual_seed(dist.get_rank()) - if use_merged_ffn: - experts_new = MergedFFNExpert(d_model, 3072, local_n_experts, dgrid=dgrid).to(dist.get_rank()) - else: - experts_new = torch.nn.ModuleList() - for i in range(local_n_experts): - e = FFNExpert(d_model, 3072, dgrid=dgrid) - experts_new.append(e).to(dist.get_rank()) - gate_new = TopKGate(d_model, glboal_n_experts, k=1, dgrid=dgrid).to(dist.get_rank()) - m_new = MixtureOfExpertsFunc(gate_new, experts_new, distribution_grid=dgrid) - - torch.manual_seed(dist.get_rank()) - if use_merged_ffn: - experts_old = MergedFFNExpert(d_model, 3072, local_n_experts, dgrid=dgrid).to(dist.get_rank()) - else: - experts_old = torch.nn.ModuleList() - for j in range(local_n_experts): - e = FFNExpert(d_model, 3072, dgrid=dgrid).to(dist.get_rank()) - experts_old.append(e) - gate_old = topKgate_old.TopKGate(d_model, glboal_n_experts, k=1).to(dist.get_rank()) - m_old = moe_old.MixtureOfExperts(gate_old, experts_old) - - for step in range(10): - input = torch.rand(batch, seq_length, d_model).to(dist.get_rank()) - output = m_old(input) - output_new = m_new(input) - - loss = torch.nn.MSELoss() - output = loss(output, input) - output_new = loss(output_new, input) - output.backward() - output_new.backward() - moe_module_all_reduce(m_old, dgrid) - moe_module_all_reduce(m_new, dgrid) - #comparing the results - assert torch.allclose(m_old.l_aux[0], m_new.gate.loss) - assert torch.allclose(output, output_new, atol=1e-5) - assert torch.allclose(m_old.gate.wg.weight.grad, m_new.gate.wg.weight.grad, atol=1e-5) - if use_merged_ffn: - assert torch.allclose(m_old.experts.weight1, m_new.moe_experts.weight1) - assert torch.allclose(m_old.experts.weight2, m_new.moe_experts.weight2) - else: - for i in range(0, local_n_experts): - assert torch.allclose(m_old.experts[i].linear1.weight, m_new.moe_experts[i].linear1.weight) - assert torch.allclose(m_old.experts[i].linear2.weight, m_new.moe_experts[i].linear2.weight) - -def test_moe_loss(device=rank): - # make sure the refactored loss computation code gets the same loss values and gate log values with the old version. - # use two-layer net to also validate utility function get_moe_loss() - d_model = 64 - d_ff = 256 - n_layers = 2 - batch = random.randint(10, 15) - seq_length = random.randint(100, 200) - input = torch.rand(batch, seq_length, d_model).to(device) - local_n_experts = 4 - glboal_n_experts = local_n_experts * dist.get_world_size() - dgrid = DistributionGrid(expert_parallel_group_size=dist.get_world_size()) - - balance_ratio=[0.1, 0.1, 0.1, 0.1] - torch.manual_seed(device) - nonpadding = torch.ones(batch, seq_length).to(int).to(device) # nonpadding affects loss in new version, so set all 1 - torch.manual_seed(device) - class OldNet(nn.Module): - def __init__(self): - super().__init__() - self.layers = nn.ModuleList() - for _ in range(n_layers): - experts = nn.ModuleList() - for _ in range(local_n_experts): - e = FFNExpert(d_model, d_ff, dgrid) - experts.append(e) - gate = topKgate_old.TopKGate(d_model, glboal_n_experts, k=1, switch_jitter=0.1, switch_dropout=0.1, random_token_drop=True) - self.layers.append(moe_old.MixtureOfExperts(gate, experts, balance_ratio=balance_ratio)) - def forward(self, x): - for layer in self.layers: - x = layer(x, nonpadding=nonpadding) - return x - m_old = OldNet().to(device) - m_old(input) - loss_old = None - n_layers_old = 0 - for p in m_old.named_modules(): - if isinstance(p[1], moe_old.MixtureOfExperts): - if loss_old is None: - loss_old = p[1].l_aux - else: - loss_old += p[1].l_aux - n_layers_old += 1 - - balance_ratio={'load_balance': 0.1, 'sparsity_l1': 0.1, 'mean_importance': 0.1, 'z_loss': 0.1} - gate_log_req={'gate_entropy': True, 'gate_probability': True, 'gate_routed': True, 'expert_fraction': True, 'expert_routed_fraction': True} - torch.manual_seed(device) - class NewNet(nn.Module): - def __init__(self): - super().__init__() - self.layers = nn.ModuleList() - for _ in range(n_layers): - experts = nn.ModuleList() - for _ in range(local_n_experts): - e = FFNExpert(d_model, d_ff, dgrid=dgrid) - experts.append(e) - gate = TopKGate(d_model, glboal_n_experts, k=1, balance_ratio=balance_ratio, gate_log_req=gate_log_req, dgrid=dgrid, switch_jitter=0.1, switch_dropout=0.1, token_drop_type='random') - gate.set_gate_metrics(balance_ratio, gate_log_req) - self.layers.append(MixtureOfExpertsFunc(gate, experts, distribution_grid=dgrid)) - def forward(self, x): - for layer in self.layers: - x = layer(x, nonpadding=nonpadding) - return x - m_new = NewNet().to(device) - m_new(input) - loss_new, gate_log, n_layers_new = get_moe_loss(m_new) - - assert n_layers_new == n_layers_old - assert torch.isclose(loss_new, torch.sum(loss_old[:4])) - for i, l in zip(range(len(loss_functions)), loss_functions.keys()): - if balance_ratio.get(l, 0) > 0: - assert torch.isclose(gate_log[l], loss_old[i]) - for i, l in zip(range(3), list(gate_logs.keys())[:3]): - if gate_log_req.get(l, False) and l != 'gate_probability': # gate probability changed from pre-capacity to after-capacity - assert torch.isclose(gate_log[l], loss_old[4+i]) - if gate_log_req.get('expert_fraction', False): - assert torch.allclose(gate_log.get('expert_fraction', torch.zeros(glboal_n_experts)), loss_old[7 : 7+glboal_n_experts]) - if gate_log_req.get('expert_routed_fraction', False): - assert torch.allclose(gate_log.get('expert_routed_fraction', torch.zeros(glboal_n_experts)), loss_old[7+glboal_n_experts : 7+glboal_n_experts*2]) - @pytest.mark.mpi def test_mp_gating(device = rank): torch.manual_seed(7) @@ -1159,7 +1041,7 @@ def test_mp_gating(device = rank): dgrid = DistributionGrid(expert_slicing_group_size = 2, expert_parallel_replica_group_size= dist.get_world_size()//2) - gate = TopKGate(10, 4, dgrid, k=1, use_tutel_cumsum_sub_one=use_tutel).to(device) + gate = TopKGate(10, 4, dgrid, k=1).to(device) if use_ort: gate = ORTModule(gate) gate.wg = wg @@ -1171,7 +1053,7 @@ def test_mp_gating(device = rank): assert torch.equal(outputs[1], dispatch_mask_ref) assert(torch.equal(outputs[2], expert_cumsum_ref)) -def test_mp_moe_forward(device = rank): +def do_test_mp_moe_forward(dynamic_capacity, device = rank): dgrid = DistributionGrid(expert_slicing_group_size = 2, expert_parallel_replica_group_size=dist.get_world_size()//2) torch.manual_seed(7) @@ -1180,7 +1062,9 @@ def test_mp_moe_forward(device = rank): d_ff = d_model num_expert = 4 d_token = 6 - gate = TopKGate(d_model, num_expert, dgrid, k=1).to(device) + options = {} + options["enable_dynamic_capacity"] = dynamic_capacity + gate = TopKGate(d_model, num_expert, dgrid, k=1, options=options).to(device) experts = torch.nn.ModuleList() @@ -1190,7 +1074,7 @@ def test_mp_moe_forward(device = rank): expert.linear2.weight = torch.nn.Parameter(torch.eye(d_model) * (device%2+1)) experts.append(expert) - moe = MixtureOfExpertsFunc(gate, experts, distribution_grid=dgrid).to(device) + moe = MixtureOfExpertsFunc(gate, experts, distribution_grid=dgrid, options=options).to(device) if use_ort: moe = ORTModule(moe) @@ -1201,6 +1085,8 @@ def test_mp_moe_forward(device = rank): loss.backward() assert dgrid.get_expert_slicing_world_size() == 2 + if dynamic_capacity: + return #compute the reference gate_s by rerun the gate function with the same input reshaped_input_ref = torch.cat((input, input), 0).reshape(-1, d_model) @@ -1209,6 +1095,10 @@ def test_mp_moe_forward(device = rank): output_ref = torch.einsum("ks, ksm->sm", combine_weights_ref[:, device%2*d_token: (device%2+1)*d_token], input * 5.0) assert torch.allclose(output_ref, output, atol = 1e-03) +def test_mp_moe_forward_pass(): + do_test_mp_moe_forward(dynamic_capacity=True) + do_test_mp_moe_forward(dynamic_capacity=False) + #Compare the MP forward and backward pass by: (1) ref model: running the entire model on one GPU, # (2) testing_model: running the same model on multiple GPUs, the forward and backward output should be the same (allclose) def test_mp_moe_forward_backward(device = rank): @@ -1282,7 +1172,7 @@ def test_mp_moe_forward_backward(device = rank): #NOTE: We need to put all apex tests at the bottom, otherwise all tests after the apex tests are all casted to fp16. #Test the manually casting alltoall input to fp16 is the same as not casting, when apex O1 is applied -@pytest.mark.with_ort +#@pytest.mark.with_ort # opened an exporter issue for investigation def test_moe_fp16(device = rank): model_dim = 256 ff_dim = 64 @@ -1291,7 +1181,8 @@ def test_moe_fp16(device = rank): num_experts = dist.get_world_size(dist.group.WORLD) * num_local_experts dgrid_ref = DistributionGrid(expert_parallel_group_size = dist.get_world_size()) - gating_fn = TopKGate(model_dim, num_experts, k=1, fp16_mode=True, switch_jitter=0.0, dgrid=dgrid_ref, use_tutel_cumsum_sub_one=use_tutel) # Top1 + options = { "fp16_mode" : True} + gating_fn = TopKGate(model_dim, num_experts, k=1, switch_jitter=0.0, dgrid=dgrid_ref, options = options) # Top1 merged_experts = MergedFFNExpert(model_dim, ff_dim, num_local_experts, dgrid=dgrid_ref) model_ref = MixtureOfExpertsFunc(gating_fn, merged_experts, distribution_grid=dgrid_ref).to(device) @@ -1300,7 +1191,8 @@ def test_moe_fp16(device = rank): optimizer_ref = torch.optim.SGD(model_ref.parameters(), lr=1e-3) dgrid = DistributionGrid(expert_parallel_group_size = dist.get_world_size()) - model = MixtureOfExpertsFunc(gating_fn, merged_experts, fp16_mode = True, distribution_grid = dgrid).to(device) + options = { "fp16_mode" : True} + model = MixtureOfExpertsFunc(gating_fn, merged_experts, distribution_grid = dgrid, options = options).to(device) if use_ort: model = ORTModule(model) optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) @@ -1308,9 +1200,6 @@ def test_moe_fp16(device = rank): input = torch.rand(4, 16, model_dim).to(device) model_ref, optimizer_ref = apex_amp.initialize(model_ref, optimizer_ref, opt_level="O1") output_ref = model_ref(input) - if use_ort: - loss = output_ref[0].sum() - loss.backward() model, optimizer = apex_amp.initialize(model, optimizer, opt_level="O1") output = model(input) @@ -1320,7 +1209,7 @@ def test_moe_fp16(device = rank): assert torch.equal(output_ref, output) -@pytest.mark.with_ort +#@pytest.mark.with_ort # opened an exporter issue for investigation def test_moe_loss_scale(device = rank): model_dim = 256 ff_dim = 64 @@ -1328,8 +1217,9 @@ def test_moe_loss_scale(device = rank): dgrid = DistributionGrid(expert_parallel_group_size = dist.get_world_size()) merged_experts = MergedFFNExpert(model_dim, ff_dim, num_local_experts, dgrid=dgrid) num_experts = dist.get_world_size(dist.group.WORLD) * num_local_experts - gating_fn = TopKGate(model_dim, num_experts, k=1, fp16_mode=True, dgrid=dgrid, use_tutel_cumsum_sub_one=use_tutel) # Top1 - model = MixtureOfExpertsFunc(gating_fn, merged_experts, fp16_mode = True, distribution_grid = dgrid).to(device) + options = { "fp16_mode" : True} + gating_fn = TopKGate(model_dim, num_experts, k=1, dgrid=dgrid, options = options) # Top1 + model = MixtureOfExpertsFunc(gating_fn, merged_experts, distribution_grid = dgrid, options = options).to(device) if use_ort: model = ORTModule(model) optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) @@ -1376,7 +1266,7 @@ def test_parameter_synchronization(device = rank): d_model = 8 input = torch.randn(4, 16, d_model).to(device) - gating_fn = TopKGate(d_model, num_experts, k=2, dgrid=dgrid, use_tutel_cumsum_sub_one=use_tutel) # Top2 + gating_fn = TopKGate(d_model, num_experts, k=2, dgrid=dgrid) # Top2 experts = torch.nn.ModuleList() for _ in range(num_local_experts): expert = torch.nn.Linear(d_model, d_model, bias = False) @@ -1431,7 +1321,7 @@ def test_pipeline_parallelism_single_node(device=rank): # ep model ep_dgrid = DistributionGrid(expert_parallel_group_size = dist.get_world_size()) - ep_gating_fn = TopKGate(d_model, num_experts, k=1, dgrid=ep_dgrid, use_tutel_cumsum_sub_one=use_tutel) + ep_gating_fn = TopKGate(d_model, num_experts, k=1, dgrid=ep_dgrid) ep_expert = FFNExpert(d_model, d_ff, dgrid=ep_dgrid) ep_experts = torch.nn.ModuleList() ep_experts.append(ep_expert) @@ -1446,7 +1336,7 @@ def test_pipeline_parallelism_single_node(device=rank): # pp model pp_dgrid = DistributionGrid(num_of_nodes_in_pipeline = 1, num_of_pipeline_stage = dist.get_world_size()) - pp_gating_fn = TopKGate(d_model, num_experts, k=1, dgrid=pp_dgrid, use_tutel_cumsum_sub_one=use_tutel) + pp_gating_fn = TopKGate(d_model, num_experts, k=1, dgrid=pp_dgrid) pp_expert = FFNExpert(d_model, d_ff, dgrid=pp_dgrid) pp_experts = torch.nn.ModuleList() pp_experts.append(pp_expert) @@ -1481,7 +1371,7 @@ def test_assertion(): num_local_experts = 4 num_experts = dist.get_world_size(dist.group.WORLD) * num_local_experts - gating_fn = TopKGate(model_dim, num_experts, k=1, dgrid=dgrid, use_tutel_cumsum_sub_one=use_tutel) # Top1 + gating_fn = TopKGate(model_dim, num_experts, k=1, dgrid=dgrid) # Top1 experts = torch.nn.ModuleList() for i in range(num_local_experts): experts.append(FFNExpert(model_dim, ff_dim, dgrid=dgrid)).to("cuda") @@ -1519,12 +1409,16 @@ def test_fsdp_ep(): num_local_experts = 4 nexperts = dist.get_world_size(dist.group.WORLD) * num_local_experts + options = { "fsdp_zero_optimization" : {"stage": 1, "flatten_parameters" : False}, + "imbalanced_input_support" : {"enabled" : False}} + orig_options = {"imbalanced_input_support" : {"enabled" : False}} + # Original model orig_dg = DistributionGrid() - orig_gate = TopKGate(d_model, nexperts, k=1, dgrid=orig_dg, use_tutel_cumsum_sub_one=use_tutel) + orig_gate = TopKGate(d_model, nexperts, k=1, dgrid=orig_dg) orig_enc = TransformerMoEEncoderLayer(d_model, nhead, dim_feedforward=dim_feedforward, nexperts=nexperts, gate=orig_gate, distribution_grid = orig_dg, - use_mpi4py = False) + options = orig_options) if use_ort: orig_enc = ORTModule(orig_enc) @@ -1536,10 +1430,10 @@ def test_fsdp_ep(): # Standard EP distribution ep_dg = DistributionGrid(expert_parallel_group_size = dist.get_world_size()) - ep_gate = TopKGate(d_model, nexperts, k=1, dgrid=ep_dg, use_tutel_cumsum_sub_one=use_tutel) + ep_gate = TopKGate(d_model, nexperts, k=1, dgrid=ep_dg) ep_enc = TransformerMoEEncoderLayer(d_model, nhead, dim_feedforward=dim_feedforward, nexperts=nexperts, gate=ep_gate, distribution_grid = ep_dg, - use_mpi4py = False) + options = orig_options) if use_ort: ep_enc = ORTModule(ep_enc) @@ -1551,10 +1445,10 @@ def test_fsdp_ep(): # FSDP distribution fsdp_dg = DistributionGrid(expert_parallel_group_size = dist.get_world_size()) - fsdp_gate = TopKGate(d_model, nexperts, k=1, dgrid=fsdp_dg, use_tutel_cumsum_sub_one=use_tutel) + fsdp_gate = TopKGate(d_model, nexperts, k=1, dgrid=fsdp_dg) fsdp_enc = TransformerMoEEncoderLayer(d_model, nhead, dim_feedforward=dim_feedforward, nexperts=nexperts, gate=fsdp_gate, distribution_grid = fsdp_dg, - use_mpi4py = False, use_fsdp = True, flatten_parameters = False) + options = options) if use_ort: fsdp_enc = ORTModule(fsdp_enc) fsdp_params = dict(flatten_parameters=False) @@ -1578,6 +1472,7 @@ def test_fsdp_ep(): assert fsdp_moe_params == expert_parameters assert fsdp_params == ((non_expert_parameters / size) + gate_parameters + expert_parameters) + @pytest.mark.with_ort def test_einsum(): @@ -1586,12 +1481,209 @@ def test_einsum(): a = torch.rand(M) b = torch.rand(M, N) rule = 's,se->se' - assert torch.allclose(einsum(rule, a, b), torch.einsum(rule, a, b)) + assert torch.allclose(om_einsum(rule, a, b), torch.einsum(rule, a, b)) a = torch.rand(M, N) rule = 'se,sc->sec' - assert torch.allclose(einsum(rule, a, b), torch.einsum(rule, a, b)) + assert torch.allclose(om_einsum(rule, a, b), torch.einsum(rule, a, b)) rule = 'se,sc->sc' - assert torch.allclose(einsum(rule, a, b), torch.einsum(rule, a, b)) + assert torch.allclose(om_einsum(rule, a, b), torch.einsum(rule, a, b)) a = torch.rand(M, N, D) rule = 'sec,sm->ecm' - assert torch.allclose(einsum(rule, a, b), torch.einsum(rule, a, b)) + assert torch.allclose(om_einsum(rule, a, b), torch.einsum(rule, a, b)) + +def test_enable_zero_optimization_z0(): + options_z0 = { "deepspeed_zero_optimization" : {"stage": 0}} + + dgrid = DistributionGrid() + model_dim = 8 + ff_dim = 12 + num_experts = 4 + gating_fn = TopKGate(model_dim, num_experts, k=1, dgrid=dgrid) + experts = torch.nn.ModuleList() + experts.append(FFNExpert(model_dim, ff_dim, dgrid=dgrid)) + + m_z0 = MixtureOfExpertsFunc(gating_fn, experts, distribution_grid=dgrid, options = options_z0) + for param in m_z0.parameters(): + if is_moe_parameter(param): + assert hasattr(param, "allreduce") is False + assert hasattr(param, "group_name") is False + + enc = TransformerMoEEncoderLayer(256, 4, nexperts=8, distribution_grid = dgrid, options = options_z0) + for param in enc.parameters(): + if is_moe_parameter(param): + assert hasattr(param, "allreduce") is False + assert hasattr(param, "group_name") is False + + dec = TransformerMoEDecoderLayer(256, 4, nexperts=8, distribution_grid = dgrid, options = options_z0) + for param in dec.parameters(): + if is_moe_parameter(param): + assert hasattr(param, "allreduce") is False + assert hasattr(param, "group_name") is False + + lenc = LanguageExpertMoEEncoderLayer(256, 4, nexperts=8, distribution_grid = dgrid, options = options_z0) + for param in lenc.parameters(): + if is_moe_parameter(param): + assert hasattr(param, "allreduce") is False + assert hasattr(param, "group_name") is False + + ldec = LanguageExpertMoEDecoderLayer(256, 4, nexperts=8, distribution_grid = dgrid, options = options_z0) + for param in ldec.parameters(): + if is_moe_parameter(param): + assert hasattr(param, "allreduce") is False + assert hasattr(param, "group_name") is False + +def test_enable_zero_optimization_z1(): + options_z1 = { "deepspeed_zero_optimization" : {"stage": 1}} + + dgrid = DistributionGrid() + model_dim = 8 + ff_dim = 12 + num_experts = 4 + gating_fn = TopKGate(model_dim, num_experts, k=1, dgrid=dgrid) + experts = torch.nn.ModuleList() + experts.append(FFNExpert(model_dim, ff_dim, dgrid=dgrid)) + + m_z1 = MixtureOfExpertsFunc(gating_fn, experts, distribution_grid=dgrid, options = options_z1) + for param in m_z1.parameters(): + if is_moe_parameter(param): + assert hasattr(param, "allreduce") is True + assert param.allreduce == False + assert hasattr(param, "group_name") is True + + enc = TransformerMoEEncoderLayer(256, 4, nexperts=8, distribution_grid = dgrid, options = options_z1) + for param in enc.parameters(): + if is_moe_parameter(param): + assert hasattr(param, "allreduce") is True + assert param.allreduce == False + assert hasattr(param, "group_name") is True + + dec = TransformerMoEDecoderLayer(256, 4, nexperts=8, distribution_grid = dgrid, options = options_z1) + for param in dec.parameters(): + if is_moe_parameter(param): + assert hasattr(param, "allreduce") is True + assert param.allreduce == False + assert hasattr(param, "group_name") is True + + lenc = LanguageExpertMoEEncoderLayer(256, 4, nexperts=8, distribution_grid = dgrid, options = options_z1) + for param in lenc.parameters(): + if is_moe_parameter(param): + assert hasattr(param, "allreduce") is True + assert param.allreduce == False + assert hasattr(param, "group_name") is True + + ldec = LanguageExpertMoEDecoderLayer(256, 4, nexperts=8, distribution_grid = dgrid, options = options_z1) + for param in ldec.parameters(): + if is_moe_parameter(param): + assert hasattr(param, "allreduce") is True + assert param.allreduce == False + assert hasattr(param, "group_name") is True + +def test_enable_zero_optimization_z2(): + options_z2 = { "deepspeed_zero_optimization" : {"stage": 2}} + + dgrid = DistributionGrid() + model_dim = 8 + ff_dim = 12 + num_experts = 4 + gating_fn = TopKGate(model_dim, num_experts, k=1, dgrid=dgrid) + experts = torch.nn.ModuleList() + experts.append(FFNExpert(model_dim, ff_dim, dgrid=dgrid)) + + m_z2 = MixtureOfExpertsFunc(gating_fn, experts, distribution_grid=dgrid, options = options_z2) + for param in m_z2.parameters(): + if is_moe_parameter(param): + assert hasattr(param, "allreduce") is True + assert param.allreduce == False + assert hasattr(param, "group_name") is True + + enc = TransformerMoEEncoderLayer(256, 4, nexperts=8, distribution_grid = dgrid, options = options_z2) + for param in enc.parameters(): + if is_moe_parameter(param): + assert hasattr(param, "allreduce") is True + assert param.allreduce == False + assert hasattr(param, "group_name") is True + + dec = TransformerMoEDecoderLayer(256, 4, nexperts=8, distribution_grid = dgrid, options = options_z2) + for param in dec.parameters(): + if is_moe_parameter(param): + assert hasattr(param, "allreduce") is True + assert param.allreduce == False + assert hasattr(param, "group_name") is True + + lenc = LanguageExpertMoEEncoderLayer(256, 4, nexperts=8, distribution_grid = dgrid, options = options_z2) + for param in lenc.parameters(): + if is_moe_parameter(param): + assert hasattr(param, "allreduce") is True + assert param.allreduce == False + assert hasattr(param, "group_name") is True + + ldec = LanguageExpertMoEDecoderLayer(256, 4, nexperts=8, distribution_grid = dgrid, options = options_z2) + for param in ldec.parameters(): + if is_moe_parameter(param): + assert hasattr(param, "allreduce") is True + assert param.allreduce == False + assert hasattr(param, "group_name") is True + +def test_enable_zero_optimization_z3(): + options_z3 = { "deepspeed_zero_optimization" : {"stage": 3}} + + dgrid = DistributionGrid() + model_dim = 8 + ff_dim = 12 + num_experts = 4 + gating_fn = TopKGate(model_dim, num_experts, k=1, dgrid=dgrid) + experts = torch.nn.ModuleList() + experts.append(FFNExpert(model_dim, ff_dim, dgrid=dgrid)) + + m_z3 = MixtureOfExpertsFunc(gating_fn, experts, distribution_grid=dgrid, options = options_z3) + for param in m_z3.parameters(): + if is_moe_parameter(param): + assert hasattr(param, "allreduce") is True + assert param.allreduce == False + assert hasattr(param, "group_name") is True + + enc = TransformerMoEEncoderLayer(256, 4, nexperts=8, distribution_grid = dgrid, options = options_z3) + for param in enc.parameters(): + if is_moe_parameter(param): + assert hasattr(param, "allreduce") is True + assert param.allreduce == False + assert hasattr(param, "group_name") is True + + dec = TransformerMoEDecoderLayer(256, 4, nexperts=8, distribution_grid = dgrid, options = options_z3) + for param in dec.parameters(): + if is_moe_parameter(param): + assert hasattr(param, "allreduce") is True + assert param.allreduce == False + assert hasattr(param, "group_name") is True + + lenc = LanguageExpertMoEEncoderLayer(256, 4, nexperts=8, distribution_grid = dgrid, options = options_z3) + for param in lenc.parameters(): + if is_moe_parameter(param): + assert hasattr(param, "allreduce") is True + assert param.allreduce == False + assert hasattr(param, "group_name") is True + + ldec = LanguageExpertMoEDecoderLayer(256, 4, nexperts=8, distribution_grid = dgrid, options = options_z3) + for param in ldec.parameters(): + if is_moe_parameter(param): + assert hasattr(param, "allreduce") is True + assert param.allreduce == False + assert hasattr(param, "group_name") is True + +def test_enable_expert_weight_calculation_optimization(device=rank): + model_dim = 8 + num_local_experts = 4 + num_experts = dist.get_world_size(dist.group.WORLD) * num_local_experts + input = torch.rand(8, 16, model_dim).to(device) + dgrid = DistributionGrid(expert_parallel_group_size = dist.get_world_size()) + gate = TopKGate(model_dim, num_experts, k=2, dgrid=dgrid) + experts = torch.nn.ModuleList() + for i in range(num_local_experts): + expert = torch.nn.Linear(model_dim, model_dim, bias=False) + # use identify matrix + expert.weight = torch.nn.Parameter(torch.eye(model_dim)) + experts.append(expert) + options = {"enable_expert_weight_calculation_optimization" : True} + moe = MixtureOfExpertsFunc(gate, experts, distribution_grid = dgrid, options = options).to(device) + output = moe(input) + assert output.shape == input.shape diff --git a/ort_moe/tests/test_nccl.py b/ort_moe/tests/test_nccl.py new file mode 100644 index 00000000..db8568f5 --- /dev/null +++ b/ort_moe/tests/test_nccl.py @@ -0,0 +1,10 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import os + +os.environ["LD_LIBRARY_PATH"] = "/usr/src/nccl-2.8.4-1/build/lib/" + +# run original nccl 2.8.4 baseline +os.system("/usr/src/nccl-tests-baseline/build/alltoall_perf -b 128 -e 1GB -f 2 -g 1 -c 1 -n 200 -w 10 -z 0") diff --git a/ort_moe/tests/test_sccl_without_import.py b/ort_moe/tests/test_sccl_without_import.py new file mode 100644 index 00000000..def60a78 --- /dev/null +++ b/ort_moe/tests/test_sccl_without_import.py @@ -0,0 +1,10 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import os + +os.environ["LD_LIBRARY_PATH"] = "/usr/src/nccl-master-0.3.1/build/lib/" + +# run sccl without import sccl -- should be identical to nccl 2.8.4 +os.system("/usr/src/nccl-tests/build/alltoall_perf -b 128 -e 1GB -f 2 -g 1 -c 1 -n 200 -w 10 -z 0") diff --git a/ort_moe/tests/test_top2gating.py b/ort_moe/tests/test_top2gating.py index 208b14d8..b1d1646b 100644 --- a/ort_moe/tests/test_top2gating.py +++ b/ort_moe/tests/test_top2gating.py @@ -15,22 +15,29 @@ import pytest import torch import math +import os from mpi4py import MPI +import torch.distributed as dist from ort_moe.topKgate import TopKGate, top2gating, top1gating, fast_one_hot, balance_ratio_to_dict from ort_moe.loss_functions import loss_functions from ort_moe.gate_logs import gate_logs from ort_moe.grids import DistributionGrid -from tutel.jit_kernels.gating import fast_cumsum_sub_one - -import topKgate_old +from ort_moe.moe_config import moe_config #Uncomment this if there is no cuda #skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") dgrid = DistributionGrid() +BACKEND = dist.Backend.NCCL +os.environ["MASTER_ADDR"] = "localhost" +os.environ["MASTER_PORT"] = "29502" # torch 1.5 compatibility comm = MPI.COMM_WORLD rank = comm.Get_rank() +size = comm.Get_size() + +if not dist.is_initialized(): + dist.init_process_group(backend=BACKEND, rank=rank, world_size=size) def test_create(): gate = TopKGate(4, 8, dgrid=dgrid, k = 1) @@ -50,7 +57,9 @@ def do_test_forward(device='cpu', topk = 1, fp16_mode = False): model_dim = 4 input = torch.randn(num_tokens, model_dim).to(device) - gate = TopKGate(model_dim, num_experts, dgrid=dgrid, k = topk, fp16_mode = fp16_mode).to(device) + options = {} + options["fp16_mode"] = fp16_mode + gate = TopKGate(model_dim, num_experts, dgrid=dgrid, k = topk, options = options).to(device) capacity_fp = max(min(num_tokens, (topk * num_tokens / num_experts)), 4) capacity = math.ceil(capacity_fp) @@ -118,7 +127,8 @@ def do_test_nonpadding_top1(): num_nonpadding = 5 logits = torch.randn(num_tokens, num_experts) nonpadding = torch.zeros(num_tokens).to(torch.int64).scatter_(0, torch.arange(0, num_nonpadding), 1) - _, _, _, dispatch_mask, _ = top1gating(logits, capacity_factor=1, nonpadding=nonpadding, straight_through=True) + br={'z_loss': 0.1} + _, _, _, dispatch_mask, _ = top1gating(logits, capacity_factor=1, balance_ratio=br, nonpadding=nonpadding, straight_through=True) top1s = torch.argmax(logits, dim=1) capacity = num_tokens // num_experts ce = [0] * num_tokens @@ -153,85 +163,63 @@ def test_nonpadding(): do_test_nonpadding_top1() do_test_nonpadding_top2() -def do_test_loss(k, balance_ratio=None, gate_log_req=None, device=None, tutel_cumsum_sub_one=None): +def do_test_loss(k, balance_ratio=None, gate_log_req=None, device=None, options=None): # verify that refactored loss computation can get the same result with old version code num_tokens = 8 num_experts = 4 logits = torch.randn(num_tokens, num_experts, device=device) if k == 2: - loss_old, _, _, _, = topKgate_old.top2gating(logits) if balance_ratio is None and gate_log_req is None: - loss, _, _, _, _ = top2gating(logits, capacity_factor=1, logits_gumbel=1, tutel_cumsum_sub_one=tutel_cumsum_sub_one) + loss, _, _, _, _ = top2gating(logits, capacity_factor=1, logits_gumbel=1, options=options) elif gate_log_req is None: - loss, _, _, _, _ = top2gating(logits, capacity_factor=1, logits_gumbel=1, balance_ratio=balance_ratio_to_dict(balance_ratio), straight_through=True, straight_through_temperature=1.0, tutel_cumsum_sub_one=tutel_cumsum_sub_one) - loss1, _, _, _, _ = top2gating(logits, capacity_factor=1, logits_gumbel=1, balance_ratio=balance_ratio_to_dict(balance_ratio), straight_through=True, straight_through_temperature=1.0-1e-9, tutel_cumsum_sub_one=tutel_cumsum_sub_one) - loss2, _, _, _, _ = top2gating(logits, capacity_factor=1, logits_gumbel=0, balance_ratio=balance_ratio_to_dict(balance_ratio), straight_through=True, straight_through_temperature=1.0, tutel_cumsum_sub_one=tutel_cumsum_sub_one) - loss3, _, _, _, _ = top2gating(logits, capacity_factor=1, logits_gumbel=0, balance_ratio=balance_ratio_to_dict(balance_ratio), straight_through=True, straight_through_temperature=1.0-1e-9, tutel_cumsum_sub_one=tutel_cumsum_sub_one) + loss, _, _, _, _ = top2gating(logits, capacity_factor=1, logits_gumbel=1, balance_ratio=balance_ratio_to_dict(balance_ratio), straight_through=True, straight_through_temperature=1.0, options=options) + loss1, _, _, _, _ = top2gating(logits, capacity_factor=1, logits_gumbel=1, balance_ratio=balance_ratio_to_dict(balance_ratio), straight_through=True, straight_through_temperature=1.0-1e-9, options=options) + loss2, _, _, _, _ = top2gating(logits, capacity_factor=1, logits_gumbel=0, balance_ratio=balance_ratio_to_dict(balance_ratio), straight_through=True, straight_through_temperature=1.0, options=options) + loss3, _, _, _, _ = top2gating(logits, capacity_factor=1, logits_gumbel=0, balance_ratio=balance_ratio_to_dict(balance_ratio), straight_through=True, straight_through_temperature=1.0-1e-9, options=options) assert loss == loss1 == loss2 == loss3 elif balance_ratio is None: - loss, _, _, _, _ = top2gating(logits, capacity_factor=1, logits_gumbel=1, gate_log_req=gate_log_req, tutel_cumsum_sub_one=tutel_cumsum_sub_one) + loss, _, _, _, _ = top2gating(logits, capacity_factor=1, logits_gumbel=1, gate_log_req=gate_log_req, options=options) else: - loss, _, _, _, _ = top2gating(logits, capacity_factor=1, logits_gumbel=1, balance_ratio=balance_ratio_to_dict(balance_ratio), gate_log_req=gate_log_req, tutel_cumsum_sub_one=tutel_cumsum_sub_one) + loss, _, _, _, _ = top2gating(logits, capacity_factor=1, logits_gumbel=1, balance_ratio=balance_ratio_to_dict(balance_ratio), gate_log_req=gate_log_req, options=options) if balance_ratio is None: balance_ratio = 0.01 balance_ratio = balance_ratio_to_dict(balance_ratio) - loss_old = loss_old[0] * balance_ratio['load_balance'] - assert torch.isclose(loss, loss_old) else: if balance_ratio is not None: balance_ratio = balance_ratio_to_dict(balance_ratio) - loss_old, _, _, _, = topKgate_old.top1gating(logits, capacity_factor=1) - if balance_ratio is None and gate_log_req is None: - loss, gate_log, _, _, _ = top1gating(logits, capacity_factor=1, tutel_cumsum_sub_one=tutel_cumsum_sub_one) + loss, gate_log, _, _, _ = top1gating(logits, capacity_factor=1, options=options) elif gate_log_req is None: - loss, gate_log, _, _, _ = top1gating(logits, capacity_factor=1, balance_ratio=balance_ratio, straight_through=True, straight_through_temperature=1.0, logits_gumbel=0, tutel_cumsum_sub_one=tutel_cumsum_sub_one) - loss1, _, _, _, _ = top1gating(logits, capacity_factor=1, balance_ratio=balance_ratio, straight_through=True, straight_through_temperature=1.0-1e-9, logits_gumbel=0, tutel_cumsum_sub_one=tutel_cumsum_sub_one) - loss2, _, _, _, _ = top1gating(logits, capacity_factor=1, balance_ratio=balance_ratio, straight_through=True, straight_through_temperature=1.0, logits_gumbel=1e-9, tutel_cumsum_sub_one=tutel_cumsum_sub_one) - loss3, _, _, _, _ = top1gating(logits, capacity_factor=1, balance_ratio=balance_ratio, straight_through=True, straight_through_temperature=1.0-1e-9, logits_gumbel=1e-9, tutel_cumsum_sub_one=tutel_cumsum_sub_one) + loss, gate_log, _, _, _ = top1gating(logits, capacity_factor=1, balance_ratio=balance_ratio, straight_through=True, straight_through_temperature=1.0, logits_gumbel=0, options=options) + loss1, _, _, _, _ = top1gating(logits, capacity_factor=1, balance_ratio=balance_ratio, straight_through=True, straight_through_temperature=1.0-1e-9, logits_gumbel=0, options=options) + loss2, _, _, _, _ = top1gating(logits, capacity_factor=1, balance_ratio=balance_ratio, straight_through=True, straight_through_temperature=1.0, logits_gumbel=1e-9, options=options) + loss3, _, _, _, _ = top1gating(logits, capacity_factor=1, balance_ratio=balance_ratio, straight_through=True, straight_through_temperature=1.0-1e-9, logits_gumbel=1e-9, options=options) assert loss == loss1 == loss2 == loss3 elif balance_ratio is None: - loss, gate_log, _, _, _ = top1gating(logits, capacity_factor=1, gate_log_req=gate_log_req, tutel_cumsum_sub_one=tutel_cumsum_sub_one) + loss, gate_log, _, _, _ = top1gating(logits, capacity_factor=1, gate_log_req=gate_log_req, options=options) else: - loss, gate_log, _, _, _ = top1gating(logits, capacity_factor=1, balance_ratio=balance_ratio, gate_log_req=gate_log_req, tutel_cumsum_sub_one=tutel_cumsum_sub_one) - - if balance_ratio is None: - balance_ratio = {'load_balance': 0.01} - if gate_log_req is None: - gate_log_req = {} - l_old = sum([loss_old[i] * balance_ratio.get(l, 0) for i, l in zip(range(len(loss_functions)), loss_functions.keys())]) - - assert torch.isclose(loss, l_old) - for i, l in zip(range(len(loss_functions)), loss_functions.keys()): - if balance_ratio.get(l, 0) > 0: - assert torch.isclose(gate_log[l], loss_old[i]*balance_ratio[l]) - for i, l in zip(range(3), list(gate_logs.keys())[:3]): - if gate_log_req.get(l, False): - # skip gate_probability because it is computed over routed tokens in the new version - if l != 'gate_probability': - assert torch.isclose(gate_log[l], loss_old[4+i]) - if gate_log_req.get('expert_fraction', False): - assert torch.allclose(gate_log.get('expert_fraction', torch.zeros(num_experts)), loss_old[7 : 7+num_experts]) - if gate_log_req.get('expert_routed_fraction', False): - assert torch.allclose(gate_log.get('expert_routed_fraction', torch.zeros(num_experts)), loss_old[7+num_experts : 7+num_experts*2]) + loss, gate_log, _, _, _ = top1gating(logits, capacity_factor=1, balance_ratio=balance_ratio, gate_log_req=gate_log_req, options=options) def test_loss(): balance_ratio={'load_balance': 0.1, 'sparsity_l1': 0.1, 'mean_importance': 0.1, 'z_loss': 0.1, 'ideal_load_balance': 1e-9} gate_log_req={'gate_entropy': True, 'gate_probability': True, 'gate_routed': True, 'expert_fraction': True, 'expert_routed_fraction': True} + options = {} + options["enable_tutel_cumsum"] = True + do_test_loss(2) - do_test_loss(2, device=rank, tutel_cumsum_sub_one=fast_cumsum_sub_one) + do_test_loss(2, device=rank, options=moe_config(options)) do_test_loss(2, balance_ratio=0.01) do_test_loss(2, gate_log_req=gate_log_req) do_test_loss(2, balance_ratio=0.01, gate_log_req=gate_log_req) do_test_loss(1) - do_test_loss(1, device=rank, tutel_cumsum_sub_one=fast_cumsum_sub_one) + do_test_loss(1, device=rank, options=moe_config(options)) do_test_loss(1, balance_ratio=balance_ratio) do_test_loss(1, gate_log_req=gate_log_req) do_test_loss(1, balance_ratio=balance_ratio, gate_log_req=gate_log_req) @@ -299,6 +287,23 @@ def test_token_drop(): do_test_token_drop_top2(token_drop_type='routing_weight') def test_tutel_cumsum(): matrix = torch.randint(0, 100, (10000, 100), device=rank) + from tutel.jit_kernels.gating import fast_cumsum_sub_one cumsum_tutel = fast_cumsum_sub_one(matrix, dim=0) + 1 cumsum_torch = torch.cumsum(matrix, dim=0) assert cumsum_tutel.equal(cumsum_torch), "Result of tutel's cumsum is not correct" + +# Test dynamic capacity in top1gating when the max_token_usage based capacity is in effect. +def test_top1gating_dynamic_capacity(): + dgrid = DistributionGrid(expert_parallel_group_size = 4) + rank_list = list(range(0, dist.get_world_size())) + ranks_count = 4 + logits = torch.eye(128, dtype=torch.float32) + logits = logits.repeat(3, 1) + options = {} + options["enable_dynamic_capacity"] = True + _, _, _, _, capacity_fp = top1gating(logits, capacity_factor=100.0, options=options) + assert capacity_fp == 32.0 + logits = torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32) + logits = logits.repeat(3, 1) + _, _, _, _, capacity_fp = top1gating(logits, capacity_factor=5.0, options=options) + assert capacity_fp == 4.0 diff --git a/ort_moe/tests/test_uni_image.py b/ort_moe/tests/test_uni_image.py new file mode 100644 index 00000000..dde4c011 --- /dev/null +++ b/ort_moe/tests/test_uni_image.py @@ -0,0 +1,32 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +######################## +# ORTModule related tests +######################## + +#a simple test to make sure ortmodule result is the same as pytorch result +from torch_ort import ORTModule +import torch +class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4,4) + + def forward(self, x): + y = self.linear(x) + return y + +model = M() +model_ort = ORTModule(model) + +input = torch.rand(4, 4).requires_grad_() +output = model(input) +output_ort = model_ort(input) +assert torch.allclose(output, output_ort) + +######################## +# SCCL/NCCL Related test +######################## +#TODO