Skip to content

Commit

Permalink
Merge branch 'master' into patch-1
Browse files Browse the repository at this point in the history
  • Loading branch information
tjruwase authored Oct 29, 2024
2 parents 95573dd + 130fb58 commit 9cd25dd
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions training/cifar/cifar10_deepspeed.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import os

import deepspeed
import torch
Expand Down Expand Up @@ -279,6 +280,8 @@ def test(model_engine, testset, local_device, target_dtype, test_batch_size=4):
def main(args):
# Initialize DeepSpeed distributed backend.
deepspeed.init_distributed()
_local_rank = int(os.environ.get("LOCAL_RANK"))
get_accelerator().set_device(_local_rank)

########################################################################
# Step1. Data Preparation.
Expand Down

0 comments on commit 9cd25dd

Please sign in to comment.