diff --git a/training/cifar/cifar10_deepspeed.py b/training/cifar/cifar10_deepspeed.py index 521a75cdf..9888544d5 100755 --- a/training/cifar/cifar10_deepspeed.py +++ b/training/cifar/cifar10_deepspeed.py @@ -1,4 +1,5 @@ import argparse +import os import deepspeed import torch @@ -279,6 +280,8 @@ def test(model_engine, testset, local_device, target_dtype, test_batch_size=4): def main(args): # Initialize DeepSpeed distributed backend. deepspeed.init_distributed() + _local_rank = int(os.environ.get("LOCAL_RANK")) + get_accelerator().set_device(_local_rank) ######################################################################## # Step1. Data Preparation.