Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions torchtitan/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
"vlm",
"compiler_toolkit.deepseek_v3",
"compiler_toolkit.llama3",
"graph_based_training.llama3",
"graph_based_training.deepseek_v3",
"transformers_modeling_backend",
"autoparallel.llama3",
"autoparallel.deepseek_v3",
Expand Down
90 changes: 90 additions & 0 deletions torchtitan/experiments/graph_based_training/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Graph-Based Training

Unified experiment merging [SimpleFSDP](../simple_fsdp/) and [Compiler Toolkit](../compiler_toolkit/) into a single framework with two compilation modes:

- **JIT mode** (`--compile.mode jit`): Uses `torch.compile` with a custom backend. Graph passes are registered to the backend and applied during just-in-time compilation.
- **AOT mode** (`--compile.mode aot`): Captures the joint forward-backward graph ahead of time and applies optimization passes directly to the FX graph modules before execution.

Both modes share the same DTensor-based SimpleFSDP model authoring and the same unified pass registry.

## Configuration

Compilation is configured via `--job.custom_config_module=torchtitan.experiments.graph_based_training.job_config`:

- `--compile.mode`: `"jit"` or `"aot"` (omit to disable compilation)
- `--compile.passes`: Comma-separated list of pass names

### Available Passes

| Pass | Modes | Description |
|------|-------|-------------|
| `auto_bucketing` | jit, aot | Automatic comm/compute overlap bucketing |
| `transformer_block_bucketing` | jit, aot | Manual per-transformer-block bucketing |
| `regional_inductor` | aot | Regional Inductor compilation |
| `cudagraph` | aot | CUDA graph capture and replay (must be last) |
| `full_inductor_compilation` | aot | Full Inductor code generation (requires `inductor_decomposition`) |
| `inductor_decomposition` | aot | Inductor decompositions on joint graph |

Constraints:
- `auto_bucketing` and `transformer_block_bucketing` are mutually exclusive
- `full_inductor_compilation` requires `inductor_decomposition`
- `cudagraph` must be the last pass in the list

## Llama3

**JIT mode (no passes)**
```shell
NGPU=4 TRAIN_FILE=torchtitan.experiments.graph_based_training.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name graph_based_training.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --job.custom_config_module=torchtitan.experiments.graph_based_training.job_config --compile.mode jit
```

**JIT mode + auto-bucketing**
```shell
NGPU=4 TRAIN_FILE=torchtitan.experiments.graph_based_training.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name graph_based_training.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --job.custom_config_module=torchtitan.experiments.graph_based_training.job_config --compile.mode jit --compile.passes auto_bucketing
```

**JIT mode + transformer-block-bucketing**
```shell
NGPU=4 TRAIN_FILE=torchtitan.experiments.graph_based_training.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name graph_based_training.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --job.custom_config_module=torchtitan.experiments.graph_based_training.job_config --compile.mode jit --compile.passes transformer_block_bucketing
```

**AOT mode (no passes)**
```shell
NGPU=4 TRAIN_FILE=torchtitan.experiments.graph_based_training.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name graph_based_training.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --job.custom_config_module=torchtitan.experiments.graph_based_training.job_config --compile.mode aot
```

**AOT mode + auto-bucketing**
```shell
NGPU=4 TRAIN_FILE=torchtitan.experiments.graph_based_training.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name graph_based_training.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --job.custom_config_module=torchtitan.experiments.graph_based_training.job_config --compile.mode aot --compile.passes auto_bucketing
```

**AOT mode + transformer-block-bucketing + regional-inductor**
```shell
NGPU=4 TRAIN_FILE=torchtitan.experiments.graph_based_training.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name graph_based_training.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --job.custom_config_module=torchtitan.experiments.graph_based_training.job_config --compile.mode aot --compile.passes transformer_block_bucketing,regional_inductor
```

**AOT mode + transformer-block-bucketing + regional-inductor + cudagraph**
```shell
NCCL_GRAPH_REGISTER=0 NGPU=4 TRAIN_FILE=torchtitan.experiments.graph_based_training.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name graph_based_training.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --job.custom_config_module=torchtitan.experiments.graph_based_training.job_config --compile.mode aot --compile.passes transformer_block_bucketing,regional_inductor,cudagraph
```

