Skip to content

Commit

Permalink
deepspeed-chat: support explicit configuration of dropout (#746)
Browse files Browse the repository at this point in the history
Currently, only disable_dropout configuration is supported.
However, some models (e.g. Bloom) have a default of dropout=0 in model config.
Therefore, modify to support explicit dropout configuration.
Also, update accordingly existing training scripts.

Change-Id: I5ee96a77ca2b58d9787573a48009e2af36a270b0

Signed-off-by: Moshe Island <misland@habana.ai>
Co-authored-by: Moshe Island <misland@habana.ai>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
  • Loading branch information
3 people authored Oct 3, 2023
1 parent 2f99dcd commit 4bf1924
Show file tree
Hide file tree
Showing 22 changed files with 66 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,12 @@ def parse_args():
parser.add_argument('--gradient_checkpointing',
action='store_true',
help='Enable HF gradient checkpointing for model.')
parser.add_argument('--disable_dropout',
action='store_true',
help='Disable the dropout of the model.')
parser.add_argument(
"--dropout",
type=float,
default=None,
help="If dropout configured, use it. "
"Otherwise, keep the default dropout configuration of the model.")
# deepspeed features
parser.add_argument('--offload',
action='store_true',
Expand Down Expand Up @@ -229,7 +232,7 @@ def main():
args.model_name_or_path,
tokenizer,
ds_config,
disable_dropout=args.disable_dropout)
dropout=args.dropout)

if args.lora_dim > 0:
model = convert_linear_layer_to_lora(model, args.lora_module_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,12 @@ def parse_args():
'--gradient_checkpointing',
action='store_true',
help='Enable HF gradient checkpointing for Actor model.')
parser.add_argument('--disable_dropout',
action='store_true',
help='Disable the dropout of the model.')
parser.add_argument(
"--dropout",
type=float,
default=None,
help="If dropout configured, use it. "
"Otherwise, keep the default dropout configuration of the model.")
# deepspeed features
parser.add_argument('--offload',
action='store_true',
Expand Down Expand Up @@ -223,7 +226,7 @@ def main():
tokenizer,
ds_config,
args.num_padding_at_beginning,
disable_dropout=args.disable_dropout)
dropout=args.dropout)

if args.lora_dim > 0:
rm_model = convert_linear_layer_to_lora(rm_model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,11 @@ def load_stuff(model_name_or_path, num_padding_at_beginning):

tokenizer = load_hf_tokenizer(model_name_or_path, fast_tokenizer=True)
tokenizer.pad_token = tokenizer.eos_token
model = create_critic_model(model_name_or_path, tokenizer, None,
num_padding_at_beginning, True)
model = create_critic_model(model_name_or_path,
tokenizer,
None,
num_padding_at_beginning,
dropout=0.)

return model, tokenizer

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ deepspeed main.py \
--max_seq_len 512 \
--learning_rate 5e-5 \
--weight_decay 0.1 \
--disable_dropout \
--dropout 0.0 \
--num_train_epochs 1 \
--gradient_accumulation_steps 1 \
--lr_scheduler_type cosine \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ fi
mkdir -p $OUTPUT

deepspeed --num_gpus 1 main.py --model_name_or_path facebook/opt-350m \
--num_padding_at_beginning 1 --weight_decay 0.1 --disable_dropout --gradient_accumulation_steps 4 --zero_stage $ZERO_STAGE \
--num_padding_at_beginning 1 --weight_decay 0.1 --dropout 0.0 --gradient_accumulation_steps 4 --zero_stage $ZERO_STAGE \
--enable_tensorboard \
--tensorboard_path $OUTPUT \
--deepspeed --output_dir $OUTPUT &> $OUTPUT/training.log
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ deepspeed main.py \
--learning_rate 5e-5 \
--weight_decay 0.1 \
--num_train_epochs 1 \
--disable_dropout \
--dropout 0.0 \
--gradient_accumulation_steps 1 \
--lr_scheduler_type cosine \
--num_warmup_steps 0 \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ cmd="deepspeed main.py \
--learning_rate 5e-5 \
--weight_decay 0.1 \
--num_train_epochs 1 \
--disable_dropout \
--dropout 0.0 \
--gradient_accumulation_steps 1 \
--lr_scheduler_type cosine \
--num_warmup_steps 0 \
Expand Down
20 changes: 14 additions & 6 deletions applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,12 +268,20 @@ def parse_args():
'--critic_gradient_checkpointing',
action='store_true',
help='Enable HF gradient checkpointing for Critic model.')
parser.add_argument('--disable_actor_dropout',
action='store_true',
help='Disable the dropout of the actor model.')
parser.add_argument('--disable_critic_dropout',
action='store_true',
help='Disable the dropout of the critical model.')
parser.add_argument(
"--actor_dropout",
type=float,
default=None,
help="If actor dropout configured, use it. "
"Otherwise, keep the default dropout configuration of the actor model."
)
parser.add_argument(
"--critic_dropout",
type=float,
default=None,
help="If critic dropout configured, use it. "
"Otherwise, keep the default dropout configuration of the critic model."
)
## LoRA for efficient training setting
parser.add_argument("--actor_lora_dim",
type=int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def _init_actor(self, actor_model_name_or_path):
model_name_or_path=actor_model_name_or_path,
tokenizer=self.tokenizer,
ds_config=ds_config,
disable_dropout=self.args.disable_actor_dropout)
dropout=self.args.actor_dropout)

