Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions docs/debugging.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ When debugging issues with multi-dimensional parallelism (combinations of FSDP,
Set consistent random seeds across all parallelism dimensions:

```bash
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --training.seed 42
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --debug.seed 42
```

**Seed behavior with parallelism:**
Expand All @@ -84,7 +84,7 @@ CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_tr
Enable deterministic algorithms to ensure bit-for-bit reproducibility across runs:

```bash
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --training.deterministic
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --debug.deterministic
```

**What it does:**
Expand All @@ -93,6 +93,19 @@ CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_tr
- Sets deterministic workspace configuration for CuBLAS operations
- **Note:** This will significantly reduce training performance but ensures exact reproducibility

Use `--debug.deterministic_warn_only` to only warn about (not stop running) kernel without deterministic implementation.

### Activation Checkipointing Debugging ###

The following debug configs are available for AC.

`ac_preserve_rng_state` - if deterministic output compared to non-checkpointed passes is required, set to true. Results in stashing and restoring the RNG state during each checkpoint, may be slower.

`ac_determinism_check` - A string specifying the determinism function

`ac_debug` - capture ac debug information. Will be slower.

See https://docs.pytorch.org/docs/stable/checkpoint.html for details.

### Seed-Checkpoint-based Reproducibility

Expand Down
2 changes: 2 additions & 0 deletions torchtitan/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
Quantize,
Training,
Validation,
Debug
)
from .manager import ConfigManager

Expand All @@ -49,4 +50,5 @@
"Profiling",
"Training",
"Validation",
"Debug"
]
33 changes: 24 additions & 9 deletions torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,15 +233,6 @@ class Training:
many temporary files.
"""

seed: int | None = None
"""Choose the base RNG seed used for training"""

deterministic: bool = False
"""Use deterministic algorithms wherever possible, may be slower"""

debug_moe_force_load_balance: bool = False
"""If True, we force each experts to get the same amount of tokens via round-robin. This option is for debugging usage only."""


@dataclass
class Parallelism:
Expand Down Expand Up @@ -813,6 +804,29 @@ def __post_init__(self):
), "validation steps must be positive or -1"


@dataclass
class Debug:
deterministic: bool = False
"""Use deterministic algorithms wherever possible, may be slower"""

deterministic_warn_only: bool = False
"""Only warns about ops without deterministic implementations rather than erroring out """

seed: int | None = None
"""Choose the base RNG seed used for training"""

ac_preserve_rng_state: bool = False
"""If deterministic output compared to non-checkpointed passes is required, set to true. Results in stashing and restoring the RNG state during each checkpoint, may be slower. See https://docs.pytorch.org/docs/stable/checkpoint.html for details."""

ac_determinism_check: str = "default"
"""A string specifying the determinism function. See https://docs.pytorch.org/docs/stable/checkpoint.html for details."""

ac_debug: bool = False
""" Capture ac debug information. Will be slower. See https://docs.pytorch.org/docs/stable/checkpoint.html for details."""

moe_force_load_balance: bool = False
"""If True, we force each experts to get the same amount of tokens via round-robin. This option is for debugging usage only."""

@dataclass
class JobConfig:
"""
Expand All @@ -838,6 +852,7 @@ class JobConfig:
fault_tolerance: FaultTolerance = field(default_factory=FaultTolerance)
experimental: Experimental = field(default_factory=Experimental)
validation: Validation = field(default_factory=Validation)
debug: Debug = field(default_factory=Debug)

def to_dict(self) -> dict[str, Any]:
return asdict(self)
28 changes: 21 additions & 7 deletions torchtitan/distributed/activation_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
)

from torchtitan.config.job_config import ActivationCheckpoint as ACConfig
from torchtitan.config.job_config import Debug as DebugConfig
from torchtitan.tools.logging import logger, warn_once


_layer_sac_count = 0


def _apply_layer_sac(module: nn.Module, ac_config: ACConfig) -> nn.Module:
def _apply_layer_sac(module: nn.Module, ac_config: ACConfig, debug_config:DebugConfig) -> nn.Module:
"""Apply layer selective activation checkpointing to the module.

Args:
Expand All @@ -37,7 +38,11 @@ def _apply_layer_sac(module: nn.Module, ac_config: ACConfig) -> nn.Module:
ac_freq = int(ac_config.selective_ac_option)
if not ac_freq or _layer_sac_count % ac_freq == 0:
return ptd_checkpoint_wrapper(
module, preserve_rng_state=False, early_stop=ac_config.early_stop
module,
preserve_rng_state=debug_config.ac_preserve_rng_state,
determinism_check=debug_config.ac_determinism_check,
early_stop=ac_config.early_stop,
debug=debug_config.ac_debug
)
else:
return module
Expand Down Expand Up @@ -122,11 +127,13 @@ def selective_checkpointing_context_fn():
return create_selective_checkpoint_contexts(_get_custom_policy(meta))

