Skip to content

Commit

Permalink
support bf16 for RLHF training (#733)
Browse files Browse the repository at this point in the history
  • Loading branch information
ys950902 authored Sep 22, 2023
1 parent db56381 commit 9b3d898
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ def parse_args():
parser.add_argument('--offload',
action='store_true',
help='Enable ZeRO Offload techniques.')
parser.add_argument('--dtype', type=str, default='fp16',
choices=['fp16', 'bf16'],
help = 'Training data type')
parser.add_argument(
'--zero_stage',
type=int,
Expand Down Expand Up @@ -202,6 +205,7 @@ def main():
args.global_rank = torch.distributed.get_rank()

ds_config = get_train_ds_config(offload=args.offload,
dtype=args.dtype,
stage=args.zero_stage,
enable_tensorboard=args.enable_tensorboard,
tb_path=args.tensorboard_path,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ def parse_args():
parser.add_argument('--offload',
action='store_true',
help='Enable ZeRO Offload techniques.')
parser.add_argument('--dtype', type=str, default='fp16',
choices=['fp16', 'bf16'],
help = 'Training data type')
parser.add_argument(
'--zero_stage',
type=int,
Expand Down Expand Up @@ -197,6 +200,7 @@ def main():
args.global_rank = torch.distributed.get_rank()

ds_config = get_train_ds_config(offload=args.offload,
dtype=args.dtype,
stage=args.zero_stage,
enable_tensorboard=args.enable_tensorboard,
tb_path=args.tensorboard_path,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,9 @@ def parse_args():
parser.add_argument('--offload',
action='store_true',
help='Enable ZeRO Offload techniques.')
parser.add_argument('--dtype', type=str, default='fp16',
choices=['fp16', 'bf16'],
help = 'Training data type')
parser.add_argument(
'--offload_reference_model',
action='store_true',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def _init_actor(self, actor_model_name_or_path):
# DS Config
ds_config = get_train_ds_config(
offload=self.args.offload,
dtype=self.args.dtype,
stage=self.args.actor_zero_stage,
enable_hybrid_engine=self.args.enable_hybrid_engine,
inference_tp_size=self.args.inference_tp_size,
Expand Down Expand Up @@ -139,6 +140,7 @@ def _init_ref(self, actor_model_name_or_path):
# If actor is ZeRO-3 then we use it for everything, otherwise assume we have enough memory for ref model
zero_stage = 0
ds_config = get_eval_ds_config(self.args.offload_reference_model,
self.args.dtype,
zero_stage)
ds_config[
'train_micro_batch_size_per_gpu'] = self.args.per_device_training_batch_size
Expand All @@ -165,6 +167,7 @@ def _init_ema(self, actor_model_name_or_path):
# If actor is ZeRO-3 then we use it for everything, otherwise assume we have enough memory
zero_stage = 0
ds_config = get_eval_ds_config(self.args.offload_reference_model,
self.args.dtype,
zero_stage)
ds_config[
'train_micro_batch_size_per_gpu'] = self.args.per_device_training_batch_size
Expand All @@ -191,6 +194,7 @@ def _init_critic(self, critic_model_name_or_path):
stime = log_init("Critic")
ds_config = get_train_ds_config(
offload=self.args.offload,
dtype=self.args.dtype,
stage=self.args.critic_zero_stage,
enable_tensorboard=self.args.enable_tensorboard,
tb_path=self.args.tensorboard_path,
Expand All @@ -203,6 +207,7 @@ def _init_critic(self, critic_model_name_or_path):
) * self.args.gradient_accumulation_steps

ds_eval_config = get_eval_ds_config(offload=False,
dtype=self.args.dtype,
stage=self.args.critic_zero_stage)
# We need to set train batch size and micro batch size here to pass the sanity check of DeepSpeed engine.
ds_eval_config[
Expand Down Expand Up @@ -266,14 +271,15 @@ def _init_reward(self, critic_model_name_or_path):
zero_stage = 0

ds_config = get_eval_ds_config(offload=self.args.offload,
dtype=self.args.dtype,
stage=zero_stage)
ds_config[
'train_micro_batch_size_per_gpu'] = self.args.per_device_training_batch_size
ds_config[
'train_batch_size'] = self.args.per_device_training_batch_size * torch.distributed.get_world_size(
) * self.args.gradient_accumulation_steps

ds_eval_config = get_eval_ds_config(offload=False, stage=zero_stage)
ds_eval_config = get_eval_ds_config(offload=False, dtype=self.args.dtype, stage=zero_stage)

# We need to set train batch size and micro batch size here to pass the sanity check of DeepSpeed engine.
ds_eval_config[
Expand Down
33 changes: 25 additions & 8 deletions applications/DeepSpeed-Chat/training/utils/ds_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@


def get_train_ds_config(offload,
dtype,
stage=2,
enable_hybrid_engine=False,
inference_tp_size=1,
Expand All @@ -25,6 +26,17 @@ def get_train_ds_config(offload,
tb_name=""):

device = "cpu" if offload else "none"
if dtype == "fp16":
data_type = "fp16"
dtype_config = {
"enabled": True,
"loss_scale_window": 100
}
elif dtype == "bf16":
data_type = "bfloat16"
dtype_config = {
"enabled": True
}
zero_opt_dict = {
"stage": stage,
"offload_param": {
Expand All @@ -48,10 +60,7 @@ def get_train_ds_config(offload,
"train_micro_batch_size_per_gpu": MICRO_BATCH_SIZE,
"steps_per_print": 10,
"zero_optimization": zero_opt_dict,
"fp16": {
"enabled": True,
"loss_scale_window": 100
},
data_type: dtype_config,
"gradient_clipping": 1.0,
"prescale_gradients": False,
"wall_clock_breakdown": False,
Expand All @@ -71,8 +80,18 @@ def get_train_ds_config(offload,
}


def get_eval_ds_config(offload, stage=0):
def get_eval_ds_config(offload, dtype, stage=0):
device = "cpu" if offload else "none"
if dtype == "fp16":
data_type = "fp16"
dtype_config = {
"enabled": True,
}
elif dtype == "bf16":
data_type = "bfloat16"
dtype_config = {
"enabled": True
}
zero_opt_dict = {
"stage": stage,
"stage3_param_persistence_threshold": 1e4,
Expand All @@ -86,9 +105,7 @@ def get_eval_ds_config(offload, stage=0):
"train_micro_batch_size_per_gpu": MICRO_BATCH_SIZE,
"steps_per_print": 10,
"zero_optimization": zero_opt_dict,
"fp16": {
"enabled": True
},
data_type: dtype_config,
"gradient_clipping": 1.0,
"prescale_gradients": False,
"wall_clock_breakdown": False
Expand Down

0 comments on commit 9b3d898

Please sign in to comment.