Skip to content

Commit

Permalink
Merge branch 'master' into olruwase/zero_inference
Browse files Browse the repository at this point in the history
  • Loading branch information
tjruwase authored Sep 22, 2023
2 parents d702342 + 9b3d898 commit e9b66f0
Show file tree
Hide file tree
Showing 10 changed files with 72 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import deepspeed
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
from deepspeed import get_accelerator

sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
Expand Down Expand Up @@ -145,6 +146,9 @@ def parse_args():
parser.add_argument('--offload',
action='store_true',
help='Enable ZeRO Offload techniques.')
parser.add_argument('--dtype', type=str, default='fp16',
choices=['fp16', 'bf16'],
help = 'Training data type')
parser.add_argument(
'--zero_stage',
type=int,
Expand Down Expand Up @@ -190,17 +194,18 @@ def main():
args = parse_args()

if args.local_rank == -1:
device = torch.device("cuda")
device = torch.device(get_accelerator().device_name())
else:
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
get_accelerator().set_device(args.local_rank)
device = torch.device(get_accelerator().device_name(), args.local_rank)
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
# torch.distributed.init_process_group(backend='nccl')
deepspeed.init_distributed()

args.global_rank = torch.distributed.get_rank()

ds_config = get_train_ds_config(offload=args.offload,
dtype=args.dtype,
stage=args.zero_stage,
enable_tensorboard=args.enable_tensorboard,
tb_path=args.tensorboard_path,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
from utils.model.model_utils import create_hf_model
from utils.utils import load_hf_tokenizer
from deepspeed import get_accelerator

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -194,7 +195,7 @@ def prompt_eval(args, model_baseline, model_fintuned, tokenizer, device,
def main():
args = parse_args()

device = torch.device("cuda:0")
device = torch.device(get_accelerator().device_name(0))

tokenizer = load_hf_tokenizer(args.model_name_or_path_baseline,
fast_tokenizer=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import deepspeed
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
from deepspeed.accelerator import get_accelerator

sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
Expand Down Expand Up @@ -144,6 +145,9 @@ def parse_args():
parser.add_argument('--offload',
action='store_true',
help='Enable ZeRO Offload techniques.')
parser.add_argument('--dtype', type=str, default='fp16',
choices=['fp16', 'bf16'],
help = 'Training data type')
parser.add_argument(
'--zero_stage',
type=int,
Expand Down Expand Up @@ -185,17 +189,18 @@ def main():
args = parse_args()

if args.local_rank == -1:
device = torch.device("cuda")
device = torch.device(get_accelerator().device_name())
else:
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
get_accelerator().set_device(args.local_rank)
device = torch.device(get_accelerator().device_name(), args.local_rank)
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
# torch.distributed.init_process_group(backend='nccl')
deepspeed.init_distributed()

args.global_rank = torch.distributed.get_rank()

ds_config = get_train_ds_config(offload=args.offload,
dtype=args.dtype,
stage=args.zero_stage,
enable_tensorboard=args.enable_tensorboard,
tb_path=args.tensorboard_path,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from utils.model.model_utils import create_critic_model
from utils.utils import to_device
from utils.utils import load_hf_tokenizer
from deepspeed import get_accelerator


def parse_args():
Expand Down Expand Up @@ -100,7 +101,7 @@ def prepare_singlesample(prompt,
def run_pair_comparison():
args = parse_args()

device = torch.device("cuda:0")
device = torch.device(get_accelerator().device_name(0))

rm_model, tokenizer = load_stuff(args.model_name_or_path,
args.num_padding_at_beginning)
Expand Down Expand Up @@ -144,7 +145,7 @@ def run_pair_comparison():

def run_single_sample():
args = parse_args()
device = torch.device("cuda")
device = torch.device(get_accelerator().device_name())

rm_model, tokenizer = load_stuff(args.model_name_or_path,
args.num_padding_at_beginning)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from utils.utils import print_rank_0, to_device, save_hf_format, set_random_seed, get_all_reduce_mean, moving_average, save_zero_three_model, load_hf_tokenizer
from utils.module.lora import convert_lora_to_linear_layer
from utils.perf import print_throughput_step3
from deepspeed.accelerator import get_accelerator

writer = None

Expand Down Expand Up @@ -240,6 +241,9 @@ def parse_args():
parser.add_argument('--offload',
action='store_true',
help='Enable ZeRO Offload techniques.')
parser.add_argument('--dtype', type=str, default='fp16',
choices=['fp16', 'bf16'],
help = 'Training data type')
parser.add_argument(
'--offload_reference_model',
action='store_true',
Expand Down Expand Up @@ -417,10 +421,10 @@ def main():
args = parse_args()

if args.local_rank == -1:
device = torch.device("cuda")
device = torch.device(get_accelerator().device_name())
else:
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
get_accelerator().set_device(args.local_rank)
device = torch.device(get_accelerator().device_name(), args.local_rank)
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
deepspeed.init_distributed()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import time
import deepspeed
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from deepspeed.accelerator import get_accelerator

sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
Expand All @@ -18,7 +19,8 @@

def print_all_ranks(tag, value, rank):
world_size = torch.distributed.get_world_size()
all_tensor = torch.zeros(world_size, dtype=torch.float32).cuda()
all_tensor = torch.zeros(world_size, dtype=torch.float32).to(
get_accelerator().current_device_name())
all_tensor[rank] = value
torch.distributed.all_reduce(all_tensor, op=torch.distributed.ReduceOp.SUM)
print_rank_0(f'{tag} {all_tensor}', rank)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def _init_actor(self, actor_model_name_or_path):
# DS Config
ds_config = get_train_ds_config(
offload=self.args.offload,
dtype=self.args.dtype,
stage=self.args.actor_zero_stage,
enable_hybrid_engine=self.args.enable_hybrid_engine,
inference_tp_size=self.args.inference_tp_size,
Expand Down Expand Up @@ -139,6 +140,7 @@ def _init_ref(self, actor_model_name_or_path):
# If actor is ZeRO-3 then we use it for everything, otherwise assume we have enough memory for ref model
zero_stage = 0
ds_config = get_eval_ds_config(self.args.offload_reference_model,
self.args.dtype,
zero_stage)
ds_config[
'train_micro_batch_size_per_gpu'] = self.args.per_device_training_batch_size
Expand All @@ -165,6 +167,7 @@ def _init_ema(self, actor_model_name_or_path):
# If actor is ZeRO-3 then we use it for everything, otherwise assume we have enough memory
zero_stage = 0
ds_config = get_eval_ds_config(self.args.offload_reference_model,
self.args.dtype,
zero_stage)
ds_config[
'train_micro_batch_size_per_gpu'] = self.args.per_device_training_batch_size
Expand All @@ -191,6 +194,7 @@ def _init_critic(self, critic_model_name_or_path):
stime = log_init("Critic")
ds_config = get_train_ds_config(
offload=self.args.offload,
dtype=self.args.dtype,
stage=self.args.critic_zero_stage,
enable_tensorboard=self.args.enable_tensorboard,
tb_path=self.args.tensorboard_path,
Expand All @@ -203,6 +207,7 @@ def _init_critic(self, critic_model_name_or_path):
) * self.args.gradient_accumulation_steps

ds_eval_config = get_eval_ds_config(offload=False,
dtype=self.args.dtype,
stage=self.args.critic_zero_stage)
# We need to set train batch size and micro batch size here to pass the sanity check of DeepSpeed engine.
ds_eval_config[
Expand Down Expand Up @@ -266,14 +271,15 @@ def _init_reward(self, critic_model_name_or_path):
zero_stage = 0

ds_config = get_eval_ds_config(offload=self.args.offload,
dtype=self.args.dtype,
stage=zero_stage)
ds_config[
'train_micro_batch_size_per_gpu'] = self.args.per_device_training_batch_size
ds_config[
'train_batch_size'] = self.args.per_device_training_batch_size * torch.distributed.get_world_size(
) * self.args.gradient_accumulation_steps

ds_eval_config = get_eval_ds_config(offload=False, stage=zero_stage)
ds_eval_config = get_eval_ds_config(offload=False, dtype=self.args.dtype, stage=zero_stage)

# We need to set train batch size and micro batch size here to pass the sanity check of DeepSpeed engine.
ds_eval_config[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import hashlib
from itertools import chain
from . import raw_datasets
from deepspeed.accelerator import get_accelerator


def get_raw_dataset(dataset_name, output_path, seed, local_rank):
Expand Down Expand Up @@ -281,7 +282,8 @@ def create_prompt_dataset(local_rank,
eval_fname = f"{output_path}/evaldata_{fname}.pt"

cache_found = os.path.isfile(train_fname) and os.path.isfile(eval_fname)
buf_create_cache = torch.ByteTensor([not cache_found]).cuda()
buf_create_cache = torch.ByteTensor([not cache_found]).to(
get_accelerator().current_device_name())
torch.distributed.all_reduce(buf_create_cache)

if local_rank <= 0 and (buf_create_cache.item() != 0 or reload):
Expand Down
40 changes: 29 additions & 11 deletions applications/DeepSpeed-Chat/training/utils/ds_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@

import torch
import deepspeed.comm as dist
from deepspeed.accelerator import get_accelerator

GLOBAL_BATCH_SIZE = 32
MICRO_BATCH_SIZE = 4


def get_train_ds_config(offload,
dtype,
stage=2,
enable_hybrid_engine=False,
inference_tp_size=1,
Expand All @@ -24,6 +26,17 @@ def get_train_ds_config(offload,
tb_name=""):

device = "cpu" if offload else "none"
if dtype == "fp16":
data_type = "fp16"
dtype_config = {
"enabled": True,
"loss_scale_window": 100
}
elif dtype == "bf16":
data_type = "bfloat16"
dtype_config = {
"enabled": True
}
zero_opt_dict = {
"stage": stage,
"offload_param": {
Expand All @@ -39,18 +52,15 @@ def get_train_ds_config(offload,
}
if enable_mixed_precision_lora:
zero_opt_dict["zero_quantized_nontrainable_weights"] = True
if dist.get_world_size() != torch.cuda.device_count():
zero_opt_dict["zero_hpz_partition_size"] = torch.cuda.device_count(
)
if dist.get_world_size() != get_accelerator().device_count():
zero_opt_dict["zero_hpz_partition_size"] = get_accelerator(
).device_count()
return {
"train_batch_size": GLOBAL_BATCH_SIZE,
"train_micro_batch_size_per_gpu": MICRO_BATCH_SIZE,
"steps_per_print": 10,
"zero_optimization": zero_opt_dict,
"fp16": {
"enabled": True,
"loss_scale_window": 100
},
data_type: dtype_config,
"gradient_clipping": 1.0,
"prescale_gradients": False,
"wall_clock_breakdown": False,
Expand All @@ -70,8 +80,18 @@ def get_train_ds_config(offload,
}


def get_eval_ds_config(offload, stage=0):
def get_eval_ds_config(offload, dtype, stage=0):
device = "cpu" if offload else "none"
if dtype == "fp16":
data_type = "fp16"
dtype_config = {
"enabled": True,
}
elif dtype == "bf16":
data_type = "bfloat16"
dtype_config = {
"enabled": True
}
zero_opt_dict = {
"stage": stage,
"stage3_param_persistence_threshold": 1e4,
Expand All @@ -85,9 +105,7 @@ def get_eval_ds_config(offload, stage=0):
"train_micro_batch_size_per_gpu": MICRO_BATCH_SIZE,
"steps_per_print": 10,
"zero_optimization": zero_opt_dict,
"fp16": {
"enabled": True
},
data_type: dtype_config,
"gradient_clipping": 1.0,
"prescale_gradients": False,
"wall_clock_breakdown": False
Expand Down
3 changes: 2 additions & 1 deletion applications/DeepSpeed-Chat/training/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import json
import deepspeed
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from deepspeed.accelerator import get_accelerator
import torch.nn as nn


Expand Down Expand Up @@ -102,7 +103,7 @@ def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
get_accelerator().manual_seed_all(seed)


def get_all_reduce_mean(tensor):
Expand Down

0 comments on commit e9b66f0

Please sign in to comment.