diff --git a/README.md b/README.md index cc2c72cb..8ef586a1 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,41 @@ Tutel MoE: An Optimized Mixture-of-Experts Implementation. - Supported CPU: fp64/fp32 +### What's New: + +- Tutel v0.3: Add Megablocks solution to improve decoder inference on single-GPU with num_local_expert >= 2: +```py + >> Example (capacity_factor=0 for dropless-MoE): + # Using BatchMatmul: + python3 -m tutel.examples.helloworld --megablocks_size=0 --batch_size=1 --num_tokens=32 --top=1 --eval --num_local_experts=128 --capacity_factor=0 + # Using Megablocks with block_size = 1: + python3 -m tutel.examples.helloworld --megablocks_size=1 --batch_size=1 --num_tokens=32 --top=1 --eval --num_local_experts=128 --capacity_factor=0 + # Using Megablocks with block_size = 2: + python3 -m tutel.examples.helloworld --megablocks_size=2 --batch_size=1 --num_tokens=32 --top=1 --eval --num_local_experts=128 --capacity_factor=0 + + >> How to: + self._moe_layer.forward(x, .., megablocks_size=1) # Control the switch of megablocks_size (0 for disabled) +``` + +- Tutel v0.2: Allow most configurations to be dynamic switchable with free cost: +```py + >> Example: + python3 -m torch.distributed.run --nproc_per_node=8 -m tutel.examples.helloworld_switch --batch_size=16 + + >> How to: + self._moe_layer.forward(x, .., a2a_ffn_overlap_degree=2) # Control the switch of overlap granularity (1 for no overlapping) + self._moe_layer.forward(x, .., adaptive_r=1) # Control the switch of parallelism (0 for DP, 1 for DP + EP, W / E for MP + EP, else for DP + MP + EP) + self._moe_layer.forward(x, .., capacity_factor=1) # Control the switch of capacity_volume (positive for padding, negative for no-padding, 0 for dropless) + self._moe_layer.forward(x, .., top_k=1) # Control the switch of top_k sparsity +``` + +- Tutel v0.1: Optimize the Einsum Complexity of Data Dispatch Encoding and Decoding, add 2DH option to deal with All-to-All at scale: +```py + >> Example (suggest enabling 2DH only at scale): + python3 -m torch.distributed.run --nproc_per_node=8 -m tutel.examples.helloworld_switch --batch_size=16 --use_2dh=1 +``` + + How to setup Tutel MoE for Pytorch and [run examples](tutel/examples), or [enable fairseq with MoE](tutel/examples/fairseq_moe): ``` * Recommended Pytorch (minimize version == 1.8.0): @@ -48,15 +83,11 @@ How to setup Tutel MoE for Pytorch and [run examples](tutel/examples), or [enabl $ python3 ./tutel/examples/helloworld.py --batch_size=16 .. -* Switch Test using single-node 8 GPUs: - - $ python3 -m torch.distributed.launch --nproc_per_node=8 -m tutel.examples.helloworld_switch --batch_size=16 - * Run Tutel MoE in Distributed Mode: (Method A - Torch launcher for `Multi-Node x Multi-GPU`:) - $ ssh python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr= -m tutel.examples.helloworld --batch_size=16 - $ ssh python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr= -m tutel.examples.helloworld --batch_size=16 + $ ssh python3 -m torch.distributed.run --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr= -m tutel.examples.helloworld --batch_size=16 + $ ssh python3 -m torch.distributed.run --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr= -m tutel.examples.helloworld --batch_size=16 (Method B - Tutel launcher for `Multi-Node x Multi-GPU`, requiring package `openmpi-bin`:) # << Single Node >> diff --git a/tutel/custom/custom_kernel.cpp b/tutel/custom/custom_kernel.cpp index 3b897d15..418dacb7 100644 --- a/tutel/custom/custom_kernel.cpp +++ b/tutel/custom/custom_kernel.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #else #undef USE_NCCL #endif @@ -777,7 +778,25 @@ extern "C" __global__ void cumsum_fn(int* input0 /* (num_samples, batch_num) */, return y; } +torch::Tensor warp_sparse_bmm_infer(const torch::Tensor &x, const torch::Tensor &w, const torch::Tensor &sparse_groups_device, bool w_transpose, int64_t sparse_size) { + auto sparse_groups = sparse_groups_device.cpu().to(torch::kInt32); + auto group_ptr = ((int*)sparse_groups.data_ptr()); + + auto y = torch::empty({x.size(0), x.size(1), w_transpose ? w.size(1) : w.size(2)}, torch::TensorOptions().dtype(x.dtype()).device(x.device())); + + // auto hCublas = at::cuda::getCurrentCUDABlasHandle(); -- Wait Pytorch to add builtin support for cublasSgemmBatched() + for (int i = 0; i < sparse_groups.size(0); ++i) { + int group_size = group_ptr[i]; + if (group_size > 0) { + auto y_sub = y.select(0, i).narrow(0, 0, int(group_size * sparse_size)); + torch::matmul_out(y_sub, x.select(0, i).narrow(0, 0, int(group_size * sparse_size)), w_transpose ? w.select(0, i).t() : w.select(0, i)); + } + } + return y; +} + TORCH_LIBRARY(tutel_ops, m) { m.def("cumsum", warp_cumsum); + m.def("sparse_bmm_infer", warp_sparse_bmm_infer); } #endif diff --git a/tutel/examples/helloworld.py b/tutel/examples/helloworld.py index 38effd67..63005842 100755 --- a/tutel/examples/helloworld.py +++ b/tutel/examples/helloworld.py @@ -34,6 +34,9 @@ parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') parser.add_argument('--use_2dh', default=False, action='store_true') parser.add_argument('--eval', default=False, action='store_true') +parser.add_argument('--capacity_factor', type=float, default=1.0) # 0.0 for dMoE (dropless-MoE), negative for no-padded capacity. +parser.add_argument('--megablocks_size', type=int, default=1) + args = parser.parse_args() parallel_env = system.init_data_model_parallel(backend='nccl' if args.device == 'cuda' else 'gloo') @@ -66,7 +69,7 @@ def __init__(self): super().__init__() self._moe_layer = tutel_moe.moe_layer( - gate_type = {'type': 'top', 'k': top_value, 'fp32_gate': args.fp32_gate}, + gate_type = {'type': 'top', 'k': top_value, 'fp32_gate': args.fp32_gate, 'capacity_factor': args.capacity_factor}, experts = {'type': 'ffn', 'count_per_node': num_local_experts, 'hidden_size_per_expert': hidden_size, 'activation_fn': lambda x: F.relu(x)}, model_dim = model_dim, scan_expert_func = lambda name, param: setattr(param, 'skip_allreduce', True), @@ -82,7 +85,10 @@ def __init__(self): dist_print('[Statistics] param count for MoE local_experts = %s, param count for MoE gate = %s.\n' % (local_count, shared_count)) def forward(self, input): - result = self._moe_layer(input) + if args.megablocks_size > 0: + result = self._moe_layer(input, megablocks_size=args.megablocks_size) + else: + result = self._moe_layer(input) result = F.log_softmax(torch.sum(result, dim=2), dim=1) return result diff --git a/tutel/experts/ffn.py b/tutel/experts/ffn.py index b58edbf1..8fae48d5 100644 --- a/tutel/experts/ffn.py +++ b/tutel/experts/ffn.py @@ -57,6 +57,19 @@ def forward(self, x, ctx): batched_fc1_bias = self.batched_fc1_bias.unsqueeze(1) batched_fc2_bias = self.batched_fc2_bias.unsqueeze(1) + # Implementation of https://arxiv.org/pdf/2211.15841.pdf in Tutel v0.3.x + # which benifits decoder inference on single-GPU if num_local_experts >= 2 + if ctx.megablocks_size > 0: + sparse_size = ctx.megablocks_size + sparse_groups = torch.div(ctx.dispatch_count + (sparse_size - 1), sparse_size, rounding_mode='floor') + sparse_groups = torch.minimum(sparse_groups, torch.tensor(x.size(1) // sparse_size, dtype=torch.int32, device=x.device)) + y = torch.ops.tutel_ops.sparse_bmm_infer(x, batched_fc1_w, sparse_groups, True, sparse_size) + y = torch.add(y, batched_fc1_bias) + y = self.activation_fn(y) + y = torch.ops.tutel_ops.sparse_bmm_infer(y, batched_fc2_w, sparse_groups, False, sparse_size) + y = torch.add(y, batched_fc2_bias) + return y + if ctx.adaptive_degree == 0: batched_fc1_w = net.zero_gather(batched_fc1_w, group=ctx.group).view(ctx.num_global_experts, -1, batched_fc1_w.size(2)) batched_fc2_w = net.zero_gather(batched_fc2_w, group=ctx.group).view(ctx.num_global_experts, -1, batched_fc2_w.size(2)) diff --git a/tutel/impls/fast_dispatch.py b/tutel/impls/fast_dispatch.py index c637a820..3fca138d 100644 --- a/tutel/impls/fast_dispatch.py +++ b/tutel/impls/fast_dispatch.py @@ -170,6 +170,9 @@ def extract_critical(scores, top_k, loss_fn=losses.gshard_loss, capacity_factor= if normalize_gate: denom_s = torch.clamp(sum(gates_s), min=torch.finfo(gates_s[0].dtype).eps) gates_s = [x / denom_s for x in gates_s] + else: + locations2 = locations1 + locations2 = locations2[-1] + 1 indices_s = [x.to(torch.int32) for x in indices_s] @@ -183,8 +186,8 @@ def extract_critical(scores, top_k, loss_fn=losses.gshard_loss, capacity_factor= if capacity_factor > 0: capacity = top_k * int(capacity_factor * samples_per_expert) else: - capacity = torch.max(torch.cat(locations_s, dim=0)) - capacity = int(simple_all_reduce(capacity, group=group, op=torch.distributed.ReduceOp.MAX)) + 1 + capacity = locations2.max() + capacity = int(simple_all_reduce(capacity, group=group, op=torch.distributed.ReduceOp.MAX)) if capacity_factor < 0: capacity = min(capacity, top_k * int(-capacity_factor * samples_per_expert)) @@ -195,16 +198,19 @@ def extract_critical(scores, top_k, loss_fn=losses.gshard_loss, capacity_factor= if get_world_rank(group) == 0: logging.info(f"Capacity = {capacity}, real-time capacity-factor for top-{top_k_original} = {capacity / (top_k * samples_per_expert)}") - return (num_global_experts, indices_s, locations_s, gates_s, capacity), l_loss + return (num_global_experts, indices_s, locations_s, gates_s, capacity, locations2), l_loss + +def get_dispatch_count(critial_data): + return critial_data[-1] def fast_encode(data, critial_data, is_postscore=True): num_global_experts = critial_data[0] dispatcher = TutelMoeFastDispatcher(num_global_experts, 0, data.size(-1), data.dtype) - dispatcher.update(*critial_data[1:], is_postscore=is_postscore) + dispatcher.update(*critial_data[1:-1], is_postscore=is_postscore) return dispatcher.encode(data).view(num_global_experts, -1, data.size(-1)) def fast_decode(data, critial_data, is_postscore=True): num_global_experts = critial_data[0] dispatcher = TutelMoeFastDispatcher(num_global_experts, 0, data.size(-1), data.dtype) - dispatcher.update(*critial_data[1:], is_postscore=is_postscore) + dispatcher.update(*critial_data[1:-1], is_postscore=is_postscore) return dispatcher.decode(data).view(-1, data.size(-1)) diff --git a/tutel/impls/moe_layer.py b/tutel/impls/moe_layer.py index c5270f9c..e0815c40 100644 --- a/tutel/impls/moe_layer.py +++ b/tutel/impls/moe_layer.py @@ -18,7 +18,7 @@ import torch.nn.functional as F from ..impls import communicate as C -from ..impls.fast_dispatch import fast_encode, fast_decode, extract_critical +from ..impls.fast_dispatch import fast_encode, fast_decode, extract_critical, get_dispatch_count from ..impls.overlap import a2a_ffn_overlap_forward from . import losses @@ -216,7 +216,7 @@ def expert_local(self, x, reserve_shape): self.protected_shape = y.shape return y.reshape(y.size(0), y.size(1), -1) - def forward(self, input: Tensor, gate_index=0, capacity_factor=None, top_k=None, a2a_ffn_overlap_degree=None, reserve_dims=1, inequivalent_tokens=False, adaptive_r=None): + def forward(self, input: Tensor, gate_index=0, capacity_factor=None, top_k=None, a2a_ffn_overlap_degree=None, reserve_dims=1, inequivalent_tokens=False, adaptive_r=None, megablocks_size=0): if self.skip_moe: result_output = input result_output.l_aux = None @@ -234,6 +234,12 @@ def forward(self, input: Tensor, gate_index=0, capacity_factor=None, top_k=None, self.a2a_ffn_overlap_degree = a2a_ffn_overlap_degree a2a_ffn_overlap_degree = self.a2a_ffn_overlap_degree + top_k = top_k or gctx.top_k + + if megablocks_size > 0: + if self.num_local_experts <= 1 or torch.is_grad_enabled() or self.world_size > 1: + megablocks_size = 0 + def routing(): logits = gctx(x) @@ -249,14 +255,17 @@ def routing(): _loss_fn = lambda gates, topk_ids: losses.load_importance_loss( F.softmax(logits, dim=1), logits_w_noise.gather(index=topk_ids, dim=1), self.num_global_experts, gctx.gate_noise) + + mega_up = max(megablocks_size, 1) + return logits.dtype, extract_critical(scores, - top_k = gctx.top_k if top_k is None else top_k, + top_k = top_k, loss_fn = _loss_fn, - capacity_factor = gctx.capacity_factor if capacity_factor is None else capacity_factor, + capacity_factor = capacity_factor or gctx.capacity_factor, batch_prioritized_routing = self.batch_prioritized_routing, normalize_gate = self.normalize_gate, group = self.group, - alignment = self.sharded_count * a2a_ffn_overlap_degree, + alignment = (self.sharded_count * a2a_ffn_overlap_degree + mega_up - 1) // mega_up * mega_up, inequivalent_tokens = inequivalent_tokens, ) @@ -267,6 +276,8 @@ def routing(): else: logits_dtype, (crit, l_aux) = routing() + self.megablocks_size = megablocks_size + self.dispatch_count = get_dispatch_count(crit) y = fast_encode(x.to(logits_dtype), crit, self.is_postscore).to(x.dtype) if adaptive_r is not None: