Skip to content

Commit

Permalink
update README.md for v0.3.2 (#234)
Browse files Browse the repository at this point in the history
  • Loading branch information
ghostplant authored May 8, 2024
1 parent 13c7a72 commit d4c20c3
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 2 deletions.
22 changes: 20 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Tutel

Tutel MoE: An Optimized Mixture-of-Experts Implementation.
Tutel MoE: An Optimized Mixture-of-Experts Implementation, also the first parallel solution proposing ["No-penalty Parallism/Sparsity/Capacity/.. Switching"](https://mlsys.org/media/mlsys-2023/Slides/2477.pdf) for modern training and inference that have dynamic behaviors.

- Supported Framework: Pytorch (recommend: >= 1.10)
- Supported GPUs: CUDA(fp64/fp32/fp16/bfp16), ROCm(fp64/fp32/fp16)
Expand All @@ -9,6 +9,24 @@ Tutel MoE: An Optimized Mixture-of-Experts Implementation.

### What's New:

- Tutel v0.3.2: Add tensorcore option for extra benchmarks / Extend the example for custom experts / Allow NCCL timeout settings:
```py
>> Example for using tensorcore:

python3 -m tutel.examples.helloworld --dtype=float32
python3 -m tutel.examples.helloworld --dtype=float32 --use_tensorcore

python3 -m tutel.examples.helloworld --dtype=float16
python3 -m tutel.examples.helloworld --dtype=float16 --use_tensorcore

>> Example for custom experts:
python3 -m tutel.examples.helloworld_custom_expert --batch_size=16

>> Example for NCCL timeout settings:
TUTEL_GLOBAL_TIMEOUT_SEC=60 python3 -m torch.distributed.run --nproc_per_node=8 -m tutel.examples.helloworld --use_tensorcore

```

- Tutel v0.3.1: Add NCCL all_to_all_v and all_gather_v for arbitrary-length message transfers:
```py
>> Example:
Expand Down Expand Up @@ -84,7 +102,7 @@ Tutel MoE: An Optimized Mixture-of-Experts Implementation.
$ python3 -m tutel.examples.helloworld_ddp --batch_size=16 # Test Tutel-optimized MoE + Pytorch DDP distribution (requires: Pytorch >= 1.8.0)
$ python3 -m tutel.examples.helloworld_ddp_tutel --batch_size=16 # Test Tutel-optimized MoE + Tutel DDP distribution (ZeRO on optimizors)
$ python3 -m tutel.examples.helloworld_amp --batch_size=16 # Test Tutel-optimized MoE with AMP data type + manual distribution
$ python3 -m tutel.examples.helloworld_demo --batch_size=16 # Test Tutel-optimized MoE + custom defined expert layer
$ python3 -m tutel.examples.helloworld_custom_expert --batch_size=16 # Test Tutel-optimized MoE + custom defined expert layer
$ python3 -m tutel.examples.helloworld_from_scratch # Test Custom MoE implementation from scratch
$ python3 -m tutel.examples.moe_mnist # Test MoE layer in end-to-end MNIST dataset
$ python3 -m tutel.examples.moe_cifar10 # Test MoE layer in end-to-end CIFAR10 dataset
Expand Down
4 changes: 4 additions & 0 deletions tutel/examples/helloworld.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,13 @@
parser.add_argument('--eval', default=False, action='store_true')
parser.add_argument('--capacity_factor', type=float, default=1.0) # 0.0 for dMoE (dropless-MoE), negative for no-padded capacity.
parser.add_argument('--megablocks_size', type=int, default=0)
parser.add_argument('--use_tensorcore', default=False, action='store_true')

args = parser.parse_args()

if args.use_tensorcore:
torch.backends.cuda.matmul.allow_tf32 = True

parallel_env = system.init_data_model_parallel(backend='nccl' if args.device == 'cuda' else 'gloo')
dist_rank, dist_world_size, dist_print = parallel_env.global_rank, parallel_env.global_size, parallel_env.dist_print
args.local_rank = parallel_env.local_device.index
Expand Down
File renamed without changes.

0 comments on commit d4c20c3

Please sign in to comment.