Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions nemo_rl/algorithms/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def __call__(
vocab_parallel_rank: Optional[int] = None,
vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
context_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
logprob_chunk_size: Optional[int] = None,
) -> tuple[torch.Tensor, dict]:
"""Clipped Policy Gradient RL loss function."""
token_mask = data["token_mask"][:, 1:]
Expand Down Expand Up @@ -171,12 +172,16 @@ def __call__(
tp_group=vocab_parallel_group,
inference_only=False,
cp_group=context_parallel_group,
chunk_size=logprob_chunk_size,
)
# slice off to the correct length to remove potential CP padding
curr_logprobs = curr_logprobs[:, : data["input_ids"].shape[1] - 1]
elif isinstance(next_token_logits, torch.distributed.tensor.DTensor):
curr_logprobs = get_logprobs_from_vocab_parallel_logits(
next_token_logits, data["input_ids"], seq_index=seq_index
next_token_logits,
data["input_ids"],
seq_index=seq_index,
chunk_size=logprob_chunk_size,
)
else:
next_token_logits_wo_last = next_token_logits[
Expand Down Expand Up @@ -373,6 +378,7 @@ def __call__(
vocab_parallel_rank: Optional[int] = None,
vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
context_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
logprob_chunk_size: Optional[int] = None,
dpo_loss: bool = False,
dpo_average_log_probs: bool = False,
) -> tuple[torch.Tensor, dict[str, Any]]:
Expand All @@ -398,12 +404,16 @@ def __call__(
tp_group=vocab_parallel_group,
inference_only=False,
cp_group=context_parallel_group,
chunk_size=logprob_chunk_size,
)
# slice off to the correct length to remove potential CP padding
token_logprobs = token_logprobs[:, : data["input_ids"].shape[1] - 1]
elif isinstance(next_token_logits, torch.distributed.tensor.DTensor):
token_logprobs = get_logprobs_from_vocab_parallel_logits(
next_token_logits, data["input_ids"], seq_index=seq_index
next_token_logits,
data["input_ids"],
seq_index=seq_index,
chunk_size=logprob_chunk_size,
)
else:
next_tokens = data["input_ids"][:, 1:].cuda() # Skip first token
Expand Down Expand Up @@ -518,6 +528,7 @@ def __call__(
data: BatchedDataDict[PreferenceLossDataDict],
global_valid_seqs: Tensor,
global_valid_toks: Tensor | None,
logprob_chunk_size: Optional[int] = None,
) -> tuple[torch.Tensor, dict[str, Any]]:
sample_mask = data["sample_mask"]

Expand Down Expand Up @@ -633,6 +644,7 @@ def _dpo_loss(
vocab_parallel_rank: Optional[int] = None,
vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
context_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
logprob_chunk_size: Optional[int] = None,
) -> tuple[Tensor, Tensor, Tensor, Tensor]:
## TODO(@ashors): there's some duplicate code here with the NLLLoss function. We should refactor
token_mask = data["token_mask"][:, 1:]
Expand All @@ -652,12 +664,16 @@ def _dpo_loss(
tp_group=vocab_parallel_group,
inference_only=False,
cp_group=context_parallel_group,
chunk_size=logprob_chunk_size,
)
# slice off to the correct length to remove potential CP padding
token_logprobs = token_logprobs[:, : data["input_ids"].shape[1] - 1]
elif isinstance(next_token_logits, torch.distributed.tensor.DTensor):
token_logprobs = get_logprobs_from_vocab_parallel_logits(
next_token_logits, data["input_ids"], seq_index=seq_index
next_token_logits,
data["input_ids"],
seq_index=seq_index,
chunk_size=logprob_chunk_size,
)
else:
next_tokens = data["input_ids"][:, 1:].cuda() # Skip first token
Expand Down Expand Up @@ -691,6 +707,7 @@ def __call__( # type: ignore
vocab_parallel_rank: Optional[int] = None,
vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
context_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
logprob_chunk_size: Optional[int] = None,
) -> tuple[torch.Tensor, dict[str, Any]]:
sft_loss_chosen = torch.tensor(0.0)
if self.sft_loss_weight > 0:
Expand All @@ -705,6 +722,7 @@ def __call__( # type: ignore
vocab_parallel_rank=vocab_parallel_rank,
vocab_parallel_group=vocab_parallel_group,
context_parallel_group=context_parallel_group,
logprob_chunk_size=logprob_chunk_size,
dpo_loss=True,
dpo_average_log_probs=self.sft_average_log_probs,
)
Expand All @@ -727,6 +745,7 @@ def __call__( # type: ignore
vocab_parallel_rank=vocab_parallel_rank,
vocab_parallel_group=vocab_parallel_group,
context_parallel_group=context_parallel_group,
logprob_chunk_size=logprob_chunk_size,
)

dpo_loss = (
Expand Down Expand Up @@ -768,6 +787,7 @@ def __call__(
vocab_parallel_rank: Optional[int] = None,
vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
context_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
logprob_chunk_size: Optional[int] = None,
) -> tuple[Tensor, dict[str, Any]]:
"""Wraps a loss function to handle sequence packing by doing one sequence at a time to avoid excessive padding."""
unpadded_cu_seqlens = self.cu_seqlens_q
Expand Down Expand Up @@ -818,6 +838,7 @@ def __call__(
vocab_parallel_rank=vocab_parallel_rank,
vocab_parallel_group=vocab_parallel_group,
context_parallel_group=context_parallel_group,
logprob_chunk_size=logprob_chunk_size,
)
loss_accum += loss
for k, v in metrics.items():
Expand Down Expand Up @@ -867,6 +888,7 @@ def __call__(
vocab_parallel_rank: Optional[int] = None,
vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
context_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
logprob_chunk_size: Optional[int] = None,
) -> tuple[torch.Tensor, dict[str, Any]]:
"""Compute distillation loss between teacher and student logits."""
# Basic shapes
Expand Down Expand Up @@ -942,6 +964,7 @@ def __call__(
)

S_local = int(logits_tensor.shape[1])
# TODO: hardcoded 1024 below ignores logprob_chunk_size.
chunk_size = max(1, min(S_local, 1024))
student_topk_logprobs = ChunkedDistributedGatherLogprob.apply( # type: ignore
logits_tensor,
Expand Down
6 changes: 6 additions & 0 deletions nemo_rl/models/megatron/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,11 @@ def forward_step_arbitrary_loss(

loss_data = data_dict

if policy_cfg is not None:
logprob_chunk_size = policy_cfg.get("logprob_chunk_size", None)
else:
logprob_chunk_size = None

loss_fn_wrapped = partial(
loss_fn,
data=loss_data,
Expand All @@ -385,6 +390,7 @@ def forward_step_arbitrary_loss(
vocab_parallel_rank=get_tensor_model_parallel_rank(),
vocab_parallel_group=get_tensor_model_parallel_group(),
context_parallel_group=get_context_parallel_group(),
logprob_chunk_size=logprob_chunk_size,
)

if cp_normalize:
Expand Down
1 change: 1 addition & 0 deletions nemo_rl/models/policy/dtensor_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,7 @@ def train(
mb,
global_valid_seqs,
global_valid_toks,
logprob_chunk_size=self.cfg.get("logprob_chunk_size", None),
)
del logits

Expand Down
1 change: 1 addition & 0 deletions nemo_rl/models/policy/dtensor_policy_worker_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,7 @@ def train(
mb,
global_valid_seqs,
global_valid_toks,
logprob_chunk_size=self.cfg.get("logprob_chunk_size", None),
)
del logits

Expand Down
Loading