Skip to content

Commit

Permalink
2025-01-22 nightly release (dd5457c)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Jan 22, 2025
1 parent dcc60e8 commit 61a849d
Show file tree
Hide file tree
Showing 12 changed files with 147 additions and 134 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/unittest_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ jobs:
python-tag: "py312"
cuda-tag: "cu124"
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
permissions:
id-token: write
contents: read
with:
runner: ${{ matrix.os }}
timeout: 30
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/unittest_ci_cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ jobs:
python-version: '3.12'
python-tag: "py312"
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
permissions:
id-token: write
contents: read
with:
runner: ${{ matrix.os }}
timeout: 15
Expand Down
13 changes: 11 additions & 2 deletions torchrec/distributed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
)

import torch
from tensordict import TensorDict
from torch import distributed as dist, nn
from torch.autograd.profiler import record_function
from torch.distributed._shard.sharding_spec import EnumerableShardingSpec
Expand Down Expand Up @@ -90,6 +91,7 @@
from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer
from torchrec.sparse.jagged_tensor import _to_offsets, JaggedTensor, KeyedJaggedTensor
from torchrec.sparse.tensor_dict import maybe_td_to_kjt

try:
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
Expand Down Expand Up @@ -1198,8 +1200,15 @@ def _compute_sequence_vbe_context(
def input_dist(
self,
ctx: EmbeddingCollectionContext,
features: KeyedJaggedTensor,
features: TypeUnion[KeyedJaggedTensor, TensorDict],
) -> Awaitable[Awaitable[KJTList]]:
need_permute: bool = True
if isinstance(features, TensorDict):
feature_keys = list(features.keys()) # pyre-ignore[6]
if self._features_order:
feature_keys = [feature_keys[i] for i in self._features_order]
need_permute = False
features = maybe_td_to_kjt(features, feature_keys) # pyre-ignore[6]
if self._has_uninitialized_input_dist:
self._create_input_dist(input_feature_names=features.keys())
self._has_uninitialized_input_dist = False
Expand All @@ -1209,7 +1218,7 @@ def input_dist(
unpadded_features = features
features = pad_vbe_kjt_lengths(unpadded_features)

if self._features_order:
if need_permute and self._features_order:
features = features.permute(
self._features_order,
# pyre-fixme[6]: For 2nd argument expected `Optional[Tensor]`
Expand Down
5 changes: 4 additions & 1 deletion torchrec/distributed/embedding_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,10 @@ def bucketize_kjt_before_all2all(
batch_size_per_feature=_fx_wrap_batch_size_per_feature(kjt),
max_B=_fx_wrap_max_B(kjt),
block_bucketize_pos=(
_fx_wrap_tensor_to_device_dtype(block_bucketize_row_pos, kjt.lengths())
[
_fx_wrap_tensor_to_device_dtype(pos, kjt.values())
for pos in block_bucketize_row_pos
]
if block_bucketize_row_pos is not None
else None
),
Expand Down
16 changes: 12 additions & 4 deletions torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import torch
from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings
from tensordict import TensorDict
from torch import distributed as dist, nn, Tensor
from torch.autograd.profiler import record_function
from torch.distributed._shard.sharded_tensor import TensorProperties
Expand Down Expand Up @@ -94,6 +95,7 @@
from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer
from torchrec.sparse.jagged_tensor import _to_offsets, KeyedJaggedTensor, KeyedTensor
from torchrec.sparse.tensor_dict import maybe_td_to_kjt

try:
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
Expand Down Expand Up @@ -656,9 +658,7 @@ def __init__(
self._inverse_indices_permute_indices: Optional[torch.Tensor] = None
# to support mean pooling callback hook
self._has_mean_pooling_callback: bool = (
True
if PoolingType.MEAN.value in self._pooling_type_to_rs_features
else False
PoolingType.MEAN.value in self._pooling_type_to_rs_features
)
self._dim_per_key: Optional[torch.Tensor] = None
self._kjt_key_indices: Dict[str, int] = {}
Expand Down Expand Up @@ -1189,8 +1189,16 @@ def _create_inverse_indices_permute_indices(

# pyre-ignore [14]
def input_dist(
self, ctx: EmbeddingBagCollectionContext, features: KeyedJaggedTensor
self,
ctx: EmbeddingBagCollectionContext,
features: Union[KeyedJaggedTensor, TensorDict],
) -> Awaitable[Awaitable[KJTList]]:
if isinstance(features, TensorDict):
feature_keys = list(features.keys()) # pyre-ignore[6]
if len(self._features_order) > 0:
feature_keys = [feature_keys[i] for i in self._features_order]
self._has_features_permute = False # feature_keys are in order
features = maybe_td_to_kjt(features, feature_keys) # pyre-ignore[6]
ctx.variable_batch_per_feature = features.variable_stride_per_key()
ctx.inverse_indices = features.inverse_indices_or_none()
if self._has_uninitialized_input_dist:
Expand Down
19 changes: 12 additions & 7 deletions torchrec/distributed/test_utils/multi_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#!/usr/bin/env python3

import logging
import multiprocessing
import os
import unittest
Expand All @@ -24,11 +25,6 @@
)


# AMD's HIP runtime doesn't seem to work with forkserver; hipMalloc will fail
# Therefore we use spawn for HIP runtime until AMD fixes the issue
_MP_INIT_MODE = "forkserver" if torch.version.hip is None else "spawn"


class MultiProcessContext:
def __init__(
self,
Expand Down Expand Up @@ -98,6 +94,15 @@ def __exit__(self, exc_type, exc_instance, traceback) -> None:


class MultiProcessTestBase(unittest.TestCase):
def __init__(
self, methodName: str = "runTest", mp_init_mode: str = "forkserver"
) -> None:
super().__init__(methodName)

# AMD's HIP runtime doesn't seem to work with forkserver; hipMalloc will fail
# Therefore we use spawn for HIP runtime until AMD fixes the issue
self._mp_init_mode: str = mp_init_mode if torch.version.hip is None else "spawn"
logging.info(f"Using {self._mp_init_mode} for multiprocessing")

@seed_and_log
def setUp(self) -> None:
Expand Down Expand Up @@ -131,7 +136,7 @@ def _run_multi_process_test(
# pyre-ignore
**kwargs,
) -> None:
ctx = multiprocessing.get_context(_MP_INIT_MODE)
ctx = multiprocessing.get_context(self._mp_init_mode)
processes = []
for rank in range(world_size):
kwargs["rank"] = rank
Expand All @@ -157,7 +162,7 @@ def _run_multi_process_test_per_rank(
world_size: int,
kwargs_per_rank: List[Dict[str, Any]],
) -> None:
ctx = multiprocessing.get_context(_MP_INIT_MODE)
ctx = multiprocessing.get_context(self._mp_init_mode)
processes = []
for rank in range(world_size):
kwargs = {}
Expand Down
32 changes: 26 additions & 6 deletions torchrec/distributed/test_utils/test_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def gen_model_and_input(
long_indices: bool = True,
global_constant_batch: bool = False,
num_inputs: int = 1,
input_type: str = "kjt", # "kjt" or "td"
) -> Tuple[nn.Module, List[Tuple[ModelInput, List[ModelInput]]]]:
torch.manual_seed(0)
if dedup_feature_names:
Expand Down Expand Up @@ -177,9 +178,9 @@ def gen_model_and_input(
feature_processor_modules=feature_processor_modules,
)
inputs = []
for _ in range(num_inputs):
inputs.append(
(
if input_type == "kjt" and generate == ModelInput.generate_variable_batch_input:
for _ in range(num_inputs):
inputs.append(
cast(VariableBatchModelInputCallable, generate)(
average_batch_size=batch_size,
world_size=world_size,
Expand All @@ -188,8 +189,26 @@ def gen_model_and_input(
weighted_tables=weighted_tables or [],
global_constant_batch=global_constant_batch,
)
if generate == ModelInput.generate_variable_batch_input
else cast(ModelInputCallable, generate)(
)
elif generate == ModelInput.generate:
for _ in range(num_inputs):
inputs.append(
ModelInput.generate(
world_size=world_size,
tables=tables,
dedup_tables=dedup_tables,
weighted_tables=weighted_tables or [],
num_float_features=num_float_features,
variable_batch_size=variable_batch_size,
batch_size=batch_size,
long_indices=long_indices,
input_type=input_type,
)
)
else:
for _ in range(num_inputs):
inputs.append(
cast(ModelInputCallable, generate)(
world_size=world_size,
tables=tables,
dedup_tables=dedup_tables,
Expand All @@ -200,7 +219,6 @@ def gen_model_and_input(
long_indices=long_indices,
)
)
)
return (model, inputs)


Expand Down Expand Up @@ -297,6 +315,7 @@ def sharding_single_rank_test(
global_constant_batch: bool = False,
world_size_2D: Optional[int] = None,
node_group_size: Optional[int] = None,
input_type: str = "kjt", # "kjt" or "td"
) -> None:
with MultiProcessContext(rank, world_size, backend, local_size) as ctx:
# Generate model & inputs.
Expand All @@ -319,6 +338,7 @@ def sharding_single_rank_test(
batch_size=batch_size,
feature_processor_modules=feature_processor_modules,
global_constant_batch=global_constant_batch,
input_type=input_type,
)
global_model = global_model.to(ctx.device)
global_input = inputs[0][0].to(ctx.device)
Expand Down
41 changes: 41 additions & 0 deletions torchrec/distributed/tests/test_sequence_model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,3 +376,44 @@ def _test_sharding(
variable_batch_per_feature=variable_batch_per_feature,
global_constant_batch=True,
)


@skip_if_asan_class
class TDSequenceModelParallelTest(SequenceModelParallelTest):

def test_sharding_variable_batch(self) -> None:
pass

def _test_sharding(
self,
sharders: List[TestEmbeddingCollectionSharder],
backend: str = "gloo",
world_size: int = 2,
local_size: Optional[int] = None,
constraints: Optional[Dict[str, ParameterConstraints]] = None,
model_class: Type[TestSparseNNBase] = TestSequenceSparseNN,
qcomms_config: Optional[QCommsConfig] = None,
apply_optimizer_in_backward_config: Optional[
Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]]
] = None,
variable_batch_size: bool = False,
variable_batch_per_feature: bool = False,
) -> None:
self._run_multi_process_test(
callable=sharding_single_rank_test,
world_size=world_size,
local_size=local_size,
model_class=model_class,
tables=self.tables,
embedding_groups=self.embedding_groups,
sharders=sharders,
optim=EmbOptimType.EXACT_SGD,
backend=backend,
constraints=constraints,
qcomms_config=qcomms_config,
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
variable_batch_size=variable_batch_size,
variable_batch_per_feature=variable_batch_per_feature,
global_constant_batch=True,
input_type="td",
)
Loading

0 comments on commit 61a849d

Please sign in to comment.