# LoRA
if self.args.actor_lora_dim > 0:
Expand Down Expand Up @@ -221,7 +221,7 @@ def _init_critic(self, critic_model_name_or_path):
ds_config=ds_eval_config,
num_padding_at_beginning=self.args.num_padding_at_beginning,
rlhf_training=True,
disable_dropout=self.args.disable_critic_dropout,
dropout=self.args.critic_dropout,
zero_stage=self.args.critic_zero_stage)

# LoRA
Expand Down Expand Up @@ -295,7 +295,7 @@ def _init_reward(self, critic_model_name_or_path):
ds_config=ds_eval_config,
num_padding_at_beginning=self.args.num_padding_at_beginning,
rlhf_training=True,
disable_dropout=self.args.disable_critic_dropout,
dropout=self.args.critic_dropout,
zero_stage=zero_stage)

reward_engine, *_ = deepspeed.initialize(model=reward_model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ deepspeed --master_port 12346 main.py \
--actor_gradient_checkpointing \
--critic_gradient_checkpointing \
--offload_reference_model \
--disable_actor_dropout \
--actor_dropout 0.0 \
--num_warmup_steps 100 \
--deepspeed --seed 1234 \
--actor_zero_stage $ACTOR_ZERO_STAGE \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ deepspeed --master_port 12346 main.py \
--actor_gradient_checkpointing \
--critic_gradient_checkpointing \
--offload_reference_model \
--disable_actor_dropout \
--actor_dropout 0.0 \
--num_warmup_steps 100 \
--deepspeed --seed 1234 \
--actor_zero_stage $ACTOR_ZERO_STAGE \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ deepspeed --master_port 12346 main.py \
--actor_gradient_checkpointing \
--critic_gradient_checkpointing \
--offload_reference_model \
--disable_actor_dropout \
--actor_dropout 0.0 \
--num_warmup_steps 100 \
--deepspeed --seed 1234 \
--actor_zero_stage $ACTOR_ZERO_STAGE \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ deepspeed --master_port 12346 main.py \
--actor_zero_stage $ACTOR_ZERO_STAGE \
--critic_zero_stage $CRITIC_ZERO_STAGE \
--actor_gradient_checkpointing \
--disable_actor_dropout \
--actor_dropout 0.0 \
--actor_lora_dim 128 \
--actor_lora_module_name decoder.layers. \
--output_dir $OUTPUT \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,5 @@ deepspeed --num_gpus 1 main.py \
--actor_model_name_or_path $ACTOR_MODEL_PATH --critic_model_name_or_path $CRITIC_MODEL_PATH \
--actor_zero_stage $ACTOR_ZERO_STAGE --critic_zero_stage $CRITIC_ZERO_STAGE \
--num_padding_at_beginning 1 --gradient_accumulation_steps 2 \
--deepspeed --actor_lora_dim 128 --enable_hybrid_engine --actor_gradient_checkpointing --disable_actor_dropout \
--deepspeed --actor_lora_dim 128 --enable_hybrid_engine --actor_gradient_checkpointing --actor_dropout 0.0 \
--output_dir $OUTPUT &> $OUTPUT/training.log
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ deepspeed --num_gpus 1 main.py \
--actor_lora_dim 128 \
--actor_gradient_checkpointing \
--critic_gradient_checkpointing \
--disable_actor_dropout \
--actor_dropout 0.0 \
--enable_hybrid_engine \
--output_dir $OUTPUT \
&> $OUTPUT/training.log
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ deepspeed --master_port 12346 main.py \
--num_train_epochs 1 \
--lr_scheduler_type cosine \
--gradient_accumulation_steps 1 \
--disable_actor_dropout \
--actor_dropout 0.0 \
--num_warmup_steps 100 \
--deepspeed --seed 1234 \
--enable_hybrid_engine \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ deepspeed --master_port 12346 main.py \
--gradient_accumulation_steps 1 \
--num_warmup_steps 100 \
--deepspeed --seed 1234 \
--disable_actor_dropout \
--actor_dropout 0.0 \
${ACTOR_ZERO_STAGE} \
${CRITIC_ZERO_STAGE} \
--actor_lora_dim 128 \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ deepspeed --master_port 12346 main.py \
--actor_zero_stage $ACTOR_ZERO_STAGE \
--critic_zero_stage $CRITIC_ZERO_STAGE \
--actor_gradient_checkpointing \
--disable_actor_dropout \
--actor_dropout 0.0 \
--actor_lora_dim 128 \
--actor_lora_module_name decoder.layers. \
--output_dir $OUTPUT \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ deepspeed --master_port 12346 main.py \
--lr_scheduler_type cosine \
--gradient_accumulation_steps 1 \
--actor_gradient_checkpointing \
--disable_actor_dropout \
--actor_dropout 0.0 \
--num_warmup_steps 100 \
--deepspeed --seed 1234 \
${ACTOR_ZERO_STAGE} \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ deepspeed --master_port 12346 main.py \
--lr_scheduler_type cosine \
--gradient_accumulation_steps 1 \
--actor_gradient_checkpointing \
--disable_actor_dropout \
--actor_dropout 0.0 \
--num_warmup_steps 100 \
--deepspeed --seed 1234 \
--enable_hybrid_engine \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ cmd="deepspeed --num_nodes=1 main.py \
--critic_weight_decay 0 \
--num_warmup_steps 100 \
--deepspeed --seed 1234 \
--disable_actor_dropout \
--actor_dropout 0.0 \
--print_answers \
--actor_zero_stage ${ACTOR_ZERO_STAGE} \
--critic_zero_stage ${CRITIC_ZERO_STAGE} \
Expand Down
19 changes: 14 additions & 5 deletions applications/DeepSpeed-Chat/training/utils/model/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,24 @@
from ..utils import load_state_dict_into_model


def configure_dropout(model_config, dropout):
if dropout is not None:
for key in ('dropout', 'attention_dropout', 'hidden_dropout',
'activation_dropout'):
if hasattr(model_config, key):
print(f"Setting model_config.{key} to {dropout}")
setattr(model_config, key, dropout)


def create_hf_model(model_class,
model_name_or_path,
tokenizer,
ds_config=None,
rlhf_training=False,
disable_dropout=False):
dropout=None):
model_config = AutoConfig.from_pretrained(model_name_or_path)
if disable_dropout:
model_config.dropout = 0.0
configure_dropout(model_config, dropout)

# Note: dschf is defined in function scope to avoid global effects
# https://huggingface.co/docs/transformers/main_classes/deepspeed#nontrainer-deepspeed-integration
if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3:
Expand Down Expand Up @@ -54,7 +63,7 @@ def create_critic_model(model_name_or_path,
ds_config,
num_padding_at_beginning=0,
rlhf_training=False,
disable_dropout=False,
dropout=None,
zero_stage=0):
# OPT model family always put a padding token at the beginning of the sequence,
# we did not see this in other models but not sure if it is a general rule
Expand All @@ -63,7 +72,7 @@ def create_critic_model(model_name_or_path,

start = time.time()
critic_model = create_hf_model(AutoModel, model_name_or_path, tokenizer,
ds_config, rlhf_training, disable_dropout)
ds_config, rlhf_training, dropout)
end = time.time()
if torch.distributed.get_rank() == 0:
print(f"> Creating model from_config took {end - start} seconds")
Expand Down

0 comments on commit 4bf1924

Please sign in to comment.