diff --git a/nemo_rl/algorithms/loss_functions.py b/nemo_rl/algorithms/loss_functions.py index 2a3038ddbd..a74648ddaf 100644 --- a/nemo_rl/algorithms/loss_functions.py +++ b/nemo_rl/algorithms/loss_functions.py @@ -135,6 +135,7 @@ def __call__( vocab_parallel_rank: Optional[int] = None, vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + logprob_chunk_size: Optional[int] = None, ) -> tuple[torch.Tensor, dict]: """Clipped Policy Gradient RL loss function.""" token_mask = data["token_mask"][:, 1:] @@ -171,12 +172,16 @@ def __call__( tp_group=vocab_parallel_group, inference_only=False, cp_group=context_parallel_group, + chunk_size=logprob_chunk_size, ) # slice off to the correct length to remove potential CP padding curr_logprobs = curr_logprobs[:, : data["input_ids"].shape[1] - 1] elif isinstance(next_token_logits, torch.distributed.tensor.DTensor): curr_logprobs = get_logprobs_from_vocab_parallel_logits( - next_token_logits, data["input_ids"], seq_index=seq_index + next_token_logits, + data["input_ids"], + seq_index=seq_index, + chunk_size=logprob_chunk_size, ) else: next_token_logits_wo_last = next_token_logits[ @@ -373,6 +378,7 @@ def __call__( vocab_parallel_rank: Optional[int] = None, vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + logprob_chunk_size: Optional[int] = None, dpo_loss: bool = False, dpo_average_log_probs: bool = False, ) -> tuple[torch.Tensor, dict[str, Any]]: @@ -398,12 +404,16 @@ def __call__( tp_group=vocab_parallel_group, inference_only=False, cp_group=context_parallel_group, + chunk_size=logprob_chunk_size, ) # slice off to the correct length to remove potential CP padding token_logprobs = token_logprobs[:, : data["input_ids"].shape[1] - 1] elif isinstance(next_token_logits, torch.distributed.tensor.DTensor): token_logprobs = get_logprobs_from_vocab_parallel_logits( - next_token_logits, data["input_ids"], seq_index=seq_index + next_token_logits, + data["input_ids"], + seq_index=seq_index, + chunk_size=logprob_chunk_size, ) else: next_tokens = data["input_ids"][:, 1:].cuda() # Skip first token @@ -518,6 +528,7 @@ def __call__( data: BatchedDataDict[PreferenceLossDataDict], global_valid_seqs: Tensor, global_valid_toks: Tensor | None, + logprob_chunk_size: Optional[int] = None, ) -> tuple[torch.Tensor, dict[str, Any]]: sample_mask = data["sample_mask"] @@ -633,6 +644,7 @@ def _dpo_loss( vocab_parallel_rank: Optional[int] = None, vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + logprob_chunk_size: Optional[int] = None, ) -> tuple[Tensor, Tensor, Tensor, Tensor]: ## TODO(@ashors): there's some duplicate code here with the NLLLoss function. We should refactor token_mask = data["token_mask"][:, 1:] @@ -652,12 +664,16 @@ def _dpo_loss( tp_group=vocab_parallel_group, inference_only=False, cp_group=context_parallel_group, + chunk_size=logprob_chunk_size, ) # slice off to the correct length to remove potential CP padding token_logprobs = token_logprobs[:, : data["input_ids"].shape[1] - 1] elif isinstance(next_token_logits, torch.distributed.tensor.DTensor): token_logprobs = get_logprobs_from_vocab_parallel_logits( - next_token_logits, data["input_ids"], seq_index=seq_index + next_token_logits, + data["input_ids"], + seq_index=seq_index, + chunk_size=logprob_chunk_size, ) else: next_tokens = data["input_ids"][:, 1:].cuda() # Skip first token @@ -691,6 +707,7 @@ def __call__( # type: ignore vocab_parallel_rank: Optional[int] = None, vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + logprob_chunk_size: Optional[int] = None, ) -> tuple[torch.Tensor, dict[str, Any]]: sft_loss_chosen = torch.tensor(0.0) if self.sft_loss_weight > 0: @@ -705,6 +722,7 @@ def __call__( # type: ignore vocab_parallel_rank=vocab_parallel_rank, vocab_parallel_group=vocab_parallel_group, context_parallel_group=context_parallel_group, + logprob_chunk_size=logprob_chunk_size, dpo_loss=True, dpo_average_log_probs=self.sft_average_log_probs, ) @@ -727,6 +745,7 @@ def __call__( # type: ignore vocab_parallel_rank=vocab_parallel_rank, vocab_parallel_group=vocab_parallel_group, context_parallel_group=context_parallel_group, + logprob_chunk_size=logprob_chunk_size, ) dpo_loss = ( @@ -768,6 +787,7 @@ def __call__( vocab_parallel_rank: Optional[int] = None, vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + logprob_chunk_size: Optional[int] = None, ) -> tuple[Tensor, dict[str, Any]]: """Wraps a loss function to handle sequence packing by doing one sequence at a time to avoid excessive padding.""" unpadded_cu_seqlens = self.cu_seqlens_q @@ -818,6 +838,7 @@ def __call__( vocab_parallel_rank=vocab_parallel_rank, vocab_parallel_group=vocab_parallel_group, context_parallel_group=context_parallel_group, + logprob_chunk_size=logprob_chunk_size, ) loss_accum += loss for k, v in metrics.items(): @@ -867,6 +888,7 @@ def __call__( vocab_parallel_rank: Optional[int] = None, vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + logprob_chunk_size: Optional[int] = None, ) -> tuple[torch.Tensor, dict[str, Any]]: """Compute distillation loss between teacher and student logits.""" # Basic shapes @@ -942,6 +964,7 @@ def __call__( ) S_local = int(logits_tensor.shape[1]) + # TODO: hardcoded 1024 below ignores logprob_chunk_size. chunk_size = max(1, min(S_local, 1024)) student_topk_logprobs = ChunkedDistributedGatherLogprob.apply( # type: ignore logits_tensor, diff --git a/nemo_rl/models/megatron/common.py b/nemo_rl/models/megatron/common.py index 87a0ddb1d5..14854e7c88 100644 --- a/nemo_rl/models/megatron/common.py +++ b/nemo_rl/models/megatron/common.py @@ -377,6 +377,11 @@ def forward_step_arbitrary_loss( loss_data = data_dict + if policy_cfg is not None: + logprob_chunk_size = policy_cfg.get("logprob_chunk_size", None) + else: + logprob_chunk_size = None + loss_fn_wrapped = partial( loss_fn, data=loss_data, @@ -385,6 +390,7 @@ def forward_step_arbitrary_loss( vocab_parallel_rank=get_tensor_model_parallel_rank(), vocab_parallel_group=get_tensor_model_parallel_group(), context_parallel_group=get_context_parallel_group(), + logprob_chunk_size=logprob_chunk_size, ) if cp_normalize: diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index 1a8ef6547c..321e61a5a1 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -807,6 +807,7 @@ def train( mb, global_valid_seqs, global_valid_toks, + logprob_chunk_size=self.cfg.get("logprob_chunk_size", None), ) del logits diff --git a/nemo_rl/models/policy/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/dtensor_policy_worker_v2.py index 6d623af8a7..fefa683da6 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/dtensor_policy_worker_v2.py @@ -783,6 +783,7 @@ def train( mb, global_valid_seqs, global_valid_toks, + logprob_chunk_size=self.cfg.get("logprob_chunk_size", None), ) del logits