Skip to content

Commit 5af1b1b

Browse files
committed
Manual Resharding Handler (#3398)
Summary: Pull Request resolved: #3398 X-link: #3398 * `DistributedModelParallel`: stores sharders passed into planner, this will be reused at resharding * Removed hardcoding for module FQNs and plan keys since not compatible with all models. Leveraging how planner/enumerator identifies sharders/EBCs to find the right plan Differential Revision: D83392188
1 parent ef3e5ca commit 5af1b1b

File tree

7 files changed

+191
-205
lines changed

7 files changed

+191
-205
lines changed

torchrec/distributed/benchmark/benchmark_resharding_handler.py

Lines changed: 0 additions & 169 deletions
This file was deleted.

torchrec/distributed/embeddingbag.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1722,7 +1722,7 @@ def compute_and_output_dist(
17221722

17231723
def update_shards(
17241724
self,
1725-
changed_sharding_params: Dict[str, ParameterSharding], # NOTE: only delta
1725+
changed_sharding_params: EmbeddingModuleShardingPlan, # NOTE: only delta
17261726
env: ShardingEnv,
17271727
device: Optional[torch.device],
17281728
) -> None:
@@ -1964,7 +1964,7 @@ def shardable_parameters(
19641964
def reshard(
19651965
self,
19661966
sharded_module: ShardedEmbeddingBagCollection,
1967-
changed_shard_to_params: Dict[str, ParameterSharding],
1967+
changed_shard_to_params: EmbeddingModuleShardingPlan,
19681968
env: ShardingEnv,
19691969
device: Optional[torch.device] = None,
19701970
) -> ShardedEmbeddingBagCollection:

torchrec/distributed/model_parallel.py

Lines changed: 57 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@
3737
from torchrec.distributed.types import (
3838
DMPCollectionConfig,
3939
DMPCollectionContext,
40+
EmbeddingModuleShardingPlan,
4041
EnumerableShardingSpec,
4142
ModuleSharder,
42-
ParameterSharding,
4343
ShardedModule,
4444
ShardingEnv,
4545
ShardingEnv2D,
@@ -258,11 +258,12 @@ def __init__(
258258
device = torch.device("cpu")
259259
self.device: torch.device = device
260260

261-
if sharders is None:
262-
sharders = get_default_sharders()
261+
self.sharders: List[ModuleSharder[nn.modules.module.Module]] = (
262+
get_default_sharders() if sharders is None else sharders
263+
)
263264

264265
self._sharder_map: Dict[Type[nn.Module], ModuleSharder[nn.Module]] = {
265-
sharder.module_type: sharder for sharder in sharders
266+
sharder.module_type: sharder for sharder in self.sharders
266267
}
267268

268269
if data_parallel_wrapper is None:
@@ -279,9 +280,9 @@ def __init__(
279280
)
280281
pg = self._env.process_group
281282
if pg is not None:
282-
plan = planner.collective_plan(module, sharders, pg)
283+
plan = planner.collective_plan(module, self.sharders, pg)
283284
else:
284-
plan = planner.plan(module, sharders)
285+
plan = planner.plan(module, self.sharders)
285286
self._plan: ShardingPlan = plan
286287
self._dmp_wrapped_module: nn.Module = self._init_dmp(module)
287288
self._optim: CombinedOptimizer = self._init_optim(self._dmp_wrapped_module)
@@ -668,11 +669,12 @@ def _reset_parameters(module: nn.Module) -> None:
668669

669670
def reshard(
670671
self,
671-
sharded_module_fqn: str,
672-
changed_shard_to_params: Dict[str, ParameterSharding],
673-
) -> None:
672+
changed_shard_to_params: Dict[str, Tuple[float, EmbeddingModuleShardingPlan]],
673+
sharded_module_fqn: Optional[str] = None,
674+
) -> float:
674675
"""
675676
Reshards an already-sharded module in the DMP given a set of ParameterShardings to change placements.
677+
Returns the data volume resharded
676678
677679
This method allows you to dynamically change the sharding strategy for a specific module
678680
without recreating the entire DMP. It's particularly useful for:
@@ -721,21 +723,53 @@ def reshard(
721723
- After resharding, the optimizer state is maintained for the module
722724
- The sharding plan is updated to reflect the new configuration
723725
"""
724-
steps = sharded_module_fqn.split(".")
725-
sharded_module = self.module
726-
for s in steps:
727-
sharded_module = getattr(sharded_module, s)
728-
729-
assert isinstance(sharded_module, ShardedModule)
730-
assert changed_shard_to_params is not None
731-
sharder_key = sharded_module.unsharded_module_type
732-
sharder = self._sharder_map[sharder_key]
733-
assert hasattr(
734-
sharder, "reshard"
735-
), "reshard is not implemented for this sharder"
726+
sharder = None
727+
sharded_module = None
728+
729+
if sharded_module_fqn is None:
730+
named_modules_queue = [("", self.module)]
731+
while named_modules_queue:
732+
child_path, child_module = named_modules_queue.pop(0)
733+
if isinstance(child_module, ShardedModule):
734+
sharder_key = child_module.unsharded_module_type
735+
sharder = self._sharder_map.get(sharder_key, None)
736+
if not sharder:
737+
for n, m in child_module.named_children():
738+
if child_path != "":
739+
named_modules_queue.append((child_path + "." + n, m))
740+
else:
741+
named_modules_queue.append((n, m))
742+
continue
743+
if hasattr(sharder, "reshard"):
744+
sharded_module = child_module
745+
sharded_module_fqn = child_path
746+
break
747+
else: # Parse the fqn to identify module to be resharded
748+
steps = sharded_module_fqn.split(".")
749+
sharded_module = self.module
750+
for s in steps:
751+
sharded_module = getattr(sharded_module, s)
752+
753+
# TODO: consider sharding unsharded module
754+
assert isinstance(
755+
sharded_module, ShardedModule
756+
), "Given module is unsharded"
757+
assert changed_shard_to_params is not None
758+
sharder_key = sharded_module.unsharded_module_type
759+
sharder = self._sharder_map[sharder_key]
760+
assert hasattr(
761+
sharder, "reshard"
762+
), "reshard is not implemented for this sharder"
763+
764+
assert sharder is not None, "Could not find sharder to reshard"
765+
assert (
766+
sharded_module is not None and sharded_module_fqn is not None
767+
), "Could not find sharded_module to reshard"
768+
data_volume, delta_plan = changed_shard_to_params[sharded_module_fqn]
769+
736770
sharded_module = sharder.reshard( # pyre-ignore
737771
sharded_module,
738-
changed_shard_to_params,
772+
delta_plan,
739773
self._env,
740774
self.device,
741775
)
@@ -745,7 +779,7 @@ def reshard(
745779
self._dmp_wrapped_module.module # pyre-ignore
746780
)
747781
self._plan.plan[sharded_module_fqn] = sharded_module.module_sharding_plan
748-
return sharded_module
782+
return data_volume
749783

750784

751785
class DMPCollection(DistributedModelParallel):

torchrec/distributed/sharding/dynamic_sharding.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from torchrec.distributed.embedding_sharding import EmbeddingShardingInfo
1818
from torchrec.distributed.types import (
1919
EmbeddingModuleShardingPlan,
20+
ModuleShardingPlan,
2021
ParameterSharding,
2122
ShardedModule,
2223
ShardedTensor,
@@ -869,7 +870,7 @@ def update_module_sharding_plan(
869870

870871

871872
# Utils
872-
def output_sharding_plan_delta(
873+
def output_sharding_plan_delta_single(
873874
old_plan: EmbeddingModuleShardingPlan,
874875
new_plan: EmbeddingModuleShardingPlan,
875876
return_data_volume: bool = False,
@@ -880,7 +881,7 @@ def output_sharding_plan_delta(
880881
have the same number of parameters/tables.
881882
882883
This is useful for Dynamic Sharding since Resharding API takes in only the
883-
ParameterSharding or shards that needs to be moved.
884+
ParameterSharding or shards that needs to be moved. Takes in EmbeddingModuleShardingPlan.
884885
"""
885886
assert len(old_plan) == len(new_plan)
886887
diff = EmbeddingModuleShardingPlan(
@@ -900,3 +901,29 @@ def output_sharding_plan_delta(
900901
) # Asumming float datatype
901902

902903
return (data_volume, diff)
904+
905+
906+
def output_sharding_plans_delta(
907+
old_plan: Dict[str, EmbeddingModuleShardingPlan],
908+
new_plan: Dict[str, EmbeddingModuleShardingPlan],
909+
return_data_volume: bool = False,
910+
) -> Dict[str, Tuple[float, EmbeddingModuleShardingPlan]]:
911+
"""
912+
Compute and return a new sharding plan that is the delta
913+
between new and old embedding module plans. Assumes that the old and new plan
914+
have the same number of parameters/tables.
915+
916+
This is useful for Dynamic Sharding since Resharding API takes in only the
917+
ParameterSharding or shards that needs to be moved. Takes in a Dict
918+
which is the format of DMP sharding plans.
919+
"""
920+
delta_plans: Dict[str, Tuple[float, EmbeddingModuleShardingPlan]] = {}
921+
for key, plan in old_plan.items():
922+
assert (
923+
key in new_plan
924+
), f"Found mismatch between old and new plans, key: {key} not found in new plan"
925+
926+
delta_plans[key] = output_sharding_plan_delta_single(
927+
plan, new_plan[key], return_data_volume
928+
)
929+
return delta_plans

torchrec/distributed/test_utils/test_model_parallel.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ def _test_dynamic_sharding(
249249
lengths_dtype: torch.dtype = torch.int64,
250250
sharding_type: ShardingType = None, # pyre-ignore
251251
random_seed: int = 0,
252+
skip_passing_resharding_fqn: bool = False,
252253
) -> None:
253254
"""
254255
Tests the reshard API with dynamic_sharding_test, which creates 2 identical models
@@ -297,6 +298,7 @@ def _test_dynamic_sharding(
297298
lengths_dtype=lengths_dtype,
298299
random_seed=random_seed,
299300
sharding_type=sharding_type,
301+
skip_passing_resharding_fqn=skip_passing_resharding_fqn,
300302
)
301303

302304

0 commit comments

Comments
 (0)