Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Device]Support DeepSpeedChat to run on different device besides cuda #736

Merged
merged 2 commits into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import deepspeed
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
from deepspeed import get_accelerator

sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
Expand Down Expand Up @@ -190,10 +191,10 @@ def main():
args = parse_args()

if args.local_rank == -1:
device = torch.device("cuda")
device = torch.device(get_accelerator().device_name())
else:
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
get_accelerator().set_device(args.local_rank)
device = torch.device(get_accelerator().device_name(), args.local_rank)
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
# torch.distributed.init_process_group(backend='nccl')
deepspeed.init_distributed()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
from utils.model.model_utils import create_hf_model
from utils.utils import load_hf_tokenizer
from deepspeed import get_accelerator

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -194,7 +195,7 @@ def prompt_eval(args, model_baseline, model_fintuned, tokenizer, device,
def main():
args = parse_args()

device = torch.device("cuda:0")
device = torch.device(get_accelerator().device_name(0))

tokenizer = load_hf_tokenizer(args.model_name_or_path_baseline,
fast_tokenizer=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import deepspeed
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
from deepspeed.accelerator import get_accelerator

sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
Expand Down Expand Up @@ -185,10 +186,10 @@ def main():
args = parse_args()

if args.local_rank == -1:
device = torch.device("cuda")
device = torch.device(get_accelerator().device_name())
else:
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
get_accelerator().set_device(args.local_rank)
device = torch.device(get_accelerator().device_name(), args.local_rank)
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
# torch.distributed.init_process_group(backend='nccl')
deepspeed.init_distributed()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from utils.model.model_utils import create_critic_model
from utils.utils import to_device
from utils.utils import load_hf_tokenizer
from deepspeed import get_accelerator


def parse_args():
Expand Down Expand Up @@ -100,7 +101,7 @@ def prepare_singlesample(prompt,
def run_pair_comparison():
args = parse_args()

device = torch.device("cuda:0")
device = torch.device(get_accelerator().device_name(0))

rm_model, tokenizer = load_stuff(args.model_name_or_path,
args.num_padding_at_beginning)
Expand Down Expand Up @@ -144,7 +145,7 @@ def run_pair_comparison():

def run_single_sample():
args = parse_args()
device = torch.device("cuda")
device = torch.device(get_accelerator().device_name())

rm_model, tokenizer = load_stuff(args.model_name_or_path,
args.num_padding_at_beginning)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from utils.utils import print_rank_0, to_device, save_hf_format, set_random_seed, get_all_reduce_mean, moving_average, save_zero_three_model, load_hf_tokenizer
from utils.module.lora import convert_lora_to_linear_layer
from utils.perf import print_throughput_step3
from deepspeed.accelerator import get_accelerator

writer = None

Expand Down Expand Up @@ -417,10 +418,10 @@ def main():
args = parse_args()

if args.local_rank == -1:
device = torch.device("cuda")
device = torch.device(get_accelerator().device_name())
else:
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
get_accelerator().set_device(args.local_rank)
device = torch.device(get_accelerator().device_name(), args.local_rank)
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
deepspeed.init_distributed()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import time
import deepspeed
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from deepspeed.accelerator import get_accelerator

sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
Expand All @@ -18,7 +19,8 @@

def print_all_ranks(tag, value, rank):
world_size = torch.distributed.get_world_size()
all_tensor = torch.zeros(world_size, dtype=torch.float32).cuda()
all_tensor = torch.zeros(world_size, dtype=torch.float32).to(
get_accelerator().current_device_name())
all_tensor[rank] = value
torch.distributed.all_reduce(all_tensor, op=torch.distributed.ReduceOp.SUM)
print_rank_0(f'{tag} {all_tensor}', rank)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import hashlib
from itertools import chain
from . import raw_datasets
from deepspeed.accelerator import get_accelerator


def get_raw_dataset(dataset_name, output_path, seed, local_rank):
Expand Down Expand Up @@ -281,7 +282,8 @@ def create_prompt_dataset(local_rank,
eval_fname = f"{output_path}/evaldata_{fname}.pt"

cache_found = os.path.isfile(train_fname) and os.path.isfile(eval_fname)
buf_create_cache = torch.ByteTensor([not cache_found]).cuda()
buf_create_cache = torch.ByteTensor([not cache_found]).to(
get_accelerator().current_device_name())
torch.distributed.all_reduce(buf_create_cache)

if local_rank <= 0 and (buf_create_cache.item() != 0 or reload):
Expand Down
7 changes: 4 additions & 3 deletions applications/DeepSpeed-Chat/training/utils/ds_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch
import deepspeed.comm as dist
from deepspeed.accelerator import get_accelerator

GLOBAL_BATCH_SIZE = 32
MICRO_BATCH_SIZE = 4
Expand Down Expand Up @@ -39,9 +40,9 @@ def get_train_ds_config(offload,
}
if enable_mixed_precision_lora:
zero_opt_dict["zero_quantized_nontrainable_weights"] = True
if dist.get_world_size() != torch.cuda.device_count():
zero_opt_dict["zero_hpz_partition_size"] = torch.cuda.device_count(
)
if dist.get_world_size() != get_accelerator().device_count():
zero_opt_dict["zero_hpz_partition_size"] = get_accelerator(
).device_count()
return {
"train_batch_size": GLOBAL_BATCH_SIZE,
"train_micro_batch_size_per_gpu": MICRO_BATCH_SIZE,
Expand Down
3 changes: 2 additions & 1 deletion applications/DeepSpeed-Chat/training/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import json
import deepspeed
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from deepspeed.accelerator import get_accelerator
import torch.nn as nn


Expand Down Expand Up @@ -102,7 +103,7 @@ def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
get_accelerator().manual_seed_all(seed)


def get_all_reduce_mean(tensor):
Expand Down
Loading