37
37
from torchrec .distributed .types import (
38
38
DMPCollectionConfig ,
39
39
DMPCollectionContext ,
40
+ EmbeddingModuleShardingPlan ,
40
41
EnumerableShardingSpec ,
41
42
ModuleSharder ,
42
- ParameterSharding ,
43
43
ShardedModule ,
44
44
ShardingEnv ,
45
45
ShardingEnv2D ,
@@ -258,11 +258,12 @@ def __init__(
258
258
device = torch .device ("cpu" )
259
259
self .device : torch .device = device
260
260
261
- if sharders is None :
262
- sharders = get_default_sharders ()
261
+ self .sharders : List [ModuleSharder [nn .modules .module .Module ]] = (
262
+ get_default_sharders () if sharders is None else sharders
263
+ )
263
264
264
265
self ._sharder_map : Dict [Type [nn .Module ], ModuleSharder [nn .Module ]] = {
265
- sharder .module_type : sharder for sharder in sharders
266
+ sharder .module_type : sharder for sharder in self . sharders
266
267
}
267
268
268
269
if data_parallel_wrapper is None :
@@ -279,9 +280,9 @@ def __init__(
279
280
)
280
281
pg = self ._env .process_group
281
282
if pg is not None :
282
- plan = planner .collective_plan (module , sharders , pg )
283
+ plan = planner .collective_plan (module , self . sharders , pg )
283
284
else :
284
- plan = planner .plan (module , sharders )
285
+ plan = planner .plan (module , self . sharders )
285
286
self ._plan : ShardingPlan = plan
286
287
self ._dmp_wrapped_module : nn .Module = self ._init_dmp (module )
287
288
self ._optim : CombinedOptimizer = self ._init_optim (self ._dmp_wrapped_module )
@@ -668,11 +669,12 @@ def _reset_parameters(module: nn.Module) -> None:
668
669
669
670
def reshard (
670
671
self ,
671
- sharded_module_fqn : str ,
672
- changed_shard_to_params : Dict [str , ParameterSharding ] ,
673
- ) -> None :
672
+ changed_shard_to_params : Dict [ str , Tuple [ float , EmbeddingModuleShardingPlan ]] ,
673
+ sharded_module_fqn : Optional [str ] = None ,
674
+ ) -> float :
674
675
"""
675
676
Reshards an already-sharded module in the DMP given a set of ParameterShardings to change placements.
677
+ Returns the data volume resharded
676
678
677
679
This method allows you to dynamically change the sharding strategy for a specific module
678
680
without recreating the entire DMP. It's particularly useful for:
@@ -721,21 +723,53 @@ def reshard(
721
723
- After resharding, the optimizer state is maintained for the module
722
724
- The sharding plan is updated to reflect the new configuration
723
725
"""
724
- steps = sharded_module_fqn .split ("." )
725
- sharded_module = self .module
726
- for s in steps :
727
- sharded_module = getattr (sharded_module , s )
728
-
729
- assert isinstance (sharded_module , ShardedModule )
730
- assert changed_shard_to_params is not None
731
- sharder_key = sharded_module .unsharded_module_type
732
- sharder = self ._sharder_map [sharder_key ]
733
- assert hasattr (
734
- sharder , "reshard"
735
- ), "reshard is not implemented for this sharder"
726
+ sharder = None
727
+ sharded_module = None
728
+
729
+ if sharded_module_fqn is None :
730
+ named_modules_queue = [("" , self .module )]
731
+ while named_modules_queue :
732
+ child_path , child_module = named_modules_queue .pop (0 )
733
+ if isinstance (child_module , ShardedModule ):
734
+ sharder_key = child_module .unsharded_module_type
735
+ sharder = self ._sharder_map .get (sharder_key , None )
736
+ if not sharder :
737
+ for n , m in child_module .named_children ():
738
+ if child_path != "" :
739
+ named_modules_queue .append ((child_path + "." + n , m ))
740
+ else :
741
+ named_modules_queue .append ((n , m ))
742
+ continue
743
+ if hasattr (sharder , "reshard" ):
744
+ sharded_module = child_module
745
+ sharded_module_fqn = child_path
746
+ break
747
+ else : # Parse the fqn to identify module to be resharded
748
+ steps = sharded_module_fqn .split ("." )
749
+ sharded_module = self .module
750
+ for s in steps :
751
+ sharded_module = getattr (sharded_module , s )
752
+
753
+ # TODO: consider sharding unsharded module
754
+ assert isinstance (
755
+ sharded_module , ShardedModule
756
+ ), "Given module is unsharded"
757
+ assert changed_shard_to_params is not None
758
+ sharder_key = sharded_module .unsharded_module_type
759
+ sharder = self ._sharder_map [sharder_key ]
760
+ assert hasattr (
761
+ sharder , "reshard"
762
+ ), "reshard is not implemented for this sharder"
763
+
764
+ assert sharder is not None , "Could not find sharder to reshard"
765
+ assert (
766
+ sharded_module is not None and sharded_module_fqn is not None
767
+ ), "Could not find sharded_module to reshard"
768
+ data_volume , delta_plan = changed_shard_to_params [sharded_module_fqn ]
769
+
736
770
sharded_module = sharder .reshard ( # pyre-ignore
737
771
sharded_module ,
738
- changed_shard_to_params ,
772
+ delta_plan ,
739
773
self ._env ,
740
774
self .device ,
741
775
)
@@ -745,7 +779,7 @@ def reshard(
745
779
self ._dmp_wrapped_module .module # pyre-ignore
746
780
)
747
781
self ._plan .plan [sharded_module_fqn ] = sharded_module .module_sharding_plan
748
- return sharded_module
782
+ return data_volume
749
783
750
784
751
785
class DMPCollection (DistributedModelParallel ):
0 commit comments