Skip to content

Commit

Permalink
[bug fixes]
Browse files Browse the repository at this point in the history
  • Loading branch information
kyegomez committed Jan 29, 2025
1 parent 1c808b9 commit 51a0b5b
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 10 deletions.
64 changes: 55 additions & 9 deletions agentgym/r1_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def __init__(
tokenizer_name: str = "None",
use_prebuilt_reward_funcs: bool = True,
only_grpo: bool = False,
use_vllm: bool = False,
*args,
**kwargs,
):
Expand Down Expand Up @@ -152,6 +153,8 @@ def __init__(
self.grpo_args = grpo_args
self.use_prebuilt_reward_funcs = use_prebuilt_reward_funcs
self.only_grpo = only_grpo
self.use_vllm = use_vllm

self.saved_model_file_path = f"{self.output_dir}/{model_name}_{generate_model_uuid()}.pth"

self.check_for_flash_attention()
Expand All @@ -166,11 +169,11 @@ def __init__(
peft_config=(
self.peft_config if sft_lora_only is True else None
),
use_liger=(
self.liger_kernel_on
if liger_kernel_on is True
else False
),
# use_liger=(
# self.liger_kernel_on
# if liger_kernel_on is True
# else False
# ),
*args,
**kwargs,
)
Expand Down Expand Up @@ -221,9 +224,52 @@ def check_for_flash_attention(self):
raise e

def load_grpo_args(self):
config = GRPOArgs(**self.grpo_args)
training_args = GRPOConfig(**config)
return training_args
args = GRPOArgs()

# Ensure all necessary attributes are set with default values if they are None
args.output_dir = args.output_dir or '/tmp'
args.run_name = args.run_name or 'default_run_name'
args.learning_rate = args.learning_rate or 5e-5
args.adam_beta1 = args.adam_beta1 or 0.9
args.adam_beta2 = args.adam_beta2 or 0.999
args.weight_decay = args.weight_decay or 0.01
args.warmup_ratio = args.warmup_ratio or 0.1
args.lr_scheduler_type = args.lr_scheduler_type or 'linear'
args.logging_steps = args.logging_steps or 500
args.bf16 = args.bf16 if args.bf16 is not None else False
args.per_device_train_batch_size = args.per_device_train_batch_size or 8
args.gradient_accumulation_steps = args.gradient_accumulation_steps or 1
args.num_generations = args.num_generations or 1
args.max_prompt_length = args.max_prompt_length or 512
args.max_completion_length = args.max_completion_length or 128
args.num_train_epochs = args.num_train_epochs or 3
args.save_steps = args.save_steps or 1000
args.max_grad_norm = args.max_grad_norm or 1.0
args.report_to = args.report_to or 'none'
args.log_on_each_node = args.log_on_each_node if args.log_on_each_node is not None else False

return GRPOConfig(
output_dir=args.output_dir,
run_name=args.run_name,
learning_rate=args.learning_rate,
adam_beta1=args.adam_beta1,
adam_beta2=args.adam_beta2,
weight_decay=args.weight_decay,
warmup_ratio=args.warmup_ratio,
lr_scheduler_type=args.lr_scheduler_type,
logging_steps=args.logging_steps,
bf16=args.bf16,
per_device_train_batch_size=args.per_device_train_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
num_generations=args.num_generations,
max_prompt_length=args.max_prompt_length,
max_completion_length=args.max_completion_length,
num_train_epochs=args.num_train_epochs,
save_steps=args.save_steps,
max_grad_norm=args.max_grad_norm,
report_to=args.report_to,
log_on_each_node=args.log_on_each_node,
)

def grpo_train(self, model_path: str, *args, **kwargs):
try:
Expand All @@ -236,7 +282,7 @@ def grpo_train(self, model_path: str, *args, **kwargs):
)

trainer = GRPOTrainer(
model=model_path,
model=self.sft_model,
processing_class=self.tokenizer,
reward_funcs=reward_funcs,
args=training_args,
Expand Down
4 changes: 3 additions & 1 deletion example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

r1_pipeline = R1Pipeline(
sft_model="gpt2",
tokenizer_name="gpt2",
sft_dataset="stanfordnlp/imdb",
sft_args=SFTConfig(output_dir="/tmp"),
only_grpo=True
only_grpo=True,
model_name="gpt-2"
)

r1_pipeline.run()

0 comments on commit 51a0b5b

Please sign in to comment.