return ptd_checkpoint_wrapper(
module,
context_fn=selective_checkpointing_context_fn,
preserve_rng_state=False,
early_stop=ac_config.early_stop,
)
module,
context_fn=selective_checkpointing_context_fn,
preserve_rng_state=dbg_config.ac_preserve_rng_state,
determinism_check=dbg_config.ac_determinism_check,
early_stop=ac_config.early_stop,
debug=dbg_config.ac_debug
)


def _apply_full_ac(module: nn.Module, ac_config: ACConfig) -> nn.Module:
Expand All @@ -142,6 +149,13 @@ def _apply_full_ac(module: nn.Module, ac_config: ACConfig) -> nn.Module:
return ptd_checkpoint_wrapper(
module, preserve_rng_state=False, early_stop=ac_config.early_stop
)
return ptd_checkpoint_wrapper(
module,
preserve_rng_state=dbg_config.ac_preserve_rng_state,
determinism_check=dbg_config.ac_determinism_check,
early_stop=ac_config.early_stop,
debug=dbg_config.ac_debug
)


def _apply_op_sac_to_transformer_block_with_flex(
Expand Down
8 changes: 5 additions & 3 deletions torchtitan/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torch.distributed.tensor import DTensor

from torchtitan.config import Comm as CommConfig, TORCH_DTYPE_MAP
from torchtitan.config import Debug as DebugConfig
from torchtitan.distributed.parallel_dims import ParallelDims
from torchtitan.tools.logging import logger
from torchtitan.tools.utils import device_module, device_type
Expand Down Expand Up @@ -83,8 +84,7 @@ def dist_mean(
def set_determinism(
world_mesh: DeviceMesh | None,
device: torch.device,
seed: int | None = None,
deterministic: bool = False,
debug_config: DebugConfig,
distinct_seed_mesh_dim: str = "pp",
) -> None:
"""
Expand All @@ -97,15 +97,17 @@ def set_determinism(

Set Determinism flags for increased reproducibility with loss of performance.
"""
if deterministic:
if debug_config.deterministic:
logger.info("Deterministic algorithm enabled (expect perf degradation).")
torch.use_deterministic_algorithms(True)
torch.use_deterministic_algorithms(True, warn_only=debug_config.deterministic_warn_only)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# env var for deterministic CuBLAS
# https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

seed = debug_config.seed
if not world_mesh:
if seed is not None:
torch.manual_seed(seed)
Expand Down
4 changes: 2 additions & 2 deletions torchtitan/experiments/flux/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def __init__(self, job_config: JobConfig):
dist_utils.set_determinism(
self.parallel_dims.world_mesh,
self.device,
job_config.training.seed,
job_config.training.deterministic,
job_config.debug.seed,
job_config.debug.deterministic,
distinct_seed_mesh_dim="dp_shard",
)

Expand Down
4 changes: 2 additions & 2 deletions torchtitan/experiments/forge/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ def __init__(self, job_config: ForgeJobConfig):
dist_utils.set_determinism(
world_mesh,
self.device,
job_config.training.seed,
job_config.training.deterministic,
job_config.debug.seed,
job_config.debug.deterministic,
)
self.train_spec = get_train_spec(job_config.model.name)

Expand Down
2 changes: 2 additions & 0 deletions torchtitan/experiments/forge/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Parallelism,
Quantize,
Training,
Debug,
)


Expand All @@ -45,6 +46,7 @@ class ForgeJobConfig:
# fault_tolerance: FaultTolerance = field(default_factory=FaultTolerance)
# experimental: Experimental = field(default_factory=Experimental)
# validation: Validation = field(default_factory=Validation)
debug: Debug = field(default_factory=Debug)

def to_dict(self) -> dict[str, Any]:
return asdict(self)
2 changes: 1 addition & 1 deletion torchtitan/experiments/llama4/model/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None:
)

self.moe_args._debug_force_load_balance = (
job_config.training.debug_moe_force_load_balance
job_config.debug.moe_force_load_balance
)

def get_nparams_and_flops(
Expand Down
10 changes: 8 additions & 2 deletions torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@
from typing import Any, Generator, Iterable, Optional

import torch

try:
import intel_extension_for_pytorch as ipex
print ( f"IPEX found - hence using IPEX")
except:
print ( f"IPEX not found, hence not using")

Comment on lines +14 to +20
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry what's this for? Could we remove it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also please rebase to resolve conflict

from torch.distributed.elastic.multiprocessing.errors import record

import torchtitan.protocols.train_spec as train_spec_module
Expand Down Expand Up @@ -126,8 +133,7 @@ def __init__(self, job_config: JobConfig):
dist_utils.set_determinism(
world_mesh,
self.device,
job_config.training.seed,
job_config.training.deterministic,
job_config.debug,
)
self.train_spec = train_spec_module.get_train_spec(job_config.model.name)

Expand Down