**AOT mode + full Inductor compilation**
```shell
NGPU=4 TRAIN_FILE=torchtitan.experiments.graph_based_training.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name graph_based_training.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --job.custom_config_module=torchtitan.experiments.graph_based_training.job_config --compile.mode aot --compile.passes inductor_decomposition,full_inductor_compilation
```

## DeepSeek v3

**JIT mode (SimpleFSDP + TP + EP)**
```shell
NGPU=4 TRAIN_FILE=torchtitan.experiments.graph_based_training.train CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name graph_based_training.deepseek_v3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --parallelism.expert_parallel_degree=2 --activation_checkpoint.mode none --job.custom_config_module=torchtitan.experiments.graph_based_training.job_config --compile.mode jit
```

**AOT mode (SimpleFSDP + TP + EP)**
```shell
NGPU=4 TRAIN_FILE=torchtitan.experiments.graph_based_training.train CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name graph_based_training.deepseek_v3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --parallelism.expert_parallel_degree=2 --activation_checkpoint.mode none --job.custom_config_module=torchtitan.experiments.graph_based_training.job_config --compile.mode aot
```

**AOT mode (SimpleFSDP + TP + EP + auto-bucketing)**
```shell
NGPU=4 TRAIN_FILE=torchtitan.experiments.graph_based_training.train CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name graph_based_training.deepseek_v3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --parallelism.expert_parallel_degree=2 --activation_checkpoint.mode none --job.custom_config_module=torchtitan.experiments.graph_based_training.job_config --compile.mode aot --compile.passes auto_bucketing
```
5 changes: 5 additions & 0 deletions torchtitan/experiments/graph_based_training/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
52 changes: 52 additions & 0 deletions torchtitan/experiments/graph_based_training/common_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Callable

import torch
from torch.distributed.tensor import DTensor, Replicate
from torch.utils._pytree import register_pytree_node, tree_map


def parallelize_inputs(parallel_dims, args, kwargs):
def to_dtensor(tensor):
if isinstance(tensor, torch.Tensor):
return DTensor.from_local(
tensor, parallel_dims.get_mesh("tp"), [Replicate()]
)
return tensor

dt_args = tree_map(to_dtensor, args)

# TODO: When using flex_attention, BlockMask would show up in kwargs,
# and it's unclear how to convert it to DTensor. If I use to_dtensor,
# it would fail with Dynamo Error: P2011360347
# dt_kwargs = tree_map(to_dtensor, kwargs)

dt_kwargs = kwargs

return dt_args, dt_kwargs


def register_blockmask_pytree_node():
from torch.nn.attention.flex_attention import BlockMask

if BlockMask not in torch.utils._pytree.SUPPORTED_NODES:
register_pytree_node(
BlockMask,
BlockMask._flatten,
BlockMask._unflatten,
flatten_with_keys_fn=BlockMask._flatten_with_keys,
serialized_type_name="torch.nn.attention.flex_attention.BlockMask",
)


def end_with_pass(passes: list[Callable], names: list[str]) -> bool:
return (
len(passes) > 0
and (last_pass_name := getattr(passes[-1], "__name__", None))
and (last_pass_name in names)
)
168 changes: 168 additions & 0 deletions torchtitan/experiments/graph_based_training/compilation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Unified compilation dispatcher for graph-based training.

