Skip to content

Commit 26ee333

Browse files
updated test
1 parent 196a26b commit 26ee333

File tree

2 files changed

+102
-3
lines changed

2 files changed

+102
-3
lines changed

open_instruct/grpo_fast.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -682,10 +682,13 @@ def calculate_loss_and_backward(
682682
ref_logprobs_diff = (local_logprobs - ref_logprob).clamp(-40.0, 40.0)
683683
kl = loss_statistics.update_kl_estimates(i, ref_logprobs_diff, ratio, response_masks_bool, args)
684684

685-
loss = masked_mean(
686-
pg_loss_max + (args.beta * kl), response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator
685+
total_loss = pg_loss_max + (args.beta * kl)
686+
loss_sum = (total_loss * response_masks_bool).sum()
687+
denominator = (
688+
args.masked_mean_denominator if args.masked_mean_denominator is not None else response_masks_bool.sum()
687689
)
688-
loss = loss / accumulation_steps
690+
loss = loss_sum / denominator
691+
689692
model.backward(loss)
690693

691694
with torch.no_grad():

open_instruct/test_grpo_fast.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1085,6 +1085,102 @@ def test_distribution_and_structure(
10851085
first_pad_idx = padding_mask.nonzero(as_tuple=True)[0][0].item()
10861086
self.assertTrue(torch.all(row[first_pad_idx:] == pad_token_id))
10871087

1088+
def test_calculate_loss_and_backward_no_accumulation_steps_division(self):
1089+
"""Test that loss is NOT divided by accumulation_steps.
1090+
1091+
This test verifies the bug fix: the old code incorrectly divided loss by
1092+
accumulation_steps, which meant gradient magnitude depended on batch splitting
1093+
configuration rather than actual number of tokens.
1094+
1095+
With the fix, loss should be the same regardless of accumulation_steps value
1096+
(since we're using the same data and token count).
1097+
"""
1098+
torch.manual_seed(42)
1099+
np.random.seed(42)
1100+
1101+
batch_size = 4
1102+
seq_len = 16
1103+
1104+
local_logprobs = torch.randn(batch_size, seq_len)
1105+
old_logprobs = torch.randn(batch_size, seq_len)
1106+
ref_logprob = torch.randn(batch_size, seq_len)
1107+
advantages = torch.randn(batch_size, seq_len + 1)
1108+
response_masks_bool = torch.ones(batch_size, seq_len, dtype=torch.bool)
1109+
entropy = torch.randn(batch_size, seq_len)
1110+
1111+
args = grpo_fast.Args()
1112+
args.clip_lower = 0.2
1113+
args.clip_higher = 0.2
1114+
args.beta = 0.05
1115+
args.kl_estimator = "kl3"
1116+
args.masked_mean_axis = None
1117+
args.masked_mean_denominator = None
1118+
args.truncated_importance_sampling_ratio_cap = 0.0
1119+
args.record_entropy = False
1120+
1121+
mock_model_1 = Mock()
1122+
mock_model_1.backward = Mock()
1123+
loss_statistics_1 = grpo_fast.LossStatistics(num_batches=1, record_entropy=False)
1124+
1125+
grpo_fast.calculate_loss_and_backward(
1126+
mock_model_1,
1127+
0,
1128+
loss_statistics_1,
1129+
local_logprobs.clone(),
1130+
old_logprobs.clone(),
1131+
ref_logprob.clone(),
1132+
advantages.clone(),
1133+
response_masks_bool.clone(),
1134+
None,
1135+
entropy.clone(),
1136+
accumulation_steps=1,
1137+
local_step=0,
1138+
args=args,
1139+
)
1140+
1141+
loss_with_accum_1 = mock_model_1.backward.call_args[0][0].item()
1142+
1143+
torch.manual_seed(42)
1144+
np.random.seed(42)
1145+
1146+
local_logprobs = torch.randn(batch_size, seq_len)
1147+
old_logprobs = torch.randn(batch_size, seq_len)
1148+
ref_logprob = torch.randn(batch_size, seq_len)
1149+
advantages = torch.randn(batch_size, seq_len + 1)
1150+
response_masks_bool = torch.ones(batch_size, seq_len, dtype=torch.bool)
1151+
entropy = torch.randn(batch_size, seq_len)
1152+
1153+
mock_model_4 = Mock()
1154+
mock_model_4.backward = Mock()
1155+
loss_statistics_4 = grpo_fast.LossStatistics(num_batches=1, record_entropy=False)
1156+
1157+
grpo_fast.calculate_loss_and_backward(
1158+
mock_model_4,
1159+
0,
1160+
loss_statistics_4,
1161+
local_logprobs.clone(),
1162+
old_logprobs.clone(),
1163+
ref_logprob.clone(),
1164+
advantages.clone(),
1165+
response_masks_bool.clone(),
1166+
None,
1167+
entropy.clone(),
1168+
accumulation_steps=4,
1169+
local_step=0,
1170+
args=args,
1171+
)
1172+
1173+
loss_with_accum_4 = mock_model_4.backward.call_args[0][0].item()
1174+
1175+
self.assertAlmostEqual(
1176+
loss_with_accum_1,
1177+
loss_with_accum_4,
1178+
places=5,
1179+
msg=f"Loss should be the same regardless of accumulation_steps. "
1180+
f"Got {loss_with_accum_1:.6f} (accum=1) vs {loss_with_accum_4:.6f} (accum=4). "
1181+
f"Old buggy code would have made accum=4 loss 4x smaller.",
1182+
)
1183+
10881184

10891185
if __name__ == "__main__":
10901186
unittest.main()

0 commit comments

Comments
 (0)