Skip to content

Commit 3f716d6

Browse files
Updated code
1 parent abe4387 commit 3f716d6

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

CLAUDE.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,8 @@
88
- To run the `./scripts/train/build_image_and_launch.sh` script, you must commit the current changes.
99
- Launch tool use experiments by running `./scripts/train/build_image_and_launch.sh scripts/train/debug/tool_grpo_fast.sh`.
1010
- Launch multi-node non-tool experiments by running `./scripts/train/build_image_and_launch.sh scripts/train/debug/large_test_script.sh`.
11+
12+
# Comments Policy
13+
- NEVER remove existing comments from code when making edits unless they are obviously outdated, in which case ALWAYS ask for permission.
14+
- Always preserve all existing comments, especially explanatory ones
15+
- Only add comments when they are needed for clarity

open_instruct/grpo_fast.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -623,7 +623,7 @@ def maybe_apply_importance_sampling(
623623
return unclipped_pg_loss, clipped_pg_loss
624624

625625

626-
def calculate_loss(
626+
def calculate_loss_and_backward(
627627
model: deepspeed.DeepSpeedEngine,
628628
i: int,
629629
loss_statistics: LossStatistics,
@@ -638,14 +638,13 @@ def calculate_loss(
638638
local_step: int,
639639
args: Args,
640640
) -> int:
641-
"""Calculate and apply GRPO loss for a single minibatch.
641+
"""Calculate GRPO loss and perform backward pass for a single minibatch.
642642
643643
Computes the policy gradient loss using the clipped surrogate objective from PPO,
644-
combines it with a KL penalty term, performs the backward pass, and optionally
645-
steps the optimizer.
644+
combines it with a KL penalty term, and performs the backward pass.
646645
647646
Args:
648-
model: Model wrapper with backward() and step() methods (e.g., DeepSpeed engine)
647+
model: Model wrapper with backward() method (e.g., DeepSpeed engine)
649648
i: Minibatch index for tracking statistics
650649
loss_statistics: LossStatistics object to accumulate training metrics
651650
local_logprobs: Log probabilities from current policy (shape: [batch, seq_len])
@@ -688,8 +687,6 @@ def calculate_loss(
688687
)
689688
loss = loss / accumulation_steps
690689
model.backward(loss)
691-
if (local_step + 1) % accumulation_steps == 0:
692-
model.step()
693690

694691
with torch.no_grad():
695692
loss_statistics.update_stats(
@@ -1298,7 +1295,7 @@ def train(
12981295
f"response_mask sum={mb_response_masks_bool.sum()}"
12991296
)
13001297

1301-
local_step = calculate_loss(
1298+
local_step = calculate_loss_and_backward(
13021299
self.model,
13031300
i,
13041301
loss_statistics,
@@ -1313,6 +1310,8 @@ def train(
13131310
local_step,
13141311
args,
13151312
)
1313+
if local_step % accumulation_steps == 0:
1314+
self.model.step()
13161315

13171316
local_metrics |= loss_statistics.to_dict()
13181317
local_metrics["lr"] = self.scheduler.get_last_lr()[0]

0 commit comments

Comments
 (0)