@@ -1237,8 +1237,11 @@ def test_normalize_advantages_with_epsilon_zero_std():
12371237
12381238 result = normalize_advantages_with_epsilon (advantages , std , epsilon )
12391239
1240- # When std=0, result should be advantages / epsilon
1241- expected = torch .tensor ([[1.0 / epsilon ], [2.0 ], [3.0 / epsilon ]])
1240+ # When std=0 AND advantage!=0, normalization is skipped (advantages unchanged)
1241+ # When std>0, normal normalization occurs
1242+ expected = torch .tensor (
1243+ [[1.0 ], [2.0 ], [3.0 ]]
1244+ ) # Samples 0,2 unchanged; sample 1 normalized
12421245 assert torch .allclose (result , expected , rtol = 1e-5 )
12431246
12441247
@@ -1248,9 +1251,12 @@ def test_normalize_advantages_with_epsilon_all_zero_std():
12481251 std = torch .tensor ([0.0 , 0.0 , 0.0 ])
12491252 epsilon = 1e-8
12501253
1254+ # Save expected values BEFORE calling function (since it modifies in-place)
1255+ expected = advantages .clone ()
1256+
12511257 result = normalize_advantages_with_epsilon (advantages , std , epsilon )
12521258
1253- expected = advantages / epsilon
1259+ # When std=0 AND advantage!=0, normalization is skipped (all unchanged)
12541260 assert torch .allclose (result , expected , rtol = 1e-5 )
12551261
12561262
@@ -1281,3 +1287,62 @@ def test_normalize_advantages_with_epsilon_negative_advantages():
12811287
12821288 expected = torch .tensor ([[- 2.0 ], [2.0 ], [- 3.0 ]])
12831289 assert torch .allclose (result , expected , rtol = 1e-5 )
1290+
1291+
1292+ def test_normalize_advantages_with_zero_std_from_leave_one_out ():
1293+ """Test that zero std (from leave-one-out baseline) is handled gracefully by skipping normalization."""
1294+ # Simulate the leave-one-out case: rewards [1.0, 0.0, 0.0, 0.0]
1295+ # Sample 0 has baseline from [0, 0, 0] -> std=0, advantage=1.0
1296+ # Samples 1-3 have baseline from [1, 0, 0] -> std≈0.577, advantage≈-0.333
1297+ advantages = torch .tensor ([[1.0 ], [- 0.333 ], [- 0.333 ], [- 0.333 ]])
1298+ std = torch .tensor ([0.0 , 0.577 , 0.577 , 0.577 ])
1299+ epsilon = 1e-6
1300+
1301+ # Compute expected values BEFORE calling function (since it modifies in-place)
1302+ expected_sample_0 = advantages [0 ].clone ()
1303+ expected_normalized = advantages [1 :].clone () / (std [1 :].unsqueeze (- 1 ) + epsilon )
1304+
1305+ result = normalize_advantages_with_epsilon (advantages , std , epsilon )
1306+
1307+ # Sample 0: std=0 -> advantage unchanged (skip normalization)
1308+ assert torch .allclose (result [0 ], expected_sample_0 , rtol = 1e-5 )
1309+
1310+ # Samples 1-3: std>0 -> normalized with epsilon
1311+ assert torch .allclose (result [1 :], expected_normalized , rtol = 1e-5 )
1312+
1313+
1314+ def test_normalize_advantages_with_zero_std_and_zero_advantage ():
1315+ """Test that zero std with zero advantage is left unchanged."""
1316+ advantages = torch .tensor ([[0.0 ], [1.0 ], [0.0 ]])
1317+ std = torch .tensor ([0.0 , 0.0 , 1.0 ])
1318+ epsilon = 1e-6
1319+
1320+ # Compute expected values BEFORE calling function (since it modifies in-place)
1321+ expected_sample_0 = advantages [0 ].clone ()
1322+ expected_sample_1 = advantages [1 ].clone ()
1323+ expected_sample_2 = advantages [2 ].clone () / (std [2 ] + epsilon )
1324+
1325+ result = normalize_advantages_with_epsilon (advantages , std , epsilon )
1326+
1327+ # Sample 0: std=0, advantage=0 -> unchanged (skip normalization)
1328+ assert torch .allclose (result [0 ], expected_sample_0 , rtol = 1e-5 )
1329+
1330+ # Sample 1: std=0, advantage!=0 -> unchanged (skip normalization)
1331+ assert torch .allclose (result [1 ], expected_sample_1 , rtol = 1e-5 )
1332+
1333+ # Sample 2: std>0 -> normalize with epsilon
1334+ assert torch .allclose (result [2 ], expected_sample_2 , rtol = 1e-5 )
1335+
1336+
1337+ def test_normalize_advantages_with_small_nonzero_std ():
1338+ """Test that small but non-zero std values still get normalized (no threshold)."""
1339+ advantages = torch .tensor ([[2.0 ], [3.0 ], [- 1.0 ]])
1340+ std = torch .tensor ([0.001 , 0.01 , 0.0001 ]) # All small but non-zero
1341+
1342+ # Compute expected values BEFORE calling function (since it modifies in-place)
1343+ expected = advantages .clone () / (std .unsqueeze (- 1 ) + 1e-6 )
1344+
1345+ result = normalize_advantages_with_epsilon (advantages , std )
1346+
1347+ # All should be normalized since std > 0
1348+ assert torch .allclose (result , expected , rtol = 1e-5 )
0 commit comments