Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 95 additions & 37 deletions torchtitan/distributed/activation_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,95 @@
# Technically, this is not a part of distributed, but distributed module is the best place to put it.

import os
from collections import defaultdict
from functools import lru_cache, partial
from typing import Callable

import torch
import torch._functorch.config
import torch.nn as nn
from torch._functorch.partitioners import get_default_op_list
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper as ptd_checkpoint_wrapper,
)
from torch.utils.checkpoint import CheckpointPolicy

from torchtitan.config.job_config import ActivationCheckpoint as ACConfig
from torchtitan.tools.logging import logger


_PolicyFn = Callable[..., CheckpointPolicy]

_layer_sac_count = 0


def _sac_policy_fn(
ctx,
op,
*args,
compute_intensive_ops: dict,
communication_intensive_ops: dict,
**kwargs,
) -> CheckpointPolicy:
if op in (compute_intensive_ops | communication_intensive_ops):
return CheckpointPolicy.MUST_SAVE

return CheckpointPolicy.PREFER_RECOMPUTE


@lru_cache()
def default_activation_checkpoint_policy() -> _PolicyFn:
"""Returns a checkpointing policy function that saves results of compute and communicate ops.
"""
aten_op_types = get_default_op_list()
compute_intensive_ops = {
op.default: CheckpointPolicy.MUST_SAVE # pyrefly: ignore [missing-attribute]
for op in aten_op_types.compute_intensive_ops
}

compute_intensive_ops[
torch.ops.aten._scaled_dot_product_cudnn_attention.default
] = CheckpointPolicy.MUST_SAVE
compute_intensive_ops[
torch.ops.aten._scaled_dot_product_attention_math.default
] = CheckpointPolicy.MUST_SAVE
compute_intensive_ops[
torch.ops.aten._scaled_dot_product_fused_attention_overrideable.default
] = CheckpointPolicy.MUST_SAVE

compute_intensive_ops[
torch.ops.higher_order.flex_attention
] = CheckpointPolicy.MUST_SAVE
compute_intensive_ops[
torch._higher_order_ops.flex_attention
] = CheckpointPolicy.MUST_SAVE
if hasattr(torch._higher_order_ops, "inductor_compiled_code"):
compute_intensive_ops[
torch._higher_order_ops.inductor_compiled_code
] = CheckpointPolicy.MUST_SAVE

if hasattr(torch.ops, "torch_attn") and hasattr(
torch.ops.torch_attn, "_varlen_attn"
):
compute_intensive_ops[
torch.ops.torch_attn._varlen_attn.default
] = CheckpointPolicy.MUST_SAVE

communication_intensive_ops = {
torch.ops._c10d_functional.reduce_scatter_tensor.default: CheckpointPolicy.MUST_SAVE,
torch.ops._c10d_functional.all_to_all_single.default: CheckpointPolicy.MUST_SAVE,
}

policy_fn = partial(
_sac_policy_fn,
compute_intensive_ops=compute_intensive_ops,
communication_intensive_ops=communication_intensive_ops,
)
# pyrefly: ignore [missing-attribute]
policy_fn.cache_hash = "default_activation_checkpoint_policy"
# pyrefly: ignore [bad-return]
return policy_fn


def _apply_layer_sac(module: nn.Module, ac_config: ACConfig) -> nn.Module:
"""Apply layer selective activation checkpointing to the module.

Expand Down Expand Up @@ -54,24 +127,22 @@ def _apply_op_sac(
ac_config: ACConfig,
*,
base_fqn: str | None = None,
op_sac_save_list: set[torch._ops.OpOverload],
) -> nn.Module:
"""Apply selective activation checkpointing to the module.

This function uses the policy-based approach. The policy is obtained from
`default_activation_checkpoint_policy()` which returns a policy function that decides which
ops to save vs recompute.

Args:
module (nn.Module): The module to apply selective activation checkpointing to.
ac_config (ACConfig): The activation checkpointing config.
base_fqn (str, optional): The base fqn of the module. Defaults to None.
op_sac_save_list (set[torch._ops.OpOverload]): The list of ops to save instead
of recomputing.

