@@ -623,7 +623,7 @@ def maybe_apply_importance_sampling(
623623 return unclipped_pg_loss , clipped_pg_loss
624624
625625
626- def calculate_loss (
626+ def calculate_loss_and_backward (
627627 model : deepspeed .DeepSpeedEngine ,
628628 i : int ,
629629 loss_statistics : LossStatistics ,
@@ -638,14 +638,13 @@ def calculate_loss(
638638 local_step : int ,
639639 args : Args ,
640640) -> int :
641- """Calculate and apply GRPO loss for a single minibatch.
641+ """Calculate GRPO loss and perform backward pass for a single minibatch.
642642
643643 Computes the policy gradient loss using the clipped surrogate objective from PPO,
644- combines it with a KL penalty term, performs the backward pass, and optionally
645- steps the optimizer.
644+ combines it with a KL penalty term, and performs the backward pass.
646645
647646 Args:
648- model: Model wrapper with backward() and step() methods (e.g., DeepSpeed engine)
647+ model: Model wrapper with backward() method (e.g., DeepSpeed engine)
649648 i: Minibatch index for tracking statistics
650649 loss_statistics: LossStatistics object to accumulate training metrics
651650 local_logprobs: Log probabilities from current policy (shape: [batch, seq_len])
@@ -688,8 +687,6 @@ def calculate_loss(
688687 )
689688 loss = loss / accumulation_steps
690689 model .backward (loss )
691- if (local_step + 1 ) % accumulation_steps == 0 :
692- model .step ()
693690
694691 with torch .no_grad ():
695692 loss_statistics .update_stats (
@@ -1298,7 +1295,7 @@ def train(
12981295 f"response_mask sum={ mb_response_masks_bool .sum ()} "
12991296 )
13001297
1301- local_step = calculate_loss (
1298+ local_step = calculate_loss_and_backward (
13021299 self .model ,
13031300 i ,
13041301 loss_statistics ,
@@ -1313,6 +1310,8 @@ def train(
13131310 local_step ,
13141311 args ,
13151312 )
1313+ if local_step % accumulation_steps == 0 :
1314+ self .model .step ()
13161315
13171316 local_metrics |= loss_statistics .to_dict ()
13181317 local_metrics ["lr" ] = self .scheduler .get_last_lr ()[0 ]
0 commit comments