diff --git a/training/cifar/README.md b/training/cifar/README.md index 7c58f3b98..878b28157 100644 --- a/training/cifar/README.md +++ b/training/cifar/README.md @@ -1,21 +1,22 @@ Thanks Gopi Kumar for contributing this example, demonstrating how to apply DeepSpeed to CIFAR-10 model. -cifar10_tutorial.py +`cifar10_tutorial.py` Baseline CIFAR-10 model. -cifar10_deepspeed.py +`cifar10_deepspeed.py` DeepSpeed applied CIFAR-10 model. -ds_config.json - DeepSpeed configuration file. - -run_ds.sh +`run_ds.sh` Script for running DeepSpeed applied model. -run_ds_moe.sh +`run_ds_moe.sh` Script for running DeepSpeed model with Mixture of Experts (MoE) integration. -* To run baseline CIFAR-10 model - "python cifar10_tutorial.py" -* To run DeepSpeed CIFAR-10 model - "bash run_ds.sh" -* To run DeepSpeed CIFAR-10 model with Mixture of Experts (MoE) - "bash run_ds_moe.sh" -* To run with different data type (default='fp16') and zero stages (default=0) - "bash run_ds.sh --dtype={fp16|bf16} --stage={0|1|2|3}" +`run_ds_prmoe.sh` + Script for running DeepSpeed model with Pyramid Residual MoE (PR-MoE) integration. + +* To run baseline CIFAR-10 model - `python cifar10_tutorial.py` +* To run DeepSpeed CIFAR-10 model - `bash run_ds.sh` +* To run DeepSpeed CIFAR-10 model with Mixture of Experts (MoE) - `bash run_ds_moe.sh` +* To run DeepSpeed CIFAR-10 model with Pyramid Residual MoE (PR-MoE) - `bash run_ds_prmoe.sh` +* To run with different data type (default=`fp16`) and zero stages (default=`0`) - `bash run_ds.sh --dtype={fp16|bf16} --stage={0|1|2|3}` diff --git a/training/cifar/cifar10_deepspeed.py b/training/cifar/cifar10_deepspeed.py index da82e60db..521a75cdf 100755 --- a/training/cifar/cifar10_deepspeed.py +++ b/training/cifar/cifar10_deepspeed.py @@ -1,112 +1,105 @@ +import argparse + +import deepspeed import torch +import torch.nn as nn +import torch.nn.functional as F import torchvision import torchvision.transforms as transforms -import argparse -import deepspeed from deepspeed.accelerator import get_accelerator +from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer def add_argument(): + parser = argparse.ArgumentParser(description="CIFAR") - parser = argparse.ArgumentParser(description='CIFAR') - - #data - # cuda - parser.add_argument('--with_cuda', - default=False, - action='store_true', - help='use CPU in case there\'s no GPU support') - parser.add_argument('--use_ema', - default=False, - action='store_true', - help='whether use exponential moving average') - - # train - parser.add_argument('-b', - '--batch_size', - default=32, - type=int, - help='mini-batch size (default: 32)') - parser.add_argument('-e', - '--epochs', - default=30, - type=int, - help='number of total epochs (default: 30)') - parser.add_argument('--local_rank', - type=int, - default=-1, - help='local rank passed from distributed launcher') - - parser.add_argument('--log-interval', - type=int, - default=2000, - help="output logging information at a given interval") - - parser.add_argument('--moe', - default=False, - action='store_true', - help='use deepspeed mixture of experts (moe)') - - parser.add_argument('--ep-world-size', - default=1, - type=int, - help='(moe) expert parallel world size') - parser.add_argument('--num-experts', - type=int, - nargs='+', - default=[ - 1, - ], - help='number of experts list, MoE related.') + # For train. parser.add_argument( - '--mlp-type', - type=str, - default='standard', - help= - 'Only applicable when num-experts > 1, accepts [standard, residual]') - parser.add_argument('--top-k', - default=1, - type=int, - help='(moe) gating top 1 and 2 supported') + "-e", + "--epochs", + default=30, + type=int, + help="number of total epochs (default: 30)", + ) parser.add_argument( - '--min-capacity', - default=0, + "--local_rank", type=int, - help= - '(moe) minimum capacity of an expert regardless of the capacity_factor' + default=-1, + help="local rank passed from distributed launcher", ) parser.add_argument( - '--noisy-gate-policy', - default=None, + "--log-interval", + type=int, + default=2000, + help="output logging information at a given interval", + ) + + # For mixed precision training. + parser.add_argument( + "--dtype", + default="fp16", type=str, - help= - '(moe) noisy gating (only supported with top-1). Valid values are None, RSample, and Jitter' + choices=["bf16", "fp16", "fp32"], + help="Datatype used for training", + ) + + # For ZeRO Optimization. + parser.add_argument( + "--stage", + default=0, + type=int, + choices=[0, 1, 2, 3], + help="Datatype used for training", ) + + # For MoE (Mixture of Experts). parser.add_argument( - '--moe-param-group', + "--moe", default=False, - action='store_true', - help= - '(moe) create separate moe param groups, required when using ZeRO w. MoE' + action="store_true", + help="use deepspeed mixture of experts (moe)", + ) + parser.add_argument( + "--ep-world-size", default=1, type=int, help="(moe) expert parallel world size" + ) + parser.add_argument( + "--num-experts", + type=int, + nargs="+", + default=[ + 1, + ], + help="number of experts list, MoE related.", ) parser.add_argument( - '--dtype', - default='fp16', + "--mlp-type", type=str, - choices=['bf16', 'fp16', 'fp32'], - help= - 'Datatype used for training' + default="standard", + help="Only applicable when num-experts > 1, accepts [standard, residual]", + ) + parser.add_argument( + "--top-k", default=1, type=int, help="(moe) gating top 1 and 2 supported" ) parser.add_argument( - '--stage', + "--min-capacity", default=0, type=int, - choices=[0, 1, 2, 3], - help= - 'Datatype used for training' + help="(moe) minimum capacity of an expert regardless of the capacity_factor", + ) + parser.add_argument( + "--noisy-gate-policy", + default=None, + type=str, + help="(moe) noisy gating (only supported with top-1). Valid values are None, RSample, and Jitter", + ) + parser.add_argument( + "--moe-param-group", + default=False, + action="store_true", + help="(moe) create separate moe param groups, required when using ZeRO w. MoE", ) - # Include DeepSpeed configuration arguments + # Include DeepSpeed configuration arguments. parser = deepspeed.add_config_arguments(parser) args = parser.parse_args() @@ -114,110 +107,87 @@ def add_argument(): return args -deepspeed.init_distributed() - -######################################################################## -# The output of torchvision datasets are PILImage images of range [0, 1]. -# We transform them to Tensors of normalized range [-1, 1]. -# .. note:: -# If running on Windows and you get a BrokenPipeError, try setting -# the num_worker of torch.utils.data.DataLoader() to 0. - -transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) -]) - -if torch.distributed.get_rank() != 0: - # might be downloading cifar data, let rank 0 download first - torch.distributed.barrier() - -trainset = torchvision.datasets.CIFAR10(root='./data', - train=True, - download=True, - transform=transform) - -if torch.distributed.get_rank() == 0: - # cifar data is downloaded, indicate other ranks can proceed - torch.distributed.barrier() - -trainloader = torch.utils.data.DataLoader(trainset, - batch_size=16, - shuffle=True, - num_workers=2) - -testset = torchvision.datasets.CIFAR10(root='./data', - train=False, - download=True, - transform=transform) -testloader = torch.utils.data.DataLoader(testset, - batch_size=4, - shuffle=False, - num_workers=2) - -classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', - 'ship', 'truck') - -######################################################################## -# Let us show some of the training images, for fun. - -import matplotlib.pyplot as plt -import numpy as np - -# functions to show an image - - -def imshow(img): - img = img / 2 + 0.5 # unnormalize - npimg = img.numpy() - plt.imshow(np.transpose(npimg, (1, 2, 0))) - plt.show() - - -# get some random training images -dataiter = iter(trainloader) -images, labels = next(dataiter) - -# show images -imshow(torchvision.utils.make_grid(images)) -# print labels -print(' '.join('%5s' % classes[labels[j]] for j in range(4))) - -######################################################################## -# 2. Define a Convolutional Neural Network -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -# Copy the neural network from the Neural Networks section before and modify it to -# take 3-channel images (instead of 1-channel images as it was defined). +def create_moe_param_groups(model): + """Create separate parameter groups for each expert.""" + parameters = {"params": [p for p in model.parameters()], "name": "parameters"} + return split_params_into_different_moe_groups_for_optimizer(parameters) -import torch.nn as nn -import torch.nn.functional as F -args = add_argument() +def get_ds_config(args): + """Get the DeepSpeed configuration dictionary.""" + ds_config = { + "train_batch_size": 16, + "steps_per_print": 2000, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.001, + "betas": [0.8, 0.999], + "eps": 1e-8, + "weight_decay": 3e-7, + }, + }, + "scheduler": { + "type": "WarmupLR", + "params": { + "warmup_min_lr": 0, + "warmup_max_lr": 0.001, + "warmup_num_steps": 1000, + }, + }, + "gradient_clipping": 1.0, + "prescale_gradients": False, + "bf16": {"enabled": args.dtype == "bf16"}, + "fp16": { + "enabled": args.dtype == "fp16", + "fp16_master_weights_and_grads": False, + "loss_scale": 0, + "loss_scale_window": 500, + "hysteresis": 2, + "min_loss_scale": 1, + "initial_scale_power": 15, + }, + "wall_clock_breakdown": False, + "zero_optimization": { + "stage": args.stage, + "allgather_partitions": True, + "reduce_scatter": True, + "allgather_bucket_size": 50000000, + "reduce_bucket_size": 50000000, + "overlap_comm": True, + "contiguous_gradients": True, + "cpu_offload": False, + }, + } + return ds_config class Net(nn.Module): - def __init__(self): + def __init__(self, args): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) - if args.moe: + self.moe = args.moe + if self.moe: fc3 = nn.Linear(84, 84) self.moe_layer_list = [] for n_e in args.num_experts: - # create moe layers based on the number of experts + # Create moe layers based on the number of experts. self.moe_layer_list.append( deepspeed.moe.layer.MoE( hidden_size=84, expert=fc3, num_experts=n_e, ep_size=args.ep_world_size, - use_residual=args.mlp_type == 'residual', + use_residual=args.mlp_type == "residual", k=args.top_k, min_capacity=args.min_capacity, - noisy_gate_policy=args.noisy_gate_policy)) + noisy_gate_policy=args.noisy_gate_policy, + ) + ) self.moe_layer_list = nn.ModuleList(self.moe_layer_list) self.fc4 = nn.Linear(84, 10) else: @@ -229,7 +199,7 @@ def forward(self, x): x = x.view(-1, 16 * 5 * 5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) - if args.moe: + if self.moe: for layer in self.moe_layer_list: x, _, _ = layer(x) x = self.fc4(x) @@ -238,214 +208,192 @@ def forward(self, x): return x -net = Net() +def test(model_engine, testset, local_device, target_dtype, test_batch_size=4): + """Test the network on the test data. + + Args: + model_engine (deepspeed.runtime.engine.DeepSpeedEngine): the DeepSpeed engine. + testset (torch.utils.data.Dataset): the test dataset. + local_device (str): the local device name. + target_dtype (torch.dtype): the target datatype for the test data. + test_batch_size (int): the test batch size. + + """ + # The 10 classes for CIFAR10. + classes = ( + "plane", + "car", + "bird", + "cat", + "deer", + "dog", + "frog", + "horse", + "ship", + "truck", + ) + # Define the test dataloader. + testloader = torch.utils.data.DataLoader( + testset, batch_size=test_batch_size, shuffle=False, num_workers=0 + ) -def create_moe_param_groups(model): - from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer + # For total accuracy. + correct, total = 0, 0 + # For accuracy per class. + class_correct = list(0.0 for i in range(10)) + class_total = list(0.0 for i in range(10)) + + # Start testing. + model_engine.eval() + with torch.no_grad(): + for data in testloader: + images, labels = data + if target_dtype != None: + images = images.to(target_dtype) + outputs = model_engine(images.to(local_device)) + _, predicted = torch.max(outputs.data, 1) + # Count the total accuracy. + total += labels.size(0) + correct += (predicted == labels.to(local_device)).sum().item() + + # Count the accuracy per class. + batch_correct = (predicted == labels.to(local_device)).squeeze() + for i in range(test_batch_size): + label = labels[i] + class_correct[label] += batch_correct[i].item() + class_total[label] += 1 + + if model_engine.local_rank == 0: + print( + f"Accuracy of the network on the {total} test images: {100 * correct / total : .0f} %" + ) + + # For all classes, print the accuracy. + for i in range(10): + print( + f"Accuracy of {classes[i] : >5s} : {100 * class_correct[i] / class_total[i] : 2.0f} %" + ) + + +def main(args): + # Initialize DeepSpeed distributed backend. + deepspeed.init_distributed() + + ######################################################################## + # Step1. Data Preparation. + # + # The output of torchvision datasets are PILImage images of range [0, 1]. + # We transform them to Tensors of normalized range [-1, 1]. + # + # Note: + # If running on Windows and you get a BrokenPipeError, try setting + # the num_worker of torch.utils.data.DataLoader() to 0. + ######################################################################## + transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] + ) - parameters = { - 'params': [p for p in model.parameters()], - 'name': 'parameters' - } + if torch.distributed.get_rank() != 0: + # Might be downloading cifar data, let rank 0 download first. + torch.distributed.barrier() - return split_params_into_different_moe_groups_for_optimizer(parameters) + # Load or download cifar data. + trainset = torchvision.datasets.CIFAR10( + root="./data", train=True, download=True, transform=transform + ) + testset = torchvision.datasets.CIFAR10( + root="./data", train=False, download=True, transform=transform + ) + if torch.distributed.get_rank() == 0: + # Cifar data is downloaded, indicate other ranks can proceed. + torch.distributed.barrier() + + ######################################################################## + # Step 2. Define the network with DeepSpeed. + # + # First, we define a Convolution Neural Network. + # Then, we define the DeepSpeed configuration dictionary and use it to + # initialize the DeepSpeed engine. + ######################################################################## + net = Net(args) + + # Get list of parameters that require gradients. + parameters = filter(lambda p: p.requires_grad, net.parameters()) + + # If using MoE, create separate param groups for each expert. + if args.moe_param_group: + parameters = create_moe_param_groups(net) + + # Initialize DeepSpeed to use the following features. + # 1) Distributed model. + # 2) Distributed data loader. + # 3) DeepSpeed optimizer. + ds_config = get_ds_config(args) + model_engine, optimizer, trainloader, __ = deepspeed.initialize( + args=args, + model=net, + model_parameters=parameters, + training_data=trainset, + config=ds_config, + ) -parameters = filter(lambda p: p.requires_grad, net.parameters()) -if args.moe_param_group: - parameters = create_moe_param_groups(net) - -# Initialize DeepSpeed to use the following features -# 1) Distributed model -# 2) Distributed data loader -# 3) DeepSpeed optimizer -ds_config = { - "train_batch_size": 16, - "steps_per_print": 2000, - "optimizer": { - "type": "Adam", - "params": { - "lr": 0.001, - "betas": [ - 0.8, - 0.999 - ], - "eps": 1e-8, - "weight_decay": 3e-7 - } - }, - "scheduler": { - "type": "WarmupLR", - "params": { - "warmup_min_lr": 0, - "warmup_max_lr": 0.001, - "warmup_num_steps": 1000 - } - }, - "gradient_clipping": 1.0, - "prescale_gradients": False, - "bf16": { - "enabled": args.dtype == "bf16" - }, - "fp16": { - "enabled": args.dtype == "fp16", - "fp16_master_weights_and_grads": False, - "loss_scale": 0, - "loss_scale_window": 500, - "hysteresis": 2, - "min_loss_scale": 1, - "initial_scale_power": 15 - }, - "wall_clock_breakdown": False, - "zero_optimization": { - "stage": args.stage, - "allgather_partitions": True, - "reduce_scatter": True, - "allgather_bucket_size": 50000000, - "reduce_bucket_size": 50000000, - "overlap_comm": True, - "contiguous_gradients": True, - "cpu_offload": False - } -} - -model_engine, optimizer, trainloader, __ = deepspeed.initialize( - args=args, model=net, model_parameters=parameters, training_data=trainset, config=ds_config) - -local_device = get_accelerator().device_name(model_engine.local_rank) -local_rank = model_engine.local_rank - -# For float32, target_dtype will be None so no datatype conversion needed -target_dtype = None -if model_engine.bfloat16_enabled(): - target_dtype=torch.bfloat16 -elif model_engine.fp16_enabled(): - target_dtype=torch.half - -#device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -#net.to(device) -######################################################################## -# 3. Define a Loss function and optimizer -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -# Let's use a Classification Cross-Entropy loss and SGD with momentum. - -import torch.optim as optim - -criterion = nn.CrossEntropyLoss() -#optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) - -######################################################################## -# 4. Train the network -# ^^^^^^^^^^^^^^^^^^^^ -# -# This is when things start to get interesting. -# We simply have to loop over our data iterator, and feed the inputs to the -# network and optimize. - -for epoch in range(args.epochs): # loop over the dataset multiple times - - running_loss = 0.0 - for i, data in enumerate(trainloader): - # get the inputs; data is a list of [inputs, labels] - inputs, labels = data[0].to(local_device), data[1].to(local_device) - if target_dtype != None: - inputs = inputs.to(target_dtype) - outputs = model_engine(inputs) - loss = criterion(outputs, labels) - - model_engine.backward(loss) - model_engine.step() - - # print statistics - running_loss += loss.item() - if local_rank == 0 and i % args.log_interval == ( - args.log_interval - - 1): # print every log_interval mini-batches - print('[%d, %5d] loss: %.3f' % - (epoch + 1, i + 1, running_loss / args.log_interval)) - running_loss = 0.0 - -print('Finished Training') - -######################################################################## -# 5. Test the network on the test data -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -# -# We have trained the network for 2 passes over the training dataset. -# But we need to check if the network has learnt anything at all. -# -# We will check this by predicting the class label that the neural network -# outputs, and checking it against the ground-truth. If the prediction is -# correct, we add the sample to the list of correct predictions. -# -# Okay, first step. Let us display an image from the test set to get familiar. - -dataiter = iter(testloader) -images, labels = next(dataiter) - -# print images -imshow(torchvision.utils.make_grid(images)) -print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4))) - -######################################################################## -# Okay, now let us see what the neural network thinks these examples above are: -if target_dtype != None: - images = images.to(target_dtype) -outputs = net(images.to(local_device)) - -######################################################################## -# The outputs are energies for the 10 classes. -# The higher the energy for a class, the more the network -# thinks that the image is of the particular class. -# So, let's get the index of the highest energy: -_, predicted = torch.max(outputs, 1) - -print('Predicted: ', ' '.join('%5s' % classes[predicted[j]] for j in range(4))) - -######################################################################## -# The results seem pretty good. -# -# Let us look at how the network performs on the whole dataset. - -correct = 0 -total = 0 -with torch.no_grad(): - for data in testloader: - images, labels = data - if target_dtype != None: - images = images.to(target_dtype) - outputs = net(images.to(local_device)) - _, predicted = torch.max(outputs.data, 1) - total += labels.size(0) - correct += (predicted == labels.to(local_device)).sum().item() - -print('Accuracy of the network on the 10000 test images: %d %%' % - (100 * correct / total)) - -######################################################################## -# That looks way better than chance, which is 10% accuracy (randomly picking -# a class out of 10 classes). -# Seems like the network learnt something. -# -# Hmmm, what are the classes that performed well, and the classes that did -# not perform well: - -class_correct = list(0. for i in range(10)) -class_total = list(0. for i in range(10)) -with torch.no_grad(): - for data in testloader: - images, labels = data - if target_dtype != None: - images = images.to(target_dtype) - outputs = net(images.to(local_device)) - _, predicted = torch.max(outputs, 1) - c = (predicted == labels.to(local_device)).squeeze() - for i in range(4): - label = labels[i] - class_correct[label] += c[i].item() - class_total[label] += 1 - -for i in range(10): - print('Accuracy of %5s : %2d %%' % - (classes[i], 100 * class_correct[i] / class_total[i])) + # Get the local device name (str) and local rank (int). + local_device = get_accelerator().device_name(model_engine.local_rank) + local_rank = model_engine.local_rank + + # For float32, target_dtype will be None so no datatype conversion needed. + target_dtype = None + if model_engine.bfloat16_enabled(): + target_dtype = torch.bfloat16 + elif model_engine.fp16_enabled(): + target_dtype = torch.half + + # Define the Classification Cross-Entropy loss function. + criterion = nn.CrossEntropyLoss() + + ######################################################################## + # Step 3. Train the network. + # + # This is when things start to get interesting. + # We simply have to loop over our data iterator, and feed the inputs to the + # network and optimize. (DeepSpeed handles the distributed details for us!) + ######################################################################## + + for epoch in range(args.epochs): # loop over the dataset multiple times + running_loss = 0.0 + for i, data in enumerate(trainloader): + # Get the inputs. ``data`` is a list of [inputs, labels]. + inputs, labels = data[0].to(local_device), data[1].to(local_device) + + # Try to convert to target_dtype if needed. + if target_dtype != None: + inputs = inputs.to(target_dtype) + + outputs = model_engine(inputs) + loss = criterion(outputs, labels) + + model_engine.backward(loss) + model_engine.step() + + # Print statistics + running_loss += loss.item() + if local_rank == 0 and i % args.log_interval == ( + args.log_interval - 1 + ): # Print every log_interval mini-batches. + print( + f"[{epoch + 1 : d}, {i + 1 : 5d}] loss: {running_loss / args.log_interval : .3f}" + ) + running_loss = 0.0 + print("Finished Training") + + ######################################################################## + # Step 4. Test the network on the test data. + ######################################################################## + test(model_engine, testset, local_device, target_dtype) + + +if __name__ == "__main__": + args = add_argument() + main(args) diff --git a/training/cifar/run_ds_moe.sh b/training/cifar/run_ds_moe.sh index b7dcb7fa7..f87a29628 100755 --- a/training/cifar/run_ds_moe.sh +++ b/training/cifar/run_ds_moe.sh @@ -15,7 +15,6 @@ deepspeed --num_nodes=${NUM_NODES}\ cifar10_deepspeed.py \ --log-interval 100 \ --deepspeed \ - --deepspeed_config ds_config.json \ --moe \ --ep-world-size ${EP_SIZE} \ --num-experts ${EXPERTS} \ diff --git a/training/cifar/run_ds_prmoe.sh b/training/cifar/run_ds_prmoe.sh index 72731b0d5..d9755a331 100644 --- a/training/cifar/run_ds_prmoe.sh +++ b/training/cifar/run_ds_prmoe.sh @@ -12,7 +12,6 @@ EXPERTS='2 4' deepspeed --num_nodes=${NUM_NODES} --num_gpus=${NUM_GPUS} cifar10_deepspeed.py \ --log-interval 100 \ --deepspeed \ - --deepspeed_config ds_config.json \ --moe \ --ep-world-size ${EP_SIZE} \ --num-experts ${EXPERTS} \