@@ -1085,6 +1085,102 @@ def test_distribution_and_structure(
10851085 first_pad_idx = padding_mask .nonzero (as_tuple = True )[0 ][0 ].item ()
10861086 self .assertTrue (torch .all (row [first_pad_idx :] == pad_token_id ))
10871087
1088+ def test_calculate_loss_and_backward_no_accumulation_steps_division (self ):
1089+ """Test that loss is NOT divided by accumulation_steps.
1090+
1091+ This test verifies the bug fix: the old code incorrectly divided loss by
1092+ accumulation_steps, which meant gradient magnitude depended on batch splitting
1093+ configuration rather than actual number of tokens.
1094+
1095+ With the fix, loss should be the same regardless of accumulation_steps value
1096+ (since we're using the same data and token count).
1097+ """
1098+ torch .manual_seed (42 )
1099+ np .random .seed (42 )
1100+
1101+ batch_size = 4
1102+ seq_len = 16
1103+
1104+ local_logprobs = torch .randn (batch_size , seq_len )
1105+ old_logprobs = torch .randn (batch_size , seq_len )
1106+ ref_logprob = torch .randn (batch_size , seq_len )
1107+ advantages = torch .randn (batch_size , seq_len + 1 )
1108+ response_masks_bool = torch .ones (batch_size , seq_len , dtype = torch .bool )
1109+ entropy = torch .randn (batch_size , seq_len )
1110+
1111+ args = grpo_fast .Args ()
1112+ args .clip_lower = 0.2
1113+ args .clip_higher = 0.2
1114+ args .beta = 0.05
1115+ args .kl_estimator = "kl3"
1116+ args .masked_mean_axis = None
1117+ args .masked_mean_denominator = None
1118+ args .truncated_importance_sampling_ratio_cap = 0.0
1119+ args .record_entropy = False
1120+
1121+ mock_model_1 = Mock ()
1122+ mock_model_1 .backward = Mock ()
1123+ loss_statistics_1 = grpo_fast .LossStatistics (num_batches = 1 , record_entropy = False )
1124+
1125+ grpo_fast .calculate_loss_and_backward (
1126+ mock_model_1 ,
1127+ 0 ,
1128+ loss_statistics_1 ,
1129+ local_logprobs .clone (),
1130+ old_logprobs .clone (),
1131+ ref_logprob .clone (),
1132+ advantages .clone (),
1133+ response_masks_bool .clone (),
1134+ None ,
1135+ entropy .clone (),
1136+ accumulation_steps = 1 ,
1137+ local_step = 0 ,
1138+ args = args ,
1139+ )
1140+
1141+ loss_with_accum_1 = mock_model_1 .backward .call_args [0 ][0 ].item ()
1142+
1143+ torch .manual_seed (42 )
1144+ np .random .seed (42 )
1145+
1146+ local_logprobs = torch .randn (batch_size , seq_len )
1147+ old_logprobs = torch .randn (batch_size , seq_len )
1148+ ref_logprob = torch .randn (batch_size , seq_len )
1149+ advantages = torch .randn (batch_size , seq_len + 1 )
1150+ response_masks_bool = torch .ones (batch_size , seq_len , dtype = torch .bool )
1151+ entropy = torch .randn (batch_size , seq_len )
1152+
1153+ mock_model_4 = Mock ()
1154+ mock_model_4 .backward = Mock ()
1155+ loss_statistics_4 = grpo_fast .LossStatistics (num_batches = 1 , record_entropy = False )
1156+
1157+ grpo_fast .calculate_loss_and_backward (
1158+ mock_model_4 ,
1159+ 0 ,
1160+ loss_statistics_4 ,
1161+ local_logprobs .clone (),
1162+ old_logprobs .clone (),
1163+ ref_logprob .clone (),
1164+ advantages .clone (),
1165+ response_masks_bool .clone (),
1166+ None ,
1167+ entropy .clone (),
1168+ accumulation_steps = 4 ,
1169+ local_step = 0 ,
1170+ args = args ,
1171+ )
1172+
1173+ loss_with_accum_4 = mock_model_4 .backward .call_args [0 ][0 ].item ()
1174+
1175+ self .assertAlmostEqual (
1176+ loss_with_accum_1 ,
1177+ loss_with_accum_4 ,
1178+ places = 5 ,
1179+ msg = f"Loss should be the same regardless of accumulation_steps. "
1180+ f"Got { loss_with_accum_1 :.6f} (accum=1) vs { loss_with_accum_4 :.6f} (accum=4). "
1181+ f"Old buggy code would have made accum=4 loss 4x smaller." ,
1182+ )
1183+
10881184
10891185if __name__ == "__main__" :
10901186 unittest .main ()
0 commit comments