From eb740911306e3892dfd264170483c62fece8af3f Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Fri, 14 Nov 2025 12:28:00 -0800 Subject: [PATCH 1/2] Enables specified cp rank slicing Signed-off-by: Jonathan Mitchell --- .../dot_product_attention/context_parallel.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 00d609ab9e..6ee68c9825 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -4,7 +4,7 @@ """Context Parallelism.""" import os -from typing import List, Union, Tuple +from typing import List, Union, Optional, Tuple import torch import transformer_engine_torch as tex @@ -4018,6 +4018,7 @@ def get_batch_on_this_cp_rank( position_ids_padded: torch.Tensor, cp_group: torch.distributed.ProcessGroup = None, qvk_format: str = "thd", + cp_rank: Optional[int] = None, ): """Slice batch input along sequence dimension into multiple chunks for THD format. @@ -4026,6 +4027,8 @@ def get_batch_on_this_cp_rank( Which are parallelized across GPUs in a context parallel group. This version works with variable-length sequences using cumulative sequence lengths. + + If cp_rank is provided, it will slice the batch for the provided rank. """ if qvk_format not in ["thd", "bshd", "sbhd"]: raise ValueError(f"Unsupported qvk_format: {qvk_format}!") @@ -4033,7 +4036,10 @@ def get_batch_on_this_cp_rank( # Get context parallel size and rank cp_size = torch.distributed.get_world_size(group=cp_group) if cp_size > 1: - cp_rank = torch.distributed.get_rank(group=cp_group) + if cp_rank is None: + cp_rank = torch.distributed.get_rank(group=cp_group) + elif not (0 <= cp_rank < cp_size): + raise ValueError(f"cp_rank must be in [0, {cp_size}), but received {cp_rank}.") # Calculate the chunk sizes for each sequence total_slices_of_any_sequence = 2 * cp_size From efd8da541de5da79390b33a958b993e0ff72c08d Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Mon, 17 Nov 2025 11:00:43 -0800 Subject: [PATCH 2/2] uses cp size Signed-off-by: Jonathan Mitchell --- tests/pytorch/attention/test_cp_utils.py | 14 +++++++------- .../dot_product_attention/context_parallel.py | 11 +++++------ 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/tests/pytorch/attention/test_cp_utils.py b/tests/pytorch/attention/test_cp_utils.py index 0dd5ba601e..5396ad2dba 100644 --- a/tests/pytorch/attention/test_cp_utils.py +++ b/tests/pytorch/attention/test_cp_utils.py @@ -519,7 +519,7 @@ def test_cp_rank_slicing_simple_case(self): # Test rank 0 self._mock_distributed_env(cp_size=2, cp_rank=0) input_ids_r0, labels_r0, pos_ids_r0 = get_batch_on_this_cp_rank( - cu_seqlens, input_ids, labels, position_ids + cu_seqlens, input_ids, labels, position_ids, cp_size=2, cp_rank=0 ) # Rank 0 should get indices [0,1] and [6,7] @@ -534,7 +534,7 @@ def test_cp_rank_slicing_simple_case(self): # Test rank 1 self._mock_distributed_env(cp_size=2, cp_rank=1) input_ids_r1, labels_r1, pos_ids_r1 = get_batch_on_this_cp_rank( - cu_seqlens, input_ids, labels, position_ids + cu_seqlens, input_ids, labels, position_ids, cp_size=2, cp_rank=1 ) # Rank 1 should get indices [2,3] and [4,5] @@ -561,7 +561,7 @@ def test_cp_rank_slicing_multiple_sequences(self): # Test rank 0 self._mock_distributed_env(cp_size=2, cp_rank=0) input_ids_r0, labels_r0, pos_ids_r0 = get_batch_on_this_cp_rank( - cu_seqlens, input_ids, labels, position_ids + cu_seqlens, input_ids, labels, position_ids, cp_size=2, cp_rank=0 ) # For each sequence, rank 0 gets first and last slices @@ -584,7 +584,7 @@ def test_cp_rank_slicing_with_cp_size_1(self): self._mock_distributed_env(cp_size=1, cp_rank=0) input_ids_result, labels_result, pos_ids_result = get_batch_on_this_cp_rank( - cu_seqlens, input_ids, labels, position_ids + cu_seqlens, input_ids, labels, position_ids, cp_size=1, cp_rank=0 ) # With CP size = 1, should return original tensors @@ -608,7 +608,7 @@ def test_cp_rank_slicing_sequence_dim_detection(self): self._mock_distributed_env(cp_size=2, cp_rank=0) input_ids_r0, labels_r0, pos_ids_r0 = get_batch_on_this_cp_rank( - cu_seqlens, input_ids, labels, position_ids + cu_seqlens, input_ids, labels, position_ids, cp_size=2, cp_rank=0 ) # Should get indices [0,1] and [6,7] along dimension 0 @@ -635,7 +635,7 @@ def test_cp_rank_slicing_mixed_dimensions(self): # Test rank 0 self._mock_distributed_env(cp_size=2, cp_rank=0) input_ids_r0, labels_r0, pos_ids_r0 = get_batch_on_this_cp_rank( - cu_seqlens, input_ids, labels, position_ids + cu_seqlens, input_ids, labels, position_ids, cp_size=2, cp_rank=0 ) # Rank 0 should get indices [0,1] and [6,7] @@ -650,7 +650,7 @@ def test_cp_rank_slicing_mixed_dimensions(self): # Test rank 1 self._mock_distributed_env(cp_size=2, cp_rank=1) input_ids_r1, labels_r1, pos_ids_r1 = get_batch_on_this_cp_rank( - cu_seqlens, input_ids, labels, position_ids + cu_seqlens, input_ids, labels, position_ids, cp_size=2, cp_rank=1 ) # Rank 1 should get indices [2,3] and [4,5] diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 6ee68c9825..32c9495329 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -4016,9 +4016,9 @@ def get_batch_on_this_cp_rank( input_ids_padded: torch.Tensor, labels_padded: torch.Tensor, position_ids_padded: torch.Tensor, - cp_group: torch.distributed.ProcessGroup = None, - qvk_format: str = "thd", + cp_size: Optional[int] = None, cp_rank: Optional[int] = None, + qvk_format: str = "thd", ): """Slice batch input along sequence dimension into multiple chunks for THD format. @@ -4034,12 +4034,11 @@ def get_batch_on_this_cp_rank( raise ValueError(f"Unsupported qvk_format: {qvk_format}!") if qvk_format == "thd": # Get context parallel size and rank - cp_size = torch.distributed.get_world_size(group=cp_group) if cp_size > 1: - if cp_rank is None: - cp_rank = torch.distributed.get_rank(group=cp_group) - elif not (0 <= cp_rank < cp_size): + if not (0 <= cp_rank < cp_size): raise ValueError(f"cp_rank must be in [0, {cp_size}), but received {cp_rank}.") + if cp_rank is None: + raise ValueError("cp_rank must be provided when cp_size > 1.") # Calculate the chunk sizes for each sequence total_slices_of_any_sequence = 2 * cp_size