Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions tests/pytorch/attention/test_cp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ def test_cp_rank_slicing_simple_case(self):
# Test rank 0
self._mock_distributed_env(cp_size=2, cp_rank=0)
input_ids_r0, labels_r0, pos_ids_r0 = get_batch_on_this_cp_rank(
cu_seqlens, input_ids, labels, position_ids
cu_seqlens, input_ids, labels, position_ids, cp_size=2, cp_rank=0
)

# Rank 0 should get indices [0,1] and [6,7]
Expand All @@ -534,7 +534,7 @@ def test_cp_rank_slicing_simple_case(self):
# Test rank 1
self._mock_distributed_env(cp_size=2, cp_rank=1)
input_ids_r1, labels_r1, pos_ids_r1 = get_batch_on_this_cp_rank(
cu_seqlens, input_ids, labels, position_ids
cu_seqlens, input_ids, labels, position_ids, cp_size=2, cp_rank=1
)

# Rank 1 should get indices [2,3] and [4,5]
Expand All @@ -561,7 +561,7 @@ def test_cp_rank_slicing_multiple_sequences(self):
# Test rank 0
self._mock_distributed_env(cp_size=2, cp_rank=0)
input_ids_r0, labels_r0, pos_ids_r0 = get_batch_on_this_cp_rank(
cu_seqlens, input_ids, labels, position_ids
cu_seqlens, input_ids, labels, position_ids, cp_size=2, cp_rank=0
)

# For each sequence, rank 0 gets first and last slices
Expand All @@ -584,7 +584,7 @@ def test_cp_rank_slicing_with_cp_size_1(self):

self._mock_distributed_env(cp_size=1, cp_rank=0)
input_ids_result, labels_result, pos_ids_result = get_batch_on_this_cp_rank(
cu_seqlens, input_ids, labels, position_ids
cu_seqlens, input_ids, labels, position_ids, cp_size=1, cp_rank=0
)

# With CP size = 1, should return original tensors
Expand All @@ -608,7 +608,7 @@ def test_cp_rank_slicing_sequence_dim_detection(self):

self._mock_distributed_env(cp_size=2, cp_rank=0)
input_ids_r0, labels_r0, pos_ids_r0 = get_batch_on_this_cp_rank(
cu_seqlens, input_ids, labels, position_ids
cu_seqlens, input_ids, labels, position_ids, cp_size=2, cp_rank=0
)

# Should get indices [0,1] and [6,7] along dimension 0
Expand All @@ -635,7 +635,7 @@ def test_cp_rank_slicing_mixed_dimensions(self):
# Test rank 0
self._mock_distributed_env(cp_size=2, cp_rank=0)
input_ids_r0, labels_r0, pos_ids_r0 = get_batch_on_this_cp_rank(
cu_seqlens, input_ids, labels, position_ids
cu_seqlens, input_ids, labels, position_ids, cp_size=2, cp_rank=0
)

# Rank 0 should get indices [0,1] and [6,7]
Expand All @@ -650,7 +650,7 @@ def test_cp_rank_slicing_mixed_dimensions(self):
# Test rank 1
self._mock_distributed_env(cp_size=2, cp_rank=1)
input_ids_r1, labels_r1, pos_ids_r1 = get_batch_on_this_cp_rank(
cu_seqlens, input_ids, labels, position_ids
cu_seqlens, input_ids, labels, position_ids, cp_size=2, cp_rank=1
)

# Rank 1 should get indices [2,3] and [4,5]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

"""Context Parallelism."""
import os
from typing import List, Union, Tuple
from typing import List, Union, Optional, Tuple
import torch
import transformer_engine_torch as tex

Expand Down Expand Up @@ -4016,7 +4016,8 @@ def get_batch_on_this_cp_rank(
input_ids_padded: torch.Tensor,
labels_padded: torch.Tensor,
position_ids_padded: torch.Tensor,
cp_group: torch.distributed.ProcessGroup = None,
cp_size: Optional[int] = None,
cp_rank: Optional[int] = None,
qvk_format: str = "thd",
Comment on lines +4019 to 4021
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: breaking change - removed backward compatibility with cp_group parameter

Old signature had cp_group: torch.distributed.ProcessGroup = None and would call torch.distributed.get_world_size(group=cp_group) and torch.distributed.get_rank(group=cp_group) as fallback. New code removes this entirely, breaking existing callers. Either restore fallback logic or update PR description to mark as breaking change.

):
"""Slice batch input along sequence dimension into multiple chunks for THD format.
Expand All @@ -4026,14 +4027,18 @@ def get_batch_on_this_cp_rank(

Which are parallelized across GPUs in a context parallel group.
This version works with variable-length sequences using cumulative sequence lengths.

If cp_rank is provided, it will slice the batch for the provided rank.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you put this next to cp_group in the input list? Also, I think the better docstring would be

cp_rank: The rank for which the batch should be created. If None (default), will determine this processes's current rank from torch.distributed.get_rank(group=cp_group)

but maybe the better option would be to have a cp_worldsize argument that's similarly an alternative option to cp_group? Currently you'd need to create a torch.distributed process group, which you may not want to do if you're preparing these ahead-of-time on a CPU

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I put it next to cp_group sure.

The functional already takes in a cp_group so they have distributed setup already, this just a small additoin that let's them provide the cp_rank.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If its for context parallel I expect the user to already have their distributed process groups setup right?

"""
if qvk_format not in ["thd", "bshd", "sbhd"]:
raise ValueError(f"Unsupported qvk_format: {qvk_format}!")
if qvk_format == "thd":
# Get context parallel size and rank
cp_size = torch.distributed.get_world_size(group=cp_group)
if cp_size > 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: will fail with TypeError if cp_size is None

Suggested change
if cp_size > 1:
if cp_size is not None and cp_size > 1:

cp_rank = torch.distributed.get_rank(group=cp_group)
if not (0 <= cp_rank < cp_size):
raise ValueError(f"cp_rank must be in [0, {cp_size}), but received {cp_rank}.")
if cp_rank is None:
raise ValueError("cp_rank must be provided when cp_size > 1.")
Comment on lines +4038 to +4041
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: validation checks are in wrong order - will fail when cp_rank is None

Line 4038 checks range before line 4040 checks for None. If cp_rank=None, line 4038 will raise TypeError before reaching the None check.

Suggested change
if not (0 <= cp_rank < cp_size):
raise ValueError(f"cp_rank must be in [0, {cp_size}), but received {cp_rank}.")
if cp_rank is None:
raise ValueError("cp_rank must be provided when cp_size > 1.")
if cp_rank is None:
raise ValueError("cp_rank must be provided when cp_size > 1.")
if not (0 <= cp_rank < cp_size):
raise ValueError(f"cp_rank must be in [0, {cp_size}), but received {cp_rank}.")


# Calculate the chunk sizes for each sequence
total_slices_of_any_sequence = 2 * cp_size
Expand Down
Loading