From 43315c02dbfcec1220eaa37201ae6cec664d8b5b Mon Sep 17 00:00:00 2001 From: Felicity Liao <11263993+aporialiao@users.noreply.github.com> Date: Tue, 30 Sep 2025 15:36:45 -0700 Subject: [PATCH] Manual Resharding Handler (#3398) Summary: Pull Request resolved: https://github.com/meta-pytorch/torchrec/pull/3398 X-link: https://github.com/pytorch/torchrec/pull/3398 * `DistributedModelParallel`: stores sharders passed into planner, this will be reused at resharding * Removed hardcoding for module FQNs and plan keys since not compatible with all models. Leveraging how planner/enumerator identifies sharders/EBCs to find the right plan Reviewed By: isururanawaka Differential Revision: D83392188 --- .../benchmark/benchmark_resharding_handler.py | 169 ------------------ torchrec/distributed/embeddingbag.py | 4 +- torchrec/distributed/model_parallel.py | 80 ++++++--- .../distributed/sharding/dynamic_sharding.py | 31 +++- .../test_utils/test_model_parallel.py | 2 + .../distributed/test_utils/test_sharding.py | 12 +- .../tests/test_dynamic_sharding.py | 98 +++++++++- 7 files changed, 191 insertions(+), 205 deletions(-) delete mode 100644 torchrec/distributed/benchmark/benchmark_resharding_handler.py diff --git a/torchrec/distributed/benchmark/benchmark_resharding_handler.py b/torchrec/distributed/benchmark/benchmark_resharding_handler.py deleted file mode 100644 index 7effc98a2..000000000 --- a/torchrec/distributed/benchmark/benchmark_resharding_handler.py +++ /dev/null @@ -1,169 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - -import logging -import random -from typing import List, Optional - -import torch -import torch.distributed as dist -import torch.nn as nn -from torchrec.distributed.embeddingbag import EmbeddingBagCollection - -from torchrec.distributed.sharding.dynamic_sharding import output_sharding_plan_delta - -from torchrec.distributed.sharding_plan import ( - column_wise, - construct_module_sharding_plan, - table_wise, -) - -from torchrec.distributed.test_utils.test_sharding import generate_rank_placements -from torchrec.distributed.types import EmbeddingModuleShardingPlan - -logger: logging.Logger = logging.getLogger(__name__) - - -class ReshardingHandler: - """ - Handles the resharding of a training module by generating and applying different sharding plans. - """ - - def __init__(self, train_module: nn.Module, num_plans: int) -> None: - """ - Initializes the ReshardingHandler with a training module and the number of sharding plans to generate. - - Args: - train_module (nn.Module): The training module to be resharded. - num_plans (int): The number of sharding plans to generate. - """ - self._train_module = train_module - if not hasattr(train_module, "_module"): - raise RuntimeError("Incorrect train module") - - if not hasattr(train_module._module, "plan"): - raise RuntimeError("sharding plan cannot be found") - - # Pyre-ignore - plan = train_module._module.plan.plan - key = "main_module.sparse_arch.embedding_bag_collection" - module = ( - # Pyre-ignore - train_module._module.module.main_module.sparse_arch.embedding_bag_collection - ) - self._resharding_plans: List[EmbeddingModuleShardingPlan] = [] - world_size = dist.get_world_size() - - # TODO: change this logic when, proper planner is integrated - if key in plan: - ebc = plan[key] - num_tables = len(ebc) - ranks_per_tables = [1 for _ in range(num_tables)] - ranks_per_tables_for_CW = [] - for index, table_config in enumerate(module._embedding_bag_configs): - # CW sharding - valid_candidates = [ - i - for i in range(1, world_size + 1) - if table_config.embedding_dim % i == 0 - ] - rng = random.Random(index) - ranks_per_tables_for_CW.append(rng.choice(valid_candidates)) - - lightweight_ebc = EmbeddingBagCollection( - tables=module._embedding_bag_configs, - device=torch.device( - "meta" - ), # Use meta device to avoid actual memory allocation - ) - meta_device = torch.device("meta") - for i in range(num_plans): - new_ranks = generate_rank_placements( - world_size, num_tables, ranks_per_tables, i - ) - new_ranks_cw = generate_rank_placements( - world_size, num_tables, ranks_per_tables_for_CW, i - ) - new_per_param_sharding = {} - for i, (talbe_id, param) in enumerate(ebc.items()): - if param.sharding_type == "column_wise": - cw_gen = column_wise( - ranks=new_ranks_cw[i], - compute_kernel=param.compute_kernel, - ) - new_per_param_sharding[talbe_id] = cw_gen - else: - tw_gen = table_wise( - rank=new_ranks[i][0], - compute_kernel=param.compute_kernel, - ) - new_per_param_sharding[talbe_id] = tw_gen - - new_plan = construct_module_sharding_plan( - lightweight_ebc, - per_param_sharding=new_per_param_sharding, - world_size=world_size, - # Pyre-ignore - device_type=meta_device, - ) - self._resharding_plans.append(new_plan) - else: - raise RuntimeError(f"Plan does not have key: {key}") - - def step(self, batch_no: int) -> float: - """ - Executes a step in the training process by selecting and applying a sharding plan. - - Args: - batch_no (int): The current batch number. - - Returns: - float: The data volume of the sharding plan delta. - """ - # Pyre-ignore - plan = self._train_module._module.plan.plan - key = "main_module.sparse_arch.embedding_bag_collection" - - # Use the current step as a seed to ensure all ranks get the same random number - # but it changes on each call - - rng = random.Random(batch_no) - index = rng.randint(0, len(self._resharding_plans) - 1) - logger.info(f"Selected resharding plan index {index} for step {batch_no}") - # Get the selected plan - selected_plan = self._resharding_plans[index] - - # Fix the device mismatch by updating the placement device in the sharding spec - # This is necessary because the plan was created with meta device but needs to be applied on CUDA - for _, param_sharding in selected_plan.items(): - sharding_spec = param_sharding.sharding_spec - if sharding_spec is not None: - # pyre-ignore - for shard in sharding_spec.shards: - placement = shard.placement - rank: Optional[int] = placement.rank() - assert rank is not None - current_device = ( - torch.cuda.current_device() - if rank == torch.distributed.get_rank() - else rank % torch.cuda.device_count() - ) - shard.placement = torch.distributed._remote_device( - f"rank:{rank}/cuda:{current_device}" - ) - - data_volume, delta_plan = output_sharding_plan_delta( - plan[key], selected_plan, True - ) - # Pyre-ignore - self._train_module.module.reshard( - sharded_module_fqn="main_module.sparse_arch.embedding_bag_collection", - changed_shard_to_params=delta_plan, - ) - return data_volume diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index debc78dc6..7c822051d 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -1722,7 +1722,7 @@ def compute_and_output_dist( def update_shards( self, - changed_sharding_params: Dict[str, ParameterSharding], # NOTE: only delta + changed_sharding_params: EmbeddingModuleShardingPlan, # NOTE: only delta env: ShardingEnv, device: Optional[torch.device], ) -> None: @@ -1964,7 +1964,7 @@ def shardable_parameters( def reshard( self, sharded_module: ShardedEmbeddingBagCollection, - changed_shard_to_params: Dict[str, ParameterSharding], + changed_shard_to_params: EmbeddingModuleShardingPlan, env: ShardingEnv, device: Optional[torch.device] = None, ) -> ShardedEmbeddingBagCollection: diff --git a/torchrec/distributed/model_parallel.py b/torchrec/distributed/model_parallel.py index 334b706c9..7c9c45824 100644 --- a/torchrec/distributed/model_parallel.py +++ b/torchrec/distributed/model_parallel.py @@ -37,9 +37,9 @@ from torchrec.distributed.types import ( DMPCollectionConfig, DMPCollectionContext, + EmbeddingModuleShardingPlan, EnumerableShardingSpec, ModuleSharder, - ParameterSharding, ShardedModule, ShardingEnv, ShardingEnv2D, @@ -258,11 +258,12 @@ def __init__( device = torch.device("cpu") self.device: torch.device = device - if sharders is None: - sharders = get_default_sharders() + self.sharders: List[ModuleSharder[nn.modules.module.Module]] = ( + get_default_sharders() if sharders is None else sharders + ) self._sharder_map: Dict[Type[nn.Module], ModuleSharder[nn.Module]] = { - sharder.module_type: sharder for sharder in sharders + sharder.module_type: sharder for sharder in self.sharders } if data_parallel_wrapper is None: @@ -279,9 +280,9 @@ def __init__( ) pg = self._env.process_group if pg is not None: - plan = planner.collective_plan(module, sharders, pg) + plan = planner.collective_plan(module, self.sharders, pg) else: - plan = planner.plan(module, sharders) + plan = planner.plan(module, self.sharders) self._plan: ShardingPlan = plan self._dmp_wrapped_module: nn.Module = self._init_dmp(module) self._optim: CombinedOptimizer = self._init_optim(self._dmp_wrapped_module) @@ -668,11 +669,12 @@ def _reset_parameters(module: nn.Module) -> None: def reshard( self, - sharded_module_fqn: str, - changed_shard_to_params: Dict[str, ParameterSharding], - ) -> None: + changed_shard_to_params: Dict[str, Tuple[float, EmbeddingModuleShardingPlan]], + sharded_module_fqn: Optional[str] = None, + ) -> float: """ Reshards an already-sharded module in the DMP given a set of ParameterShardings to change placements. + Returns the data volume resharded This method allows you to dynamically change the sharding strategy for a specific module without recreating the entire DMP. It's particularly useful for: @@ -721,21 +723,53 @@ def reshard( - After resharding, the optimizer state is maintained for the module - The sharding plan is updated to reflect the new configuration """ - steps = sharded_module_fqn.split(".") - sharded_module = self.module - for s in steps: - sharded_module = getattr(sharded_module, s) - - assert isinstance(sharded_module, ShardedModule) - assert changed_shard_to_params is not None - sharder_key = sharded_module.unsharded_module_type - sharder = self._sharder_map[sharder_key] - assert hasattr( - sharder, "reshard" - ), "reshard is not implemented for this sharder" + sharder = None + sharded_module = None + + if sharded_module_fqn is None: + named_modules_queue = [("", self.module)] + while named_modules_queue: + child_path, child_module = named_modules_queue.pop(0) + if isinstance(child_module, ShardedModule): + sharder_key = child_module.unsharded_module_type + sharder = self._sharder_map.get(sharder_key, None) + if not sharder: + for n, m in child_module.named_children(): + if child_path != "": + named_modules_queue.append((child_path + "." + n, m)) + else: + named_modules_queue.append((n, m)) + continue + if hasattr(sharder, "reshard"): + sharded_module = child_module + sharded_module_fqn = child_path + break + else: # Parse the fqn to identify module to be resharded + steps = sharded_module_fqn.split(".") + sharded_module = self.module + for s in steps: + sharded_module = getattr(sharded_module, s) + + # TODO: consider sharding unsharded module + assert isinstance( + sharded_module, ShardedModule + ), "Given module is unsharded" + assert changed_shard_to_params is not None + sharder_key = sharded_module.unsharded_module_type + sharder = self._sharder_map[sharder_key] + assert hasattr( + sharder, "reshard" + ), "reshard is not implemented for this sharder" + + assert sharder is not None, "Could not find sharder to reshard" + assert ( + sharded_module is not None and sharded_module_fqn is not None + ), "Could not find sharded_module to reshard" + data_volume, delta_plan = changed_shard_to_params[sharded_module_fqn] + sharded_module = sharder.reshard( # pyre-ignore sharded_module, - changed_shard_to_params, + delta_plan, self._env, self.device, ) @@ -745,7 +779,7 @@ def reshard( self._dmp_wrapped_module.module # pyre-ignore ) self._plan.plan[sharded_module_fqn] = sharded_module.module_sharding_plan - return sharded_module + return data_volume class DMPCollection(DistributedModelParallel): diff --git a/torchrec/distributed/sharding/dynamic_sharding.py b/torchrec/distributed/sharding/dynamic_sharding.py index e401b041e..f99a47185 100644 --- a/torchrec/distributed/sharding/dynamic_sharding.py +++ b/torchrec/distributed/sharding/dynamic_sharding.py @@ -17,6 +17,7 @@ from torchrec.distributed.embedding_sharding import EmbeddingShardingInfo from torchrec.distributed.types import ( EmbeddingModuleShardingPlan, + ModuleShardingPlan, ParameterSharding, ShardedModule, ShardedTensor, @@ -869,7 +870,7 @@ def update_module_sharding_plan( # Utils -def output_sharding_plan_delta( +def output_sharding_plan_delta_single( old_plan: EmbeddingModuleShardingPlan, new_plan: EmbeddingModuleShardingPlan, return_data_volume: bool = False, @@ -880,7 +881,7 @@ def output_sharding_plan_delta( have the same number of parameters/tables. This is useful for Dynamic Sharding since Resharding API takes in only the - ParameterSharding or shards that needs to be moved. + ParameterSharding or shards that needs to be moved. Takes in EmbeddingModuleShardingPlan. """ assert len(old_plan) == len(new_plan) diff = EmbeddingModuleShardingPlan( @@ -900,3 +901,29 @@ def output_sharding_plan_delta( ) # Asumming float datatype return (data_volume, diff) + + +def output_sharding_plans_delta( + old_plan: Dict[str, EmbeddingModuleShardingPlan], + new_plan: Dict[str, EmbeddingModuleShardingPlan], + return_data_volume: bool = False, +) -> Dict[str, Tuple[float, EmbeddingModuleShardingPlan]]: + """ + Compute and return a new sharding plan that is the delta + between new and old embedding module plans. Assumes that the old and new plan + have the same number of parameters/tables. + + This is useful for Dynamic Sharding since Resharding API takes in only the + ParameterSharding or shards that needs to be moved. Takes in a Dict + which is the format of DMP sharding plans. + """ + delta_plans: Dict[str, Tuple[float, EmbeddingModuleShardingPlan]] = {} + for key, plan in old_plan.items(): + assert ( + key in new_plan + ), f"Found mismatch between old and new plans, key: {key} not found in new plan" + + delta_plans[key] = output_sharding_plan_delta_single( + plan, new_plan[key], return_data_volume + ) + return delta_plans diff --git a/torchrec/distributed/test_utils/test_model_parallel.py b/torchrec/distributed/test_utils/test_model_parallel.py index 107b927af..94ba5514b 100644 --- a/torchrec/distributed/test_utils/test_model_parallel.py +++ b/torchrec/distributed/test_utils/test_model_parallel.py @@ -249,6 +249,7 @@ def _test_dynamic_sharding( lengths_dtype: torch.dtype = torch.int64, sharding_type: ShardingType = None, # pyre-ignore random_seed: int = 0, + skip_passing_resharding_fqn: bool = False, ) -> None: """ Tests the reshard API with dynamic_sharding_test, which creates 2 identical models @@ -297,6 +298,7 @@ def _test_dynamic_sharding( lengths_dtype=lengths_dtype, random_seed=random_seed, sharding_type=sharding_type, + skip_passing_resharding_fqn=skip_passing_resharding_fqn, ) diff --git a/torchrec/distributed/test_utils/test_sharding.py b/torchrec/distributed/test_utils/test_sharding.py index c955249d2..6e2cd35ca 100644 --- a/torchrec/distributed/test_utils/test_sharding.py +++ b/torchrec/distributed/test_utils/test_sharding.py @@ -38,7 +38,7 @@ ParameterConstraints, Topology, ) -from torchrec.distributed.sharding.dynamic_sharding import output_sharding_plan_delta +from torchrec.distributed.sharding.dynamic_sharding import output_sharding_plans_delta from torchrec.distributed.sharding_plan import ( construct_module_sharding_plan, get_sharding_constructor_from_type, @@ -341,6 +341,7 @@ def dynamic_sharding_test( lengths_dtype: torch.dtype = torch.int64, sharding_type: ShardingType = None, # pyre-ignore random_seed: int = 0, + skip_passing_resharding_fqn: bool = False, ) -> None: """ Test case for dynamic sharding: @@ -590,8 +591,8 @@ def dynamic_sharding_test( exclude_predfix="sparse.pooled_embedding_arch.embedding_modules._itp_iter", ) - _, new_module_sharding_plan_delta = output_sharding_plan_delta( - plan.plan["sparse.ebc"], new_module_sharding_plan # pyre-ignore + new_module_sharding_plan_delta = output_sharding_plans_delta( + plan.plan, plan_1.plan # pyre-ignore ) dense_m1_optim = KeyedOptimizerWrapper( @@ -615,7 +616,10 @@ def dynamic_sharding_test( True, ) - local_m1_dmp.reshard("sparse.ebc", new_module_sharding_plan_delta) + local_m1_dmp.reshard( + sharded_module_fqn=None if skip_passing_resharding_fqn else "sparse.ebc", + changed_shard_to_params=new_module_sharding_plan_delta, + ) # Must recreate local_m1_opt, because current local_m1_opt is a copy of underlying fused_opt diff --git a/torchrec/distributed/tests/test_dynamic_sharding.py b/torchrec/distributed/tests/test_dynamic_sharding.py index 11dd86a12..9b06db0b6 100644 --- a/torchrec/distributed/tests/test_dynamic_sharding.py +++ b/torchrec/distributed/tests/test_dynamic_sharding.py @@ -11,7 +11,7 @@ import random import unittest -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import Any, Dict, List, Optional import hypothesis.strategies as st @@ -30,7 +30,9 @@ from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection from torchrec.distributed.fbgemm_qcomm_codec import CommType, QCommsConfig -from torchrec.distributed.sharding.dynamic_sharding import output_sharding_plan_delta +from torchrec.distributed.sharding.dynamic_sharding import ( + output_sharding_plan_delta_single, +) from torchrec.distributed.sharding_plan import ( column_wise, @@ -284,7 +286,7 @@ def _test_ebc_resharding( device=ctx.device, ) - _, new_module_sharding_plan_delta = output_sharding_plan_delta( + _, new_module_sharding_plan_delta = output_sharding_plan_delta_single( module_sharding_plan, new_module_sharding_plan ) @@ -545,6 +547,7 @@ class MultiRankDMPDynamicShardingTest(ModelParallelTestShared): data_type=st.sampled_from([DataType.FP16, DataType.FP32]), random_seed=st.integers(0, 1000), world_size=st.sampled_from([2, 4]), + skip_passing_resharding_fqn=st.booleans(), ) @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) def test_sharding( @@ -556,6 +559,7 @@ def test_sharding( data_type: DataType, random_seed: int, world_size: int, + skip_passing_resharding_fqn: bool, ) -> None: """ Tests resharding from DMP module interface with conditional optimizer selection: @@ -618,11 +622,12 @@ def test_sharding( sharding_type=sharding_type_e, random_seed=random_seed, world_size=world_size, + skip_passing_resharding_fqn=skip_passing_resharding_fqn, ) class SingleRankDynamicShardingUtilsTest(unittest.TestCase): - def test_output_sharding_plan_delta(self) -> None: + def test_output_sharding_plan_delta_single(self) -> None: """ Tests output_sharding_plan_delta function """ @@ -673,7 +678,7 @@ def test_output_sharding_plan_delta(self) -> None: device_type="cuda" if torch.cuda.is_available() else "cpu", ) - _, new_module_sharding_plan_delta = output_sharding_plan_delta( + _, new_module_sharding_plan_delta = output_sharding_plan_delta_single( module_sharding_plan, new_module_sharding_plan ) @@ -696,3 +701,86 @@ def test_output_sharding_plan_delta(self) -> None: ) # NOTE there are other attributes to test for equivalence in ParameterSharding type # but the ones included here are the most important. + + def test_output_sharding_plans_delta(self) -> None: + """ + Tests output_sharding_plan_delta_single with multiple tables and sharding types. + """ + num_tables = 3 + world_size = 4 + data_type = DataType.FP32 + embedding_dim = 16 + num_embeddings = 8 + + # Generate random ranks for table-wise and column-wise sharding + # Table 0: Table-wise, Table 1: Column-wise, Table 2: Table-wise + ranks_per_tables = [1, 2, 1] + old_ranks = generate_rank_placements(world_size, num_tables, ranks_per_tables) + new_ranks = generate_rank_placements(world_size, num_tables, ranks_per_tables) + # Ensure at least one table changes ranks + while new_ranks == old_ranks: + new_ranks = generate_rank_placements( + world_size, num_tables, ranks_per_tables + ) + + per_param_sharding = {} + new_per_param_sharding = {} + + # Table 0: Table-wise, Table 1: Column-wise, Table 2: Table-wise + per_param_sharding[table_name(0)] = table_wise(rank=old_ranks[0][0]) + per_param_sharding[table_name(1)] = column_wise(ranks=old_ranks[1]) + per_param_sharding[table_name(2)] = table_wise(rank=old_ranks[2][0]) + + new_per_param_sharding[table_name(0)] = table_wise(rank=new_ranks[0][0]) + new_per_param_sharding[table_name(1)] = column_wise(ranks=new_ranks[1]) + new_per_param_sharding[table_name(2)] = table_wise(rank=new_ranks[2][0]) + + embedding_bag_config = generate_embedding_bag_config( + data_type, num_tables, embedding_dim, num_embeddings + ) + + module_sharding_plan = construct_module_sharding_plan( + EmbeddingBagCollection(tables=embedding_bag_config), + per_param_sharding=per_param_sharding, + local_size=world_size, + world_size=world_size, + device_type="cuda" if torch.cuda.is_available() else "cpu", + ) + + new_module_sharding_plan = construct_module_sharding_plan( + EmbeddingBagCollection(tables=embedding_bag_config), + per_param_sharding=new_per_param_sharding, + local_size=world_size, + world_size=world_size, + device_type="cuda" if torch.cuda.is_available() else "cpu", + ) + + _, new_module_sharding_plan_delta = output_sharding_plan_delta_single( + module_sharding_plan, new_module_sharding_plan + ) + + # The delta should only contain tables whose sharding/ranks changed + for t_name, new_sharding in new_module_sharding_plan.items(): + old_sharding = module_sharding_plan[t_name] + if ( + new_sharding.ranks != old_sharding.ranks + or new_sharding.sharding_type != old_sharding.sharding_type + or new_sharding.compute_kernel != old_sharding.compute_kernel + ): + assert t_name in new_module_sharding_plan_delta + assert ( + new_module_sharding_plan_delta[t_name].ranks == new_sharding.ranks + ) + assert ( + new_module_sharding_plan_delta[t_name].sharding_type + == new_sharding.sharding_type + ) + assert ( + new_module_sharding_plan_delta[t_name].compute_kernel + == new_sharding.compute_kernel + ) + else: + assert t_name not in new_module_sharding_plan_delta + + # The delta should not contain more keys than the new plan + assert len(new_module_sharding_plan_delta) <= len(new_module_sharding_plan)