diff --git a/.github/workflows/integration_test_8gpu_compiler_toolkit.yaml b/.github/workflows/integration_test_8gpu_compiler_toolkit.yaml index 78a7572849..e0a500ead9 100644 --- a/.github/workflows/integration_test_8gpu_compiler_toolkit.yaml +++ b/.github/workflows/integration_test_8gpu_compiler_toolkit.yaml @@ -7,10 +7,12 @@ on: - ciflow/8gpu/* paths: - 'torchtitan/experiments/compiler_toolkit/**' + - 'torchtitan/experiments/simple_fsdp/**' - '.github/workflows/integration_test_8gpu_compiler_toolkit.yaml' pull_request: paths: - 'torchtitan/experiments/compiler_toolkit/**' + - 'torchtitan/experiments/simple_fsdp/**' - '.github/workflows/integration_test_8gpu_compiler_toolkit.yaml' schedule: # Runs every 12 hours diff --git a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py index 5e62bd4891..b49ebb7edb 100644 --- a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py @@ -149,20 +149,12 @@ def parallelize_deepseekv3( ): experts_shard_dim = 1 - # when EP is enable, the routed experts' gradient reduction is done over - # edp_mesh instead of whole dp_mesh. - # we add a `fsdp_gradient_divide_factor` to scale gradient over dp_mesh - # to be consistent with data. - # TODO (ruisizhang123): update the logic following the link below instead - # of using a reduction_divide_factor - # https://github.com/pytorch/torchtitan/pull/1803#discussion_r2415190883 transformer_block.moe.experts = data_parallel( transformer_block.moe.experts, edp_mesh, dp_mode, mp_policy=mp_policy, shard_dim=experts_shard_dim, - reduction_divide_factor=parallel_dims.fsdp_gradient_divide_factor, ) model = data_parallel( diff --git a/torchtitan/experiments/simple_fsdp/simple_fsdp.py b/torchtitan/experiments/simple_fsdp/simple_fsdp.py index 7fff28faab..791dc9bc58 100644 --- a/torchtitan/experiments/simple_fsdp/simple_fsdp.py +++ b/torchtitan/experiments/simple_fsdp/simple_fsdp.py @@ -42,37 +42,6 @@ class MixedPrecisionPolicy: reduce_dtype: torch.dtype | None = None -class _ScaledPartial(Partial): - # A subclass of Partial placement that allows user to perform reduction with a custom - # factor (reduction_divide_factor) other than the default world size. - def __init__( - self, - reduction_divide_factor: float, - ): - self.reduction_divide_factor = reduction_divide_factor - super().__init__(reduce_op="sum") - - def _reduce_value( - self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int - ) -> torch.Tensor: - # for all_reduce in DDP - tensor.div_(self.reduction_divide_factor) - reduced = super()._reduce_value(tensor, mesh, mesh_dim) - return reduced - - def _reduce_shard_value( - self, - tensor: torch.Tensor, - mesh: DeviceMesh, - mesh_dim: int, - shard_spec: Placement, - ) -> torch.Tensor: - # for reduce_scatter in FSDP - tensor.div_(self.reduction_divide_factor) - reduced = super()._reduce_shard_value(tensor, mesh, mesh_dim, shard_spec) - return reduced - - def _distribute_dtensor( tensor: DTensor, device_mesh: DeviceMesh, @@ -188,7 +157,6 @@ def __init__( param_sharding: tuple[Placement, ...], mode: str, mp_policy: MixedPrecisionPolicy | None, - reduction_divide_factor: float | None, full_dtensor: bool = False, ) -> None: super().__init__() @@ -197,11 +165,7 @@ def __init__( self.mode = mode self.compute_placements: list[Placement] = [Replicate()] * self.device_mesh.ndim self.grad_placements: list[Placement] = [ - _ScaledPartial( - reduction_divide_factor=reduction_divide_factor, - ) - if reduction_divide_factor is not None - else Partial(reduce_op="avg") + Partial(reduce_op="sum") ] * self.device_mesh.ndim mp_policy = mp_policy or MixedPrecisionPolicy() self.param_dtype: torch.dtype | None = mp_policy.param_dtype @@ -286,7 +250,6 @@ def data_parallel( mode: str = "replicate", mp_policy: MixedPrecisionPolicy | None = None, shard_dim: int = 0, - reduction_divide_factor: float | None = None, full_dtensor: bool = False, ) -> nn.Module: param_sharding: tuple[Placement, ...] @@ -346,7 +309,6 @@ def data_parallel( param_sharding, mode, mp_policy=mp_policy, - reduction_divide_factor=reduction_divide_factor, full_dtensor=full_dtensor, ), )