From 970be3f7303898693b2e2a149adca95e84a417d1 Mon Sep 17 00:00:00 2001 From: mori360 Date: Mon, 9 Feb 2026 13:08:53 -0800 Subject: [PATCH 1/5] Refactor activation checkpointing to use centralized policy-based approach --- .../distributed/activation_checkpoint.py | 231 ++++++++++++++---- .../models/deepseek_v3/infra/parallelize.py | 24 -- .../models/gpt_oss/infra/parallelize.py | 21 -- torchtitan/models/llama3/infra/parallelize.py | 21 -- torchtitan/models/llama4/infra/parallelize.py | 28 --- torchtitan/models/qwen3/infra/parallelize.py | 21 -- 6 files changed, 184 insertions(+), 162 deletions(-) diff --git a/torchtitan/distributed/activation_checkpoint.py b/torchtitan/distributed/activation_checkpoint.py index 152995c1c2..0f2f91c343 100644 --- a/torchtitan/distributed/activation_checkpoint.py +++ b/torchtitan/distributed/activation_checkpoint.py @@ -4,26 +4,144 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# This file provides the util functions to apply activation checkpointing to the model. -# Technically, this is not a part of distributed, but distributed module is the best place to put it. +"""Activation checkpointing module. + +This module provides utilities to apply activation checkpointing to the model. + +Key design patterns: +1. Policy factory functions are decorated with @lru_cache() to avoid dynamo recompilations +2. Ops are categorized into compute_intensive_ops and communication_intensive_ops +3. Policy functions have signature: (ctx, op, *args, **kwargs) -> CheckpointPolicy +4. Policy factories have a `cache_hash` attribute for dynamo cache management +""" import os -from collections import defaultdict +from functools import lru_cache, partial +from typing import Callable import torch import torch._functorch.config import torch.nn as nn +from torch._functorch.partitioners import get_default_op_list from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( checkpoint_wrapper as ptd_checkpoint_wrapper, ) +from torch.utils.checkpoint import CheckpointPolicy from torchtitan.config.job_config import ActivationCheckpoint as ACConfig from torchtitan.tools.logging import logger +# Type alias for policy functions +_PolicyFn = Callable[..., CheckpointPolicy] + _layer_sac_count = 0 +def _sac_policy_fn( + ctx, + op, + *args, + compute_intensive_ops: dict, + communication_intensive_ops: dict, + **kwargs, +) -> CheckpointPolicy: + # Save compute-intensive ops (mm, attention, conv, flex_attention, etc.) + if op in compute_intensive_ops: + return CheckpointPolicy.MUST_SAVE + + # Save communication-intensive ops (reduce_scatter, all_to_all, etc.) + if op in communication_intensive_ops: + return CheckpointPolicy.MUST_SAVE + + # Default: recompute everything else + return CheckpointPolicy.PREFER_RECOMPUTE + + +@lru_cache() +def default_activation_checkpoint_policy() -> _PolicyFn: + """Returns a checkpointing policy function that saves results of compute-intensive ops. + + The policy saves compute-intensive and communication-intensive ops while + recomputing everything else. Uses dicts (not sets) to workaround dynamo + guarding issues (https://github.com/pytorch/pytorch/issues/168163). + + Returns: + A policy function that can be used with checkpoint contexts. + + Note: + This function is cached with @lru_cache() to avoid dynamo recompilations. + The cache_hash attribute is used by dynamo for cache management. + """ + aten_op_types = get_default_op_list() + compute_intensive_ops = { + op.default: CheckpointPolicy.MUST_SAVE # pyrefly: ignore [missing-attribute] + for op in aten_op_types.compute_intensive_ops + } + + compute_intensive_ops[ + torch.ops.aten._scaled_dot_product_cudnn_attention.default + ] = CheckpointPolicy.MUST_SAVE + compute_intensive_ops[ + torch.ops.aten._scaled_dot_product_attention_math.default + ] = CheckpointPolicy.MUST_SAVE + compute_intensive_ops[ + torch.ops.aten._scaled_dot_product_fused_attention_overrideable.default + ] = CheckpointPolicy.MUST_SAVE + + compute_intensive_ops[ + torch.ops.higher_order.flex_attention + ] = CheckpointPolicy.MUST_SAVE + compute_intensive_ops[ + torch._higher_order_ops.flex_attention + ] = CheckpointPolicy.MUST_SAVE + if hasattr(torch._higher_order_ops, "inductor_compiled_code"): + compute_intensive_ops[ + torch._higher_order_ops.inductor_compiled_code + ] = CheckpointPolicy.MUST_SAVE + + compute_intensive_ops[torch.ops.aten.max.default] = CheckpointPolicy.MUST_SAVE + + if hasattr(torch.ops, "torch_attn") and hasattr( + torch.ops.torch_attn, "_varlen_attn" + ): + compute_intensive_ops[ + torch.ops.torch_attn._varlen_attn.default + ] = CheckpointPolicy.MUST_SAVE + + communication_intensive_ops = { + torch.ops._c10d_functional.reduce_scatter_tensor.default: CheckpointPolicy.MUST_SAVE, + torch.ops._c10d_functional.all_to_all_single.default: CheckpointPolicy.MUST_SAVE, + } + + # DeepEP ops for MoE expert parallelism + # Try to import deepep module to register custom ops, then check if they exist + try: + import torchtitan.distributed.deepep # noqa: F401 - registers torch.ops.deepep + + if hasattr(torch.ops, "deepep"): + if hasattr(torch.ops.deepep, "dispatch"): + communication_intensive_ops[ + torch.ops.deepep.dispatch.default + ] = CheckpointPolicy.MUST_SAVE + if hasattr(torch.ops.deepep, "combine"): + communication_intensive_ops[ + torch.ops.deepep.combine.default + ] = CheckpointPolicy.MUST_SAVE + except ImportError: + pass # DeepEP not available + + policy_fn = partial( + _sac_policy_fn, + compute_intensive_ops=compute_intensive_ops, + communication_intensive_ops=communication_intensive_ops, + ) + # pyrefly: ignore [missing-attribute] + policy_fn.cache_hash = "default_activation_checkpoint_policy" + # pyrefly: ignore [bad-return] + return policy_fn + + def _apply_layer_sac(module: nn.Module, ac_config: ACConfig) -> nn.Module: """Apply layer selective activation checkpointing to the module. @@ -49,31 +167,22 @@ def _apply_layer_sac(module: nn.Module, ac_config: ACConfig) -> nn.Module: return module -def _apply_op_sac( +def _get_mm_recompute_shapes( module: nn.Module, ac_config: ACConfig, - *, base_fqn: str | None = None, - op_sac_save_list: set[torch._ops.OpOverload], -) -> nn.Module: - """Apply selective activation checkpointing to the module. +) -> set[tuple[int, int]]: + """Extract mm shapes that should be force-recomputed based on FQN matching. Args: - module (nn.Module): The module to apply selective activation checkpointing to. - ac_config (ACConfig): The activation checkpointing config. - base_fqn (str, optional): The base fqn of the module. Defaults to None. - op_sac_save_list (set[torch._ops.OpOverload]): The list of ops to save instead - of recomputing. + module: The module to analyze. + ac_config: The activation checkpointing config. + base_fqn: The base FQN of the module. Returns: - nn.Module: The module with selective activation checkpointing applied. + Set of (in_features, out_features) shapes to force recompute. """ - from torch.utils.checkpoint import ( - CheckpointPolicy, - create_selective_checkpoint_contexts, - ) - - mm_recompute_shapes = set() + mm_recompute_shapes: set[tuple[int, int]] = set() if len(ac_config.per_op_sac_force_recompute_mm_shapes_by_fqns) > 0: for module_fqn, submod in module.named_modules(): fqn = module_fqn @@ -94,38 +203,72 @@ def _apply_op_sac( logger.debug( f"Selective op AC force recomputing mms with rhs shapes {mm_recompute_shapes}" ) + return mm_recompute_shapes + - def _get_custom_policy(meta): - def _custom_policy(ctx, func, *args, **kwargs): +def _apply_op_sac( + module: nn.Module, + ac_config: ACConfig, + *, + base_fqn: str | None = None, +) -> nn.Module: + """Apply selective activation checkpointing to the module. + + This function uses the policy-based approach. The policy is obtained from + `default_activation_checkpoint_policy()` which returns a policy function that decides which + ops to save vs recompute. + + Args: + module (nn.Module): The module to apply selective activation checkpointing to. + ac_config (ACConfig): The activation checkpointing config. + base_fqn (str, optional): The base fqn of the module. Defaults to None. + + Returns: + nn.Module: The module with selective activation checkpointing applied. + """ + from torch.utils.checkpoint import create_selective_checkpoint_contexts + + # Get mm shapes to force recompute based on FQN matching + mm_recompute_shapes = _get_mm_recompute_shapes(module, ac_config, base_fqn) + + # Get the policy from default_activation_checkpoint_policy + # This returns a policy function directly (via functools.partial) + base_policy = default_activation_checkpoint_policy() + + def _create_wrapped_policy(): + """Create a policy that wraps the base policy with additional logic. + + This wrapper handles: + 1. Force recompute for specific mm shapes (per_op_sac_force_recompute_mm_shapes_by_fqns) + 2. CUDA->CPU tensor copies that must be saved + """ + + def wrapped_policy(ctx, func, *args, **kwargs) -> CheckpointPolicy: + # Special case: CUDA->CPU tensor copies must be saved + # This prevents issues with CPU offloading during recomputation if ( func == torch.ops.aten._to_copy.default + and len(args) > 0 and "cuda" in str(args[0].device) and "device" in kwargs and str(kwargs["device"]) == "cpu" ): return CheckpointPolicy.MUST_SAVE - mode = "recompute" if ctx.is_recompute else "forward" - mm_count_key = f"{mode}_mm_count" - if func == torch.ops.aten.mm.default: + # Special case: Force recompute for specific mm shapes + # This is used for things like router gates in MoE models + if func == torch.ops.aten.mm.default and len(args) > 1: if args[1].shape in mm_recompute_shapes: return CheckpointPolicy.PREFER_RECOMPUTE - meta[mm_count_key] += 1 - # Saves output of all compute ops, except every second mm - to_save = func in op_sac_save_list and not ( - func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0 - ) - return ( - CheckpointPolicy.MUST_SAVE - if to_save - else CheckpointPolicy.PREFER_RECOMPUTE - ) - return _custom_policy + # Delegate to the base policy for all other decisions + return base_policy(ctx, func, *args, **kwargs) + + return wrapped_policy def selective_checkpointing_context_fn(): - meta = defaultdict(int) - return create_selective_checkpoint_contexts(_get_custom_policy(meta)) + """Context function that creates checkpoint contexts with the wrapped policy.""" + return create_selective_checkpoint_contexts(_create_wrapped_policy()) return ptd_checkpoint_wrapper( module, @@ -162,7 +305,6 @@ def _apply_ac_to_transformer_block( *, base_fqn: str | None = None, model_compile_enabled: bool = False, - op_sac_save_list: set[torch._ops.OpOverload] | None = None, ) -> nn.Module: valid_ac_modes = ("full", "selective") if ac_config.mode not in valid_ac_modes: @@ -183,10 +325,7 @@ def _apply_ac_to_transformer_block( ) if use_op_sac: - op_sac_save_list = op_sac_save_list or set() - return _apply_op_sac( - module, ac_config, base_fqn=base_fqn, op_sac_save_list=op_sac_save_list - ) + return _apply_op_sac(module, ac_config, base_fqn=base_fqn) return _apply_layer_sac(module, ac_config) @@ -196,7 +335,6 @@ def apply_ac( ac_config: ACConfig, *, model_compile_enabled: bool = False, - op_sac_save_list: set[torch._ops.OpOverload] | None = None, base_folder: str = "", ) -> None: """Apply activation checkpointing to the model. @@ -205,8 +343,8 @@ def apply_ac( model (nn.Module): The model to apply activation checkpointing to. ac_config (ACConfig): The activation checkpointing config. model_compile_enabled (bool): Whether torch.compile is enabled for the model. - op_sac_save_list (set[torch._ops.OpOverload]): The list of ops to save instead - of recomputing. + base_folder (str): The base folder for saving memory budget pareto visualization. + Returns: None """ @@ -242,7 +380,6 @@ def apply_ac( ac_config, base_fqn=f"layers.{layer_id}", model_compile_enabled=model_compile_enabled, - op_sac_save_list=op_sac_save_list, ) layers.register_module(layer_id, transformer_block) diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 19d9f946d2..216fe311e3 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import torch import torch.nn as nn from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor import Replicate, Shard @@ -29,24 +28,6 @@ ) from torchtitan.tools.logging import logger -# for selective op activation checkpointing -_op_sac_save_list = { - torch.ops.aten.mm.default, - torch.ops.aten._scaled_dot_product_efficient_attention.default, - torch.ops.aten._scaled_dot_product_flash_attention.default, - torch.ops.aten._scaled_dot_product_cudnn_attention.default, - torch.ops.aten._scaled_dot_product_attention_math.default, - torch.ops.aten._scaled_dot_product_fused_attention_overrideable.default, - torch.ops._c10d_functional.reduce_scatter_tensor.default, - torch.ops._c10d_functional.all_to_all_single.default, - # for low precision training, it's useful to always save - # the result of max, since the absolute maximum is - # used to compute the scaling factor for quantization. - torch.ops.aten.max.default, - torch._higher_order_ops.flex_attention, - torch._higher_order_ops.inductor_compiled_code, -} - # Adapted from llama4/infra/parallelize.py def parallelize_deepseekv3( @@ -114,9 +95,6 @@ def parallelize_deepseekv3( # Import deepep module to register custom ops before accessing them import torchtitan.distributed.deepep # noqa: F401 - registers torch.ops.deepep - - _op_sac_save_list.add(torch.ops.deepep.dispatch.default) - _op_sac_save_list.add(torch.ops.deepep.combine.default) else: use_deepep = False @@ -150,8 +128,6 @@ def parallelize_deepseekv3( model, job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, - # pyrefly: ignore [bad-argument-type] - op_sac_save_list=_op_sac_save_list, base_folder=job_config.job.dump_folder, ) diff --git a/torchtitan/models/gpt_oss/infra/parallelize.py b/torchtitan/models/gpt_oss/infra/parallelize.py index 338092fb7a..08ac4167b0 100644 --- a/torchtitan/models/gpt_oss/infra/parallelize.py +++ b/torchtitan/models/gpt_oss/infra/parallelize.py @@ -38,25 +38,6 @@ from .expert_parallel import GptossExpertTensorParallel, GptossTensorParallel -# for selective op activation checkpointing -_op_sac_save_list = { - torch.ops.aten.mm.default, - torch.ops.aten._scaled_dot_product_efficient_attention.default, - torch.ops.aten._scaled_dot_product_flash_attention.default, - torch.ops.aten._scaled_dot_product_cudnn_attention.default, - torch.ops.aten._scaled_dot_product_attention_math.default, - torch.ops.aten._scaled_dot_product_fused_attention_overrideable.default, - torch.ops._c10d_functional.reduce_scatter_tensor.default, - torch.ops._c10d_functional.all_to_all_single.default, - # for low precision training, it's useful to always save - # the result of max, since the absolute maximum is - # used to compute the scaling factor for quantization. - torch.ops.aten.max.default, - torch._higher_order_ops.flex_attention, - torch._higher_order_ops.inductor_compiled_code, -} - - # Adapted from llama4/infra/parallelize.py def parallelize_gptoss( model: nn.Module, @@ -117,8 +98,6 @@ def parallelize_gptoss( model, job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, - # pyrefly: ignore [bad-argument-type] - op_sac_save_list=_op_sac_save_list, ) dp_mesh: DeviceMesh | None = None diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index 87f4f91ca9..79370909b2 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -32,25 +32,6 @@ from torchtitan.tools.logging import logger -# for selective op activation checkpointing -_op_sac_save_list = { - torch.ops.aten.mm.default, - torch.ops.aten._scaled_dot_product_efficient_attention.default, - torch.ops.aten._scaled_dot_product_flash_attention.default, - torch.ops.aten._scaled_dot_product_cudnn_attention.default, - torch.ops.aten._scaled_dot_product_attention_math.default, - torch.ops.aten._scaled_dot_product_fused_attention_overrideable.default, - torch.ops._c10d_functional.reduce_scatter_tensor.default, - # for low precision training, it's useful to always save - # the result of max, since the absolute maximum is - # used to compute the scaling factor for quantization. - torch.ops.aten.max.default, - torch._higher_order_ops.flex_attention, - torch.ops.torch_attn._varlen_attn.default, - torch._higher_order_ops.inductor_compiled_code, -} - - def parallelize_llama( model: nn.Module, parallel_dims: ParallelDims, @@ -113,8 +94,6 @@ def parallelize_llama( model, job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, - # pyrefly: ignore [bad-argument-type] - op_sac_save_list=_op_sac_save_list, base_folder=job_config.job.dump_folder, ) diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index 5e23b81ab1..56118770e3 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -47,24 +47,6 @@ from torchtitan.models.moe import moe as moe_module from torchtitan.tools.logging import logger -# for selective op activation checkpointing -_op_sac_save_list = { - torch.ops.aten.mm.default, - torch.ops.aten._scaled_dot_product_efficient_attention.default, - torch.ops.aten._scaled_dot_product_flash_attention.default, - torch.ops.aten._scaled_dot_product_cudnn_attention.default, - torch.ops.aten._scaled_dot_product_attention_math.default, - torch.ops.aten._scaled_dot_product_fused_attention_overrideable.default, - torch.ops._c10d_functional.reduce_scatter_tensor.default, - torch.ops._c10d_functional.all_to_all_single.default, - # for low precision training, it's useful to always save - # the result of max, since the absolute maximum is - # used to compute the scaling factor for quantization. - torch.ops.aten.max.default, - torch._higher_order_ops.flex_attention, - torch._higher_order_ops.inductor_compiled_code, -} - def parallelize_llama( model: nn.Module, @@ -128,9 +110,6 @@ def parallelize_llama( # Import deepep module to register custom ops before accessing them import torchtitan.distributed.deepep # noqa: F401 - registers torch.ops.deepep - - _op_sac_save_list.add(torch.ops.deepep.dispatch.default) - _op_sac_save_list.add(torch.ops.deepep.combine.default) else: use_deepep = False @@ -160,17 +139,10 @@ def parallelize_llama( job_config.compile.enable and "model" in job_config.compile.components ) if job_config.activation_checkpoint.mode != "none": - if job_config.activation_checkpoint.selective_ac_option == "op": - logger.info( - f"SAC save list contains {len(_op_sac_save_list)} ops: " - f"{sorted([str(op) for op in _op_sac_save_list])}" - ) apply_ac( model, job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, - # pyrefly: ignore [bad-argument-type] - op_sac_save_list=_op_sac_save_list, base_folder=job_config.job.dump_folder, ) diff --git a/torchtitan/models/qwen3/infra/parallelize.py b/torchtitan/models/qwen3/infra/parallelize.py index 4837dbc68e..fd81f8d0db 100644 --- a/torchtitan/models/qwen3/infra/parallelize.py +++ b/torchtitan/models/qwen3/infra/parallelize.py @@ -35,25 +35,6 @@ from torchtitan.tools.logging import logger -# for selective op activation checkpointing -_op_sac_save_list = { - torch.ops.aten.mm.default, - torch.ops.aten._scaled_dot_product_efficient_attention.default, - torch.ops.aten._scaled_dot_product_flash_attention.default, - torch.ops.aten._scaled_dot_product_cudnn_attention.default, - torch.ops.aten._scaled_dot_product_attention_math.default, - torch.ops.aten._scaled_dot_product_fused_attention_overrideable.default, - torch.ops._c10d_functional.reduce_scatter_tensor.default, - # for low precision training, it's useful to always save - # the result of max, since the absolute maximum is - # used to compute the scaling factor for quantization. - torch.ops.aten.max.default, - torch._higher_order_ops.flex_attention, - torch.ops.torch_attn._varlen_attn.default, - torch._higher_order_ops.inductor_compiled_code, -} - - def parallelize_qwen3( model: nn.Module, parallel_dims: ParallelDims, @@ -130,8 +111,6 @@ def parallelize_qwen3( model, job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, - # pyrefly: ignore [bad-argument-type] - op_sac_save_list=_op_sac_save_list, base_folder=job_config.job.dump_folder, ) From f53d80d3a692fb02103d6b664ee6592dd0f1d933 Mon Sep 17 00:00:00 2001 From: mori360 Date: Mon, 9 Feb 2026 13:31:10 -0800 Subject: [PATCH 2/5] update --- .../distributed/activation_checkpoint.py | 76 ++++++------------- 1 file changed, 24 insertions(+), 52 deletions(-) diff --git a/torchtitan/distributed/activation_checkpoint.py b/torchtitan/distributed/activation_checkpoint.py index 0f2f91c343..401d082e55 100644 --- a/torchtitan/distributed/activation_checkpoint.py +++ b/torchtitan/distributed/activation_checkpoint.py @@ -4,16 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -"""Activation checkpointing module. - -This module provides utilities to apply activation checkpointing to the model. - -Key design patterns: -1. Policy factory functions are decorated with @lru_cache() to avoid dynamo recompilations -2. Ops are categorized into compute_intensive_ops and communication_intensive_ops -3. Policy functions have signature: (ctx, op, *args, **kwargs) -> CheckpointPolicy -4. Policy factories have a `cache_hash` attribute for dynamo cache management -""" +# This file provides the util functions to apply activation checkpointing to the model. +# Technically, this is not a part of distributed, but distributed module is the best place to put it. import os from functools import lru_cache, partial @@ -167,21 +159,29 @@ def _apply_layer_sac(module: nn.Module, ac_config: ACConfig) -> nn.Module: return module -def _get_mm_recompute_shapes( +def _apply_op_sac( module: nn.Module, ac_config: ACConfig, + *, base_fqn: str | None = None, -) -> set[tuple[int, int]]: - """Extract mm shapes that should be force-recomputed based on FQN matching. +) -> nn.Module: + """Apply selective activation checkpointing to the module. + + This function uses the policy-based approach. The policy is obtained from + `default_activation_checkpoint_policy()` which returns a policy function that decides which + ops to save vs recompute. Args: - module: The module to analyze. - ac_config: The activation checkpointing config. - base_fqn: The base FQN of the module. + module (nn.Module): The module to apply selective activation checkpointing to. + ac_config (ACConfig): The activation checkpointing config. + base_fqn (str, optional): The base fqn of the module. Defaults to None. Returns: - Set of (in_features, out_features) shapes to force recompute. + nn.Module: The module with selective activation checkpointing applied. """ + from torch.utils.checkpoint import create_selective_checkpoint_contexts + + # Get mm shapes to force recompute based on FQN matching mm_recompute_shapes: set[tuple[int, int]] = set() if len(ac_config.per_op_sac_force_recompute_mm_shapes_by_fqns) > 0: for module_fqn, submod in module.named_modules(): @@ -203,36 +203,8 @@ def _get_mm_recompute_shapes( logger.debug( f"Selective op AC force recomputing mms with rhs shapes {mm_recompute_shapes}" ) - return mm_recompute_shapes - - -def _apply_op_sac( - module: nn.Module, - ac_config: ACConfig, - *, - base_fqn: str | None = None, -) -> nn.Module: - """Apply selective activation checkpointing to the module. - - This function uses the policy-based approach. The policy is obtained from - `default_activation_checkpoint_policy()` which returns a policy function that decides which - ops to save vs recompute. - - Args: - module (nn.Module): The module to apply selective activation checkpointing to. - ac_config (ACConfig): The activation checkpointing config. - base_fqn (str, optional): The base fqn of the module. Defaults to None. - - Returns: - nn.Module: The module with selective activation checkpointing applied. - """ - from torch.utils.checkpoint import create_selective_checkpoint_contexts - - # Get mm shapes to force recompute based on FQN matching - mm_recompute_shapes = _get_mm_recompute_shapes(module, ac_config, base_fqn) # Get the policy from default_activation_checkpoint_policy - # This returns a policy function directly (via functools.partial) base_policy = default_activation_checkpoint_policy() def _create_wrapped_policy(): @@ -245,7 +217,6 @@ def _create_wrapped_policy(): def wrapped_policy(ctx, func, *args, **kwargs) -> CheckpointPolicy: # Special case: CUDA->CPU tensor copies must be saved - # This prevents issues with CPU offloading during recomputation if ( func == torch.ops.aten._to_copy.default and len(args) > 0 @@ -256,18 +227,19 @@ def wrapped_policy(ctx, func, *args, **kwargs) -> CheckpointPolicy: return CheckpointPolicy.MUST_SAVE # Special case: Force recompute for specific mm shapes - # This is used for things like router gates in MoE models - if func == torch.ops.aten.mm.default and len(args) > 1: - if args[1].shape in mm_recompute_shapes: - return CheckpointPolicy.PREFER_RECOMPUTE + if ( + func == torch.ops.aten.mm.default + and len(args) > 1 + and args[1].shape in mm_recompute_shapes + ): + return CheckpointPolicy.PREFER_RECOMPUTE - # Delegate to the base policy for all other decisions + # Delegate to the base policy return base_policy(ctx, func, *args, **kwargs) return wrapped_policy def selective_checkpointing_context_fn(): - """Context function that creates checkpoint contexts with the wrapped policy.""" return create_selective_checkpoint_contexts(_create_wrapped_policy()) return ptd_checkpoint_wrapper( From 448ce31a38ccb8a37f81b48b6cc15353db164ecc Mon Sep 17 00:00:00 2001 From: mori360 Date: Mon, 9 Feb 2026 13:35:38 -0800 Subject: [PATCH 3/5] update --- torchtitan/distributed/activation_checkpoint.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/torchtitan/distributed/activation_checkpoint.py b/torchtitan/distributed/activation_checkpoint.py index 401d082e55..6c1e2bbc90 100644 --- a/torchtitan/distributed/activation_checkpoint.py +++ b/torchtitan/distributed/activation_checkpoint.py @@ -181,8 +181,7 @@ def _apply_op_sac( """ from torch.utils.checkpoint import create_selective_checkpoint_contexts - # Get mm shapes to force recompute based on FQN matching - mm_recompute_shapes: set[tuple[int, int]] = set() + mm_recompute_shapes = set() if len(ac_config.per_op_sac_force_recompute_mm_shapes_by_fqns) > 0: for module_fqn, submod in module.named_modules(): fqn = module_fqn @@ -208,25 +207,15 @@ def _apply_op_sac( base_policy = default_activation_checkpoint_policy() def _create_wrapped_policy(): - """Create a policy that wraps the base policy with additional logic. - - This wrapper handles: - 1. Force recompute for specific mm shapes (per_op_sac_force_recompute_mm_shapes_by_fqns) - 2. CUDA->CPU tensor copies that must be saved - """ - def wrapped_policy(ctx, func, *args, **kwargs) -> CheckpointPolicy: - # Special case: CUDA->CPU tensor copies must be saved if ( func == torch.ops.aten._to_copy.default - and len(args) > 0 and "cuda" in str(args[0].device) and "device" in kwargs and str(kwargs["device"]) == "cpu" ): return CheckpointPolicy.MUST_SAVE - # Special case: Force recompute for specific mm shapes if ( func == torch.ops.aten.mm.default and len(args) > 1 @@ -234,7 +223,6 @@ def wrapped_policy(ctx, func, *args, **kwargs) -> CheckpointPolicy: ): return CheckpointPolicy.PREFER_RECOMPUTE - # Delegate to the base policy return base_policy(ctx, func, *args, **kwargs) return wrapped_policy @@ -315,7 +303,6 @@ def apply_ac( model (nn.Module): The model to apply activation checkpointing to. ac_config (ACConfig): The activation checkpointing config. model_compile_enabled (bool): Whether torch.compile is enabled for the model. - base_folder (str): The base folder for saving memory budget pareto visualization. Returns: None From d116992e680506c3b0c407d302dbfe8a78619c63 Mon Sep 17 00:00:00 2001 From: mori360 Date: Mon, 9 Feb 2026 13:51:19 -0800 Subject: [PATCH 4/5] update --- .../distributed/activation_checkpoint.py | 23 ++----------------- 1 file changed, 2 insertions(+), 21 deletions(-) diff --git a/torchtitan/distributed/activation_checkpoint.py b/torchtitan/distributed/activation_checkpoint.py index 6c1e2bbc90..a073f6334f 100644 --- a/torchtitan/distributed/activation_checkpoint.py +++ b/torchtitan/distributed/activation_checkpoint.py @@ -24,7 +24,6 @@ from torchtitan.tools.logging import logger -# Type alias for policy functions _PolicyFn = Callable[..., CheckpointPolicy] _layer_sac_count = 0 @@ -38,32 +37,15 @@ def _sac_policy_fn( communication_intensive_ops: dict, **kwargs, ) -> CheckpointPolicy: - # Save compute-intensive ops (mm, attention, conv, flex_attention, etc.) - if op in compute_intensive_ops: + if op in (compute_intensive_ops | communication_intensive_ops): return CheckpointPolicy.MUST_SAVE - # Save communication-intensive ops (reduce_scatter, all_to_all, etc.) - if op in communication_intensive_ops: - return CheckpointPolicy.MUST_SAVE - - # Default: recompute everything else return CheckpointPolicy.PREFER_RECOMPUTE @lru_cache() def default_activation_checkpoint_policy() -> _PolicyFn: - """Returns a checkpointing policy function that saves results of compute-intensive ops. - - The policy saves compute-intensive and communication-intensive ops while - recomputing everything else. Uses dicts (not sets) to workaround dynamo - guarding issues (https://github.com/pytorch/pytorch/issues/168163). - - Returns: - A policy function that can be used with checkpoint contexts. - - Note: - This function is cached with @lru_cache() to avoid dynamo recompilations. - The cache_hash attribute is used by dynamo for cache management. + """Returns a checkpointing policy function that saves results of compute and communicate ops. """ aten_op_types = get_default_op_list() compute_intensive_ops = { @@ -203,7 +185,6 @@ def _apply_op_sac( f"Selective op AC force recomputing mms with rhs shapes {mm_recompute_shapes}" ) - # Get the policy from default_activation_checkpoint_policy base_policy = default_activation_checkpoint_policy() def _create_wrapped_policy(): From 1aad3a2c0f6299b0ab950831b66801d3d6819c1c Mon Sep 17 00:00:00 2001 From: mori360 Date: Mon, 9 Feb 2026 14:52:02 -0800 Subject: [PATCH 5/5] update --- .../distributed/activation_checkpoint.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/torchtitan/distributed/activation_checkpoint.py b/torchtitan/distributed/activation_checkpoint.py index a073f6334f..b56ee6403d 100644 --- a/torchtitan/distributed/activation_checkpoint.py +++ b/torchtitan/distributed/activation_checkpoint.py @@ -74,8 +74,6 @@ def default_activation_checkpoint_policy() -> _PolicyFn: torch._higher_order_ops.inductor_compiled_code ] = CheckpointPolicy.MUST_SAVE - compute_intensive_ops[torch.ops.aten.max.default] = CheckpointPolicy.MUST_SAVE - if hasattr(torch.ops, "torch_attn") and hasattr( torch.ops.torch_attn, "_varlen_attn" ): @@ -88,23 +86,6 @@ def default_activation_checkpoint_policy() -> _PolicyFn: torch.ops._c10d_functional.all_to_all_single.default: CheckpointPolicy.MUST_SAVE, } - # DeepEP ops for MoE expert parallelism - # Try to import deepep module to register custom ops, then check if they exist - try: - import torchtitan.distributed.deepep # noqa: F401 - registers torch.ops.deepep - - if hasattr(torch.ops, "deepep"): - if hasattr(torch.ops.deepep, "dispatch"): - communication_intensive_ops[ - torch.ops.deepep.dispatch.default - ] = CheckpointPolicy.MUST_SAVE - if hasattr(torch.ops.deepep, "combine"): - communication_intensive_ops[ - torch.ops.deepep.combine.default - ] = CheckpointPolicy.MUST_SAVE - except ImportError: - pass # DeepEP not available - policy_fn = partial( _sac_policy_fn, compute_intensive_ops=compute_intensive_ops,