From 9b3d8984d8211e87f778da9125c836b3d8c9b90f Mon Sep 17 00:00:00 2001 From: YiSheng5 Date: Sat, 23 Sep 2023 01:03:07 +0800 Subject: [PATCH] support bf16 for RLHF training (#733) --- .../step1_supervised_finetuning/main.py | 4 +++ .../step2_reward_model_finetuning/main.py | 4 +++ .../training/step3_rlhf_finetuning/main.py | 3 ++ .../step3_rlhf_finetuning/rlhf_engine.py | 8 ++++- .../DeepSpeed-Chat/training/utils/ds_utils.py | 33 ++++++++++++++----- 5 files changed, 43 insertions(+), 9 deletions(-) diff --git a/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py b/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py index 2ce344fb4..cd956fb50 100755 --- a/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py +++ b/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py @@ -146,6 +146,9 @@ def parse_args(): parser.add_argument('--offload', action='store_true', help='Enable ZeRO Offload techniques.') + parser.add_argument('--dtype', type=str, default='fp16', + choices=['fp16', 'bf16'], + help = 'Training data type') parser.add_argument( '--zero_stage', type=int, @@ -202,6 +205,7 @@ def main(): args.global_rank = torch.distributed.get_rank() ds_config = get_train_ds_config(offload=args.offload, + dtype=args.dtype, stage=args.zero_stage, enable_tensorboard=args.enable_tensorboard, tb_path=args.tensorboard_path, diff --git a/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/main.py b/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/main.py index b3cca5b44..0bb78cd23 100644 --- a/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/main.py +++ b/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/main.py @@ -145,6 +145,9 @@ def parse_args(): parser.add_argument('--offload', action='store_true', help='Enable ZeRO Offload techniques.') + parser.add_argument('--dtype', type=str, default='fp16', + choices=['fp16', 'bf16'], + help = 'Training data type') parser.add_argument( '--zero_stage', type=int, @@ -197,6 +200,7 @@ def main(): args.global_rank = torch.distributed.get_rank() ds_config = get_train_ds_config(offload=args.offload, + dtype=args.dtype, stage=args.zero_stage, enable_tensorboard=args.enable_tensorboard, tb_path=args.tensorboard_path, diff --git a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py index 6caf89b5e..9263b6d8d 100644 --- a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py +++ b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py @@ -241,6 +241,9 @@ def parse_args(): parser.add_argument('--offload', action='store_true', help='Enable ZeRO Offload techniques.') + parser.add_argument('--dtype', type=str, default='fp16', + choices=['fp16', 'bf16'], + help = 'Training data type') parser.add_argument( '--offload_reference_model', action='store_true', diff --git a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/rlhf_engine.py b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/rlhf_engine.py index e9234f2b4..187a36efe 100755 --- a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/rlhf_engine.py +++ b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/rlhf_engine.py @@ -66,6 +66,7 @@ def _init_actor(self, actor_model_name_or_path): # DS Config ds_config = get_train_ds_config( offload=self.args.offload, + dtype=self.args.dtype, stage=self.args.actor_zero_stage, enable_hybrid_engine=self.args.enable_hybrid_engine, inference_tp_size=self.args.inference_tp_size, @@ -139,6 +140,7 @@ def _init_ref(self, actor_model_name_or_path): # If actor is ZeRO-3 then we use it for everything, otherwise assume we have enough memory for ref model zero_stage = 0 ds_config = get_eval_ds_config(self.args.offload_reference_model, + self.args.dtype, zero_stage) ds_config[ 'train_micro_batch_size_per_gpu'] = self.args.per_device_training_batch_size @@ -165,6 +167,7 @@ def _init_ema(self, actor_model_name_or_path): # If actor is ZeRO-3 then we use it for everything, otherwise assume we have enough memory zero_stage = 0 ds_config = get_eval_ds_config(self.args.offload_reference_model, + self.args.dtype, zero_stage) ds_config[ 'train_micro_batch_size_per_gpu'] = self.args.per_device_training_batch_size @@ -191,6 +194,7 @@ def _init_critic(self, critic_model_name_or_path): stime = log_init("Critic") ds_config = get_train_ds_config( offload=self.args.offload, + dtype=self.args.dtype, stage=self.args.critic_zero_stage, enable_tensorboard=self.args.enable_tensorboard, tb_path=self.args.tensorboard_path, @@ -203,6 +207,7 @@ def _init_critic(self, critic_model_name_or_path): ) * self.args.gradient_accumulation_steps ds_eval_config = get_eval_ds_config(offload=False, + dtype=self.args.dtype, stage=self.args.critic_zero_stage) # We need to set train batch size and micro batch size here to pass the sanity check of DeepSpeed engine. ds_eval_config[ @@ -266,6 +271,7 @@ def _init_reward(self, critic_model_name_or_path): zero_stage = 0 ds_config = get_eval_ds_config(offload=self.args.offload, + dtype=self.args.dtype, stage=zero_stage) ds_config[ 'train_micro_batch_size_per_gpu'] = self.args.per_device_training_batch_size @@ -273,7 +279,7 @@ def _init_reward(self, critic_model_name_or_path): 'train_batch_size'] = self.args.per_device_training_batch_size * torch.distributed.get_world_size( ) * self.args.gradient_accumulation_steps - ds_eval_config = get_eval_ds_config(offload=False, stage=zero_stage) + ds_eval_config = get_eval_ds_config(offload=False, dtype=self.args.dtype, stage=zero_stage) # We need to set train batch size and micro batch size here to pass the sanity check of DeepSpeed engine. ds_eval_config[ diff --git a/applications/DeepSpeed-Chat/training/utils/ds_utils.py b/applications/DeepSpeed-Chat/training/utils/ds_utils.py index 065b3fcba..f042283e7 100644 --- a/applications/DeepSpeed-Chat/training/utils/ds_utils.py +++ b/applications/DeepSpeed-Chat/training/utils/ds_utils.py @@ -12,6 +12,7 @@ def get_train_ds_config(offload, + dtype, stage=2, enable_hybrid_engine=False, inference_tp_size=1, @@ -25,6 +26,17 @@ def get_train_ds_config(offload, tb_name=""): device = "cpu" if offload else "none" + if dtype == "fp16": + data_type = "fp16" + dtype_config = { + "enabled": True, + "loss_scale_window": 100 + } + elif dtype == "bf16": + data_type = "bfloat16" + dtype_config = { + "enabled": True + } zero_opt_dict = { "stage": stage, "offload_param": { @@ -48,10 +60,7 @@ def get_train_ds_config(offload, "train_micro_batch_size_per_gpu": MICRO_BATCH_SIZE, "steps_per_print": 10, "zero_optimization": zero_opt_dict, - "fp16": { - "enabled": True, - "loss_scale_window": 100 - }, + data_type: dtype_config, "gradient_clipping": 1.0, "prescale_gradients": False, "wall_clock_breakdown": False, @@ -71,8 +80,18 @@ def get_train_ds_config(offload, } -def get_eval_ds_config(offload, stage=0): +def get_eval_ds_config(offload, dtype, stage=0): device = "cpu" if offload else "none" + if dtype == "fp16": + data_type = "fp16" + dtype_config = { + "enabled": True, + } + elif dtype == "bf16": + data_type = "bfloat16" + dtype_config = { + "enabled": True + } zero_opt_dict = { "stage": stage, "stage3_param_persistence_threshold": 1e4, @@ -86,9 +105,7 @@ def get_eval_ds_config(offload, stage=0): "train_micro_batch_size_per_gpu": MICRO_BATCH_SIZE, "steps_per_print": 10, "zero_optimization": zero_opt_dict, - "fp16": { - "enabled": True - }, + data_type: dtype_config, "gradient_clipping": 1.0, "prescale_gradients": False, "wall_clock_breakdown": False