From 17b18ec455964f5c24d49ab06c57370a643fbd63 Mon Sep 17 00:00:00 2001 From: Yiming Zhou Date: Mon, 9 Feb 2026 17:43:55 -0800 Subject: [PATCH] consolidate simple_fsdp and compiler_toolkit --- torchtitan/experiments/__init__.py | 2 + .../graph_based_training/README.md | 90 +++ .../graph_based_training/__init__.py | 5 + .../graph_based_training/common_utils.py | 52 ++ .../graph_based_training/compilation.py | 168 ++++++ .../graph_based_training/cudagraph.py | 169 ++++++ .../deepseek_v3/__init__.py | 31 + .../graph_based_training/deepseek_v3/model.py | 19 + .../deepseek_v3/parallelize.py | 209 +++++++ .../graph_based_training/graph_utils.py | 356 ++++++++++++ .../graph_based_training/jit_backend.py | 162 ++++++ .../graph_based_training/job_config.py | 41 ++ .../graph_based_training/llama3/__init__.py | 31 + .../graph_based_training/llama3/model.py | 18 + .../llama3/parallelize.py | 178 ++++++ .../graph_based_training/passes.py | 389 +++++++++++++ .../reshard_after_forward.py | 90 +++ .../graph_based_training/simple_fsdp.py | 301 ++++++++++ .../graph_based_training/tests/__init__.py | 5 + .../tests/integration_tests.py | 538 ++++++++++++++++++ .../tests/numerics_utils.py | 271 +++++++++ .../tests/test_aot_numerics.py | 90 +++ .../tests/test_numerics.py | 158 +++++ .../experiments/graph_based_training/train.py | 28 + 24 files changed, 3401 insertions(+) create mode 100644 torchtitan/experiments/graph_based_training/README.md create mode 100644 torchtitan/experiments/graph_based_training/__init__.py create mode 100644 torchtitan/experiments/graph_based_training/common_utils.py create mode 100644 torchtitan/experiments/graph_based_training/compilation.py create mode 100644 torchtitan/experiments/graph_based_training/cudagraph.py create mode 100644 torchtitan/experiments/graph_based_training/deepseek_v3/__init__.py create mode 100644 torchtitan/experiments/graph_based_training/deepseek_v3/model.py create mode 100644 torchtitan/experiments/graph_based_training/deepseek_v3/parallelize.py create mode 100644 torchtitan/experiments/graph_based_training/graph_utils.py create mode 100644 torchtitan/experiments/graph_based_training/jit_backend.py create mode 100644 torchtitan/experiments/graph_based_training/job_config.py create mode 100644 torchtitan/experiments/graph_based_training/llama3/__init__.py create mode 100644 torchtitan/experiments/graph_based_training/llama3/model.py create mode 100644 torchtitan/experiments/graph_based_training/llama3/parallelize.py create mode 100644 torchtitan/experiments/graph_based_training/passes.py create mode 100644 torchtitan/experiments/graph_based_training/reshard_after_forward.py create mode 100644 torchtitan/experiments/graph_based_training/simple_fsdp.py create mode 100644 torchtitan/experiments/graph_based_training/tests/__init__.py create mode 100644 torchtitan/experiments/graph_based_training/tests/integration_tests.py create mode 100644 torchtitan/experiments/graph_based_training/tests/numerics_utils.py create mode 100644 torchtitan/experiments/graph_based_training/tests/test_aot_numerics.py create mode 100644 torchtitan/experiments/graph_based_training/tests/test_numerics.py create mode 100644 torchtitan/experiments/graph_based_training/train.py diff --git a/torchtitan/experiments/__init__.py b/torchtitan/experiments/__init__.py index 5989025d4f..6b07fd6919 100644 --- a/torchtitan/experiments/__init__.py +++ b/torchtitan/experiments/__init__.py @@ -11,6 +11,8 @@ "vlm", "compiler_toolkit.deepseek_v3", "compiler_toolkit.llama3", + "graph_based_training.llama3", + "graph_based_training.deepseek_v3", "transformers_modeling_backend", "autoparallel.llama3", "autoparallel.deepseek_v3", diff --git a/torchtitan/experiments/graph_based_training/README.md b/torchtitan/experiments/graph_based_training/README.md new file mode 100644 index 0000000000..18f7c8a39c --- /dev/null +++ b/torchtitan/experiments/graph_based_training/README.md @@ -0,0 +1,90 @@ +# Graph-Based Training + +Unified experiment merging [SimpleFSDP](../simple_fsdp/) and [Compiler Toolkit](../compiler_toolkit/) into a single framework with two compilation modes: + +- **JIT mode** (`--compile.mode jit`): Uses `torch.compile` with a custom backend. Graph passes are registered to the backend and applied during just-in-time compilation. +- **AOT mode** (`--compile.mode aot`): Captures the joint forward-backward graph ahead of time and applies optimization passes directly to the FX graph modules before execution. + +Both modes share the same DTensor-based SimpleFSDP model authoring and the same unified pass registry. + +## Configuration + +Compilation is configured via `--job.custom_config_module=torchtitan.experiments.graph_based_training.job_config`: + +- `--compile.mode`: `"jit"` or `"aot"` (omit to disable compilation) +- `--compile.passes`: Comma-separated list of pass names + +### Available Passes + +| Pass | Modes | Description | +|------|-------|-------------| +| `auto_bucketing` | jit, aot | Automatic comm/compute overlap bucketing | +| `transformer_block_bucketing` | jit, aot | Manual per-transformer-block bucketing | +| `regional_inductor` | aot | Regional Inductor compilation | +| `cudagraph` | aot | CUDA graph capture and replay (must be last) | +| `full_inductor_compilation` | aot | Full Inductor code generation (requires `inductor_decomposition`) | +| `inductor_decomposition` | aot | Inductor decompositions on joint graph | + +Constraints: +- `auto_bucketing` and `transformer_block_bucketing` are mutually exclusive +- `full_inductor_compilation` requires `inductor_decomposition` +- `cudagraph` must be the last pass in the list + +## Llama3 + +**JIT mode (no passes)** +```shell +NGPU=4 TRAIN_FILE=torchtitan.experiments.graph_based_training.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name graph_based_training.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --job.custom_config_module=torchtitan.experiments.graph_based_training.job_config --compile.mode jit +``` + +**JIT mode + auto-bucketing** +```shell +NGPU=4 TRAIN_FILE=torchtitan.experiments.graph_based_training.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name graph_based_training.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --job.custom_config_module=torchtitan.experiments.graph_based_training.job_config --compile.mode jit --compile.passes auto_bucketing +``` + +**JIT mode + transformer-block-bucketing** +```shell +NGPU=4 TRAIN_FILE=torchtitan.experiments.graph_based_training.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name graph_based_training.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --job.custom_config_module=torchtitan.experiments.graph_based_training.job_config --compile.mode jit --compile.passes transformer_block_bucketing +``` + +**AOT mode (no passes)** +```shell +NGPU=4 TRAIN_FILE=torchtitan.experiments.graph_based_training.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name graph_based_training.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --job.custom_config_module=torchtitan.experiments.graph_based_training.job_config --compile.mode aot +``` + +**AOT mode + auto-bucketing** +```shell +NGPU=4 TRAIN_FILE=torchtitan.experiments.graph_based_training.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name graph_based_training.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --job.custom_config_module=torchtitan.experiments.graph_based_training.job_config --compile.mode aot --compile.passes auto_bucketing +``` + +**AOT mode + transformer-block-bucketing + regional-inductor** +```shell +NGPU=4 TRAIN_FILE=torchtitan.experiments.graph_based_training.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name graph_based_training.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --job.custom_config_module=torchtitan.experiments.graph_based_training.job_config --compile.mode aot --compile.passes transformer_block_bucketing,regional_inductor +``` + +**AOT mode + transformer-block-bucketing + regional-inductor + cudagraph** +```shell +NCCL_GRAPH_REGISTER=0 NGPU=4 TRAIN_FILE=torchtitan.experiments.graph_based_training.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name graph_based_training.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --job.custom_config_module=torchtitan.experiments.graph_based_training.job_config --compile.mode aot --compile.passes transformer_block_bucketing,regional_inductor,cudagraph +``` + +**AOT mode + full Inductor compilation** +```shell +NGPU=4 TRAIN_FILE=torchtitan.experiments.graph_based_training.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name graph_based_training.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --job.custom_config_module=torchtitan.experiments.graph_based_training.job_config --compile.mode aot --compile.passes inductor_decomposition,full_inductor_compilation +``` + +## DeepSeek v3 + +**JIT mode (SimpleFSDP + TP + EP)** +```shell +NGPU=4 TRAIN_FILE=torchtitan.experiments.graph_based_training.train CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name graph_based_training.deepseek_v3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --parallelism.expert_parallel_degree=2 --activation_checkpoint.mode none --job.custom_config_module=torchtitan.experiments.graph_based_training.job_config --compile.mode jit +``` + +**AOT mode (SimpleFSDP + TP + EP)** +```shell +NGPU=4 TRAIN_FILE=torchtitan.experiments.graph_based_training.train CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name graph_based_training.deepseek_v3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --parallelism.expert_parallel_degree=2 --activation_checkpoint.mode none --job.custom_config_module=torchtitan.experiments.graph_based_training.job_config --compile.mode aot +``` + +**AOT mode (SimpleFSDP + TP + EP + auto-bucketing)** +```shell +NGPU=4 TRAIN_FILE=torchtitan.experiments.graph_based_training.train CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name graph_based_training.deepseek_v3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --parallelism.expert_parallel_degree=2 --activation_checkpoint.mode none --job.custom_config_module=torchtitan.experiments.graph_based_training.job_config --compile.mode aot --compile.passes auto_bucketing +``` diff --git a/torchtitan/experiments/graph_based_training/__init__.py b/torchtitan/experiments/graph_based_training/__init__.py new file mode 100644 index 0000000000..2e41cd717f --- /dev/null +++ b/torchtitan/experiments/graph_based_training/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torchtitan/experiments/graph_based_training/common_utils.py b/torchtitan/experiments/graph_based_training/common_utils.py new file mode 100644 index 0000000000..99c045da7c --- /dev/null +++ b/torchtitan/experiments/graph_based_training/common_utils.py @@ -0,0 +1,52 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable + +import torch +from torch.distributed.tensor import DTensor, Replicate +from torch.utils._pytree import register_pytree_node, tree_map + + +def parallelize_inputs(parallel_dims, args, kwargs): + def to_dtensor(tensor): + if isinstance(tensor, torch.Tensor): + return DTensor.from_local( + tensor, parallel_dims.get_mesh("tp"), [Replicate()] + ) + return tensor + + dt_args = tree_map(to_dtensor, args) + + # TODO: When using flex_attention, BlockMask would show up in kwargs, + # and it's unclear how to convert it to DTensor. If I use to_dtensor, + # it would fail with Dynamo Error: P2011360347 + # dt_kwargs = tree_map(to_dtensor, kwargs) + + dt_kwargs = kwargs + + return dt_args, dt_kwargs + + +def register_blockmask_pytree_node(): + from torch.nn.attention.flex_attention import BlockMask + + if BlockMask not in torch.utils._pytree.SUPPORTED_NODES: + register_pytree_node( + BlockMask, + BlockMask._flatten, + BlockMask._unflatten, + flatten_with_keys_fn=BlockMask._flatten_with_keys, + serialized_type_name="torch.nn.attention.flex_attention.BlockMask", + ) + + +def end_with_pass(passes: list[Callable], names: list[str]) -> bool: + return ( + len(passes) > 0 + and (last_pass_name := getattr(passes[-1], "__name__", None)) + and (last_pass_name in names) + ) diff --git a/torchtitan/experiments/graph_based_training/compilation.py b/torchtitan/experiments/graph_based_training/compilation.py new file mode 100644 index 0000000000..5e5c75c738 --- /dev/null +++ b/torchtitan/experiments/graph_based_training/compilation.py @@ -0,0 +1,168 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Unified compilation dispatcher for graph-based training. + +Dispatches to JIT (torch.compile) or AOT (joint graph capture) compilation +based on the configured mode. +""" + +import functools + +import torch + +from torchtitan.config import JobConfig +from torchtitan.distributed import ParallelDims +from torchtitan.tools.logging import logger + +from .common_utils import parallelize_inputs +from .graph_utils import CompiledModule, joint_graph_builder, make_compiler_with_passes +from .jit_backend import get_jit_compile_backend +from .passes import ( + fsdp_reshard_after_fwd_pass, + inductor_decomposition_pass, + validate_and_get_passes, + validate_flex_attn_annotation_pass, +) + + +def _get_reshard_policy(job_config: JobConfig, parallel_dims: ParallelDims) -> bool: + """Determine fsdp_reshard_after_forward policy.""" + match job_config.parallelism.fsdp_reshard_after_forward: + case "always": + return True + case "never": + return False + case "default": + # For PP, by default do not reshard after forward to avoid + # per-microbatch all-gathers, which can be expensive and non-overlapped + return not parallel_dims.pp_enabled + case _: + raise ValueError( + f"Invalid fsdp_reshard_after_forward_policy: " + f"{job_config.parallelism.fsdp_reshard_after_forward}." + ) + + +def apply_compilation( + model: torch.nn.Module, + job_config: JobConfig, + parallel_dims: ParallelDims, + mode: str, + transformer_block_buckets: list | None = None, +) -> torch.nn.Module: + """ + Unified entry point for both JIT and AOT compilation. + + Args: + model: The parallelized model (after TP, AC, DP) + job_config: Job configuration + parallel_dims: Parallel dimensions + mode: "jit" or "aot" + transformer_block_buckets: Model-specific bucket plans for + transformer_block_bucketing pass + + Returns: + The compiled model + """ + pass_names = getattr(job_config.compile, "passes", []) + fsdp_reshard_after_forward = _get_reshard_policy(job_config, parallel_dims) + + joint_passes, fwd_bwd_passes = validate_and_get_passes( + pass_names, mode, transformer_block_buckets=transformer_block_buckets + ) + + if mode == "jit": + return _apply_jit( + model, + job_config, + fsdp_reshard_after_forward, + pass_names, + transformer_block_buckets, + ) + elif mode == "aot": + return _apply_aot( + model, + parallel_dims, + job_config, + fsdp_reshard_after_forward, + joint_passes, + fwd_bwd_passes, + ) + else: + raise ValueError(f"Unknown compilation mode: {mode}") + + +def _apply_jit( + model: torch.nn.Module, + job_config: JobConfig, + fsdp_reshard_after_forward: bool, + pass_names: list[str], + transformer_block_buckets: list | None, +) -> torch.nn.Module: + """Apply JIT compilation via torch.compile with custom backend.""" + torch._inductor.config.reorder_for_peak_memory = False + + backend = get_jit_compile_backend( + job_config.compile, + fsdp_reshard_after_forward, + pass_names, + transformer_block_buckets, + ) + model = torch.compile(model, backend=backend, fullgraph=True) + logger.info("Applied JIT compilation (torch.compile) to the model") + return model + + +def _apply_aot( + model: torch.nn.Module, + parallel_dims: ParallelDims, + job_config: JobConfig, + fsdp_reshard_after_forward: bool, + joint_passes: list, + fwd_bwd_passes: list, +) -> CompiledModule: + """Apply AOT compilation via joint graph capture.""" + # Build joint custom passes list: + # 1. validate_flex_attn_annotation (always applied) + # 2. user-configured joint passes (excluding inductor_decomposition, + # which is handled specially by joint_graph_builder since it needs + # runtime context like model, joint_with_descriptors, etc.) + # 3. fsdp_reshard_after_fwd (always applied) + joint_custom_passes = [validate_flex_attn_annotation_pass] + joint_custom_passes.extend( + p for p in joint_passes if p is not inductor_decomposition_pass + ) + joint_custom_passes.append( + functools.partial( + fsdp_reshard_after_fwd_pass, + reshard_after_forward=fsdp_reshard_after_forward, + ) + ) + + # Build forward/backward compilers with fwd/bwd passes + fw_compiler, bw_compiler = make_compiler_with_passes( + fwd_bwd_passes, dump_folder=job_config.job.dump_folder + ) + + # Create joint graph builder with configured passes + aot_joint_graph_builder = functools.partial( + joint_graph_builder, + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + joint_custom_passes=joint_custom_passes, + dump_folder=job_config.job.dump_folder, + job_config=job_config, + ) + + # TODO: CompiledModule should take sample input as well, so that we can + # compile ahead of time. + model = CompiledModule( + model, parallel_dims, aot_joint_graph_builder, parallelize_inputs + ) + logger.info("Applied AOT compilation (joint graph capture) to the model") + return model diff --git a/torchtitan/experiments/graph_based_training/cudagraph.py b/torchtitan/experiments/graph_based_training/cudagraph.py new file mode 100644 index 0000000000..c72d658cc5 --- /dev/null +++ b/torchtitan/experiments/graph_based_training/cudagraph.py @@ -0,0 +1,169 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +CUDAGraph pass for graph-based training. + +This module provides a cudagraph pass that can be applied to graph modules +during compilation. +""" + +import warnings +from typing import Any, Callable, Optional, Sequence + +import torch +from torch._inductor.cudagraph_trees import _use_cuda_memory_pool_manager +from torch.utils._ordered_set import OrderedSet + + +def init_global_graph_pool() -> tuple[ + torch.cuda.CUDAGraph, torch.cuda._POOL_HANDLE, torch.cuda.Stream +]: + dummy_graph = torch.cuda.CUDAGraph() + + # create a global cudagraph memory pool to allow memory reuse across cudagraphs. + graph_pool = torch.cuda.graph_pool_handle() + + # create a global cuda stream for graph capture. we need to use a single stream + # for all allocations to the memory pool, otherwise the allocations to separate streams + # will not be used. + graph_capture_stream = torch.cuda.Stream() + + # use a dummy graph to keep the global graph pool alive + with ( + # suppress an empty cudagraph warning, since we intentionally create + # an empty cudagraph here + warnings.catch_warnings(record=True), + torch.cuda.graph( + dummy_graph, + pool=graph_pool, + stream=graph_capture_stream, + capture_error_mode="thread_local", + ), + ): + pass + + return dummy_graph, graph_pool, graph_capture_stream + + +( + _global_dummy_graph, + _global_graph_pool, + _global_graph_capture_stream, +) = init_global_graph_pool() + + +class CUDAGraphWrapper: + def __init__( + self, + runnable: Callable, + example_inputs: Sequence[Any], + static_input_indices: Optional[tuple[int]] = None, + should_check_address: bool = False, + ): + self.runnable = runnable + self.graph_pool = _global_graph_pool + self.stream = _global_graph_capture_stream + self.static_input_indices = OrderedSet( + static_input_indices if static_input_indices is not None else [] + ) + self.input_indices_to_copy = [ + i + for i, inp in enumerate(example_inputs) + if isinstance(inp, torch.Tensor) and i not in self.static_input_indices + ] + self.cudagraph: Optional[torch.cuda.CUDAGraph] = None + self.has_warmup = False + + self.args = None + self.output = None + + # (debug only) whether check static input tensor addresses during runtime + self.should_check_address = should_check_address + + def copy_non_static_inputs(self, *args): + for i in self.input_indices_to_copy: + self.args[i].copy_(args[i]) + + def check_input_types(self, inputs) -> None: + for inp in inputs: + assert isinstance(inp, (torch.Tensor, int, torch._C.Generator)), ( + "args must be tensor, integer (for dynamic shapes), " + "or Generator (for random number generator), " + f"but found {type(inp)}" + ) + + def check_static_inputs_address(self) -> None: + for i in self.static_input_indices: + actual = self.args[i].data_ptr() + expected = self.input_addresses[i] + assert expected == actual, ( + "Expected the same static tensor address but found " + f"{expected} != {actual}" + ) + + def __call__(self, *args): + if not self.has_warmup: + self.has_warmup = True + device = torch.cuda.current_device() + + # warmup in cudagraph memory pool to avoid fragmentation + # across eager memory pool and cudagraph memory pool. + with _use_cuda_memory_pool_manager(device, self.graph_pool, self.stream): + out = self.runnable(*args) + return out + + if self.cudagraph is None: + self.check_input_types(args) + self.args = args + self.input_addresses = [ + x.data_ptr() if isinstance(x, torch.Tensor) else None for x in args + ] + + self.cudagraph = torch.cuda.CUDAGraph() + + with torch.cuda.graph( + self.cudagraph, pool=self.graph_pool, stream=self.stream + ): + # `output` is managed by pytorch's cudagraph pool + self.output = self.runnable(*args) + + if self.should_check_address: + self.check_static_inputs_address() + + self.copy_non_static_inputs(*args) + self.cudagraph.replay() + return self.output + + +def get_static_input_indices(gm: torch.fx.GraphModule, is_forward: bool) -> list[int]: + """ + Get indices of gm inputs that are static input tensors whose tensor addresses do not + change across runs. Example of static input tensors include weights, buffers, and + outputs of previous cudagraph wrapped functions. + """ + from torch._inductor.utils import count_tangents + + static_input_indices = [] + if ( + is_forward + and (tracing_context := torch._guards.TracingContext.try_get()) + and hasattr(tracing_context, "fw_metadata") + ): + # for forward, we rely on graph capture (i.e., dynamo or export) to provide + # the correct static input indices stored in tracing context. Typical examples + # include weights and buffers. + static_input_indices = tracing_context.fw_metadata.static_input_indices + + elif not is_forward: + # for backward, we identify saved tensors as static inputs, since saved tensors + # are outputs of cudagraph-wrapped forward run. In PT2-generated backward gm, + # saved tensors are always the leading args. So we can get the number of saved + # tensors and generate static input indices. + fixed = count_tangents(gm) + static_input_indices = list(range(fixed)) + + return static_input_indices diff --git a/torchtitan/experiments/graph_based_training/deepseek_v3/__init__.py b/torchtitan/experiments/graph_based_training/deepseek_v3/__init__.py new file mode 100644 index 0000000000..a6b2da0252 --- /dev/null +++ b/torchtitan/experiments/graph_based_training/deepseek_v3/__init__.py @@ -0,0 +1,31 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing +from torchtitan.components.tokenizer import build_hf_tokenizer +from torchtitan.distributed.pipeline_parallel import pipeline_llm +from torchtitan.hf_datasets.text_datasets import build_text_dataloader +from torchtitan.models.deepseek_v3 import deepseekv3_args +from torchtitan.protocols.train_spec import TrainSpec + +from .model import SimpleFSDPDeepSeekV3Model +from .parallelize import parallelize_deepseekv3 + + +def get_train_spec() -> TrainSpec: + return TrainSpec( + model_cls=SimpleFSDPDeepSeekV3Model, + model_args=deepseekv3_args, + parallelize_fn=parallelize_deepseekv3, + pipelining_fn=pipeline_llm, + build_optimizers_fn=build_optimizers_with_moe_load_balancing, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_text_dataloader, + build_tokenizer_fn=build_hf_tokenizer, + build_loss_fn=build_cross_entropy_loss, + ) diff --git a/torchtitan/experiments/graph_based_training/deepseek_v3/model.py b/torchtitan/experiments/graph_based_training/deepseek_v3/model.py new file mode 100644 index 0000000000..83c9fde561 --- /dev/null +++ b/torchtitan/experiments/graph_based_training/deepseek_v3/model.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtitan.models.deepseek_v3 import DeepSeekV3Model, DeepSeekV3ModelArgs + +from ..simple_fsdp import disable_active_parametrization + + +class SimpleFSDPDeepSeekV3Model(DeepSeekV3Model): + def __init__(self, model_args: DeepSeekV3ModelArgs): + super().__init__(model_args) + self.init_weights() + + def init_weights(self, *args, **kwargs): + with disable_active_parametrization(): + super().init_weights(*args, **kwargs) diff --git a/torchtitan/experiments/graph_based_training/deepseek_v3/parallelize.py b/torchtitan/experiments/graph_based_training/deepseek_v3/parallelize.py new file mode 100644 index 0000000000..2e50b30aec --- /dev/null +++ b/torchtitan/experiments/graph_based_training/deepseek_v3/parallelize.py @@ -0,0 +1,209 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.distributed.device_mesh import DeviceMesh +from torch.fx.traceback import annotate_fn + +from torchtitan.config import JobConfig, TORCH_DTYPE_MAP +from torchtitan.distributed import ParallelDims +from torchtitan.distributed.activation_checkpoint import apply_ac +from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp +from torchtitan.models.deepseek_v3.infra.parallelize import ( + apply_moe_ep_tp, + apply_non_moe_tp, +) +from torchtitan.tools.logging import logger + +from ..common_utils import register_blockmask_pytree_node +from ..compilation import apply_compilation +from ..simple_fsdp import data_parallel, MixedPrecisionPolicy + + +def get_transformer_block_buckets(model) -> list[list[str] | str]: + module_list = [ + model.tok_embeddings, + [model.norm, model.output], + ] + for layer_id, transformer_block in model.layers.items(): + # [TODO](ruisizhang123) add EP support for transformer block bucketing + module_list.append(transformer_block) + + def convert_modules_to_fqns(modules, module_to_fqn_mapping): + """Convert a (possibly nested) list of modules to FQN strings.""" + result = [] + for m in modules: + if isinstance(m, list): + result.append(convert_modules_to_fqns(m, module_to_fqn_mapping)) + else: + result.append(module_to_fqn_mapping.get(m, None)) + return result + + module_to_name = {m: n for n, m in model.named_modules()} + module_fqns = convert_modules_to_fqns(module_list, module_to_name) + return module_fqns + + +def annotate_deepseekv3() -> None: + from torchtitan.distributed.expert_parallel import ExpertParallel + from torchtitan.models.attention import FlexAttentionWrapper + from torchtitan.models.moe.moe import MoE + + # annotate the MoE with dispatch, compute and combine + ExpertParallel._token_dispatch = annotate_fn({"EP": "dispatch"})( + ExpertParallel._token_dispatch + ) + ExpertParallel._token_combine = annotate_fn({"EP": "combine"})( + ExpertParallel._token_combine + ) + MoE.forward = annotate_fn({"EP": "compute"})(MoE.forward) + + # annotate flex_attention with compile_with_inductor + FlexAttentionWrapper.forward = annotate_fn( + {"compile_with_inductor": "flex_attention"} + )(FlexAttentionWrapper.forward) + + +# Adapted from llama4/infra/parallelize.py +def parallelize_deepseekv3( + model: nn.Module, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + compile_mode = getattr(job_config.compile, "mode", None) + + # Annotations for AOT mode (must happen before tracing) + if compile_mode == "aot": + annotate_deepseekv3() + register_blockmask_pytree_node() + + # TODO: TP currently cannot handle uneven seq_len because we set + # `use_local_output=True` to use plain Tensors for legacy reasons. + # Need to revisit this. + assert ( + job_config.training.seq_len % parallel_dims.seq_len_divisor == 0 + ), f""" + Sequence length {job_config.training.seq_len} must be divisible by the product of TP degree + ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}), i.e. {parallel_dims.seq_len_divisor}. + """ + + if ( + job_config.parallelism.context_parallel_degree > 1 + and model.model_args.attn_type != "sdpa" + ): + raise NotImplementedError("CP support is only supported for SDPA.") + + if parallel_dims.tp_enabled: + enable_float8_linear = "float8" in job_config.model.converters + float8_is_rowwise = job_config.quantize.linear.float8.recipe_name in ( + "rowwise", + "rowwise_with_gw_hp", + ) + + enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise + if enable_float8_tensorwise_tp: + # TODO(jianiw): This branch needs to be tested and enabled + raise NotImplementedError( + "Currently, float8 tensorwise TP is not tested for deepseekv3" + ) + + apply_non_moe_tp( + model, + parallel_dims.get_mesh("tp"), + loss_parallel=not job_config.parallelism.disable_loss_parallel, + enable_float8_tensorwise_tp=False, + cp_enabled=parallel_dims.cp_enabled, + ) + maybe_enable_async_tp(job_config, parallel_dims.get_mesh("tp")) + + if parallel_dims.tp_enabled or parallel_dims.ep_enabled: + apply_moe_ep_tp( + model, + tp_mesh=parallel_dims.get_optional_mesh("tp"), + ep_mesh=parallel_dims.get_optional_mesh("ep"), + etp_mesh=parallel_dims.get_optional_mesh("etp"), + ep_etp_mesh=parallel_dims.get_optional_mesh(["ep", "etp"]), + ) + + if job_config.activation_checkpoint.mode != "none": + apply_ac(model, job_config.activation_checkpoint) + + mp_policy = MixedPrecisionPolicy( + param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], + reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], + ) + + # apply data parallel + dp_mesh: DeviceMesh | None = None + if ( + parallel_dims.fsdp_enabled + or parallel_dims.ep_enabled + or parallel_dims.dp_replicate_enabled + ): + if parallel_dims.dp_replicate_enabled: + if parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled: + dp_mesh_dim_names = ["dp_replicate", "fsdp"] + dp_mode = "hybrid_shard" + else: + dp_mesh_dim_names = ["dp_replicate"] + dp_mode = "replicate" + else: + dp_mesh_dim_names = ["fsdp"] + dp_mode = "fully_shard" + + dp_mesh = parallel_dims.get_mesh(dp_mesh_dim_names) + + # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP + edp_mesh_names = ( + ["dp_replicate", "efsdp"] + if parallel_dims.dp_replicate_enabled + else ["efsdp"] + ) + edp_mesh = parallel_dims.get_optional_mesh(edp_mesh_names) + + for _, transformer_block in model.layers.items(): + if transformer_block.moe_enabled and parallel_dims.ep_enabled: + experts_shard_dim = 0 + assert edp_mesh is not None + assert hasattr(transformer_block, "moe") + if ( + edp_mesh["efsdp"].size() * parallel_dims.ep + > transformer_block.moe.experts.num_experts + ): + experts_shard_dim = 1 + + transformer_block.moe.experts = data_parallel( + transformer_block.moe.experts, + edp_mesh, + dp_mode, + mp_policy=mp_policy, + shard_dim=experts_shard_dim, + ) + + model = data_parallel( + model, + dp_mesh, + dp_mode, + mp_policy=mp_policy, + ) + + logger.info( + "Applied Data Parallel (simple_fsdp) (dp mode=%s) to the model", dp_mode + ) + + # apply compilation + if compile_mode: + torch._dynamo.config.capture_scalar_outputs = True + model = apply_compilation( + model, + job_config, + parallel_dims, + compile_mode, + transformer_block_buckets=get_transformer_block_buckets(model), + ) + + return model diff --git a/torchtitan/experiments/graph_based_training/graph_utils.py b/torchtitan/experiments/graph_based_training/graph_utils.py new file mode 100644 index 0000000000..9f0f478f23 --- /dev/null +++ b/torchtitan/experiments/graph_based_training/graph_utils.py @@ -0,0 +1,356 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import functools +from pathlib import Path +from typing import Any, Callable, List, Optional + +import torch +from torch._dynamo.functional_export import dynamo_graph_capture_for_export +from torch._functorch.aot_autograd import ( + aot_compile_joint_with_descriptors, + aot_export_joint_with_descriptors, + JointWithDescriptors, +) +from torch._guards import tracing, TracingContext +from torch.distributed.tensor import DTensor + +from torchtitan.config import JobConfig +from torchtitan.distributed import ParallelDims +from torchtitan.experiments.graph_based_training.common_utils import end_with_pass +from torchtitan.tools.logging import logger + + +def _dump_gm(dump_folder: str | None, gm: torch.fx.GraphModule, name: str) -> None: + # TODO: make the dump rank configurable + if not dump_folder or torch.distributed.get_rank() != 0: + return + + output_path = Path(dump_folder) / "compiler" / f"{name}.txt" + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text( + gm.print_readable(print_output=False, include_stride=True, include_device=True) + ) + + +def export_joint( + model, args, kwargs=None, dump_folder: str | None = None +) -> tuple[JointWithDescriptors, TracingContext]: + """ + Export joint forward-backward graph with AOT Autograd. + + Args: + model: The model to export + args: Tuple of input arguments + kwargs: Dict of keyword arguments for the model + dump_folder: Optional folder to dump the graph to + """ + if kwargs is None: + kwargs = {} + assert isinstance(args, tuple) + assert isinstance(kwargs, dict) + with ( + # TODO Investigate error on MOE model with use_grouped_mm=False. + # For repro, see: https://gist.github.com/zhxchen17/d794ff58236243d9faddf713b9fc6a61 + torch._dynamo.config.patch(fake_tensor_cache_enabled=False), + torch.fx.traceback.preserve_node_meta(), + ): + gm = dynamo_graph_capture_for_export(model)(*args, **kwargs) + logger.debug("Dynamo gm:") + logger.debug( + gm.print_readable( + print_output=False, include_stride=True, include_device=True + ) + ) + _dump_gm(dump_folder, gm, "dynamo_gm") + + tracing_context = gm.meta["tracing_context"] + + with tracing(tracing_context): + return ( + aot_export_joint_with_descriptors_alone(gm, args, kwargs), + tracing_context, + ) + + +def aot_export_joint_with_descriptors_alone(model, args, kwargs=None): + """ + Export joint forward-backward graph with AOT Autograd. + + Args: + model: The model to export + args: Tuple of input arguments + kwargs: Dict of keyword arguments for the model + """ + if kwargs is None: + kwargs = {} + assert isinstance(args, tuple) + assert isinstance(kwargs, dict) + with contextlib.ExitStack() as stack: + joint_with_descriptors = aot_export_joint_with_descriptors( + stack, + model, + args, + kwargs, + ) + + return joint_with_descriptors + + +def joint_graph_builder( + model: torch.nn.Module, + model_args: tuple, + model_kwargs: dict, + fw_compiler: Optional[Callable] = None, + bw_compiler: Optional[Callable] = None, + joint_custom_passes: Optional[List[Callable]] = None, + dump_folder: str | None = None, + job_config: Optional["JobConfig"] = None, +): + """ + Build a joint forward-backward graph for the model with optional custom compilers. + + Args: + model: The model to compile + model_args: Tuple of model input arguments (should be DTensors) + model_kwargs: Dict of model input keyword arguments + fw_compiler: Optional custom forward compiler function + bw_compiler: Optional custom backward compiler function + joint_custom_passes: list of custom passes to run on the joint graph + dump_folder: Optional folder to dump the graph to + job_config: Job configuration + """ + assert isinstance(model_args, tuple) + for idx, arg in enumerate(model_args): + assert isinstance(arg, DTensor), f"Argument {idx} is of type {type(arg)}" + + # get joint graph + (joint_with_descriptors, tracing_context,) = export_joint( + model, + model_args, + model_kwargs, + dump_folder=dump_folder, + ) + + # Check if inductor_decomposition is configured and create the pass with proper context + if job_config is not None: + pass_names = getattr(job_config.compile, "passes", []) + if "inductor_decomposition" in pass_names: + from torchtitan.experiments.graph_based_training.passes import ( + inductor_decomposition_pass, + ) + + # Create the decomposition pass with context + decomp_pass = functools.partial( + inductor_decomposition_pass, + model=model, + joint_with_descriptors=joint_with_descriptors, + forward_inputs=model_args, + tracing_context=tracing_context, + ) + + # Prepend to joint_custom_passes + if joint_custom_passes is None: + joint_custom_passes = [] + joint_custom_passes = [decomp_pass] + joint_custom_passes + + # run custom passes on joint-graph before partitioner + if joint_custom_passes is not None: + for joint_custom_pass in joint_custom_passes: + joint_with_descriptors.graph_module = joint_custom_pass( + joint_with_descriptors.graph_module + ) + + with tracing(tracing_context): + fn = aot_compile_joint_with_descriptors( + joint_with_descriptors, fw_compiler=fw_compiler, bw_compiler=bw_compiler + ) + + def wrapper_fn(args, kwargs): + inputs = [ + *model.parameters(), + *model.buffers(), + *args, + ] + return fn(*inputs, **kwargs) + + return wrapper_fn + + +class CompiledModule(torch.nn.Module): + def __init__( + self, + inner: torch.nn.Module, + parallel_dims: ParallelDims, + joint_graph_builder: Callable, + parallelize_inputs: Callable, + **overrides, + ) -> None: + super().__init__() + self.inner = inner # register as submodule + self.parallel_dims = parallel_dims + + self.joint_graph_builder = joint_graph_builder + self.joint_graph_module = None + + self.parallelize_inputs = parallelize_inputs + + self._overrides = overrides # for custom hooks + + def __getattr__(self, name: str): + # check overrides + if "_overrides" in self.__dict__ and name in self._overrides: + return self._overrides[name] + try: + # let nn.Module handle registered stuff + return super().__getattr__(name) + except AttributeError: + # fallback to inner model + return getattr(self.inner, name) + + def __setattr__(self, name: str, value) -> None: + if "_overrides" in self.__dict__ and name in self._overrides: + self._overrides[name] = value + else: + super().__setattr__(name, value) + + def __delattr__(self, name: str) -> None: + if "_overrides" in self.__dict__ and name in self._overrides: + del self._overrides[name] + else: + super().__delattr__(name) + + def state_dict(self, *args, **kwargs) -> Any: + return self.inner.state_dict(*args, **kwargs) + + def load_state_dict(self, *args, **kwargs) -> Any: + return self.inner.load_state_dict(*args, **kwargs) + + def named_parameters(self, *args, **kwargs) -> Any: + return self.inner.named_parameters(*args, **kwargs) + + def parameters(self, *args, **kwargs) -> Any: + return self.inner.parameters(*args, **kwargs) + + def forward(self, *args, **kwargs): + assert "forward" not in self._overrides, "forward cannot be overridden" + + dt_args, dt_kwargs = self.parallelize_inputs(self.parallel_dims, args, kwargs) + + if self.joint_graph_module is None: + self.joint_graph_module = self.joint_graph_builder( + self.inner, dt_args, dt_kwargs + ) + + # calling the line below returns control to torchtitan's runner + # letting it call the backward, and optimizer. + return self.joint_graph_module(args, kwargs) + + +# Default compiler pass configuration - no passes by default +DEFAULT_COMPILER_PASSES = [] + + +def compiler( + name: str, + gm: torch.fx.GraphModule, + example_inputs, + passes: List[Callable] = None, + dump_folder: str | None = None, + is_forward: bool = True, +): + """ + Compile a graph module by applying a sequence of compiler passes. + + Args: + name: Name for logging purposes + gm: The graph module to compile + example_inputs: Example inputs for the graph module + passes: List of compiler pass functions to apply. Each function should take + (gm, example_inputs) and return a transformed gm. If None, uses + DEFAULT_COMPILER_PASSES. + dump_folder: Optional folder to dump the graph to + """ + if passes is None: + passes = DEFAULT_COMPILER_PASSES + + logger.debug(f"{name} before compiler:") + logger.debug( + gm.print_readable(print_output=False, include_stride=True, include_device=True) + ) + _dump_gm(dump_folder, gm, f"{name}_before_compiler") + + if end_with_pass(passes, ["cudagraph_pass"]): + # cudagraph pass is always the last pass if it is applied + cg_pass = passes[-1] + + # to identify static input indices, cudagraph passes behaves differently for + # forward and backward pass. so we explicitly pass the info. + _cg_pass = functools.partial(cg_pass, is_forward=is_forward) + + # keep the function name for debug log + passes[-1] = functools.wraps(cg_pass)(_cg_pass) + + for pass_fn in passes: + pass_name = ( + pass_fn.func.__name__ + if isinstance(pass_fn, functools.partial) + else pass_fn.__name__ + ) + logger.info(f"Applying pass: {pass_name}") + gm = pass_fn(gm, example_inputs) + + # Only try to print/dump if gm is still a GraphModule + # (compile_fx_inner returns a CompiledFxGraph which doesn't have print_readable) + if hasattr(gm, "print_readable"): + logger.debug(f"{name} after compiler:") + logger.debug( + gm.print_readable( + print_output=False, include_stride=True, include_device=True + ) + ) + _dump_gm(dump_folder, gm, f"{name}_after_compiler") + + return gm + + +def make_compiler_with_passes( + passes: List[Callable] = None, + dump_folder: str | None = None, +): + """ + Create forward and backward compilers with specified passes. + + Args: + passes: List of compiler pass functions to apply. If None, uses DEFAULT_COMPILER_PASSES. + dump_folder: Optional folder to dump graphs + + Returns: + Tuple of (fw_compiler, bw_compiler) functions + """ + + def fw_compiler(gm: torch.fx.GraphModule, example_inputs): + return compiler( + "fwd_gm", + gm, + example_inputs, + passes=passes, + dump_folder=dump_folder, + is_forward=True, + ) + + def bw_compiler(gm: torch.fx.GraphModule, example_inputs): + return compiler( + "bwd_gm", + gm, + example_inputs, + passes=passes, + dump_folder=dump_folder, + is_forward=False, + ) + + return fw_compiler, bw_compiler diff --git a/torchtitan/experiments/graph_based_training/jit_backend.py b/torchtitan/experiments/graph_based_training/jit_backend.py new file mode 100644 index 0000000000..d62de3fe67 --- /dev/null +++ b/torchtitan/experiments/graph_based_training/jit_backend.py @@ -0,0 +1,162 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +JIT compilation backend for graph-based training. + +This module provides a custom torch.compile backend that integrates +graph-level optimization passes (bucketing, overlapping, activation +checkpointing annotation) with Simple FSDP. +""" + +from typing import Any + +import torch +import torch._functorch.config as functorch_config + +from torchtitan.tools.logging import logger + +from .reshard_after_forward import annotate_fsdp_all_gather + + +def get_jit_compile_backend( + compile_config, + fsdp_reshard_after_forward: bool, + post_partition_pass_names: list[str], + fsdp_manual_buckets: list[list[str] | str] | None, +) -> callable: + """ + Build a torch.compile backend with the given passes for JIT mode. + + Args: + compile_config: The compile section of job config (has .backend field) + fsdp_reshard_after_forward: Whether to reshard after forward + post_partition_pass_names: List of post-partition pass names to apply + fsdp_manual_buckets: Bucket plans for transformer_block_bucketing + + Returns: + A callable backend for torch.compile + """ + backend = torch._dynamo.lookup_backend(compile_config.backend) + + # Determine which bucketing pass to apply (at most one) + bucketing_pass = None + for name in post_partition_pass_names: + if name in ("auto_bucketing", "transformer_block_bucketing"): + bucketing_pass = name + + if bucketing_pass == "auto_bucketing": + from torch._inductor.config import aten_distributed_optimizations as dist_opts + from torch._inductor.fx_passes.overlap_scheduling import ( + schedule_overlap_bucketing, + ) + + dist_opts.collective_bucketing = True + torch._inductor.config.allow_buffer_reuse = False + + if compile_config.backend == "aot_eager": + from torch._dynamo.backends.common import ( + aot_autograd as aot_autograd_backend, + ) + + def aot_eager_autobucketing_reordering_pass( + gm: torch.fx.GraphModule, example_inputs: Any + ) -> torch.fx.GraphModule: + schedule_overlap_bucketing(gm) + gm.recompile() + return gm + + dist_opts.insert_overlap_deps = False + backend = aot_autograd_backend( + fw_compiler=aot_eager_autobucketing_reordering_pass, + bw_compiler=aot_eager_autobucketing_reordering_pass, + keep_inference_input_mutations=True, + ) + elif compile_config.backend == "inductor": + + def inductor_autobucketing_reordering_pass( + gm: torch.fx.Graph, + ) -> torch.fx.GraphModule: + return schedule_overlap_bucketing(gm.owning_module) + + dist_opts.insert_overlap_deps = True + torch._inductor.config.reorder_for_peak_memory = False + torch._inductor.config.reorder_for_compute_comm_overlap = False + torch._inductor.config.post_grad_custom_post_pass = ( + inductor_autobucketing_reordering_pass + ) + else: + raise ValueError( + f"Unsupported backend {compile_config.backend} for auto_bucketing pass" + ) + logger.info("Auto bucketing pass is applied") + + elif bucketing_pass == "transformer_block_bucketing": + from functools import partial + + from torch._dynamo.backends.common import aot_autograd as aot_autograd_backend + from torch._inductor.fx_passes.overlap_manual_scheduling import ( + manual_overlap_bucketing, + ) + + torch._inductor.config.allow_buffer_reuse = False + manual_overlap_bucketing = partial( + manual_overlap_bucketing, + module_bucket_plans=fsdp_manual_buckets, + ) + + if compile_config.backend == "aot_eager": + + def aot_eager_transformer_block_bucketing_reordering_pass( + gm: torch.fx.GraphModule, example_inputs: Any + ) -> torch.fx.GraphModule: + manual_overlap_bucketing(gm, insert_overlap_deps=False) + return gm + + backend = aot_autograd_backend( + fw_compiler=aot_eager_transformer_block_bucketing_reordering_pass, + bw_compiler=aot_eager_transformer_block_bucketing_reordering_pass, + keep_inference_input_mutations=True, + ) + elif compile_config.backend == "inductor": + + def inductor_transformer_block_bucketing_reordering_pass( + gm: torch.fx.Graph, + ) -> torch.fx.GraphModule: + return manual_overlap_bucketing( + gm.owning_module, insert_overlap_deps=True + ) + + torch._inductor.config.reorder_for_peak_memory = False + torch._inductor.config.reorder_for_compute_comm_overlap = False + torch._inductor.config.post_grad_custom_post_pass = ( + inductor_transformer_block_bucketing_reordering_pass + ) + else: + raise ValueError( + f"Unsupported backend {compile_config.backend} for transformer_block_bucketing pass" + ) + logger.info("Transformer block bucketing pass is applied") + + else: + logger.info("No bucketing or overlapping pass is applied") + + # Apply activation checkpointing on joint graph before partitioner + def joint_ac_pass( + gm: torch.fx.GraphModule, example_inputs: Any + ) -> torch.fx.GraphModule: + # this pass implements simplefsdp's fsdp_reshard_after_forward behavior + gm = annotate_fsdp_all_gather(gm, fsdp_reshard_after_forward) + gm.recompile() + return gm + + def simple_fsdp_custom_pass(*args, **kwargs): + # the ac pass has to operate in a joint graph before partitioner for ac + # annotation to take into effect. + with functorch_config.patch("joint_custom_pass", joint_ac_pass): + return backend(*args, **kwargs) + + return simple_fsdp_custom_pass diff --git a/torchtitan/experiments/graph_based_training/job_config.py b/torchtitan/experiments/graph_based_training/job_config.py new file mode 100644 index 0000000000..9e9dc610ca --- /dev/null +++ b/torchtitan/experiments/graph_based_training/job_config.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field +from typing import Literal + + +@dataclass +class Compile: + """ + Compiler configuration for graph-based training. + + - mode: Compilation mode. "jit" uses torch.compile with a custom backend. + "aot" uses AOT joint graph capture with a configurable pass pipeline. + None disables compilation entirely. + + - passes: List of compiler pass names to apply. Passes are automatically + classified as pre-partition or post-partition. Some passes are only + supported in certain modes; an error is raised for unsupported combinations. + + Available passes: + - auto_bucketing: Automatic comm/compute overlap bucketing (jit, aot) + - transformer_block_bucketing: Manual per-block bucketing (jit, aot) + - regional_inductor: Regional Inductor compilation (aot only) + - cudagraph: CUDA graph capture and replay (aot only) + - full_inductor_compilation: Full Inductor code generation (aot only) + - inductor_decomposition: Inductor decompositions on joint graph (aot only) + + Example: --compile.passes auto_bucketing,cudagraph + """ + + mode: Literal["jit", "aot"] | None = None + passes: list[str] = field(default_factory=list) + + +@dataclass +class JobConfig: + compile: Compile = field(default_factory=Compile) diff --git a/torchtitan/experiments/graph_based_training/llama3/__init__.py b/torchtitan/experiments/graph_based_training/llama3/__init__.py new file mode 100644 index 0000000000..94ea27f406 --- /dev/null +++ b/torchtitan/experiments/graph_based_training/llama3/__init__.py @@ -0,0 +1,31 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers +from torchtitan.components.tokenizer import build_hf_tokenizer +from torchtitan.distributed.pipeline_parallel import pipeline_llm +from torchtitan.hf_datasets.text_datasets import build_text_dataloader +from torchtitan.models.llama3 import llama3_args +from torchtitan.protocols.train_spec import TrainSpec + +from .model import SimpleFSDPTransformer +from .parallelize import parallelize_llama + + +def get_train_spec() -> TrainSpec: + return TrainSpec( + model_cls=SimpleFSDPTransformer, + model_args=llama3_args, + parallelize_fn=parallelize_llama, + pipelining_fn=pipeline_llm, + build_optimizers_fn=build_optimizers, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_text_dataloader, + build_tokenizer_fn=build_hf_tokenizer, + build_loss_fn=build_cross_entropy_loss, + ) diff --git a/torchtitan/experiments/graph_based_training/llama3/model.py b/torchtitan/experiments/graph_based_training/llama3/model.py new file mode 100644 index 0000000000..b0c11f9a44 --- /dev/null +++ b/torchtitan/experiments/graph_based_training/llama3/model.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtitan.models.llama3 import Transformer, TransformerModelArgs + +from ..simple_fsdp import disable_active_parametrization + + +class SimpleFSDPTransformer(Transformer): + def __init__(self, model_args: TransformerModelArgs): + super().__init__(model_args) + + def init_weights(self, *args, **kwargs): + with disable_active_parametrization(): + super().init_weights(*args, **kwargs) diff --git a/torchtitan/experiments/graph_based_training/llama3/parallelize.py b/torchtitan/experiments/graph_based_training/llama3/parallelize.py new file mode 100644 index 0000000000..73c9ae7ca3 --- /dev/null +++ b/torchtitan/experiments/graph_based_training/llama3/parallelize.py @@ -0,0 +1,178 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.fx.traceback import annotate_fn + +from torchtitan.config import JobConfig, TORCH_DTYPE_MAP +from torchtitan.distributed import ParallelDims +from torchtitan.distributed.activation_checkpoint import apply_ac +from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp +from torchtitan.models.llama3.infra.parallelize import apply_tp +from torchtitan.tools.logging import logger + +from ..common_utils import register_blockmask_pytree_node +from ..compilation import apply_compilation +from ..simple_fsdp import data_parallel, MixedPrecisionPolicy + + +# for selective op activation checkpointing +_op_sac_save_list = { + torch.ops.aten.mm.default, + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten._scaled_dot_product_cudnn_attention.default, + torch.ops.aten._scaled_dot_product_attention_math.default, + torch.ops.aten._scaled_dot_product_fused_attention_overrideable.default, + torch.ops._c10d_functional.reduce_scatter_tensor.default, + # for low precision training, it's useful to always save + # the result of max, since the absolute maximum is + # used to compute the scaling factor for quantization. + torch.ops.aten.max.default, + torch._higher_order_ops.flex_attention, + torch.ops.torch_attn._varlen_attn, + torch._higher_order_ops.inductor_compiled_code, +} + + +def get_transformer_block_buckets(model) -> list[list[str] | str]: + module_list = [ + model.tok_embeddings, + [model.norm, model.output], + ] + for layer_id, transformer_block in model.layers.items(): + module_list.append(transformer_block) + + def convert_modules_to_fqns(modules, module_to_fqn_mapping): + """Convert a (possibly nested) list of modules to FQN strings.""" + result = [] + for m in modules: + if isinstance(m, list): + if fqn_list := convert_modules_to_fqns(m, module_to_fqn_mapping): + result.append(fqn_list) + else: + if fqn := module_to_fqn_mapping.get(m): + result.append(fqn) + return result + + module_to_name = {m: n for n, m in model.named_modules()} + module_fqns = convert_modules_to_fqns(module_list, module_to_name) + return module_fqns + + +def annotate_llama() -> None: + from torchtitan.models.attention import FlexAttentionWrapper + + # annotate flex_attention with compile_with_inductor + FlexAttentionWrapper.forward = annotate_fn( + {"compile_with_inductor": "flex_attention"} + )(FlexAttentionWrapper.forward) + + +def parallelize_llama( + model: nn.Module, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + """ + Apply tensor parallelism, activation checkpointing, data parallelism, + and compilation to the model. + + NOTE: The passed-in model preferably should be on meta device. Otherwise, + the model must fit on GPU or CPU memory. + """ + compile_mode = getattr(job_config.compile, "mode", None) + + # Annotations for AOT mode (must happen before tracing) + if compile_mode == "aot": + annotate_llama() + register_blockmask_pytree_node() + + # TODO: TP currently cannot handle uneven seq_len because we set + # `use_local_output=True` to use plain Tensors for legacy reasons. + # Need to revisit this. + assert ( + job_config.training.seq_len % parallel_dims.seq_len_divisor == 0 + ), f""" + Sequence length {job_config.training.seq_len} must be divisible by the product of TP degree + ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). + """ + + if parallel_dims.tp_enabled: + enable_float8_linear = "float8" in job_config.model.converters + float8_is_rowwise = job_config.quantize.linear.float8.recipe_name in ( + "rowwise", + "rowwise_with_gw_hp", + ) + + # For now, float8 all-gather with TP is only supported for tensorwise + # float8 scaling recipes. For rowwise recipes, we use regular TP and + # all-gather happens in high precision. + enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise + + tp_mesh = parallel_dims.get_mesh("tp") + apply_tp( + model, + tp_mesh, + loss_parallel=not job_config.parallelism.disable_loss_parallel, + enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, + ) + maybe_enable_async_tp(job_config, tp_mesh) + + if job_config.activation_checkpoint.mode != "none": + model_compile_enabled = compile_mode is not None + apply_ac( + model, + job_config.activation_checkpoint, + model_compile_enabled=model_compile_enabled, + op_sac_save_list=_op_sac_save_list, + base_folder=job_config.job.dump_folder, + ) + + # apply data parallel + if ( + parallel_dims.dp_replicate_enabled + or parallel_dims.dp_shard_enabled + or parallel_dims.cp_enabled + ): + if parallel_dims.dp_replicate_enabled: + if parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled: + dp_mesh_dim_names = ["dp_replicate", "fsdp"] + dp_mode = "hybrid_shard" + else: + dp_mesh_dim_names = ["dp_replicate"] + dp_mode = "replicate" + else: + dp_mesh_dim_names = ["fsdp"] + dp_mode = "fully_shard" + + mp_policy = MixedPrecisionPolicy( + param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], + reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], + ) + + model = data_parallel( + model, + parallel_dims.get_mesh(dp_mesh_dim_names), + mode=dp_mode, + mp_policy=mp_policy, + ) + logger.info( + "Applied Data Parallel (simple_fsdp) (dp mode=%s) to the model", dp_mode + ) + + # apply compilation + if compile_mode: + model = apply_compilation( + model, + job_config, + parallel_dims, + compile_mode, + transformer_block_buckets=get_transformer_block_buckets(model), + ) + + return model diff --git a/torchtitan/experiments/graph_based_training/passes.py b/torchtitan/experiments/graph_based_training/passes.py new file mode 100644 index 0000000000..4037972e6e --- /dev/null +++ b/torchtitan/experiments/graph_based_training/passes.py @@ -0,0 +1,389 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Unified compiler pass registry for graph-based training. + +This module provides all compiler passes and a registry with metadata +(phase, supported modes) for each pass. Passes can be selected and +configured via job config. + +Pass Types: +- Pre-partition passes: Applied to the joint forward-backward graph before partitioning +- Post-partition passes: Applied to the partitioned forward/backward graphs +""" + +import functools +from dataclasses import dataclass, field +from typing import Any, Callable, Sequence + +import torch +from torch._functorch.aot_autograd import JointWithDescriptors +from torch._guards import TracingContext +from torch._inductor.compile_fx import compile_fx_inner +from torch._inductor.fx_passes.overlap_manual_scheduling import manual_overlap_bucketing +from torch._inductor.fx_passes.overlap_scheduling import schedule_overlap_bucketing +from torch.fx.passes.regional_inductor import regional_inductor + +from torchtitan.tools.logging import logger + +from .cudagraph import CUDAGraphWrapper, get_static_input_indices +from .reshard_after_forward import annotate_fsdp_all_gather + + +# --------------------------------------------------------------------------- +# Post-partition passes +# --------------------------------------------------------------------------- + + +def autobucketing_reordering_pass( + gm: torch.fx.GraphModule, example_inputs=None +) -> torch.fx.GraphModule: + """ + Apply autobucketing and reordering optimization. + + This pass applies schedule_overlap_bucketing with collective_bucketing enabled + to optimize comm/compute overlap patterns in the graph. + """ + schedule_overlap_bucketing(gm, collective_bucketing=True) + gm.recompile() + return gm + + +def transformer_block_bucketing_reordering_pass( + gm: torch.fx.GraphModule, example_inputs, fsdp_manual_buckets +) -> torch.fx.GraphModule: + """ + Apply aten-level manual bucketing and reordering optimization. + """ + manual_overlap_bucketing( + gm, module_bucket_plans=fsdp_manual_buckets, insert_overlap_deps=False + ) + gm.recompile() + return gm + + +def regional_inductor_pass( + gm: torch.fx.GraphModule, example_inputs +) -> torch.fx.GraphModule: + """ + Apply regional inductor compilation based on user annotation. + """ + return regional_inductor(gm, example_inputs) + + +def cudagraph_pass( + gm: torch.fx.GraphModule, example_inputs: Sequence[Any], is_forward: bool +) -> torch.fx.GraphModule: + """ + Apply cudagraph. + + This pass wraps the forward function with cudagraph during compilation and does + not record cudagraph until runtime. + - For the first run, it will warm up operators such as nccl. + - For the second run, it will record cudagraph and replay cudagraph. + - For the following runs, it will replay cudagraph. + """ + static_input_indices = get_static_input_indices(gm, is_forward) + gm.forward = CUDAGraphWrapper(gm.forward, example_inputs, static_input_indices) + return gm + + +def full_inductor_compilation_pass( + gm: torch.fx.GraphModule, example_inputs +) -> torch.fx.GraphModule: + """ + Apply full Inductor compilation with code generation. + + This pass uses compile_fx_inner to generate optimized code for the graph. + """ + return compile_fx_inner(gm, example_inputs) + + +# --------------------------------------------------------------------------- +# Pre-partition (joint) passes +# --------------------------------------------------------------------------- + + +def validate_flex_attn_annotation_pass( + gm: torch.fx.GraphModule, +) -> torch.fx.GraphModule: + """Verify user annotations show up in the graph.""" + for node in gm.graph.nodes: + if node.target in { + torch.ops.higher_order.flex_attention, + torch.ops.higher_order.flex_attention_backward, + }: + assert "compile_with_inductor" in node.meta.get("custom", {}) + return gm + + +def fsdp_reshard_after_fwd_pass( + gm: torch.fx.GraphModule, reshard_after_forward: bool +) -> torch.fx.GraphModule: + """ + Annotate FSDP all-gather operations for reshard-after-forward behavior. + + When reshard_after_forward is True, annotates all-gather results as + MUST_RECOMPUTE (freed after forward, recomputed in backward). + When False, annotates as MUST_SAVE (kept in memory). + """ + gm = annotate_fsdp_all_gather(gm, reshard_after_forward) + gm.recompile() + return gm + + +def inductor_decomposition_pass( + gm: torch.fx.GraphModule, + model: torch.nn.Module, + joint_with_descriptors: JointWithDescriptors, + forward_inputs: tuple, + tracing_context: TracingContext, +) -> torch.fx.GraphModule: + """ + Apply Inductor decompositions to the joint graph. + + This pass applies decompositions to the joint forward-backward graph using make_fx. + It unwraps tensor subclasses (like DTensor) and retraces the graph with decompositions + applied, while preserving metadata required by the partitioner. + """ + from torch._functorch._aot_autograd.descriptors import DummyAOTInput + from torch._functorch._aot_autograd.subclass_utils import unwrap_tensor_subclasses + from torch._inductor.decomposition import select_decomp_table + from torch.fx.experimental.proxy_tensor import make_fx + + logger.info("Applying decompositions to joint graph") + + decomp_table = select_decomp_table() + + # Get traced tangents metadata + traced_tangents = joint_with_descriptors._aot_state.fw_metadata.traced_tangents + + # Collect all inputs: params, buffers, forward inputs, tangents + param_inputs = list(model.parameters()) + buffer_inputs = list(model.buffers()) + primals = param_inputs + buffer_inputs + list(forward_inputs) + tangents = list(traced_tangents) + + # Create dummy descriptors for unwrapping + primals_descs = [DummyAOTInput(i) for i in range(len(primals))] + tangents_descs = [DummyAOTInput(i + len(primals)) for i in range(len(tangents))] + + # Unwrap tensor subclasses (DTensor -> _local_tensor) + primals_unwrapped, _ = unwrap_tensor_subclasses( + primals, primals_descs, append_symints=False + ) + tangents_unwrapped, _ = unwrap_tensor_subclasses( + tangents, tangents_descs, append_symints=False + ) + + # Verify unwrapped tensor shapes match joint graph placeholders + all_inputs = primals_unwrapped + tangents_unwrapped + placeholders = [n for n in gm.graph.nodes if n.op == "placeholder"] + + if len(all_inputs) != len(placeholders): + raise RuntimeError( + f"Input count mismatch: {len(all_inputs)} inputs vs {len(placeholders)} placeholders" + ) + + shape_mismatches = [] + for i, (inp, ph) in enumerate(zip(all_inputs, placeholders)): + if hasattr(inp, "shape") and "val" in ph.meta: + expected_shape = ph.meta["val"].shape + actual_shape = inp.shape + if expected_shape != actual_shape: + shape_mismatches.append( + f" {ph.target}: expected {expected_shape}, got {actual_shape}" + ) + + if shape_mismatches: + logger.error(f"Shape mismatches found ({len(shape_mismatches)}):") + for msg in shape_mismatches: + logger.error(msg) + raise RuntimeError( + "Unwrapped tensor shapes don't match joint graph placeholders." + ) + + # Get the FakeTensorMode from the original joint graph + fake_mode = None + for node in gm.graph.nodes: + if node.op == "placeholder" and "val" in node.meta: + val = node.meta["val"] + if hasattr(val, "fake_mode"): + fake_mode = val.fake_mode + break + + if fake_mode is None: + from torch._guards import detect_fake_mode + + fake_mode = detect_fake_mode(primals_unwrapped) + + # Use make_fx with the original fake mode to retrace with decompositions + with fake_mode: + decomposed_gm = make_fx( + gm, + decomposition_table=decomp_table, + _allow_non_fake_inputs=False, + )(primals_unwrapped, tangents_unwrapped) + + # Copy metadata from original placeholders to decomposed placeholders + orig_placeholders = [n for n in gm.graph.nodes if n.op == "placeholder"] + decomp_placeholders = [ + n for n in decomposed_gm.graph.nodes if n.op == "placeholder" + ] + + if len(orig_placeholders) != len(decomp_placeholders): + raise RuntimeError( + f"Placeholder count mismatch: {len(orig_placeholders)} vs {len(decomp_placeholders)}" + ) + + for orig, decomp in zip(orig_placeholders, decomp_placeholders): + # Copy all metadata from original to decomposed + for key, value in orig.meta.items(): + if key not in decomp.meta: + decomp.meta[key] = value + + # Rename decomposed placeholder to match original name + decomp.target = orig.target + decomp.name = orig.name + + decomposed_gm.recompile() + logger.info("Decompositions applied successfully to joint graph") + + return decomposed_gm + + +# --------------------------------------------------------------------------- +# Pass registry +# --------------------------------------------------------------------------- + + +@dataclass +class PassInfo: + """Metadata for a compiler pass.""" + + fn: Callable + is_joint: bool # True if applied on joint graph, False if on fwd/bwd graphs + supported_modes: set = field(default_factory=set) + + +PASS_REGISTRY: dict[str, PassInfo] = { + "auto_bucketing": PassInfo( + fn=autobucketing_reordering_pass, + is_joint=False, + supported_modes={"jit", "aot"}, + ), + "transformer_block_bucketing": PassInfo( + fn=transformer_block_bucketing_reordering_pass, + is_joint=False, + supported_modes={"jit", "aot"}, + ), + "regional_inductor": PassInfo( + fn=regional_inductor_pass, + is_joint=False, + supported_modes={"aot"}, + ), + "cudagraph": PassInfo( + fn=cudagraph_pass, + is_joint=False, + supported_modes={"aot"}, + ), + "full_inductor_compilation": PassInfo( + fn=full_inductor_compilation_pass, + is_joint=False, + supported_modes={"aot"}, + ), + "inductor_decomposition": PassInfo( + fn=inductor_decomposition_pass, + is_joint=True, + supported_modes={"aot"}, + ), +} + + +def _validate_pass_constraints(pass_names: list[str]) -> None: + """Validate pass ordering and mutual exclusion constraints.""" + if "cudagraph" in pass_names: + if pass_names[-1] != "cudagraph": + raise ValueError("cudagraph must be the last pass in the list") + + if "auto_bucketing" in pass_names and "transformer_block_bucketing" in pass_names: + raise ValueError( + "Cannot apply auto_bucketing and transformer_block_bucketing at the same time" + ) + + if "full_inductor_compilation" in pass_names: + if "inductor_decomposition" not in pass_names: + raise ValueError( + "full_inductor_compilation requires inductor_decomposition. " + "Please add inductor_decomposition to compile.passes" + ) + + +def validate_and_get_passes( + pass_names: list[str], + mode: str, + transformer_block_buckets: list | None = None, +) -> tuple[list[Callable], list[Callable]]: + """ + Validate and split passes into joint and fwd/bwd lists. + + Args: + pass_names: List of pass names from config + mode: Compilation mode ("jit" or "aot") + transformer_block_buckets: Bucket plans for transformer_block_bucketing pass + + Returns: + (joint_passes, fwd_bwd_passes) + + Raises: + ValueError: If a pass is unknown, unsupported in the given mode, + or constraints are violated + """ + _validate_pass_constraints(pass_names) + + joint_passes = [] + fwd_bwd_passes = [] + + for name in pass_names: + if name not in PASS_REGISTRY: + raise ValueError( + f"Unknown pass: {name}. " + f"Available passes: {list(PASS_REGISTRY.keys())}" + ) + + info = PASS_REGISTRY[name] + if mode not in info.supported_modes: + raise ValueError( + f"Pass '{name}' is not supported in '{mode}' mode. " + f"Supported modes: {info.supported_modes}" + ) + + fn = info.fn + + # Apply model-specific configuration + if name == "transformer_block_bucketing": + if transformer_block_buckets is None: + raise ValueError( + "transformer_block_bucketing requires transformer_block_buckets" + ) + fn = functools.partial(fn, fsdp_manual_buckets=transformer_block_buckets) + + if info.is_joint: + joint_passes.append(fn) + else: + fwd_bwd_passes.append(fn) + + if pass_names: + logger.info(f"Using compiler passes: {pass_names}") + if "full_inductor_compilation" in pass_names: + logger.warning( + "Full Inductor compilation is enabled. Note that Inductor may change " + "numerics and does not guarantee bitwise equivalent results compared " + "to eager mode." + ) + + return joint_passes, fwd_bwd_passes diff --git a/torchtitan/experiments/graph_based_training/reshard_after_forward.py b/torchtitan/experiments/graph_based_training/reshard_after_forward.py new file mode 100644 index 0000000000..dac010bfcd --- /dev/null +++ b/torchtitan/experiments/graph_based_training/reshard_after_forward.py @@ -0,0 +1,90 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch.utils.checkpoint import CheckpointPolicy + + +def is_graph_input(node: torch.fx.Node) -> bool: + return node.op == "placeholder" + + +def is_wait_tensor(node: torch.fx.Node) -> bool: + return ( + node.op == "call_function" + and node.target == torch.ops._c10d_functional.wait_tensor.default + ) + + +def is_all_gather_into_tensor(node: torch.fx.Node) -> bool: + return ( + node.op == "call_function" + and node.target == torch.ops._c10d_functional.all_gather_into_tensor.default + ) + + +def is_wait_tensor_from_fsdp(node: torch.fx.Node) -> bool: + """ + Returns True if the node is a wait_tensor node that is the result of an all_gather + that can be arbitrarily prefetched, i.e., if all its recursive inputs are + single-input operators that leads to a graph input. + """ + if is_wait_tensor(node) and is_all_gather_into_tensor(node.args[0]): + n: torch.fx.Node = node.all_input_nodes[0] + while len(n.all_input_nodes) == 1: + if is_graph_input(n.all_input_nodes[0]): + return True + n = n.all_input_nodes[0] + return False + + +def annotate_fsdp_all_gather( + gm: torch.fx.GraphModule, reshard_after_forward: bool +) -> None: + """ + Force recompute all_gather nodes from simple fsdp in the graph. + This pass should be added in torch._inductor.config.joint_custom_post_pass + """ + graph = gm.graph + + def force_recompute_node(node): + if reshard_after_forward: + node.meta["recompute"] = CheckpointPolicy.MUST_RECOMPUTE + else: + node.meta["recompute"] = CheckpointPolicy.MUST_SAVE + # ac_graph_id is used in the partitioner to decide + # if two nodes which have AC applied come from a different + # AC regions. This is needed because nodes in the boundary + # of two AC regions are marked as MUST_SAVE. In our case + # we just add a large value of ac_graph_id so that + # all nodes we tag for recomputation do indeed get recomputed + # and are not influenced by other nodes in the graph with + # nearby ac_graph_id values + node.meta["ac_graph_id"] = 100000 + + # Make all-gather nodes (and related nodes) recomputable, to circumvent + # https://github.com/pytorch/pytorch/issues/136433 + for node in graph.nodes: + if is_wait_tensor_from_fsdp(node): + ag_node = node.args[0] + force_recompute_node(ag_node) # all_gather + force_recompute_node(node) # wait_tensor + # Force-recompute slice that comes after wait + for user in node.users: + if ( + user.op == "call_function" + and user.target == torch.ops.aten.slice.Tensor + ): + force_recompute_node(user) + # Force-recompute potential dtype casts from all_gather + if ( + ag_node.all_input_nodes[0].op == "call_function" + and ag_node.args[0].target + == torch.ops.prims.convert_element_type.default + ): + force_recompute_node(ag_node.all_input_nodes[0]) + + return gm diff --git a/torchtitan/experiments/graph_based_training/simple_fsdp.py b/torchtitan/experiments/graph_based_training/simple_fsdp.py new file mode 100644 index 0000000000..8ea108f9f4 --- /dev/null +++ b/torchtitan/experiments/graph_based_training/simple_fsdp.py @@ -0,0 +1,301 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from collections.abc import Generator, Sequence +from contextlib import contextmanager +from dataclasses import dataclass + +import torch +import torch.nn as nn + +from torch.distributed._tensor import ( + distribute_tensor, + DTensor, + Partial, + Replicate, + Shard, +) +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._redistribute import redistribute_local_tensor +from torch.distributed.tensor.placement_types import _StridedShard, Placement + +_active_parametrization = True + + +@contextmanager +def disable_active_parametrization() -> Generator[None, None, None]: + global _active_parametrization + try: + _active_parametrization = False + yield + finally: + _active_parametrization = True + + +@dataclass(frozen=True) +class MixedPrecisionPolicy: + param_dtype: torch.dtype | None = None + reduce_dtype: torch.dtype | None = None + + +def _distribute_dtensor( + tensor: DTensor, + device_mesh: DeviceMesh, + dp_placements: Sequence[Placement], +) -> DTensor: + """ + Below are experimental enhancements to distribute a DTensor. + This helps enable Simple FSDP + TP/EP, in which + inner spec/mesh is TP/EP spec/mesh + outer spec/mesh is FSDP/DDP/HSDP spec/mesh + The logic follows + https://github.com/pytorch/pytorch/blob/main/torch/distributed/_composable/fsdp/_fsdp_param.py#L261 + """ + inner_spec = tensor._spec + outer_mesh, inner_mesh = device_mesh, inner_spec.mesh + spanned_mesh = DeviceMesh._concatenate([outer_mesh, inner_mesh]) + + if len(dp_placements) == 1: + assert dp_placements[0].is_replicate() or dp_placements[0].is_shard() + if dp_placements[0].is_shard(): + # For FSDP + EP/TP/EP+TP + assert len(inner_spec.placements) == 2 or len(inner_spec.placements) == 1 + shard_dim = dp_placements[0].dim + split_factor = inner_spec.num_shards_map[shard_dim] + tensor_placement = ( + ( + _StridedShard(shard_dim, split_factor=split_factor) + if split_factor > 1 + else dp_placements[0] + ), + ) + inner_spec.placements + else: + # For DDP + TP/EP + assert len(inner_spec.placements) == 1 + tensor_placement = (dp_placements[0], inner_spec.placements[0]) + elif len(dp_placements) == 2: + assert dp_placements[0].is_replicate() and dp_placements[1].is_shard() + # For HSDP + EP/TP/EP+TP + assert len(inner_spec.placements) == 2 or len(inner_spec.placements) == 1 + shard_dim = dp_placements[1].dim + split_factor = inner_spec.num_shards_map[shard_dim] + tensor_placement = ( + dp_placements[0], + ( + _StridedShard(shard_dim, split_factor=split_factor) + if split_factor > 1 + else dp_placements[1] + ), + ) + inner_spec.placements + else: + raise ValueError( + f"Unsupported placement {dp_placements} for distributing DTensor {tensor}" + ) + + # HSDP case needs 2 placements for 2D outer_mesh + current_placements = (Replicate(),) * len(dp_placements) + target_placements = tuple(dp_placements) + + current_spec = DTensorSpec( + mesh=outer_mesh, + placements=current_placements, + tensor_meta=inner_spec.tensor_meta, + ) + target_spec = DTensorSpec( + mesh=outer_mesh, + placements=target_placements, + tensor_meta=inner_spec.tensor_meta, + ) + result_tensor = redistribute_local_tensor( + tensor._local_tensor, + current_spec=current_spec, + target_spec=target_spec, + ) + return DTensor( + result_tensor.requires_grad_(tensor.requires_grad), + DTensorSpec( + mesh=spanned_mesh, + placements=tensor_placement, + tensor_meta=inner_spec.tensor_meta, + ), + requires_grad=tensor.requires_grad, + ) + + +def _register_parametrization( + module: nn.Module, param_names: list[str], parametrization: nn.Module +) -> None: + """ + It works with state_dict without incurring parametrization calls because + state_dict accesses parameters directly from self._parameters, not from getters + https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/module.py#L2141 + TODO: In checkpoint saving/loading, avoid parametrization calls when calling + get_model_state_dict func in torchtitan's torchtitan/components/checkpoint.py. + """ + param_name_to_property = { + param_name: property( + lambda self, pn=param_name: parametrization(self._parameters[pn]) + ) + for param_name in param_names + } + module_cls = type( + f"SimpleFSDP{module.__class__.__name__}", + (module.__class__,), + param_name_to_property, + ) + module.__class__ = module_cls + + +class ReplicateComputation(torch.nn.Module): + def __init__( + self, + device_mesh: DeviceMesh, + param_sharding: tuple[Placement, ...], + mode: str, + mp_policy: MixedPrecisionPolicy | None, + full_dtensor: bool = False, + ) -> None: + super().__init__() + self.device_mesh = device_mesh + self.param_sharding = param_sharding + self.mode = mode + self.compute_placements: list[Placement] = [Replicate()] * self.device_mesh.ndim + self.grad_placements: list[Placement] = [ + Partial(reduce_op="sum") + ] * self.device_mesh.ndim + mp_policy = mp_policy or MixedPrecisionPolicy() + self.param_dtype: torch.dtype | None = mp_policy.param_dtype + self.reduce_dtype: torch.dtype | None = mp_policy.reduce_dtype + self.full_dtensor = full_dtensor + + def replicate_compute(self, x: DTensor) -> torch.Tensor: + # data parallel runtime replicate parameters and do local compute + # the gradients are partial tensors that needs to perform reduction + # (i.e. DDP: allreduce, FSDP: reduce_scatter, HSDP: mix of both) + # support FSDP/DDP/HSDP + EP + TP (assuming TP shards the inner-most dim) + non_dp_mesh_dims = x._spec.mesh.ndim - self.device_mesh.ndim + assert non_dp_mesh_dims <= 2, "Only DP + EP/TP/EP+TP is supported" + if non_dp_mesh_dims > 0: + if self.full_dtensor: + raise NotImplementedError( + "full_dtensor not implemented for nD parallelisms" + ) + dp_mesh = self.device_mesh + # re-wrap 2D DTensor to 1D DTensor on dp_mesh for efficient FSDP all-gather + sharded_local_tensor = x.to_local() + sharded_dtensor = DTensor.from_local( + sharded_local_tensor, dp_mesh, self.param_sharding + ) + + # the actual FSDP's fwd all-gather & bwd reduce-scatter + # DDP's bwd all-reduce on dp_mesh + replicated_dtensor = sharded_dtensor.redistribute( + placements=self.compute_placements, + forward_dtype=self.param_dtype, + backward_dtype=self.reduce_dtype, + ) + + # re-wrap all-gathered DTensor on dp_mesh to be on non_dp_mesh + # TODO: DTensor should support this mesh collapsing operation + replicated_local_tensor = replicated_dtensor.to_local( + grad_placements=self.grad_placements + ) + + non_dp_placements = tuple(x._spec.placements[-non_dp_mesh_dims:]) + non_dp_mesh_dim_names = tuple( + x._spec.mesh.mesh_dim_names[-non_dp_mesh_dims:] + ) + non_dp_mesh = x._spec.mesh[non_dp_mesh_dim_names] + + output = DTensor.from_local( + replicated_local_tensor, non_dp_mesh, non_dp_placements + ) + elif non_dp_mesh_dims == 0: + output = x.redistribute( + placements=self.compute_placements, + forward_dtype=self.param_dtype, + backward_dtype=self.reduce_dtype, + ) + + if not self.full_dtensor: + output = output.to_local(grad_placements=self.grad_placements) + else: + raise AssertionError( + f"Unsupported replicate compute on placement {x._spec.placements} for DTensor {x}" + ) + + return output + + def forward(self, x: DTensor) -> torch.Tensor: + global _active_parametrization + # This should never be set to true during forward, only outside for model + # inspection / debugging / initialization + # model initialization can be done now through + # with disable_active_parametrization(): + # model.init_weights() + if not _active_parametrization: + return x + + output = self.replicate_compute(x) + return output + + +def data_parallel( + model: nn.Module, + device_mesh: DeviceMesh, + mode: str = "replicate", + mp_policy: MixedPrecisionPolicy | None = None, + shard_dim: int = 0, + full_dtensor: bool = False, +) -> nn.Module: + param_sharding: tuple[Placement, ...] + if mode == "replicate": + param_sharding = (Replicate(),) + elif mode == "fully_shard": + param_sharding = (Shard(shard_dim),) + elif mode == "hybrid_shard": + # replicate inter-host, fully shard intra-host + param_sharding = (Replicate(), Shard(shard_dim)) + assert ( + device_mesh.ndim == 2 + ), "hybrid sharded data parallel requires 2D DeviceMesh" + else: + raise ValueError(f"Unsupported mode {mode}") + + modules = list(model.modules()) + + for mod in modules: + params_dict = dict(mod.named_parameters(recurse=False)) + # we shouldn't apply data parallel to the modules that are already + # sharded by data parallel + if "SimpleFSDP" in mod.__class__.__name__: + continue + + for p_name, p in params_dict.items(): + if p is not None and p.numel() > 0: + distribute_tensor_func = ( + _distribute_dtensor if isinstance(p, DTensor) else distribute_tensor + ) + mod.register_parameter( + p_name, + nn.Parameter( + distribute_tensor_func(p, device_mesh, param_sharding) + ), + ) + + _register_parametrization( + mod, + list(params_dict.keys()), + ReplicateComputation( + device_mesh, + param_sharding, + mode, + mp_policy=mp_policy, + full_dtensor=full_dtensor, + ), + ) + return model diff --git a/torchtitan/experiments/graph_based_training/tests/__init__.py b/torchtitan/experiments/graph_based_training/tests/__init__.py new file mode 100644 index 0000000000..2e41cd717f --- /dev/null +++ b/torchtitan/experiments/graph_based_training/tests/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torchtitan/experiments/graph_based_training/tests/integration_tests.py b/torchtitan/experiments/graph_based_training/tests/integration_tests.py new file mode 100644 index 0000000000..1397bcd8e9 --- /dev/null +++ b/torchtitan/experiments/graph_based_training/tests/integration_tests.py @@ -0,0 +1,538 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os + +from tests.integration_tests import OverrideDefinitions +from tests.integration_tests.run_tests import run_tests + + +_CUSTOM_CONFIG = ( + "--job.custom_config_module=torchtitan.experiments.graph_based_training.job_config" +) + + +def build_graph_based_training_test_list() -> list[OverrideDefinitions]: + """ + Merged integration tests from simple_fsdp and compiler_toolkit. + All tests use graph_based_training.llama3 or graph_based_training.deepseek_v3, + with compile.mode selecting JIT or AOT compilation. + """ + integration_tests_flavors = [ + # ===================================================================== + # JIT mode tests (from simple_fsdp) + # ===================================================================== + OverrideDefinitions( + [ + [ + "--model.name graph_based_training.llama3", + _CUSTOM_CONFIG, + "--compile.mode jit", + ], + ], + "JIT 1D", + "jit_1d", + ), + OverrideDefinitions( + [ + [ + "--model.name graph_based_training.llama3", + _CUSTOM_CONFIG, + "--compile.mode jit", + "--compile.backend aot_eager", + "--compile.passes auto_bucketing", + ], + ], + "JIT 1D+autobucketing", + "jit_1d_autobucketing", + ), + OverrideDefinitions( + [ + [ + "--model.name graph_based_training.llama3", + _CUSTOM_CONFIG, + "--compile.mode jit", + "--compile.backend aot_eager", + "--compile.passes transformer_block_bucketing", + ], + ], + "JIT 1D+transformer_block_bucketing", + "jit_1d_transformer_block_bucketing", + ), + OverrideDefinitions( + [ + [ + "--model.name graph_based_training.llama3", + _CUSTOM_CONFIG, + "--compile.mode jit", + "--activation_checkpoint.mode selective", + "--activation_checkpoint.selective_ac_option op", + ], + ], + "JIT 1D with selective op AC", + "jit_1d_sac_op", + ), + OverrideDefinitions( + [ + [ + "--model.name graph_based_training.llama3", + _CUSTOM_CONFIG, + "--compile.mode jit", + "--activation_checkpoint.mode full", + ], + ], + "JIT 1D with full AC", + "jit_1d_full_ac", + ), + OverrideDefinitions( + [ + [ + "--model.name graph_based_training.llama3", + _CUSTOM_CONFIG, + "--compile.mode jit", + "--parallelism.tensor_parallel_degree 2", + ], + ], + "JIT 2D", + "jit_2d", + ), + OverrideDefinitions( + [ + [ + "--model.name graph_based_training.llama3", + _CUSTOM_CONFIG, + "--compile.mode jit", + "--checkpoint.enable", + ], + [ + "--model.name graph_based_training.llama3", + _CUSTOM_CONFIG, + "--compile.mode jit", + "--checkpoint.enable", + "--training.steps 20", + ], + ], + "JIT Checkpoint Integration Test - Save Load Full Checkpoint", + "jit_full_checkpoint", + ), + OverrideDefinitions( + [ + [ + "--model.name graph_based_training.llama3", + _CUSTOM_CONFIG, + "--compile.mode jit", + "--checkpoint.enable", + "--parallelism.pipeline_parallel_degree 2", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.tensor_parallel_degree 2", + ], + [ + "--model.name graph_based_training.llama3", + _CUSTOM_CONFIG, + "--compile.mode jit", + "--training.steps 20", + "--checkpoint.enable", + "--parallelism.pipeline_parallel_degree 2", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.tensor_parallel_degree 2", + ], + ], + "JIT PP+DP+TP 3D test with save/load resume ckpt", + "jit_pp_dp_tp", + ngpu=8, + ), + OverrideDefinitions( + [ + [ + "--model.name graph_based_training.llama3", + _CUSTOM_CONFIG, + "--compile.mode jit", + "--parallelism.data_parallel_shard_degree 1", + "--parallelism.data_parallel_replicate_degree 4", + ] + ], + "JIT DDP", + "jit_ddp", + ngpu=4, + ), + OverrideDefinitions( + [ + [ + "--model.name graph_based_training.llama3", + _CUSTOM_CONFIG, + "--compile.mode jit", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.data_parallel_replicate_degree 2", + ] + ], + "JIT HSDP", + "jit_hsdp", + ngpu=4, + ), + OverrideDefinitions( + [ + [ + "--model.name graph_based_training.llama3", + _CUSTOM_CONFIG, + "--compile.mode jit", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.data_parallel_replicate_degree 2", + "--parallelism.tensor_parallel_degree 2", + ] + ], + "JIT HSDP+TP", + "jit_hsdp+tp", + ngpu=8, + ), + OverrideDefinitions( + [ + [ + "--model.name graph_based_training.llama3", + _CUSTOM_CONFIG, + "--compile.mode jit", + "--parallelism.data_parallel_replicate_degree 2", + "--parallelism.tensor_parallel_degree 2", + ] + ], + "JIT DDP+TP", + "jit_ddp+tp", + ngpu=4, + ), + OverrideDefinitions( + [ + [ + "--model.name graph_based_training.llama3", + _CUSTOM_CONFIG, + "--compile.mode jit", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.data_parallel_replicate_degree 2", + "--parallelism.context_parallel_degree 2", + ] + ], + "JIT HSDP+CP (with dp_shard)", + "jit_hsdp+cp_with_dp_shard", + ngpu=8, + ), + OverrideDefinitions( + [ + [ + "--model.name graph_based_training.llama3", + _CUSTOM_CONFIG, + "--compile.mode jit", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.tensor_parallel_degree 2", + "--parallelism.context_parallel_degree 2", + ] + ], + "JIT FSDP+TP+CP", + "jit_fsdp+tp+cp", + ngpu=8, + ), + OverrideDefinitions( + [ + [ + "--model.name graph_based_training.llama3", + _CUSTOM_CONFIG, + "--compile.mode jit", + "--checkpoint.enable", + "--training.steps 10", + ], + # Save at [dp:4] and load at [dp:2, tp:2]. + [ + "--model.name graph_based_training.llama3", + _CUSTOM_CONFIG, + "--compile.mode jit", + "--checkpoint.enable", + "--checkpoint.exclude_from_loading lr_scheduler,dataloader,optimizer", + "--parallelism.tensor_parallel_degree 2", + "--training.steps 20", + ], + # load at [tp:4]. + [ + "--model.name graph_based_training.llama3", + _CUSTOM_CONFIG, + "--compile.mode jit", + "--checkpoint.enable", + "--checkpoint.exclude_from_loading lr_scheduler,dataloader,optimizer", + "--parallelism.tensor_parallel_degree 4", + "--training.steps 30", + ], + ], + "JIT Optional checkpoint", + "jit_optional_checkpoint", + ngpu=4, + ), + OverrideDefinitions( + [ + [ + "--model.name graph_based_training.deepseek_v3", + _CUSTOM_CONFIG, + "--parallelism.data_parallel_shard_degree 4", + "--parallelism.expert_parallel_degree 2", + ], + ], + "JIT FSDP+EP", + "jit_fsdp+ep", + ngpu=4, + ), + OverrideDefinitions( + [ + [ + "--model.name graph_based_training.deepseek_v3", + _CUSTOM_CONFIG, + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.tensor_parallel_degree 2", + "--parallelism.expert_parallel_degree 4", + "--parallelism.expert_tensor_parallel_degree 1", + ], + ], + "JIT FSDP+TP+EP", + "jit_fsdp+tp+ep", + ngpu=4, + ), + OverrideDefinitions( + [ + [ + "--model.name graph_based_training.deepseek_v3", + _CUSTOM_CONFIG, + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.tensor_parallel_degree 2", + "--parallelism.expert_parallel_degree 2", + "--parallelism.expert_tensor_parallel_degree 2", + ], + ], + "JIT FSDP+TP+EP+ETP", + "jit_fsdp+tp+ep+etp", + ngpu=4, + ), + OverrideDefinitions( + [ + [ + "--model.name graph_based_training.deepseek_v3", + _CUSTOM_CONFIG, + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.context_parallel_degree 2", + ], + ], + "JIT FSDP+CP", + "jit_fsdp+cp", + ngpu=4, + ), + # ===================================================================== + # AOT mode tests (from compiler_toolkit) + # ===================================================================== + OverrideDefinitions( + [ + [ + "--model.name graph_based_training.llama3", + _CUSTOM_CONFIG, + "--compile.mode aot", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.tensor_parallel_degree 2", + ], + ], + "AOT llama3 FSDP+TP", + "aot_llama3_fsdp_tp", + ngpu=4, + ), + OverrideDefinitions( + [ + [ + "--model.name graph_based_training.llama3", + _CUSTOM_CONFIG, + "--compile.mode aot", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.tensor_parallel_degree 2", + "--compile.passes auto_bucketing", + ], + ], + "AOT llama3 FSDP+TP autobucketing", + "aot_llama3_fsdp_tp_autobucketing", + ngpu=4, + ), + OverrideDefinitions( + [ + [ + "--model.name graph_based_training.llama3", + _CUSTOM_CONFIG, + "--compile.mode aot", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.tensor_parallel_degree 2", + "--compile.passes transformer_block_bucketing", + ], + ], + "AOT llama3 FSDP+TP manualbucketing", + "aot_llama3_fsdp_tp_manualbucketing", + ngpu=4, + ), + OverrideDefinitions( + [ + [ + "--model.name graph_based_training.llama3", + _CUSTOM_CONFIG, + "--compile.mode aot", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.tensor_parallel_degree 2", + "--compile.passes cudagraph", + ], + ], + "AOT llama3 FSDP+TP+cudagraph", + "aot_llama3_fsdp_tp_cudagraph", + ngpu=4, + skip_rocm_test=True, + ), + OverrideDefinitions( + [ + [ + "--model.name graph_based_training.llama3", + _CUSTOM_CONFIG, + "--compile.mode aot", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.tensor_parallel_degree 2", + "--model.flavor debugmodel_flex_attn", + ], + ], + "AOT llama3 FSDP+TP+FlexAttn", + "aot_llama3_fsdp_tp_flexattn", + ngpu=4, + ), + OverrideDefinitions( + [ + [ + "--model.name graph_based_training.llama3", + _CUSTOM_CONFIG, + "--compile.mode aot", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.tensor_parallel_degree 2", + "--model.flavor debugmodel_flex_attn", + "--compile.passes auto_bucketing,regional_inductor", + ], + ], + "AOT llama3 FSDP+TP+FlexAttn autobucketing regional_inductor", + "aot_llama3_fsdp_tp_flexattn_autobucketing_regional_inductor", + ngpu=4, + ), + OverrideDefinitions( + [ + [ + "--model.name graph_based_training.llama3", + _CUSTOM_CONFIG, + "--compile.mode aot", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.tensor_parallel_degree 2", + "--compile.passes inductor_decomposition,full_inductor_compilation", + ], + ], + "AOT llama3 full_inductor_compilation", + "aot_llama3_full_inductor_compilation", + ngpu=4, + ), + OverrideDefinitions( + [ + [ + "--model.name graph_based_training.llama3", + _CUSTOM_CONFIG, + "--compile.mode aot", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.tensor_parallel_degree 2", + "--compile.passes transformer_block_bucketing,regional_inductor", + ], + ], + "AOT llama3 FSDP+TP manualbucketing regional_inductor", + "aot_llama3_fsdp_tp_manualbucketing_regional_inductor", + ngpu=4, + ), + OverrideDefinitions( + [ + [ + "--model.name graph_based_training.deepseek_v3", + _CUSTOM_CONFIG, + "--compile.mode aot", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.tensor_parallel_degree 2", + "--parallelism.expert_parallel_degree 4", + "--parallelism.expert_tensor_parallel_degree 1", + "--activation_checkpoint.mode none", + ], + ], + "AOT deepseek_v3 FSDP+TP+EP", + "aot_deepseekv3_fsdp_tp_ep", + ngpu=4, + ), + OverrideDefinitions( + [ + [ + "--model.name graph_based_training.deepseek_v3", + _CUSTOM_CONFIG, + "--compile.mode aot", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.tensor_parallel_degree 2", + "--parallelism.expert_parallel_degree 4", + "--parallelism.expert_tensor_parallel_degree 1", + "--activation_checkpoint.mode none", + "--model.flavor debugmodel_flex_attn", + ], + ], + "AOT deepseek_v3 FSDP+TP+EP+FlexAttention", + "aot_deepseekv3_fsdp_tp_ep_flexattention", + ngpu=4, + ), + # ===================================================================== + # No compilation tests (parallelization only) + # ===================================================================== + OverrideDefinitions( + [ + [ + "--model.name graph_based_training.llama3", + _CUSTOM_CONFIG, + ], + ], + "No compile 1D", + "nocompile_1d", + ), + ] + return integration_tests_flavors + + +_TEST_SUITES_FUNCTION = { + "graph_based_training": build_graph_based_training_test_list, +} + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("output_dir") + parser.add_argument( + "--config_path", + default="./tests/integration_tests/base_config.toml", + help="Base config path for integration tests.", + ) + parser.add_argument( + "--gpu_arch_type", + default="cuda", + choices=["cuda", "rocm"], + help="GPU architecture type.", + ) + parser.add_argument( + "--test_name", + default="all", + help="test to run, acceptable values: `test_name` in `build_test_list` (default: all)", + ) + parser.add_argument("--ngpu", default=8, type=int) + args = parser.parse_args() + + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + if os.listdir(args.output_dir): + raise RuntimeError("Please provide an empty output directory.") + + test_list = _TEST_SUITES_FUNCTION["graph_based_training"]() + run_tests(args, test_list) + + +if __name__ == "__main__": + main() diff --git a/torchtitan/experiments/graph_based_training/tests/numerics_utils.py b/torchtitan/experiments/graph_based_training/tests/numerics_utils.py new file mode 100644 index 0000000000..e1fdf98fb2 --- /dev/null +++ b/torchtitan/experiments/graph_based_training/tests/numerics_utils.py @@ -0,0 +1,271 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Shared utilities for numerics testing between JIT and AOT modes.""" + +import glob +import os +import subprocess + +import torch +from tensorboard.backend.event_processing.event_accumulator import EventAccumulator + + +def load_metrics(event_path, metric_names): + """Load metrics from tensorboard event files.""" + event_acc = EventAccumulator(event_path) + event_acc.Reload() + + metrics = {} + for metric_name in metric_names: + try: + scalars = event_acc.Scalars(metric_name) + metrics[metric_name] = {scalar.step: scalar.value for scalar in scalars} + except KeyError: + print(f"Warning: Metric {metric_name!r} not found in event file") + metrics[metric_name] = {} + + return metrics + + +def compare_metrics(metrics1, metrics2, label1="Eager", label2="Compiled"): + """Compare two sets of metrics and verify bitwise equivalence using torch.equal().""" + + all_metrics = set(metrics1.keys()) | set(metrics2.keys()) + all_match = True + + for metric_name in sorted(all_metrics): + + steps1 = set(metrics1[metric_name].keys()) + steps2 = set(metrics2[metric_name].keys()) + + if steps1 != steps2: + print(" ERROR: Step mismatch!") + print(f" {label1} steps: {sorted(steps1)}") + print(f" {label2} steps: {sorted(steps2)}") + all_match = False + continue + + # Convert values to tensors for each step and compare + values1 = [metrics1[metric_name][step] for step in sorted(steps1)] + values2 = [metrics2[metric_name][step] for step in sorted(steps2)] + + tensor1 = torch.tensor(values1) + tensor2 = torch.tensor(values2) + + if torch.equal(tensor1, tensor2): + print(f" PASS: All {len(steps1)} steps match exactly (bitwise equivalent)") + else: + # Find and report mismatches + mismatches = [] + for idx, step in enumerate(sorted(steps1)): + val1 = values1[idx] + val2 = values2[idx] + if val1 != val2: + mismatches.append((step, val1, val2, abs(val1 - val2))) + + print( + f" ERROR: Found {len(mismatches)} mismatches out of {len(steps1)} steps" + ) + + return all_match + + +def find_latest_event_dir(base_path): + """Find the latest timestamped directory in the base path.""" + if not os.path.exists(base_path): + raise ValueError(f"Path does not exist: {base_path}") + + subdirs = [d for d in glob.glob(os.path.join(base_path, "*")) if os.path.isdir(d)] + if not subdirs: + return base_path + + latest = max(subdirs, key=os.path.getmtime) + return latest + + +def run_training( + ngpu, + config_file, + model_name, + dp_shard_degree, + tp_degree, + cp_degree, + ep_degree, + ac_mode, + steps, + seed, + deterministic, + tb_folder, + compile_mode=None, + passes=None, +): + """Run a training job with the specified configuration.""" + print(f"\nStarting training: {model_name} (mode={compile_mode})") + + env = os.environ.copy() + env["NGPU"] = str(ngpu) + env["CONFIG_FILE"] = config_file + env["TRAIN_FILE"] = "torchtitan.experiments.graph_based_training.train" + + cmd = [ + "./run_train.sh", + "--model.name", + model_name, + "--parallelism.data_parallel_shard_degree", + str(dp_shard_degree), + "--parallelism.tensor_parallel_degree", + str(tp_degree), + ] + + if cp_degree > 1: + cmd.extend(["--parallelism.context_parallel_degree", str(cp_degree)]) + if ep_degree > 1: + cmd.extend(["--parallelism.expert_parallel_degree", str(ep_degree)]) + + cmd.extend( + [ + "--activation_checkpoint.mode", + ac_mode, + "--training.steps", + str(steps), + "--debug.seed", + str(seed), + "--debug.deterministic", + "--metrics.enable_tensorboard", + "--metrics.save_tb_folder", + tb_folder, + ] + ) + + if compile_mode or passes: + cmd.extend( + [ + "--job.custom_config_module", + "torchtitan.experiments.graph_based_training.job_config", + ] + ) + if compile_mode: + cmd.extend(["--compile.mode", compile_mode]) + if passes: + cmd.extend(["--compile.passes", passes]) + + print(f"Environment: NGPU={env['NGPU']}, CONFIG_FILE={env['CONFIG_FILE']}") + print(f"Running command: {' '.join(cmd)}") + + try: + result = subprocess.run( + cmd, + env=env, + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + ) + print(f"Training completed: {model_name}") + return True + except subprocess.CalledProcessError as e: + print(f"Training failed: {model_name}") + print(f"Error output:\n{e.stdout}") + return False + + +def determine_model_names(config_file): + """Determine model names based on config file.""" + if "deepseek" in config_file: + model_name = "deepseek_v3" + elif "llama3" in config_file: + model_name = "llama3" + else: + raise ValueError( + f"Unable to determine model names from config file: {config_file}" + ) + + # Both eager and compiled use graph_based_training experiment + return f"graph_based_training.{model_name}" + + +def run_numerics_test( + ngpu, + config_file, + dp_shard_degree, + tp_degree, + cp_degree, + ep_degree, + ac_mode, + steps, + seed, + eager_tb_folder, + compiled_tb_folder, + metrics, + passes=None, +): + """ + Run numerics test by training both JIT and AOT modes and comparing metrics. + + Returns: + bool: True if all metrics match, False otherwise. + """ + model_name = determine_model_names(config_file) + + # Run JIT (eager baseline) training + eager_success = run_training( + ngpu=ngpu, + config_file=config_file, + model_name=model_name, + dp_shard_degree=dp_shard_degree, + tp_degree=tp_degree, + cp_degree=cp_degree, + ep_degree=ep_degree, + ac_mode=ac_mode, + steps=steps, + seed=seed, + deterministic=True, + tb_folder=eager_tb_folder, + compile_mode="jit", + ) + + if not eager_success: + print("JIT training failed") + return False + + # Run AOT (compiled) training + compiled_success = run_training( + ngpu=ngpu, + config_file=config_file, + model_name=model_name, + dp_shard_degree=dp_shard_degree, + tp_degree=tp_degree, + cp_degree=cp_degree, + ep_degree=ep_degree, + ac_mode=ac_mode, + steps=steps, + seed=seed, + deterministic=True, + tb_folder=compiled_tb_folder, + compile_mode="aot", + passes=passes, + ) + + if not compiled_success: + print("AOT training failed") + return False + + # Compare metrics + eager_path = find_latest_event_dir(f"./outputs/{eager_tb_folder}") + compiled_path = find_latest_event_dir(f"./outputs/{compiled_tb_folder}") + + eager_metrics = load_metrics(eager_path, metrics) + compiled_metrics = load_metrics(compiled_path, metrics) + + all_match = compare_metrics(eager_metrics, compiled_metrics, "JIT", "AOT") + + if all_match: + print("SUCCESS: All metrics are bitwise equivalent") + else: + print("FAILURE: Metrics differ between runs") + + return all_match diff --git a/torchtitan/experiments/graph_based_training/tests/test_aot_numerics.py b/torchtitan/experiments/graph_based_training/tests/test_aot_numerics.py new file mode 100644 index 0000000000..c5095b854f --- /dev/null +++ b/torchtitan/experiments/graph_based_training/tests/test_aot_numerics.py @@ -0,0 +1,90 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import unittest + +from .numerics_utils import run_numerics_test + + +class TestAOTNumerics(unittest.TestCase): + """Test numerics equivalence between JIT and AOT compilation modes.""" + + def test_llama3_fsdp_tp(self): + """Test Llama3 with FSDP + TP configuration.""" + result = run_numerics_test( + ngpu=4, + config_file="./torchtitan/models/llama3/train_configs/debug_model.toml", + dp_shard_degree=2, + tp_degree=2, + cp_degree=1, + ep_degree=1, + ac_mode="selective", + steps=10, + seed=42, + eager_tb_folder="tb/test_llama3_fsdp_tp_jit", + compiled_tb_folder="tb/test_llama3_fsdp_tp_aot", + metrics=["loss_metrics/global_avg_loss", "grad_norm"], + ) + self.assertTrue(result, "Llama3 FSDP+TP numerics test failed") + + def test_llama3_fsdp_tp_autobucketing(self): + result = run_numerics_test( + ngpu=4, + config_file="./torchtitan/models/llama3/train_configs/debug_model.toml", + dp_shard_degree=2, + tp_degree=2, + cp_degree=1, + ep_degree=1, + ac_mode="selective", + steps=10, + seed=42, + eager_tb_folder="tb/test_llama3_fsdp_tp_autobucketing_jit", + compiled_tb_folder="tb/test_llama3_fsdp_tp_autobucketing_aot", + metrics=["loss_metrics/global_avg_loss", "grad_norm"], + passes="auto_bucketing", + ) + self.assertTrue(result, "Llama3 FSDP+TP+autobucketing numerics test failed") + + def test_llama3_fsdp_tp_manualbucketing(self): + result = run_numerics_test( + ngpu=4, + config_file="./torchtitan/models/llama3/train_configs/debug_model.toml", + dp_shard_degree=2, + tp_degree=2, + cp_degree=1, + ep_degree=1, + ac_mode="selective", + steps=10, + seed=42, + eager_tb_folder="tb/test_llama3_fsdp_tp_manualbucketing_jit", + compiled_tb_folder="tb/test_llama3_fsdp_tp_manualbucketing_aot", + metrics=["loss_metrics/global_avg_loss", "grad_norm"], + passes="transformer_block_bucketing", + ) + self.assertTrue(result, "Llama3 FSDP+TP+manualbucketing numerics test failed") + + def test_deepseek_v3_fsdp_tp_ep(self): + """Test DeepSeek V3 with FSDP + TP + EP configuration.""" + result = run_numerics_test( + ngpu=4, + config_file="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml", + dp_shard_degree=2, + tp_degree=2, + cp_degree=1, + ep_degree=4, + ac_mode="none", + steps=10, + seed=42, + eager_tb_folder="tb/test_deepseek_v3_fsdp_tp_ep_jit", + compiled_tb_folder="tb/test_deepseek_v3_fsdp_tp_ep_aot", + metrics=["loss_metrics/global_avg_loss", "grad_norm"], + ) + self.assertTrue(result, "DeepSeek V3 FSDP+TP+EP numerics test failed") + + +if __name__ == "__main__": + unittest.main() diff --git a/torchtitan/experiments/graph_based_training/tests/test_numerics.py b/torchtitan/experiments/graph_based_training/tests/test_numerics.py new file mode 100644 index 0000000000..2ab25f3c05 --- /dev/null +++ b/torchtitan/experiments/graph_based_training/tests/test_numerics.py @@ -0,0 +1,158 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import copy + +import torch +from torch.distributed._composable.fsdp import fully_shard +from torch.testing._internal.common_fsdp import FSDPTest + +from torchtitan.components.loss import cross_entropy_loss +from torchtitan.distributed import ParallelDims +from torchtitan.experiments.graph_based_training.simple_fsdp import data_parallel + + +class TestSimpleFSDP(FSDPTest): + def init_test(self): + self.optimizer = torch.optim.Adam + self.loss_fn = cross_entropy_loss + data_parallel_shard_degree = -1 + if self.mode == "replicate": + self.dp_mesh_dim_names = ["dp_replicate"] + data_parallel_replicate_degree = self.world_size + elif self.mode == "fully_shard": + self.dp_mesh_dim_names = ["fsdp"] + data_parallel_replicate_degree = 1 + elif self.mode == "hybrid_shard": + self.dp_mesh_dim_names = ["dp_replicate", "fsdp"] + data_parallel_replicate_degree = self.world_size // 2 + else: + raise ValueError(f"Unsupported mode {self.mode}") + + self.parallel_dims = ParallelDims( + dp_shard=data_parallel_shard_degree, + dp_replicate=data_parallel_replicate_degree, + cp=1, + tp=1, + pp=1, + ep=1, + etp=1, + world_size=self.world_size, + ) + + def get_input(self): + inputs = torch.randn(8, 8).cuda() + labels = torch.randn(8, 8).cuda() + model = torch.nn.Linear(8, 8) + return model, inputs, labels + + def run_fsdp2(self, model, inputs, labels, epoch=20): + fully_shard(model, mesh=self.parallel_dims.get_mesh(self.dp_mesh_dim_names)) + optim = self.optimizer(model.parameters(), lr=1e-4) + losses = [] + for _ in range(epoch): + optim.zero_grad() + out = model(inputs) + loss = self.loss_fn(out, labels) + loss.backward() + optim.step() + losses.append(loss) + return losses + + def run_simple_fsdp(self, model, inputs, labels, epoch=20): + model = data_parallel( + model, + device_mesh=self.parallel_dims.get_mesh(self.dp_mesh_dim_names), + mode=self.mode, + ) + optim = self.optimizer(model.parameters(), lr=1e-4) + losses = [] + for _ in range(epoch): + optim.zero_grad() + out = model(inputs) + loss = self.loss_fn(out, labels) + loss.backward() + optim.step() + losses.append(loss) + return losses + + def run_simple_fsdp_compiled_aot_eager(self, model, inputs, labels, epoch=20): + model = data_parallel( + model, + device_mesh=self.parallel_dims.get_mesh(self.dp_mesh_dim_names), + mode=self.mode, + ) + # TODO: Add "inductor" backend when it's numerical issues are fixed + model = torch.compile(model, backend="aot_eager", fullgraph=True) + optim = self.optimizer(model.parameters(), lr=1e-4) + losses = [] + for _ in range(epoch): + optim.zero_grad() + out = model(inputs) + loss = self.loss_fn(out, labels) + loss.backward() + optim.step() + losses.append(loss) + return losses + + def test_replicate_convergence(self): + # unit test for replicate mode + self.mode = "replicate" + self.init_test() + model, inputs, labels = self.get_input() + + fsdp2_losses = self.run_fsdp2(copy.deepcopy(model), inputs, labels) + simple_fsdp_losses = self.run_simple_fsdp(copy.deepcopy(model), inputs, labels) + simple_fsdp_compiled_aot_eager_losses = self.run_simple_fsdp_compiled_aot_eager( + copy.deepcopy(model), inputs, labels + ) + + for (fsdp2_loss, simple_fsdp_loss, simple_fsdp_compiled_aot_eager_loss,) in zip( + fsdp2_losses, + simple_fsdp_losses, + simple_fsdp_compiled_aot_eager_losses, + ): + assert torch.equal(fsdp2_loss, simple_fsdp_loss) + assert torch.equal(fsdp2_loss, simple_fsdp_compiled_aot_eager_loss) + + def test_fullyshard_convergence(self): + # unit test for fully_shard mode + self.mode = "fully_shard" + self.init_test() + model, inputs, labels = self.get_input() + + fsdp2_losses = self.run_fsdp2(copy.deepcopy(model), inputs, labels) + simple_fsdp_losses = self.run_simple_fsdp(copy.deepcopy(model), inputs, labels) + simple_fsdp_compiled_aot_eager_losses = self.run_simple_fsdp_compiled_aot_eager( + copy.deepcopy(model), inputs, labels + ) + + for (fsdp2_loss, simple_fsdp_loss, simple_fsdp_compiled_aot_eager_loss,) in zip( + fsdp2_losses, + simple_fsdp_losses, + simple_fsdp_compiled_aot_eager_losses, + ): + assert torch.equal(fsdp2_loss, simple_fsdp_loss) + assert torch.equal(fsdp2_loss, simple_fsdp_compiled_aot_eager_loss) + + def test_hybridshard_convergence(self): + # unit test for hybrid_shard mode + self.mode = "hybrid_shard" + self.init_test() + model, inputs, labels = self.get_input() + + fsdp2_losses = self.run_fsdp2(copy.deepcopy(model), inputs, labels) + simple_fsdp_losses = self.run_simple_fsdp(copy.deepcopy(model), inputs, labels) + simple_fsdp_compiled_aot_eager_losses = self.run_simple_fsdp_compiled_aot_eager( + copy.deepcopy(model), inputs, labels + ) + + for (fsdp2_loss, simple_fsdp_loss, simple_fsdp_compiled_aot_eager_loss,) in zip( + fsdp2_losses, + simple_fsdp_losses, + simple_fsdp_compiled_aot_eager_losses, + ): + assert torch.equal(fsdp2_loss, simple_fsdp_loss) + assert torch.equal(fsdp2_loss, simple_fsdp_compiled_aot_eager_loss) diff --git a/torchtitan/experiments/graph_based_training/train.py b/torchtitan/experiments/graph_based_training/train.py new file mode 100644 index 0000000000..7749428598 --- /dev/null +++ b/torchtitan/experiments/graph_based_training/train.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import gc + +from torchtitan.train import main, Trainer + + +class GraphBasedTrainingTrainer(Trainer): + def close(self) -> None: + super().close() + + # Note [explicit cudagraph close] + # cudagraph holds reference to nccl which prevents destroy nccl + # group. so we need to explicitly delete cudagraph which is held + # in joint_graph_module. An explicit gc.collect() is necessary + # to clean up reference cycles. + for part in self.model_parts: + if hasattr(part, "joint_graph_module"): + part.joint_graph_module = None + gc.collect() + + +if __name__ == "__main__": + main(GraphBasedTrainingTrainer)