From 7a26d0ac3fe7abc2d359ec88eac3b27f5c9c42b4 Mon Sep 17 00:00:00 2001 From: Huanyu He Date: Wed, 10 Dec 2025 20:56:10 -0800 Subject: [PATCH] use module object id to cache the sharded modules (#3591) Summary: # context * TorchRec relies on `nn.Module.named_children()` to traverse the model to find the sparse modules to shard. * In a normal case, every sparse module only appears once in the model hiearchy, i.e., it has **only one** parent module. * However, in some corner cases, a sparse module might have multiple parent modules. This might confuse the TorchRec sharder due to its traversing logic: the same sparse module has multiple FQNs, and hence being sharded multiple times (create multiple sharded modules according to the FQNs). {F1983896519} # solution * cache the sharded module with the original sparse module's object id * when the sparse module has multiple FQNs, only the first time in the `named_children` traversing, a sharded module will be created. # changes * the change is protected by a KillSwitch: [enable_module_id_cache_for_dmp_shard_modules](https://www.internalfb.com/intern/justknobs/?name=pytorch%2Ftorchrec#enable_module_id_cache_for_dmp_shard_modules) https://fb.workplace.com/groups/429376538334034/permalink/1343336826937996/ Reviewed By: malaybag, iamzainhuda Differential Revision: D88218200 --- torchrec/distributed/model_parallel.py | 56 ++++- .../test_model_parallel_nccl_single_rank.py | 205 +++++++++++++++++- 2 files changed, 255 insertions(+), 6 deletions(-) diff --git a/torchrec/distributed/model_parallel.py b/torchrec/distributed/model_parallel.py index 8fdcb94f1..d9b0e671f 100644 --- a/torchrec/distributed/model_parallel.py +++ b/torchrec/distributed/model_parallel.py @@ -385,7 +385,13 @@ def copy( return copy_dmp def _init_dmp(self, module: nn.Module) -> nn.Module: - return self._shard_modules_impl(module) + if torch._utils_internal.justknobs_check( + "pytorch/torchrec:enable_module_id_cache_for_dmp_shard_modules" + ): + module_id_cache: Dict[int, ShardedModule] = {} + else: + module_id_cache = None + return self._shard_modules_impl(module, module_id_cache=module_id_cache) def _init_delta_tracker( self, delta_tracker_config: DeltaTrackerConfig, module: nn.Module @@ -435,28 +441,48 @@ def _shard_modules_impl( self, module: nn.Module, path: str = "", + module_id_cache: Optional[Dict[str, ShardedModule]] = None, ) -> nn.Module: # pre-sharded module if isinstance(module, ShardedModule): return module + if module_id_cache is not None: + module_id = id(module) + if module_id in module_id_cache: + """ + This is likely due to a single sparse module being used in multiple places in the model, + which results in multiple FQNs for the same sparse module. The dedup logic is applied on + the sharded module, i.e., multiple FQNs will refer to the same sharded module, as it is in + eager-mode sparse module. However, there could be potential issues in other places where + model is travesed via `named_children()`, the same sparse module will be visited multiple + times again. + """ + logger.error( + f"Module {path} is already in cache (replaced by sharded module already)" + ) + return module_id_cache[module_id] + # shardable module module_sharding_plan = self._plan.get_plan_for_module(path) if module_sharding_plan: sharder_key = type(module) - module = self._sharder_map[sharder_key].shard( + sharded_module = self._sharder_map[sharder_key].shard( module, module_sharding_plan, self._env, self.device, path, ) - return module + if module_id_cache is not None: + module_id_cache[module_id] = sharded_module + return sharded_module for name, child in module.named_children(): child = self._shard_modules_impl( child, path + "." + name if path else name, + module_id_cache, ) setattr(module, name, child) @@ -1002,12 +1028,29 @@ def _shard_modules_impl( self, module: nn.Module, path: str = "", + module_id_cache: Optional[Dict[int, ShardedModule]] = None, ) -> nn.Module: # pre-sharded module if isinstance(module, ShardedModule): return module + if module_id_cache is not None: + module_id = id(module) + if module_id in module_id_cache: + """ + This is likely due to a single sparse module being used in multiple places in the model, + which results in multiple FQNs for the same sparse module. The dedup logic is applied on + the sharded module, i.e., multiple FQNs will refer to the same sharded module, as it is in + eager-mode sparse module. However, there could be potential issues in other places where + model is travesed via `named_children()`, the same sparse module will be visited multiple + times again. + """ + logger.error( + f"Module {path} is already in cache (replaced by sharded module already)" + ) + return module_id_cache[module_id] + # shardable module module_sharding_plan = self._plan.get_plan_for_module(path) if module_sharding_plan: @@ -1027,19 +1070,22 @@ def _shard_modules_impl( ) break - module = self._sharder_map[sharder_key].shard( + sharded_module = self._sharder_map[sharder_key].shard( module, module_sharding_plan, env, self.device, path, ) - return module + if module_id_cache is not None: + module_id_cache[module_id] = sharded_module + return sharded_module for name, child in module.named_children(): child = self._shard_modules_impl( child, path + "." + name if path else name, + module_id_cache, ) setattr(module, name, child) diff --git a/torchrec/distributed/tests/test_model_parallel_nccl_single_rank.py b/torchrec/distributed/tests/test_model_parallel_nccl_single_rank.py index 0ea359f89..1c607e487 100644 --- a/torchrec/distributed/tests/test_model_parallel_nccl_single_rank.py +++ b/torchrec/distributed/tests/test_model_parallel_nccl_single_rank.py @@ -7,15 +7,218 @@ # pyre-strict +from unittest.mock import patch + +import torch +import torch.nn as nn +from torchrec.distributed.model_parallel import DistributedModelParallel from torchrec.distributed.test_utils.test_model_parallel_base import ( ModelParallelSparseOnlyBase, ModelParallelStateDictBase, ) +from torchrec.distributed.types import ShardedModule +from torchrec.modules.embedding_configs import EmbeddingBagConfig, EmbeddingConfig +from torchrec.modules.embedding_modules import ( + EmbeddingBagCollection, + EmbeddingCollection, +) +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor class ModelParallelStateDictTestNccl(ModelParallelStateDictBase): pass +class SparseArch(nn.Module): + def __init__( + self, + ebc: EmbeddingBagCollection, + ec: EmbeddingCollection, + ) -> None: + super().__init__() + self.ebc = ebc + self.ec = ec + + def forward(self, features: KeyedJaggedTensor) -> tuple[torch.Tensor, torch.Tensor]: + ebc_out = self.ebc(features) + ec_out = self.ec(features) + return ebc_out.values(), ec_out.values() + + +# Create a model with two sparse architectures sharing the same modules +class TwoSparseArchModel(nn.Module): + def __init__( + self, + sparse1: SparseArch, + sparse2: SparseArch, + ) -> None: + super().__init__() + # Both architectures share the same EBC and EC instances + self.sparse1 = sparse1 + self.sparse2 = sparse2 + + def forward( + self, features: KeyedJaggedTensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + ebc1_out, ec1_out = self.sparse1(features) + ebc2_out, ec2_out = self.sparse2(features) + + return ebc1_out, ec1_out, ebc2_out, ec2_out + + class ModelParallelSparseOnlyTestNccl(ModelParallelSparseOnlyBase): - pass + def test_shared_sparse_module_in_multiple_parents(self) -> None: + """ + Test that the module ID cache correctly handles the same sparse module + being used in multiple parent modules. This tests the caching behavior + when a single EmbeddingBagCollection and EmbeddingCollection are shared + across two different parent sparse architectures. + """ + + # Setup: Create shared embedding modules that will be reused + ebc = EmbeddingBagCollection( + device=torch.device("meta"), + tables=[ + EmbeddingBagConfig( + name="ebc_table", + embedding_dim=64, + num_embeddings=100, + feature_names=["ebc_feature"], + ), + ], + ) + ec = EmbeddingCollection( + device=torch.device("meta"), + tables=[ + EmbeddingConfig( + name="ec_table", + embedding_dim=32, + num_embeddings=50, + feature_names=["ec_feature"], + ), + ], + ) + + # Create the model with shared modules + sparse1 = SparseArch(ebc, ec) + sparse2 = SparseArch(ebc, ec) + model = TwoSparseArchModel(sparse1, sparse2) + + # Execute: Shard the model with DistributedModelParallel + dmp = DistributedModelParallel(model, device=self.device) + + # Assert: Verify that the shared modules are properly handled + self.assertIsNotNone(dmp.module) + + # Verify that the same module instances are reused (cached behavior) + wrapped_module = dmp.module + self.assertIs( + wrapped_module.sparse1.ebc, + wrapped_module.sparse2.ebc, + "ebc1 and ebc2 should be the same sharded instance", + ) + self.assertIs( + wrapped_module.sparse1.ec, + wrapped_module.sparse2.ec, + "ec1 and ec2 should be the same sharded instance", + ) + self.assertIsInstance( + wrapped_module.sparse1.ebc, + ShardedModule, + "ebc1 should be sharded", + ) + self.assertIsInstance( + wrapped_module.sparse1.ec, + ShardedModule, + "ec1 should be sharded", + ) + + def test_shared_sparse_module_in_multiple_parents_negative(self) -> None: + """ + Test that when module ID caching is disabled (module_id_cache=None), + the same module instance gets sharded multiple times, resulting in + different sharded instances. This validates the behavior without caching. + """ + + def mock_init_dmp( + self_dmp: DistributedModelParallel, module: nn.Module + ) -> nn.Module: + """Override _init_dmp to always set module_id_cache to None""" + # Call _shard_modules_impl with module_id_cache=None (caching disabled) + return self_dmp._shard_modules_impl(module, module_id_cache=None) + + # Setup: Create shared embedding modules that will be reused + ebc = EmbeddingBagCollection( + device=torch.device("meta"), + tables=[ + EmbeddingBagConfig( + name="ebc_table", + embedding_dim=64, + num_embeddings=100, + feature_names=["ebc_feature"], + ), + ], + ) + ec = EmbeddingCollection( + device=torch.device("meta"), + tables=[ + EmbeddingConfig( + name="ec_table", + embedding_dim=32, + num_embeddings=50, + feature_names=["ec_feature"], + ), + ], + ) + + # Create the model with shared modules + sparse1 = SparseArch(ebc, ec) + sparse2 = SparseArch(ebc, ec) + model = TwoSparseArchModel(sparse1, sparse2) + + # Execute: Mock _init_dmp to disable caching, then shard the model + with patch.object( + DistributedModelParallel, + "_init_dmp", + mock_init_dmp, + ): + dmp = DistributedModelParallel(model, device=self.device) + + # Assert: Verify that modules are NOT cached (different instances) + self.assertIsNotNone(dmp.module) + wrapped_module = dmp.module + + # Without caching, the same module should be sharded twice, + # resulting in different sharded instances + self.assertIsNot( + wrapped_module.sparse1.ebc, + wrapped_module.sparse2.ebc, + "Without caching, ebc1 and ebc2 should be different sharded instances", + ) + self.assertIsNot( + wrapped_module.sparse1.ec, + wrapped_module.sparse2.ec, + "Without caching, ec1 and ec2 should be different sharded instances", + ) + + # Both should still be properly sharded, just not cached + self.assertIsInstance( + wrapped_module.sparse1.ebc, + ShardedModule, + "ebc1 should be sharded", + ) + self.assertIsInstance( + wrapped_module.sparse1.ec, + ShardedModule, + "ec1 should be sharded", + ) + self.assertIsInstance( + wrapped_module.sparse2.ebc, + ShardedModule, + "ebc2 should be sharded", + ) + self.assertIsInstance( + wrapped_module.sparse2.ec, + ShardedModule, + "ec2 should be sharded", + )