From ef8bcc8ea6f35b059d5e1ae2813c9bc115a27de9 Mon Sep 17 00:00:00 2001 From: ghostplant Date: Sun, 14 Nov 2021 10:01:10 +0000 Subject: [PATCH] add deepspeed examples (#39) --- README.md | 24 ++-- tutel/examples/helloworld.py | 2 +- tutel/examples/helloworld_ddp.py | 2 +- tutel/examples/helloworld_deepspeed.py | 149 +++++++++++++++++++++++++ tutel/examples/helloworld_megatron.py | 8 +- tutel/impls/moe_layer.py | 2 +- 6 files changed, 169 insertions(+), 18 deletions(-) create mode 100755 tutel/examples/helloworld_deepspeed.py diff --git a/README.md b/README.md index c4562bbc..f922981a 100644 --- a/README.md +++ b/README.md @@ -48,28 +48,28 @@ Full Examples & Usage: ``` * Single-GPU Test: - $ python3 -m tutel.examples.helloworld + $ python3 -m tutel.examples.helloworld --batch_size=32 # To Test Tutel-optimized MoE + manual distribution + $ python3 -m tutel.examples.helloworld_ddp --batch_size=32 # To Test Tutel-optimized MoE + Pytorch DDP distribution (requires: Pytorch >= 1.8.0) + $ python3 -m tutel.examples.helloworld_megatron --batch_size=32 # To Test Tutel using Megatron Gating (Tensor Parallel on Experts) + manual distribution + $ python3 -m tutel.examples.helloworld_deepspeed --batch_size=32 # To Test Deepspeed MoE + manual distribution - (If full source code exists:) - $ python3 ./tutel/examples/helloworld.py + (If full source code exists, the following also works:) + $ python3 ./tutel/examples/helloworld.py --batch_size=32 + .. * Running MoE Hello World Model by torch.distributed.all_reduce: - $ python3 -m torch.distributed.launch --nproc_per_node=2 -m tutel.examples.helloworld + $ python3 -m torch.distributed.launch --nproc_per_node=2 -m tutel.examples.helloworld --batch_size=32 + $ python3 -m torch.distributed.launch --nproc_per_node=2 -m tutel.examples.helloworld_ddp --batch_size=32 + .. (For New Pytorch:) $ python3 -m torch.distributed.run --nproc_per_node=2 -m tutel.examples.helloworld - -* Running MoE Hello World Model by torch.nn.parallel.DistributedDataParallel (requires torch >= 1.8.0): - - $ python3 -m torch.distributed.launch --nproc_per_node=2 -m tutel.examples.helloworld_ddp - - (For New Pytorch:) - $ python3 -m torch.distributed.run --nproc_per_node=2 -m tutel.examples.helloworld_ddp + .. * Usage of MOELayer Args: - gate_type : dict-type gate description, e.g. {'type': 'top', 'k': 2, ..} + gate_type : dict-type gate description, e.g. {'type': 'top', 'k': 2, ..}, or {'type': 'megatron'} model_dim : the number of channels for MOE's input tensor experts : a dict-type config for builtin expert network, or a torch.nn.Module-type custom expert network scan_expert_func : allow users to specify a lambda function to iterate each experts param, e.g. `scan_expert_func = lambda name, param: setattr(param, 'expert', True)` diff --git a/tutel/examples/helloworld.py b/tutel/examples/helloworld.py index 90b540e9..947f2981 100755 --- a/tutel/examples/helloworld.py +++ b/tutel/examples/helloworld.py @@ -25,7 +25,7 @@ parser = argparse.ArgumentParser() parser.add_argument('--local_rank', type=int, default=-1) -parser.add_argument('--batch_size', type=int, default=8) +parser.add_argument('--batch_size', type=int, default=16) parser.add_argument('--num_tokens', type=int, default=1024) parser.add_argument('--model_dim', type=int, default=2048) parser.add_argument('--hidden_size', type=int, default=2048) diff --git a/tutel/examples/helloworld_ddp.py b/tutel/examples/helloworld_ddp.py index 826fcba0..46333260 100755 --- a/tutel/examples/helloworld_ddp.py +++ b/tutel/examples/helloworld_ddp.py @@ -27,7 +27,7 @@ parser = argparse.ArgumentParser() parser.add_argument('--local_rank', type=int, default=-1) -parser.add_argument('--batch_size', type=int, default=8) +parser.add_argument('--batch_size', type=int, default=16) parser.add_argument('--num_tokens', type=int, default=1024) parser.add_argument('--model_dim', type=int, default=2048) parser.add_argument('--hidden_size', type=int, default=2048) diff --git a/tutel/examples/helloworld_deepspeed.py b/tutel/examples/helloworld_deepspeed.py new file mode 100755 index 00000000..847bc6f5 --- /dev/null +++ b/tutel/examples/helloworld_deepspeed.py @@ -0,0 +1,149 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os +import time +import torch +import torch.optim as optim +import torch.nn.functional as F +import torch.distributed as dist +from torch import nn +import argparse +import deepspeed + +import logging +logging.basicConfig(level=logging.INFO) + +parser = argparse.ArgumentParser() + +parser.add_argument('--local_rank', type=int, default=-1) +parser.add_argument('--batch_size', type=int, default=16) +parser.add_argument('--num_tokens', type=int, default=1024) +parser.add_argument('--model_dim', type=int, default=2048) +parser.add_argument('--hidden_size', type=int, default=2048) +parser.add_argument('--num_local_experts', type=int, default=2) +parser.add_argument('--dtype', type=str, default='float32') +parser.add_argument('--fp32_gate', default=False, action='store_true') +parser.add_argument('--top', type=int, default=2) +args = parser.parse_args() + +if args.local_rank < 0: + args.local_rank = int(os.environ.get('LOCAL_RANK', 0)) + +torch.cuda.set_device(args.local_rank) + +try: + if dist.is_available(): + dist.init_process_group('nccl') + dist_rank = dist.get_rank() + dist_world_size = dist.get_world_size() + + def dist_print(*args): + if dist_rank == 0: + print(*args) +except: + dist_rank = 0 + dist_world_size = 1 + dist_print = print + +batch_size = args.batch_size +num_tokens = args.num_tokens +model_dim = args.model_dim +hidden_size = args.hidden_size +num_local_experts = args.num_local_experts +top_value = args.top +local_rank = args.local_rank + + +device = torch.device('cuda', args.local_rank) + +if args.dtype == 'float32': + torch.set_default_dtype(torch.float32) +elif args.dtype == 'float16': + torch.set_default_dtype(torch.float16) +elif args.dtype == 'bfloat16': + torch.set_default_dtype(torch.bfloat16) +else: + raise Exception('Unrecognized data type specified: %s' % args.dtype) + +deepspeed.init_distributed() +deepspeed.utils.groups.initialize(ep_size=dist_world_size) + +class ExpertModel(torch.nn.Module): + def __init__(self, model_dim, hidden_size, activation_fn): + super().__init__() + self.fc1 = torch.nn.Linear(model_dim, hidden_size, bias=True) + self.fc2 = torch.nn.Linear(hidden_size, model_dim, bias=True) + self.activation_fn = activation_fn + def forward(self, x): + x = self.fc1(x) + x = self.activation_fn(x) + x = self.fc2(x) + return x + +class ExampleModel(torch.nn.Module): + def __init__(self): + super().__init__() + + self._moe_layer = deepspeed.moe.layer.MoE( + hidden_size = hidden_size, + expert = ExpertModel(model_dim, hidden_size, lambda x: F.relu(x)), + num_experts = num_local_experts * dist_world_size, + k = top_value + ).to(device) + + for name, param in self._moe_layer.named_parameters(): + if '.experts.' in name: + setattr(param, 'skip_allreduce', True) + + # Distinguish different parameter types: gate, local_experts + local_count = sum([torch.numel(param) for name, param in self._moe_layer.named_parameters() if '.experts.' in name]) + shared_count = sum([torch.numel(param) for name, param in self._moe_layer.named_parameters() if '.gate.' in name]) + dist_print('[Statistics] param count for MoE local_experts = %s, param count for MoE gate = %s.\n' % (local_count, shared_count)) + + def forward(self, input): + result, _, _ = self._moe_layer(input) + result = F.log_softmax(torch.sum(result, dim=2), dim=1) + return result + +model = ExampleModel() +dist_print(model) + +optimizer = torch.optim.SGD(model.parameters(), lr=1e-5) + +x = torch.randn([batch_size, num_tokens, model_dim], device=device, requires_grad=True) +y = torch.LongTensor(batch_size).random_(1).to(device) + +tuples = (dist_world_size, args.dtype, model_dim, hidden_size, batch_size * num_tokens, num_local_experts, top_value, device) +dist_print('[Benchmark] world_size = %s, dtype = %s, model_dim = %s, hidden_size = %s, samples = %s, num_local_experts = %s, topK = %s, device = `%s`' % tuples) + +average_time, num_steps = 0, 100 + +params_for_all_reduce = [p for p in model.parameters() if not hasattr(p, 'skip_allreduce') and getattr(p, 'requires_grad', False)] + +for i in range(num_steps): + + torch.cuda.synchronize() + t_start = time.time() + optimizer.zero_grad() + + output = model(x) + loss = F.nll_loss(output, y) + loss.backward() + if dist_world_size > 1: + for p in params_for_all_reduce: + p.grad /= dist_world_size + dist.all_reduce(p.grad) + optimizer.step() + + torch.cuda.synchronize() + t_stop = time.time() + dist_print('STEP-%s: DONE, loss = %s, step_time = %s sec.' % (i, float(loss.data), t_stop - t_start)) + + if i + 10 >= num_steps: + average_time += t_stop - t_start + +average_time /= 10 +dist_print('\n[Summary] Average synchronized step_time = %s sec.' % average_time) diff --git a/tutel/examples/helloworld_megatron.py b/tutel/examples/helloworld_megatron.py index 86f1d5d4..2b471678 100755 --- a/tutel/examples/helloworld_megatron.py +++ b/tutel/examples/helloworld_megatron.py @@ -25,10 +25,11 @@ parser = argparse.ArgumentParser() parser.add_argument('--local_rank', type=int, default=-1) -parser.add_argument('--batch_size', type=int, default=8) +parser.add_argument('--batch_size', type=int, default=16) parser.add_argument('--num_tokens', type=int, default=1024) parser.add_argument('--model_dim', type=int, default=2048) parser.add_argument('--hidden_size', type=int, default=2048) +parser.add_argument('--num_local_experts', type=int, default=2) parser.add_argument('--dtype', type=str, default='float32') parser.add_argument('--l_aux_wt', type=float, default=0.0) args = parser.parse_args() @@ -56,6 +57,7 @@ def dist_print(*args): num_tokens = args.num_tokens model_dim = args.model_dim hidden_size = args.hidden_size +num_local_experts = args.num_local_experts local_rank = args.local_rank @@ -77,7 +79,7 @@ def __init__(self): self._moe_layer = tutel_moe.moe_layer( gate_type = {'type': 'megatron'}, - experts = {'type': 'ffn', 'hidden_size_per_expert': hidden_size, 'activation_fn': lambda x: F.relu(x)}, + experts = {'type': 'ffn', 'hidden_size_per_expert': hidden_size * num_local_experts, 'activation_fn': lambda x: F.relu(x)}, model_dim = model_dim, scan_expert_func = lambda name, param: setattr(param, 'skip_allreduce', True), seeds = (1, dist_rank + 1, 1), @@ -101,7 +103,7 @@ def forward(self, input): x = torch.randn([batch_size, num_tokens, model_dim], device=device, requires_grad=True) y = torch.LongTensor(batch_size).random_(1).to(device) -tuples = (dist_world_size, args.dtype, model_dim, hidden_size, batch_size * num_tokens, 1, device) +tuples = (dist_world_size, args.dtype, model_dim, hidden_size, batch_size * num_tokens, num_local_experts, device) dist_print('[Benchmark] world_size = %s, dtype = %s, model_dim = %s, hidden_size = %s, samples = %s, num_local_experts = %s, gate = megatron, device = `%s`' % tuples) average_time, num_steps = 0, 100 diff --git a/tutel/impls/moe_layer.py b/tutel/impls/moe_layer.py index 8fa63b0d..d6ac8f99 100644 --- a/tutel/impls/moe_layer.py +++ b/tutel/impls/moe_layer.py @@ -166,7 +166,7 @@ class MOELayer(torch.nn.Module): """Tutel optimized MOELayer Args: - gate_type : dict-type gate description, e.g. {'type': 'top', 'k': 2, ..} + gate_type : dict-type gate description, e.g. {'type': 'top', 'k': 2, ..}, or {'type': 'megatron'} model_dim : the number of channels for MOE's input tensor experts : a dict-type config for builtin expert network, or a torch.nn.Module-type custom expert network scan_expert_func : allow users to specify a lambda function to iterate each experts param, e.g. `scan_expert_func = lambda name, param: setattr(param, 'expert', True)`