diff --git a/torchtitan/distributed/activation_checkpoint.py b/torchtitan/distributed/activation_checkpoint.py index 152995c1c2..b56ee6403d 100644 --- a/torchtitan/distributed/activation_checkpoint.py +++ b/torchtitan/distributed/activation_checkpoint.py @@ -8,22 +8,95 @@ # Technically, this is not a part of distributed, but distributed module is the best place to put it. import os -from collections import defaultdict +from functools import lru_cache, partial +from typing import Callable import torch import torch._functorch.config import torch.nn as nn +from torch._functorch.partitioners import get_default_op_list from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( checkpoint_wrapper as ptd_checkpoint_wrapper, ) +from torch.utils.checkpoint import CheckpointPolicy from torchtitan.config.job_config import ActivationCheckpoint as ACConfig from torchtitan.tools.logging import logger +_PolicyFn = Callable[..., CheckpointPolicy] + _layer_sac_count = 0 +def _sac_policy_fn( + ctx, + op, + *args, + compute_intensive_ops: dict, + communication_intensive_ops: dict, + **kwargs, +) -> CheckpointPolicy: + if op in (compute_intensive_ops | communication_intensive_ops): + return CheckpointPolicy.MUST_SAVE + + return CheckpointPolicy.PREFER_RECOMPUTE + + +@lru_cache() +def default_activation_checkpoint_policy() -> _PolicyFn: + """Returns a checkpointing policy function that saves results of compute and communicate ops. + """ + aten_op_types = get_default_op_list() + compute_intensive_ops = { + op.default: CheckpointPolicy.MUST_SAVE # pyrefly: ignore [missing-attribute] + for op in aten_op_types.compute_intensive_ops + } + + compute_intensive_ops[ + torch.ops.aten._scaled_dot_product_cudnn_attention.default + ] = CheckpointPolicy.MUST_SAVE + compute_intensive_ops[ + torch.ops.aten._scaled_dot_product_attention_math.default + ] = CheckpointPolicy.MUST_SAVE + compute_intensive_ops[ + torch.ops.aten._scaled_dot_product_fused_attention_overrideable.default + ] = CheckpointPolicy.MUST_SAVE + + compute_intensive_ops[ + torch.ops.higher_order.flex_attention + ] = CheckpointPolicy.MUST_SAVE + compute_intensive_ops[ + torch._higher_order_ops.flex_attention + ] = CheckpointPolicy.MUST_SAVE + if hasattr(torch._higher_order_ops, "inductor_compiled_code"): + compute_intensive_ops[ + torch._higher_order_ops.inductor_compiled_code + ] = CheckpointPolicy.MUST_SAVE + + if hasattr(torch.ops, "torch_attn") and hasattr( + torch.ops.torch_attn, "_varlen_attn" + ): + compute_intensive_ops[ + torch.ops.torch_attn._varlen_attn.default + ] = CheckpointPolicy.MUST_SAVE + + communication_intensive_ops = { + torch.ops._c10d_functional.reduce_scatter_tensor.default: CheckpointPolicy.MUST_SAVE, + torch.ops._c10d_functional.all_to_all_single.default: CheckpointPolicy.MUST_SAVE, + } + + policy_fn = partial( + _sac_policy_fn, + compute_intensive_ops=compute_intensive_ops, + communication_intensive_ops=communication_intensive_ops, + ) + # pyrefly: ignore [missing-attribute] + policy_fn.cache_hash = "default_activation_checkpoint_policy" + # pyrefly: ignore [bad-return] + return policy_fn + + def _apply_layer_sac(module: nn.Module, ac_config: ACConfig) -> nn.Module: """Apply layer selective activation checkpointing to the module. @@ -54,24 +127,22 @@ def _apply_op_sac( ac_config: ACConfig, *, base_fqn: str | None = None, - op_sac_save_list: set[torch._ops.OpOverload], ) -> nn.Module: """Apply selective activation checkpointing to the module. + This function uses the policy-based approach. The policy is obtained from + `default_activation_checkpoint_policy()` which returns a policy function that decides which + ops to save vs recompute. + Args: module (nn.Module): The module to apply selective activation checkpointing to. ac_config (ACConfig): The activation checkpointing config. base_fqn (str, optional): The base fqn of the module. Defaults to None. - op_sac_save_list (set[torch._ops.OpOverload]): The list of ops to save instead - of recomputing. Returns: nn.Module: The module with selective activation checkpointing applied. """ - from torch.utils.checkpoint import ( - CheckpointPolicy, - create_selective_checkpoint_contexts, - ) + from torch.utils.checkpoint import create_selective_checkpoint_contexts mm_recompute_shapes = set() if len(ac_config.per_op_sac_force_recompute_mm_shapes_by_fqns) > 0: @@ -95,8 +166,10 @@ def _apply_op_sac( f"Selective op AC force recomputing mms with rhs shapes {mm_recompute_shapes}" ) - def _get_custom_policy(meta): - def _custom_policy(ctx, func, *args, **kwargs): + base_policy = default_activation_checkpoint_policy() + + def _create_wrapped_policy(): + def wrapped_policy(ctx, func, *args, **kwargs) -> CheckpointPolicy: if ( func == torch.ops.aten._to_copy.default and "cuda" in str(args[0].device) @@ -105,27 +178,19 @@ def _custom_policy(ctx, func, *args, **kwargs): ): return CheckpointPolicy.MUST_SAVE - mode = "recompute" if ctx.is_recompute else "forward" - mm_count_key = f"{mode}_mm_count" - if func == torch.ops.aten.mm.default: - if args[1].shape in mm_recompute_shapes: - return CheckpointPolicy.PREFER_RECOMPUTE - meta[mm_count_key] += 1 - # Saves output of all compute ops, except every second mm - to_save = func in op_sac_save_list and not ( - func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0 - ) - return ( - CheckpointPolicy.MUST_SAVE - if to_save - else CheckpointPolicy.PREFER_RECOMPUTE - ) + if ( + func == torch.ops.aten.mm.default + and len(args) > 1 + and args[1].shape in mm_recompute_shapes + ): + return CheckpointPolicy.PREFER_RECOMPUTE + + return base_policy(ctx, func, *args, **kwargs) - return _custom_policy + return wrapped_policy def selective_checkpointing_context_fn(): - meta = defaultdict(int) - return create_selective_checkpoint_contexts(_get_custom_policy(meta)) + return create_selective_checkpoint_contexts(_create_wrapped_policy()) return ptd_checkpoint_wrapper( module, @@ -162,7 +227,6 @@ def _apply_ac_to_transformer_block( *, base_fqn: str | None = None, model_compile_enabled: bool = False, - op_sac_save_list: set[torch._ops.OpOverload] | None = None, ) -> nn.Module: valid_ac_modes = ("full", "selective") if ac_config.mode not in valid_ac_modes: @@ -183,10 +247,7 @@ def _apply_ac_to_transformer_block( ) if use_op_sac: - op_sac_save_list = op_sac_save_list or set() - return _apply_op_sac( - module, ac_config, base_fqn=base_fqn, op_sac_save_list=op_sac_save_list - ) + return _apply_op_sac(module, ac_config, base_fqn=base_fqn) return _apply_layer_sac(module, ac_config) @@ -196,7 +257,6 @@ def apply_ac( ac_config: ACConfig, *, model_compile_enabled: bool = False, - op_sac_save_list: set[torch._ops.OpOverload] | None = None, base_folder: str = "", ) -> None: """Apply activation checkpointing to the model. @@ -205,8 +265,7 @@ def apply_ac( model (nn.Module): The model to apply activation checkpointing to. ac_config (ACConfig): The activation checkpointing config. model_compile_enabled (bool): Whether torch.compile is enabled for the model. - op_sac_save_list (set[torch._ops.OpOverload]): The list of ops to save instead - of recomputing. + Returns: None """ @@ -242,7 +301,6 @@ def apply_ac( ac_config, base_fqn=f"layers.{layer_id}", model_compile_enabled=model_compile_enabled, - op_sac_save_list=op_sac_save_list, ) layers.register_module(layer_id, transformer_block) diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 19d9f946d2..216fe311e3 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import torch import torch.nn as nn from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor import Replicate, Shard @@ -29,24 +28,6 @@ ) from torchtitan.tools.logging import logger -# for selective op activation checkpointing -_op_sac_save_list = { - torch.ops.aten.mm.default, - torch.ops.aten._scaled_dot_product_efficient_attention.default, - torch.ops.aten._scaled_dot_product_flash_attention.default, - torch.ops.aten._scaled_dot_product_cudnn_attention.default, - torch.ops.aten._scaled_dot_product_attention_math.default, - torch.ops.aten._scaled_dot_product_fused_attention_overrideable.default, - torch.ops._c10d_functional.reduce_scatter_tensor.default, - torch.ops._c10d_functional.all_to_all_single.default, - # for low precision training, it's useful to always save - # the result of max, since the absolute maximum is - # used to compute the scaling factor for quantization. - torch.ops.aten.max.default, - torch._higher_order_ops.flex_attention, - torch._higher_order_ops.inductor_compiled_code, -} - # Adapted from llama4/infra/parallelize.py def parallelize_deepseekv3( @@ -114,9 +95,6 @@ def parallelize_deepseekv3( # Import deepep module to register custom ops before accessing them import torchtitan.distributed.deepep # noqa: F401 - registers torch.ops.deepep - - _op_sac_save_list.add(torch.ops.deepep.dispatch.default) - _op_sac_save_list.add(torch.ops.deepep.combine.default) else: use_deepep = False @@ -150,8 +128,6 @@ def parallelize_deepseekv3( model, job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, - # pyrefly: ignore [bad-argument-type] - op_sac_save_list=_op_sac_save_list, base_folder=job_config.job.dump_folder, ) diff --git a/torchtitan/models/gpt_oss/infra/parallelize.py b/torchtitan/models/gpt_oss/infra/parallelize.py index 338092fb7a..08ac4167b0 100644 --- a/torchtitan/models/gpt_oss/infra/parallelize.py +++ b/torchtitan/models/gpt_oss/infra/parallelize.py @@ -38,25 +38,6 @@ from .expert_parallel import GptossExpertTensorParallel, GptossTensorParallel -# for selective op activation checkpointing -_op_sac_save_list = { - torch.ops.aten.mm.default, - torch.ops.aten._scaled_dot_product_efficient_attention.default, - torch.ops.aten._scaled_dot_product_flash_attention.default, - torch.ops.aten._scaled_dot_product_cudnn_attention.default, - torch.ops.aten._scaled_dot_product_attention_math.default, - torch.ops.aten._scaled_dot_product_fused_attention_overrideable.default, - torch.ops._c10d_functional.reduce_scatter_tensor.default, - torch.ops._c10d_functional.all_to_all_single.default, - # for low precision training, it's useful to always save - # the result of max, since the absolute maximum is - # used to compute the scaling factor for quantization. - torch.ops.aten.max.default, - torch._higher_order_ops.flex_attention, - torch._higher_order_ops.inductor_compiled_code, -} - - # Adapted from llama4/infra/parallelize.py def parallelize_gptoss( model: nn.Module, @@ -117,8 +98,6 @@ def parallelize_gptoss( model, job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, - # pyrefly: ignore [bad-argument-type] - op_sac_save_list=_op_sac_save_list, ) dp_mesh: DeviceMesh | None = None diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index 87f4f91ca9..79370909b2 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -32,25 +32,6 @@ from torchtitan.tools.logging import logger -# for selective op activation checkpointing -_op_sac_save_list = { - torch.ops.aten.mm.default, - torch.ops.aten._scaled_dot_product_efficient_attention.default, - torch.ops.aten._scaled_dot_product_flash_attention.default, - torch.ops.aten._scaled_dot_product_cudnn_attention.default, - torch.ops.aten._scaled_dot_product_attention_math.default, - torch.ops.aten._scaled_dot_product_fused_attention_overrideable.default, - torch.ops._c10d_functional.reduce_scatter_tensor.default, - # for low precision training, it's useful to always save - # the result of max, since the absolute maximum is - # used to compute the scaling factor for quantization. - torch.ops.aten.max.default, - torch._higher_order_ops.flex_attention, - torch.ops.torch_attn._varlen_attn.default, - torch._higher_order_ops.inductor_compiled_code, -} - - def parallelize_llama( model: nn.Module, parallel_dims: ParallelDims, @@ -113,8 +94,6 @@ def parallelize_llama( model, job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, - # pyrefly: ignore [bad-argument-type] - op_sac_save_list=_op_sac_save_list, base_folder=job_config.job.dump_folder, ) diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index 5e23b81ab1..56118770e3 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -47,24 +47,6 @@ from torchtitan.models.moe import moe as moe_module from torchtitan.tools.logging import logger -# for selective op activation checkpointing -_op_sac_save_list = { - torch.ops.aten.mm.default, - torch.ops.aten._scaled_dot_product_efficient_attention.default, - torch.ops.aten._scaled_dot_product_flash_attention.default, - torch.ops.aten._scaled_dot_product_cudnn_attention.default, - torch.ops.aten._scaled_dot_product_attention_math.default, - torch.ops.aten._scaled_dot_product_fused_attention_overrideable.default, - torch.ops._c10d_functional.reduce_scatter_tensor.default, - torch.ops._c10d_functional.all_to_all_single.default, - # for low precision training, it's useful to always save - # the result of max, since the absolute maximum is - # used to compute the scaling factor for quantization. - torch.ops.aten.max.default, - torch._higher_order_ops.flex_attention, - torch._higher_order_ops.inductor_compiled_code, -} - def parallelize_llama( model: nn.Module, @@ -128,9 +110,6 @@ def parallelize_llama( # Import deepep module to register custom ops before accessing them import torchtitan.distributed.deepep # noqa: F401 - registers torch.ops.deepep - - _op_sac_save_list.add(torch.ops.deepep.dispatch.default) - _op_sac_save_list.add(torch.ops.deepep.combine.default) else: use_deepep = False @@ -160,17 +139,10 @@ def parallelize_llama( job_config.compile.enable and "model" in job_config.compile.components ) if job_config.activation_checkpoint.mode != "none": - if job_config.activation_checkpoint.selective_ac_option == "op": - logger.info( - f"SAC save list contains {len(_op_sac_save_list)} ops: " - f"{sorted([str(op) for op in _op_sac_save_list])}" - ) apply_ac( model, job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, - # pyrefly: ignore [bad-argument-type] - op_sac_save_list=_op_sac_save_list, base_folder=job_config.job.dump_folder, ) diff --git a/torchtitan/models/qwen3/infra/parallelize.py b/torchtitan/models/qwen3/infra/parallelize.py index 4837dbc68e..fd81f8d0db 100644 --- a/torchtitan/models/qwen3/infra/parallelize.py +++ b/torchtitan/models/qwen3/infra/parallelize.py @@ -35,25 +35,6 @@ from torchtitan.tools.logging import logger -# for selective op activation checkpointing -_op_sac_save_list = { - torch.ops.aten.mm.default, - torch.ops.aten._scaled_dot_product_efficient_attention.default, - torch.ops.aten._scaled_dot_product_flash_attention.default, - torch.ops.aten._scaled_dot_product_cudnn_attention.default, - torch.ops.aten._scaled_dot_product_attention_math.default, - torch.ops.aten._scaled_dot_product_fused_attention_overrideable.default, - torch.ops._c10d_functional.reduce_scatter_tensor.default, - # for low precision training, it's useful to always save - # the result of max, since the absolute maximum is - # used to compute the scaling factor for quantization. - torch.ops.aten.max.default, - torch._higher_order_ops.flex_attention, - torch.ops.torch_attn._varlen_attn.default, - torch._higher_order_ops.inductor_compiled_code, -} - - def parallelize_qwen3( model: nn.Module, parallel_dims: ParallelDims, @@ -130,8 +111,6 @@ def parallelize_qwen3( model, job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, - # pyrefly: ignore [bad-argument-type] - op_sac_save_list=_op_sac_save_list, base_folder=job_config.job.dump_folder, )