Dispatches to JIT (torch.compile) or AOT (joint graph capture) compilation
based on the configured mode.
"""

import functools

import torch

from torchtitan.config import JobConfig
from torchtitan.distributed import ParallelDims
from torchtitan.tools.logging import logger

from .common_utils import parallelize_inputs
from .graph_utils import CompiledModule, joint_graph_builder, make_compiler_with_passes
from .jit_backend import get_jit_compile_backend
from .passes import (
fsdp_reshard_after_fwd_pass,
inductor_decomposition_pass,
validate_and_get_passes,
validate_flex_attn_annotation_pass,
)


def _get_reshard_policy(job_config: JobConfig, parallel_dims: ParallelDims) -> bool:
"""Determine fsdp_reshard_after_forward policy."""
match job_config.parallelism.fsdp_reshard_after_forward:
case "always":
return True
case "never":
return False
case "default":
# For PP, by default do not reshard after forward to avoid
# per-microbatch all-gathers, which can be expensive and non-overlapped
return not parallel_dims.pp_enabled
case _:
raise ValueError(
f"Invalid fsdp_reshard_after_forward_policy: "
f"{job_config.parallelism.fsdp_reshard_after_forward}."
)


def apply_compilation(
model: torch.nn.Module,
job_config: JobConfig,
parallel_dims: ParallelDims,
mode: str,
transformer_block_buckets: list | None = None,
) -> torch.nn.Module:
"""
Unified entry point for both JIT and AOT compilation.

Args:
model: The parallelized model (after TP, AC, DP)
job_config: Job configuration
parallel_dims: Parallel dimensions
mode: "jit" or "aot"
transformer_block_buckets: Model-specific bucket plans for
transformer_block_bucketing pass

Returns:
The compiled model
"""
pass_names = getattr(job_config.compile, "passes", [])
fsdp_reshard_after_forward = _get_reshard_policy(job_config, parallel_dims)

joint_passes, fwd_bwd_passes = validate_and_get_passes(
pass_names, mode, transformer_block_buckets=transformer_block_buckets
)

if mode == "jit":
return _apply_jit(
model,
job_config,
fsdp_reshard_after_forward,
pass_names,
transformer_block_buckets,
)
elif mode == "aot":
return _apply_aot(
model,
parallel_dims,
job_config,
fsdp_reshard_after_forward,
joint_passes,
fwd_bwd_passes,
)
else:
raise ValueError(f"Unknown compilation mode: {mode}")


def _apply_jit(
model: torch.nn.Module,
job_config: JobConfig,
fsdp_reshard_after_forward: bool,
pass_names: list[str],
transformer_block_buckets: list | None,
) -> torch.nn.Module:
"""Apply JIT compilation via torch.compile with custom backend."""
torch._inductor.config.reorder_for_peak_memory = False

backend = get_jit_compile_backend(
job_config.compile,
fsdp_reshard_after_forward,
pass_names,
transformer_block_buckets,
)
model = torch.compile(model, backend=backend, fullgraph=True)
logger.info("Applied JIT compilation (torch.compile) to the model")
return model


def _apply_aot(
model: torch.nn.Module,
parallel_dims: ParallelDims,
job_config: JobConfig,
fsdp_reshard_after_forward: bool,
joint_passes: list,
fwd_bwd_passes: list,
) -> CompiledModule:
"""Apply AOT compilation via joint graph capture."""
# Build joint custom passes list:
# 1. validate_flex_attn_annotation (always applied)
# 2. user-configured joint passes (excluding inductor_decomposition,
# which is handled specially by joint_graph_builder since it needs
# runtime context like model, joint_with_descriptors, etc.)
# 3. fsdp_reshard_after_fwd (always applied)
joint_custom_passes = [validate_flex_attn_annotation_pass]
joint_custom_passes.extend(
p for p in joint_passes if p is not inductor_decomposition_pass
)
joint_custom_passes.append(
functools.partial(
fsdp_reshard_after_fwd_pass,
reshard_after_forward=fsdp_reshard_after_forward,
)
)

# Build forward/backward compilers with fwd/bwd passes
fw_compiler, bw_compiler = make_compiler_with_passes(
fwd_bwd_passes, dump_folder=job_config.job.dump_folder
)

# Create joint graph builder with configured passes
aot_joint_graph_builder = functools.partial(
joint_graph_builder,
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
joint_custom_passes=joint_custom_passes,
dump_folder=job_config.job.dump_folder,
job_config=job_config,
)

# TODO: CompiledModule should take sample input as well, so that we can
# compile ahead of time.
model = CompiledModule(
model, parallel_dims, aot_joint_graph_builder, parallelize_inputs
)
logger.info("Applied AOT compilation (joint graph capture) to the model")
return model
Loading
Loading