Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 0 additions & 169 deletions torchrec/distributed/benchmark/benchmark_resharding_handler.py

This file was deleted.

4 changes: 2 additions & 2 deletions torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1722,7 +1722,7 @@ def compute_and_output_dist(

def update_shards(
self,
changed_sharding_params: Dict[str, ParameterSharding], # NOTE: only delta
changed_sharding_params: EmbeddingModuleShardingPlan, # NOTE: only delta
env: ShardingEnv,
device: Optional[torch.device],
) -> None:
Expand Down Expand Up @@ -1964,7 +1964,7 @@ def shardable_parameters(
def reshard(
self,
sharded_module: ShardedEmbeddingBagCollection,
changed_shard_to_params: Dict[str, ParameterSharding],
changed_shard_to_params: EmbeddingModuleShardingPlan,
env: ShardingEnv,
device: Optional[torch.device] = None,
) -> ShardedEmbeddingBagCollection:
Expand Down
80 changes: 57 additions & 23 deletions torchrec/distributed/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@
from torchrec.distributed.types import (
DMPCollectionConfig,
DMPCollectionContext,
EmbeddingModuleShardingPlan,
EnumerableShardingSpec,
ModuleSharder,
ParameterSharding,
ShardedModule,
ShardingEnv,
ShardingEnv2D,
Expand Down Expand Up @@ -258,11 +258,12 @@ def __init__(
device = torch.device("cpu")
self.device: torch.device = device

if sharders is None:
sharders = get_default_sharders()
self.sharders: List[ModuleSharder[nn.modules.module.Module]] = (
get_default_sharders() if sharders is None else sharders
)

self._sharder_map: Dict[Type[nn.Module], ModuleSharder[nn.Module]] = {
sharder.module_type: sharder for sharder in sharders
sharder.module_type: sharder for sharder in self.sharders
}

if data_parallel_wrapper is None:
Expand All @@ -279,9 +280,9 @@ def __init__(
)
pg = self._env.process_group
if pg is not None:
plan = planner.collective_plan(module, sharders, pg)
plan = planner.collective_plan(module, self.sharders, pg)
else:
plan = planner.plan(module, sharders)
plan = planner.plan(module, self.sharders)
self._plan: ShardingPlan = plan
self._dmp_wrapped_module: nn.Module = self._init_dmp(module)
self._optim: CombinedOptimizer = self._init_optim(self._dmp_wrapped_module)
Expand Down Expand Up @@ -668,11 +669,12 @@ def _reset_parameters(module: nn.Module) -> None:

def reshard(
self,
sharded_module_fqn: str,
changed_shard_to_params: Dict[str, ParameterSharding],
) -> None:
changed_shard_to_params: Dict[str, Tuple[float, EmbeddingModuleShardingPlan]],
sharded_module_fqn: Optional[str] = None,
) -> float:
"""
Reshards an already-sharded module in the DMP given a set of ParameterShardings to change placements.
Returns the data volume resharded

This method allows you to dynamically change the sharding strategy for a specific module
without recreating the entire DMP. It's particularly useful for:
Expand Down Expand Up @@ -721,21 +723,53 @@ def reshard(
- After resharding, the optimizer state is maintained for the module
- The sharding plan is updated to reflect the new configuration
"""
steps = sharded_module_fqn.split(".")
sharded_module = self.module
for s in steps:
sharded_module = getattr(sharded_module, s)

assert isinstance(sharded_module, ShardedModule)
assert changed_shard_to_params is not None
sharder_key = sharded_module.unsharded_module_type
sharder = self._sharder_map[sharder_key]
assert hasattr(
sharder, "reshard"
), "reshard is not implemented for this sharder"
sharder = None
sharded_module = None

if sharded_module_fqn is None:
named_modules_queue = [("", self.module)]
while named_modules_queue:
child_path, child_module = named_modules_queue.pop(0)
if isinstance(child_module, ShardedModule):
sharder_key = child_module.unsharded_module_type
sharder = self._sharder_map.get(sharder_key, None)
if not sharder:
for n, m in child_module.named_children():
if child_path != "":
named_modules_queue.append((child_path + "." + n, m))
else:
named_modules_queue.append((n, m))
continue
if hasattr(sharder, "reshard"):
sharded_module = child_module
sharded_module_fqn = child_path
break
else: # Parse the fqn to identify module to be resharded
steps = sharded_module_fqn.split(".")
sharded_module = self.module
for s in steps:
sharded_module = getattr(sharded_module, s)

# TODO: consider sharding unsharded module
assert isinstance(
sharded_module, ShardedModule
), "Given module is unsharded"
assert changed_shard_to_params is not None
sharder_key = sharded_module.unsharded_module_type
sharder = self._sharder_map[sharder_key]
assert hasattr(
sharder, "reshard"
), "reshard is not implemented for this sharder"

assert sharder is not None, "Could not find sharder to reshard"
assert (
sharded_module is not None and sharded_module_fqn is not None
), "Could not find sharded_module to reshard"
data_volume, delta_plan = changed_shard_to_params[sharded_module_fqn]

sharded_module = sharder.reshard( # pyre-ignore
sharded_module,
changed_shard_to_params,
delta_plan,
self._env,
self.device,
)
Expand All @@ -745,7 +779,7 @@ def reshard(
self._dmp_wrapped_module.module # pyre-ignore
)
self._plan.plan[sharded_module_fqn] = sharded_module.module_sharding_plan
return sharded_module
return data_volume


class DMPCollection(DistributedModelParallel):
Expand Down
31 changes: 29 additions & 2 deletions torchrec/distributed/sharding/dynamic_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from torchrec.distributed.embedding_sharding import EmbeddingShardingInfo
from torchrec.distributed.types import (
EmbeddingModuleShardingPlan,
ModuleShardingPlan,
ParameterSharding,
ShardedModule,
ShardedTensor,
Expand Down Expand Up @@ -869,7 +870,7 @@ def update_module_sharding_plan(


# Utils
def output_sharding_plan_delta(
def output_sharding_plan_delta_single(
old_plan: EmbeddingModuleShardingPlan,
new_plan: EmbeddingModuleShardingPlan,
return_data_volume: bool = False,
Expand All @@ -880,7 +881,7 @@ def output_sharding_plan_delta(
have the same number of parameters/tables.

This is useful for Dynamic Sharding since Resharding API takes in only the
ParameterSharding or shards that needs to be moved.
ParameterSharding or shards that needs to be moved. Takes in EmbeddingModuleShardingPlan.
"""
assert len(old_plan) == len(new_plan)
diff = EmbeddingModuleShardingPlan(
Expand All @@ -900,3 +901,29 @@ def output_sharding_plan_delta(
) # Asumming float datatype

return (data_volume, diff)


def output_sharding_plans_delta(
old_plan: Dict[str, EmbeddingModuleShardingPlan],
new_plan: Dict[str, EmbeddingModuleShardingPlan],
return_data_volume: bool = False,
) -> Dict[str, Tuple[float, EmbeddingModuleShardingPlan]]:
"""
Compute and return a new sharding plan that is the delta
between new and old embedding module plans. Assumes that the old and new plan
have the same number of parameters/tables.

This is useful for Dynamic Sharding since Resharding API takes in only the
ParameterSharding or shards that needs to be moved. Takes in a Dict
which is the format of DMP sharding plans.
"""
delta_plans: Dict[str, Tuple[float, EmbeddingModuleShardingPlan]] = {}
for key, plan in old_plan.items():
assert (
key in new_plan
), f"Found mismatch between old and new plans, key: {key} not found in new plan"

delta_plans[key] = output_sharding_plan_delta_single(
plan, new_plan[key], return_data_volume
)
return delta_plans
2 changes: 2 additions & 0 deletions torchrec/distributed/test_utils/test_model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def _test_dynamic_sharding(
lengths_dtype: torch.dtype = torch.int64,
sharding_type: ShardingType = None, # pyre-ignore
random_seed: int = 0,
skip_passing_resharding_fqn: bool = False,
) -> None:
"""
Tests the reshard API with dynamic_sharding_test, which creates 2 identical models
Expand Down Expand Up @@ -297,6 +298,7 @@ def _test_dynamic_sharding(
lengths_dtype=lengths_dtype,
random_seed=random_seed,
sharding_type=sharding_type,
skip_passing_resharding_fqn=skip_passing_resharding_fqn,
)


Expand Down
Loading
Loading