Skip to content

Commit

Permalink
Sync changes in internal moe to public repo (#186)
Browse files Browse the repository at this point in the history
* update moe code with internal Microsoft code. Main changes include Extend Distributed Grid to enable expert relocation Extend Distributed Grid to enable expert relocation, Simplify interface. Use options dict for MixtureOfExperts and Layers, expert weight calculation optimization Expert weight calculation optimization. Avoid calculation on for padding used for unused expert capacity

* sync tests from internal Microsoft repo. Added new tests for nccl_output_processor,nccl and sccl tests, compression and coverage. Added readme on running tests

* restructure directory to install with setup.py and maintain same as internal

---------

Co-authored-by: Sheetal Kadam <shekadam@microsoft.com>
  • Loading branch information
sheetalarkadam and sheetalarkadam authored Nov 22, 2023
1 parent 053e971 commit 4c05265
Show file tree
Hide file tree
Showing 29 changed files with 1,326 additions and 538 deletions.
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,7 @@ Build MoE

```bash
cd ort_moe
pip install build # Install PyPA build
python -m build
python setup.py install
```

## Install for Inference
Expand Down Expand Up @@ -290,4 +289,4 @@ Please refer to our [contributing guide](CONTRIBUTING.md) for more information o

## License

This project has an MIT license, as found in the [LICENSE](LICENSE) file.
This project has an MIT license, as found in the [LICENSE](LICENSE) file.
15 changes: 15 additions & 0 deletions ort_moe/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Introduction
Mixture Of Experts (MoE) implementation in PyTorch
This repo contains following components
#moe_module
moe_module contains PyTorch implementation of MoE, Experts and Gates and Transformer layers. This module is used by 1P workloads to inject MoE layers in their model. We aim to implement moe_module such that it is amenable to variety of model distribution techniques for large scale (100B+ param) training.
#Proxy Models
We have implemented proxy of Gshard, Switch Transformer etc.. models using moe_module in moe_models.py. These models serves two purposes. First to simple standalone proxy model for approximate performance analysis and characterization of scaling efforts. Second is to evaluate flexibility of moe_module interface in variety of situations. We encourage contribution of new proxy models.
#Trainer
We have extended NLP trainer from Pytorch tutorial in baseline_nlp.py to run proxy models and collect performance data. We use WIkiText as a dataset, which is obviously not representative of real world 1P workload scenarios. The trainer allows us to experiment with variety of PyTorch packages/techniques such as DeepSpeed, Apex, torch DistributedDataParallel etc.. We welcome contributions to incorporate Pipeline Parallelism and Megatron-style training.
# ITP Scripts
We have scripts available in experiments/itp folder to easily launch jobs on ITP cluster. We have two ready made experiments available, Switch-CA and Switch-CB. They are two variants of Switch Transformer model scaled to 100B parameter size.

# Updates
0.1.7 : Adapt new name - ort_moe

1 change: 1 addition & 0 deletions ort_moe/VERSION_NUMBER
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
0.2.0
File renamed without changes.
64 changes: 63 additions & 1 deletion ort_moe/collectives.py → ort_moe/ort_moe/collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------

import os
import torch
import torch.distributed as dist

Expand Down Expand Up @@ -44,6 +45,21 @@ def backward(ctx, *grad_output):
output = (grad_output_chunks[ctx.rank]).contiguous()
return None, output, None

def compressed_all_to_all(output, input, group=None):

world_size = dist.get_world_size(group)
rank = torch.distributed.get_rank(group)

ts_in = torch.tensor_split(input, world_size)
compressed_a2a_input, _ = dg_compress(ts_in)
ts_out = torch.tensor_split(output, world_size)

for i in range(world_size):
if i != rank:
torch.distributed.send(compressed_a2a_input[i], i)
torch.distributed.recv(compressed_a2a_input[i], i)
dg_decompress(compressed_a2a_input, ts_out)

# Based on https://github.com/pytorch/pytorch/pull/40762
class AllToAll(torch.autograd.Function):
@staticmethod
Expand All @@ -65,7 +81,10 @@ def forward(ctx, group, input, max_len_cpu): # type: ignore
input = torch.nn.functional.pad(input, pad=(0,0,0,max_len_cpu-input.shape[-2]), mode = 'constant', value=0.0)
input = input.contiguous()
output = torch.empty_like(input)
dist.all_to_all_single(output, input, group=group)
if os.environ.get('ORT_MOE_COMPRESS_ALLToALL') is not None:
compressed_all_to_all(output, input, group=group)
else:
dist.all_to_all_single(output, input, group=group)
return output

@staticmethod
Expand Down Expand Up @@ -102,3 +121,46 @@ def backward(ctx, *grad_output):
dist.all_reduce(grad_output[0], group = ctx.group)
ret_val = grad_output[0].contiguous()
return ret_val, None

_DIETGPU_SCRATCH_PAD = None
_DIETGPU_LIBRARY_LOADED = False
def dg_load_library():
global _DIETGPU_LIBRARY_LOADED
if _DIETGPU_LIBRARY_LOADED is True:
return

dg_lib = os.environ.get('DIETGPU_LIB_PATH')
torch.ops.load_library(dg_lib)
_DIETGPU_LIBRARY_LOADED = True
return

def get_tensor_list_device(the_list):
if isinstance(the_list, torch.Tensor):
return the_list.device

for l in the_list:
if isinstance(l, torch.Tensor):
return l.device

return None

def dg_get_scratch_pad(device):
global _DIETGPU_SCRATCH_PAD
if _DIETGPU_SCRATCH_PAD is None:
_DIETGPU_SCRATCH_PAD = torch.empty([64*1024*1024], dtype=torch.uint8, device=device)
return _DIETGPU_SCRATCH_PAD

def dg_compress(input_list):
dg_load_library()

output, output_size, _ = torch.ops.dietgpu.compress_data(True, input_list, dg_get_scratch_pad(get_tensor_list_device(input_list)))
compressed_output_list = []
for size, t in zip(output_size, [*output]):
truncated_t = t.narrow(0, 0, size.item()).clone()
compressed_output_list.append(truncated_t)

return compressed_output_list, output_size

def dg_decompress(input_list, output_list):
dg_load_library()
torch.ops.dietgpu.decompress_data(True, input_list, output_list, dg_get_scratch_pad(get_tensor_list_device(input_list)))
24 changes: 23 additions & 1 deletion ort_moe/custom_ops.py → ort_moe/ort_moe/custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
# --------------------------------------------------------------------------

import torch
from .moe_config import moe_config

# The switch to decided whether to use torch.einsum (when this flag is true) or use rewrited-version.
# switch can be bubbled up in future
USE_EINSUM = True

def einsum(rule, a, b):
def om_einsum(rule, a, b):
"""
The rewrite of torch.einsum for some specific cases.
The rewrites are on par or more performant upon the benchmark we tested
Expand Down Expand Up @@ -44,3 +45,24 @@ def einsum(rule, a, b):
return torch.bmm(a, b.transpose(1, 2)).reshape(s, m)
else:
return torch.einsum(rule, a, b)

def om_cumsum(mask, dim, options = None):
"""
The rewrite of torch.cumsum to use tutel cumsum if desired.
Args:
tensor (torch.Tensor): the input tensor of cumsum op
dim (int): the dimension of cumsum op
options (moe_config): the options to decide whether to use tutel cumsum
"""
if mask.device.type == 'cpu' or options is None:
return torch.cumsum(mask, dim) - 1

moe_options = None
if isinstance(options, moe_config): moe_options = options
else: moe_options = moe_config(options)

if moe_options.enable_tutel_cumsum():
from tutel.jit_kernels.gating import fast_cumsum_sub_one
return fast_cumsum_sub_one(mask, dim)

return torch.cumsum(mask, dim) - 1
12 changes: 7 additions & 5 deletions ort_moe/experts.py → ort_moe/ort_moe/experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,20 @@ class FFNExpert(nn.Module):
"""
def __init__(self, d_model, dim_feedforward, dgrid, activation_fn = nn.functional.relu, expert_dropout = 0.0):
super().__init__()
self.mp_size = dgrid.get_expert_slicing_world_size()
self.mp_size = 1
if dgrid is not None:
self.mp_size = dgrid.get_expert_slicing_world_size()
self.linear1 = nn.Linear(d_model, dim_feedforward//self.mp_size, bias=False)
self.linear2 = nn.Linear(dim_feedforward//self.mp_size, d_model, bias=False)
self.activation_fn = activation_fn
self.expert_dropout_rate = expert_dropout

def forward(self, x: torch.tensor):
x = self.linear1(x)
x = self.linear1(x.float())
x = self.activation_fn(x)
if self.expert_dropout_rate > 0:
x = F.dropout(x, p=self.expert_dropout_rate, training=self.training)
x = self.linear2(x)
x = self.linear2(x.float())
return x

class MergedFFNExpert(nn.Module):
Expand Down Expand Up @@ -73,11 +75,11 @@ def forward(self, x: torch.tensor):
x = x.transpose(0, 1) #gecm --> egcm
input_shape = x.shape
reshaped_x = x.reshape(input_shape[0], -1, input_shape[-1]) #egcm --> e,gxc,m
out1 = torch.bmm(reshaped_x, self.weight1) #e, gxc, f
out1 = torch.bmm(reshaped_x.float(), self.weight1) #e, gxc, f
out1 = self.activation_fn(out1)
if self.expert_dropout_rate > 0:
out1 = F.dropout(out1, p=self.expert_dropout_rate, training=self.training)
out2 = torch.bmm(out1, self.weight2) #e, gxc, m
out2 = torch.bmm(out1.float(), self.weight2) #e, gxc, m
out2 = out2.reshape(input_shape)
out2 = out2.transpose(0, 1) #egcm --> gecm
return out2
File renamed without changes.
Loading

0 comments on commit 4c05265

Please sign in to comment.