Returns:
nn.Module: The module with selective activation checkpointing applied.
"""
from torch.utils.checkpoint import (
CheckpointPolicy,
create_selective_checkpoint_contexts,
)
from torch.utils.checkpoint import create_selective_checkpoint_contexts

mm_recompute_shapes = set()
if len(ac_config.per_op_sac_force_recompute_mm_shapes_by_fqns) > 0:
Expand All @@ -95,8 +166,10 @@ def _apply_op_sac(
f"Selective op AC force recomputing mms with rhs shapes {mm_recompute_shapes}"
)

def _get_custom_policy(meta):
def _custom_policy(ctx, func, *args, **kwargs):
base_policy = default_activation_checkpoint_policy()

def _create_wrapped_policy():
def wrapped_policy(ctx, func, *args, **kwargs) -> CheckpointPolicy:
if (
func == torch.ops.aten._to_copy.default
and "cuda" in str(args[0].device)
Expand All @@ -105,27 +178,19 @@ def _custom_policy(ctx, func, *args, **kwargs):
):
return CheckpointPolicy.MUST_SAVE

mode = "recompute" if ctx.is_recompute else "forward"
mm_count_key = f"{mode}_mm_count"
if func == torch.ops.aten.mm.default:
if args[1].shape in mm_recompute_shapes:
return CheckpointPolicy.PREFER_RECOMPUTE
meta[mm_count_key] += 1
# Saves output of all compute ops, except every second mm
to_save = func in op_sac_save_list and not (
func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0
)
return (
CheckpointPolicy.MUST_SAVE
if to_save
else CheckpointPolicy.PREFER_RECOMPUTE
)
if (
func == torch.ops.aten.mm.default
and len(args) > 1
and args[1].shape in mm_recompute_shapes
):
return CheckpointPolicy.PREFER_RECOMPUTE

return base_policy(ctx, func, *args, **kwargs)

return _custom_policy
return wrapped_policy

def selective_checkpointing_context_fn():
meta = defaultdict(int)
return create_selective_checkpoint_contexts(_get_custom_policy(meta))
return create_selective_checkpoint_contexts(_create_wrapped_policy())

return ptd_checkpoint_wrapper(
module,
Expand Down Expand Up @@ -162,7 +227,6 @@ def _apply_ac_to_transformer_block(
*,
base_fqn: str | None = None,
model_compile_enabled: bool = False,
op_sac_save_list: set[torch._ops.OpOverload] | None = None,
) -> nn.Module:
valid_ac_modes = ("full", "selective")
if ac_config.mode not in valid_ac_modes:
Expand All @@ -183,10 +247,7 @@ def _apply_ac_to_transformer_block(
)

if use_op_sac:
op_sac_save_list = op_sac_save_list or set()
return _apply_op_sac(
module, ac_config, base_fqn=base_fqn, op_sac_save_list=op_sac_save_list
)
return _apply_op_sac(module, ac_config, base_fqn=base_fqn)

return _apply_layer_sac(module, ac_config)

Expand All @@ -196,7 +257,6 @@ def apply_ac(
ac_config: ACConfig,
*,
model_compile_enabled: bool = False,
op_sac_save_list: set[torch._ops.OpOverload] | None = None,
base_folder: str = "",
) -> None:
"""Apply activation checkpointing to the model.
Expand All @@ -205,8 +265,7 @@ def apply_ac(
model (nn.Module): The model to apply activation checkpointing to.
ac_config (ACConfig): The activation checkpointing config.
model_compile_enabled (bool): Whether torch.compile is enabled for the model.
op_sac_save_list (set[torch._ops.OpOverload]): The list of ops to save instead
of recomputing.

Returns:
None
"""
Expand Down Expand Up @@ -242,7 +301,6 @@ def apply_ac(
ac_config,
base_fqn=f"layers.{layer_id}",
model_compile_enabled=model_compile_enabled,
op_sac_save_list=op_sac_save_list,
)
layers.register_module(layer_id, transformer_block)

Expand Down
24 changes: 0 additions & 24 deletions torchtitan/models/deepseek_v3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor import Replicate, Shard
Expand All @@ -29,24 +28,6 @@
)
from torchtitan.tools.logging import logger

# for selective op activation checkpointing
_op_sac_save_list = {
torch.ops.aten.mm.default,
torch.ops.aten._scaled_dot_product_efficient_attention.default,
torch.ops.aten._scaled_dot_product_flash_attention.default,
torch.ops.aten._scaled_dot_product_cudnn_attention.default,
torch.ops.aten._scaled_dot_product_attention_math.default,
torch.ops.aten._scaled_dot_product_fused_attention_overrideable.default,
torch.ops._c10d_functional.reduce_scatter_tensor.default,
torch.ops._c10d_functional.all_to_all_single.default,
# for low precision training, it's useful to always save
# the result of max, since the absolute maximum is
# used to compute the scaling factor for quantization.
torch.ops.aten.max.default,
torch._higher_order_ops.flex_attention,
torch._higher_order_ops.inductor_compiled_code,
}


# Adapted from llama4/infra/parallelize.py
def parallelize_deepseekv3(
Expand Down Expand Up @@ -114,9 +95,6 @@ def parallelize_deepseekv3(

# Import deepep module to register custom ops before accessing them
import torchtitan.distributed.deepep # noqa: F401 - registers torch.ops.deepep

_op_sac_save_list.add(torch.ops.deepep.dispatch.default)
_op_sac_save_list.add(torch.ops.deepep.combine.default)
else:
use_deepep = False

Expand Down Expand Up @@ -150,8 +128,6 @@ def parallelize_deepseekv3(
model,
job_config.activation_checkpoint,
model_compile_enabled=model_compile_enabled,
# pyrefly: ignore [bad-argument-type]
op_sac_save_list=_op_sac_save_list,
base_folder=job_config.job.dump_folder,
)

Expand Down
21 changes: 0 additions & 21 deletions torchtitan/models/gpt_oss/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,25 +38,6 @@
from .expert_parallel import GptossExpertTensorParallel, GptossTensorParallel


# for selective op activation checkpointing
_op_sac_save_list = {
torch.ops.aten.mm.default,
torch.ops.aten._scaled_dot_product_efficient_attention.default,
torch.ops.aten._scaled_dot_product_flash_attention.default,
torch.ops.aten._scaled_dot_product_cudnn_attention.default,
torch.ops.aten._scaled_dot_product_attention_math.default,
torch.ops.aten._scaled_dot_product_fused_attention_overrideable.default,
torch.ops._c10d_functional.reduce_scatter_tensor.default,
torch.ops._c10d_functional.all_to_all_single.default,
# for low precision training, it's useful to always save
# the result of max, since the absolute maximum is
# used to compute the scaling factor for quantization.
torch.ops.aten.max.default,
torch._higher_order_ops.flex_attention,
torch._higher_order_ops.inductor_compiled_code,
}


# Adapted from llama4/infra/parallelize.py
def parallelize_gptoss(
model: nn.Module,
Expand Down Expand Up @@ -117,8 +98,6 @@ def parallelize_gptoss(
model,
job_config.activation_checkpoint,
model_compile_enabled=model_compile_enabled,
# pyrefly: ignore [bad-argument-type]
op_sac_save_list=_op_sac_save_list,
)

dp_mesh: DeviceMesh | None = None
Expand Down
21 changes: 0 additions & 21 deletions torchtitan/models/llama3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,25 +32,6 @@
from torchtitan.tools.logging import logger


# for selective op activation checkpointing
_op_sac_save_list = {
torch.ops.aten.mm.default,
torch.ops.aten._scaled_dot_product_efficient_attention.default,
torch.ops.aten._scaled_dot_product_flash_attention.default,
torch.ops.aten._scaled_dot_product_cudnn_attention.default,
torch.ops.aten._scaled_dot_product_attention_math.default,
torch.ops.aten._scaled_dot_product_fused_attention_overrideable.default,
torch.ops._c10d_functional.reduce_scatter_tensor.default,
# for low precision training, it's useful to always save
# the result of max, since the absolute maximum is
# used to compute the scaling factor for quantization.
torch.ops.aten.max.default,
torch._higher_order_ops.flex_attention,
torch.ops.torch_attn._varlen_attn.default,
torch._higher_order_ops.inductor_compiled_code,
}


def parallelize_llama(
model: nn.Module,
parallel_dims: ParallelDims,
Expand Down Expand Up @@ -113,8 +94,6 @@ def parallelize_llama(
model,
job_config.activation_checkpoint,
model_compile_enabled=model_compile_enabled,
# pyrefly: ignore [bad-argument-type]
op_sac_save_list=_op_sac_save_list,
base_folder=job_config.job.dump_folder,
)

Expand Down
28 changes: 0 additions & 28 deletions torchtitan/models/llama4/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,24 +47,6 @@
from torchtitan.models.moe import moe as moe_module
from torchtitan.tools.logging import logger

# for selective op activation checkpointing
_op_sac_save_list = {
torch.ops.aten.mm.default,
torch.ops.aten._scaled_dot_product_efficient_attention.default,
torch.ops.aten._scaled_dot_product_flash_attention.default,
torch.ops.aten._scaled_dot_product_cudnn_attention.default,
torch.ops.aten._scaled_dot_product_attention_math.default,
torch.ops.aten._scaled_dot_product_fused_attention_overrideable.default,
torch.ops._c10d_functional.reduce_scatter_tensor.default,
torch.ops._c10d_functional.all_to_all_single.default,
# for low precision training, it's useful to always save
# the result of max, since the absolute maximum is
# used to compute the scaling factor for quantization.
torch.ops.aten.max.default,
torch._higher_order_ops.flex_attention,
torch._higher_order_ops.inductor_compiled_code,
}


def parallelize_llama(
model: nn.Module,
Expand Down Expand Up @@ -128,9 +110,6 @@ def parallelize_llama(

# Import deepep module to register custom ops before accessing them
import torchtitan.distributed.deepep # noqa: F401 - registers torch.ops.deepep

_op_sac_save_list.add(torch.ops.deepep.dispatch.default)
_op_sac_save_list.add(torch.ops.deepep.combine.default)
else:
use_deepep = False

Expand Down Expand Up @@ -160,17 +139,10 @@ def parallelize_llama(
job_config.compile.enable and "model" in job_config.compile.components
)
if job_config.activation_checkpoint.mode != "none":
if job_config.activation_checkpoint.selective_ac_option == "op":
logger.info(
f"SAC save list contains {len(_op_sac_save_list)} ops: "
f"{sorted([str(op) for op in _op_sac_save_list])}"
)
apply_ac(
model,
job_config.activation_checkpoint,
model_compile_enabled=model_compile_enabled,
# pyrefly: ignore [bad-argument-type]
op_sac_save_list=_op_sac_save_list,
base_folder=job_config.job.dump_folder,
)

Expand Down
Loading
Loading