Skip to content

Commit 775fc34

Browse files
authored
fix: Incompatible configuration between reward normalization and the loo (#1519)
Signed-off-by: Felipe Vieira Frujeri <ffrujeri@nvidia.com>
1 parent 74b9b17 commit 775fc34

File tree

2 files changed

+80
-7
lines changed

2 files changed

+80
-7
lines changed

nemo_rl/algorithms/grpo.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -638,18 +638,26 @@ def normalize_advantages_with_epsilon(
638638
std: torch.Tensor,
639639
epsilon: float = 1e-6,
640640
) -> torch.Tensor:
641-
"""Normalize advantages by standard deviation with epsilon to avoid division by zero.
641+
"""Normalize advantages by standard deviation, skipping samples with zero std.
642+
643+
When std is exactly zero (from leave-one-out baseline with identical rewards),
644+
normalization is skipped for those samples to prevent numerical instability.
645+
This makes normalize_rewards compatible with use_leave_one_out_baseline.
642646
643647
Args:
644648
advantages: Tensor of shape (batch_size, 1) containing advantage values
645649
std: Tensor of shape (batch_size,) containing standard deviation values
646-
epsilon: Small value to avoid division by zero, defaults to 1e-6
650+
epsilon: Small value to avoid division by very small std, defaults to 1e-6
647651
648652
Returns:
649653
Normalized advantages tensor of same shape as input advantages
650654
"""
651-
# Use epsilon to avoid division by zero instead of masking
652-
return advantages / (std.unsqueeze(-1) + epsilon)
655+
# Only normalize where std > 0 to avoid division by near-zero
656+
non_zero_std_mask = std > 0
657+
advantages[non_zero_std_mask] = advantages[non_zero_std_mask] / (
658+
std.unsqueeze(-1)[non_zero_std_mask] + epsilon
659+
)
660+
return advantages
653661

654662

655663
def dynamic_sampling(

tests/unit/algorithms/test_grpo.py

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1237,8 +1237,11 @@ def test_normalize_advantages_with_epsilon_zero_std():
12371237

12381238
result = normalize_advantages_with_epsilon(advantages, std, epsilon)
12391239

1240-
# When std=0, result should be advantages / epsilon
1241-
expected = torch.tensor([[1.0 / epsilon], [2.0], [3.0 / epsilon]])
1240+
# When std=0 AND advantage!=0, normalization is skipped (advantages unchanged)
1241+
# When std>0, normal normalization occurs
1242+
expected = torch.tensor(
1243+
[[1.0], [2.0], [3.0]]
1244+
) # Samples 0,2 unchanged; sample 1 normalized
12421245
assert torch.allclose(result, expected, rtol=1e-5)
12431246

12441247

@@ -1248,9 +1251,12 @@ def test_normalize_advantages_with_epsilon_all_zero_std():
12481251
std = torch.tensor([0.0, 0.0, 0.0])
12491252
epsilon = 1e-8
12501253

1254+
# Save expected values BEFORE calling function (since it modifies in-place)
1255+
expected = advantages.clone()
1256+
12511257
result = normalize_advantages_with_epsilon(advantages, std, epsilon)
12521258

1253-
expected = advantages / epsilon
1259+
# When std=0 AND advantage!=0, normalization is skipped (all unchanged)
12541260
assert torch.allclose(result, expected, rtol=1e-5)
12551261

12561262

@@ -1281,3 +1287,62 @@ def test_normalize_advantages_with_epsilon_negative_advantages():
12811287

12821288
expected = torch.tensor([[-2.0], [2.0], [-3.0]])
12831289
assert torch.allclose(result, expected, rtol=1e-5)
1290+
1291+
1292+
def test_normalize_advantages_with_zero_std_from_leave_one_out():
1293+
"""Test that zero std (from leave-one-out baseline) is handled gracefully by skipping normalization."""
1294+
# Simulate the leave-one-out case: rewards [1.0, 0.0, 0.0, 0.0]
1295+
# Sample 0 has baseline from [0, 0, 0] -> std=0, advantage=1.0
1296+
# Samples 1-3 have baseline from [1, 0, 0] -> std≈0.577, advantage≈-0.333
1297+
advantages = torch.tensor([[1.0], [-0.333], [-0.333], [-0.333]])
1298+
std = torch.tensor([0.0, 0.577, 0.577, 0.577])
1299+
epsilon = 1e-6
1300+
1301+
# Compute expected values BEFORE calling function (since it modifies in-place)
1302+
expected_sample_0 = advantages[0].clone()
1303+
expected_normalized = advantages[1:].clone() / (std[1:].unsqueeze(-1) + epsilon)
1304+
1305+
result = normalize_advantages_with_epsilon(advantages, std, epsilon)
1306+
1307+
# Sample 0: std=0 -> advantage unchanged (skip normalization)
1308+
assert torch.allclose(result[0], expected_sample_0, rtol=1e-5)
1309+
1310+
# Samples 1-3: std>0 -> normalized with epsilon
1311+
assert torch.allclose(result[1:], expected_normalized, rtol=1e-5)
1312+
1313+
1314+
def test_normalize_advantages_with_zero_std_and_zero_advantage():
1315+
"""Test that zero std with zero advantage is left unchanged."""
1316+
advantages = torch.tensor([[0.0], [1.0], [0.0]])
1317+
std = torch.tensor([0.0, 0.0, 1.0])
1318+
epsilon = 1e-6
1319+
1320+
# Compute expected values BEFORE calling function (since it modifies in-place)
1321+
expected_sample_0 = advantages[0].clone()
1322+
expected_sample_1 = advantages[1].clone()
1323+
expected_sample_2 = advantages[2].clone() / (std[2] + epsilon)
1324+
1325+
result = normalize_advantages_with_epsilon(advantages, std, epsilon)
1326+
1327+
# Sample 0: std=0, advantage=0 -> unchanged (skip normalization)
1328+
assert torch.allclose(result[0], expected_sample_0, rtol=1e-5)
1329+
1330+
# Sample 1: std=0, advantage!=0 -> unchanged (skip normalization)
1331+
assert torch.allclose(result[1], expected_sample_1, rtol=1e-5)
1332+
1333+
# Sample 2: std>0 -> normalize with epsilon
1334+
assert torch.allclose(result[2], expected_sample_2, rtol=1e-5)
1335+
1336+
1337+
def test_normalize_advantages_with_small_nonzero_std():
1338+
"""Test that small but non-zero std values still get normalized (no threshold)."""
1339+
advantages = torch.tensor([[2.0], [3.0], [-1.0]])
1340+
std = torch.tensor([0.001, 0.01, 0.0001]) # All small but non-zero
1341+
1342+
# Compute expected values BEFORE calling function (since it modifies in-place)
1343+
expected = advantages.clone() / (std.unsqueeze(-1) + 1e-6)
1344+
1345+
result = normalize_advantages_with_epsilon(advantages, std)
1346+
1347+
# All should be normalized since std > 0
1348+
assert torch.allclose(result, expected, rtol=1e-5)

0 commit comments

Comments
 (0)