From bf326d27d4feff9396a2b927990d119da54360b6 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 4 Feb 2026 22:28:20 -0800 Subject: [PATCH 1/2] Update [ghstack-poisoned] --- torchtitan/config/job_config.py | 6 + .../distributed/sharding_info_template.html | 434 ++++++++++ torchtitan/distributed/sharding_utils.py | 773 ++++++++++++++++++ torchtitan/distributed/utils.py | 29 +- torchtitan/train.py | 7 +- 5 files changed, 1247 insertions(+), 2 deletions(-) create mode 100644 torchtitan/distributed/sharding_info_template.html create mode 100644 torchtitan/distributed/sharding_utils.py diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index b3a24c7847..fc7537c850 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -987,6 +987,12 @@ class Debug: moe_force_load_balance: bool = False """If True, we force each experts to get the same amount of tokens via round-robin. This option is for debugging usage only.""" + log_sharding_info: bool = False + """If True, logs DTensor sharding/mesh info for module inputs, params, outputs during one fwd/bwd pass.""" + + collapse_identical_layers: bool = True + """If True, collapse repeated layer modules with identical sharding patterns in the output.""" + @dataclass class JobConfig: diff --git a/torchtitan/distributed/sharding_info_template.html b/torchtitan/distributed/sharding_info_template.html new file mode 100644 index 0000000000..30198b8a3f --- /dev/null +++ b/torchtitan/distributed/sharding_info_template.html @@ -0,0 +1,434 @@ + + + + + + + Sharding Info - Rank $RANK + + + +

Sharding Info - Rank $RANK

+ +
+ + + + | + + + +
+ +
+
+
Modules Traced
+
$NUM_MODULES_TRACED
+
+
+
Forward Entries
+
$FORWARD_ENTRIES
+
+
+
Recompute Entries
+
$RECOMPUTE_ENTRIES
+
+
+
Backward Entries
+
$BACKWARD_ENTRIES
+
+
+ +
+

Legend

+
+ R + Replicated across all devices +
+
+ S(n) + Sharded along dimension n +
+
+ P(op) + Partial with reduction operation (e.g., P(sum)) +
+
+ Local + Local tensor (not distributed) +
+
+ [recompute] + Module recomputed during backward (activation checkpointing) +
+
+ [0-N] + Layers 0 through N with identical sharding pattern +
+
+ +
+ + + + diff --git a/torchtitan/distributed/sharding_utils.py b/torchtitan/distributed/sharding_utils.py new file mode 100644 index 0000000000..91ad5838ed --- /dev/null +++ b/torchtitan/distributed/sharding_utils.py @@ -0,0 +1,773 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Sharding debug utilities for capturing and visualizing DTensor sharding info.""" + +import json +import os +import re +from collections import defaultdict +from dataclasses import dataclass +from typing import Callable, TYPE_CHECKING + +import torch +import torch.nn as nn +from torch import distributed as dist +from torch.distributed.tensor import DTensor +from torch.utils._pytree import tree_flatten +from torch.utils.hooks import RemovableHandle + +from torchtitan.tools.logging import logger + +if TYPE_CHECKING: + from torch.distributed.device_mesh import DeviceMesh + +__all__ = ["ShardingDebugContext", "ShardingEntry"] + +# Width for ASCII diagram separators +_LINE_WIDTH = 80 + + +@dataclass +class ShardingEntry: + """Structured data for a single sharding info entry. + + Attributes: + call_order: Sequential order in which this entry was recorded. + phase: One of "forward", "backward", or "recompute". + module_name: Fully qualified name of the module (e.g., "model.layers.0.attn"). + tensor_type: Type of tensor (e.g., "input0", "weight", "output0", "grad_input0"). + placements: Tuple of placement strings (e.g., ("Shard(0)",) or ("Replicate",)). + mesh: Formatted device mesh string, or None for local (non-distributed) tensors. + shape: Tensor shape as a tuple of integers. + """ + + call_order: int + phase: str + module_name: str + tensor_type: str + placements: tuple[str, ...] + mesh: str | None + shape: tuple[int, ...] + + +class ShardingDebugContext: + """Context manager that logs DTensor sharding info for all module operations. + + Registers forward and backward hooks on all modules to capture sharding + information for inputs, outputs, parameters, and gradients during + one forward/backward pass. + """ + + def __init__( + self, + model_parts: list[nn.Module], + dump_folder: str, + collapse_layers: bool = True, + ): + self.model_parts = model_parts + self.collapse_layers = collapse_layers + self.hooks: list[RemovableHandle] = [] + + # Once root module's forward completes, subsequent forward hook calls are + # from AC. + self.initial_forward_complete: dict[int, bool] = {} + + self.num_modules_traced = 0 + self.forward_entries = 0 + self.recompute_entries = 0 + self.backward_entries = 0 + self.placements_seen: set[str] = set() + self.call_counter = 0 + self.entries: list[ShardingEntry] = [] + self.rank = dist.get_rank() if dist.is_initialized() else 0 + self.output_dir = os.path.join(dump_folder, "sharding_info") + + def __enter__(self) -> "ShardingDebugContext": + os.makedirs(self.output_dir, exist_ok=True) + self._register_hooks() + return self + + def __exit__(self, exception_type, exception_value, traceback) -> None: + self._remove_hooks() + + ascii_content = self._generate_ascii_diagram() + txt_path = os.path.join(self.output_dir, f"rank{self.rank}_sharding_info.txt") + with open(txt_path, "w") as f: + f.write(ascii_content) + + html_content = self._generate_html() + html_path = os.path.join(self.output_dir, f"rank{self.rank}_sharding_info.html") + with open(html_path, "w") as f: + f.write(html_content) + + logger.info(f"Sharding info written to {self.output_dir}") + + def _register_hooks(self) -> None: + """Register forward and backward hooks on all modules.""" + if not self.model_parts: + return + + for idx, model in enumerate(self.model_parts): + self.initial_forward_complete[idx] = False + + for fqn, module in model.named_modules(): + self.num_modules_traced += 1 + module._sharding_debug_fqn = fqn # type: ignore[attr-defined] + module._sharding_debug_idx = idx # type: ignore[attr-defined] + part_prefix = f"part{idx}." if len(self.model_parts) > 1 else "" + + forward_hook = self._create_forward_hook(fqn, idx, part_prefix) + handle = module.register_forward_hook(forward_hook) + self.hooks.append(handle) + + backward_hook = self._create_backward_hook(fqn, part_prefix) + handle = module.register_full_backward_hook(backward_hook) + self.hooks.append(handle) + + def _remove_hooks(self) -> None: + """Remove all registered hooks and clean up module attributes.""" + for handle in self.hooks: + handle.remove() + self.hooks.clear() + + for model in self.model_parts: + for module in model.modules(): + for attr in ("_sharding_debug_fqn", "_sharding_debug_idx"): + if hasattr(module, attr): + delattr(module, attr) + + @staticmethod + def _clean_fqn(fqn: str) -> str: + """Remove _checkpoint_wrapped_module from FQN.""" + parts = fqn.split(".") + cleaned = [p for p in parts if p != "_checkpoint_wrapped_module"] + return ".".join(cleaned) + + def _create_forward_hook(self, fqn: str, idx: int, part_prefix: str) -> Callable: + """Create a forward hook that captures input/output sharding info.""" + clean_fqn = self._clean_fqn(fqn) + is_root = fqn == "" + + def hook(module, inputs, outputs): + is_recompute = self.initial_forward_complete.get(idx, False) + phase = "recompute" if is_recompute else "forward" + module_name = part_prefix + ("model." + clean_fqn if clean_fqn else "model") + + flat_inputs, _ = tree_flatten(inputs) + for i, tensor in enumerate(flat_inputs): + if isinstance(tensor, torch.Tensor): + self._add_entry(phase, module_name, f"input{i}", tensor) + + for param_name, param in module.named_parameters(recurse=False): + if isinstance(param, torch.Tensor): + self._add_entry(phase, module_name, param_name, param) + + flat_outputs, _ = tree_flatten(outputs) + for i, tensor in enumerate(flat_outputs): + if isinstance(tensor, torch.Tensor): + self._add_entry(phase, module_name, f"output{i}", tensor) + + if is_root and not self.initial_forward_complete.get(idx, False): + self.initial_forward_complete[idx] = True + + return hook + + def _create_backward_hook(self, fqn: str, part_prefix: str) -> Callable: + """Create a backward hook that captures gradient sharding info.""" + clean_fqn = self._clean_fqn(fqn) + + def hook(module, grad_input, grad_output): + module_name = part_prefix + ("model." + clean_fqn if clean_fqn else "model") + + flat_grad_output, _ = tree_flatten(grad_output) + for i, tensor in enumerate(flat_grad_output): + if isinstance(tensor, torch.Tensor): + self._add_entry("backward", module_name, f"grad_output{i}", tensor) + + for param_name, param in module.named_parameters(recurse=False): + if param.grad is not None and isinstance(param.grad, torch.Tensor): + self._add_entry( + "backward", module_name, f"{param_name}.grad", param.grad + ) + + flat_grad_input, _ = tree_flatten(grad_input) + for i, tensor in enumerate(flat_grad_input): + if isinstance(tensor, torch.Tensor): + self._add_entry("backward", module_name, f"grad_input{i}", tensor) + + return hook + + def _add_entry( + self, + phase: str, + module_name: str, + tensor_type: str, + tensor: torch.Tensor, + ) -> None: + """Record a sharding entry for a tensor.""" + self.call_counter += 1 + + if isinstance(tensor, DTensor): + placements = tuple(str(p) for p in tensor.placements) + mesh = self._format_mesh(tensor.device_mesh) + else: + placements = ("Local",) + mesh = None + + entry = ShardingEntry( + call_order=self.call_counter, + phase=phase, + module_name=module_name, + tensor_type=tensor_type, + placements=placements, + mesh=mesh, + shape=tuple(tensor.shape), + ) + self.entries.append(entry) + + if phase == "forward": + self.forward_entries += 1 + elif phase == "recompute": + self.recompute_entries += 1 + elif phase == "backward": + self.backward_entries += 1 + + for p in placements: + self.placements_seen.add(p) + + # Below methods are output generation related methods (ASCII and HTML) + # All are generated by Claude and can skip if you just want + # to know how to track the sharding. + + @staticmethod + def _extract_layer_pattern(module_name: str) -> tuple[str, int | None]: + """Extract layer pattern and index from module name. + + Returns: + A tuple of (pattern_or_name, index). If the module name matches a layer + pattern (e.g., 'model.layers.5.attn'), returns the pattern with {idx} + placeholder and the layer index: ('model.layers.{idx}.attn', 5). + If no layer pattern is found, returns the original module name and None: + ('model.norm', None). + """ + match = re.match( + r"^((?:model\.)?)?(layers?|blocks?)\.(\d+)(?:\.(.*))?$", module_name + ) + if not match: + return module_name, None + model_prefix, layer_prefix, idx, suffix = match.groups() + model_prefix = model_prefix or "" + base = f"{model_prefix}{layer_prefix}.{{idx}}" + return (f"{base}.{suffix}", int(idx)) if suffix else (base, int(idx)) + + def _get_module_signature( + self, + module_name: str, + entries_by_module: dict[str, list[ShardingEntry]], + ) -> tuple[tuple[str, str, str | None, tuple[int, ...]], ...]: + """Get signature for comparing modules: (tensor_type, placement, mesh, shape). + + Deduplicates by tensor_type to handle activation checkpointing where + hooks may fire multiple times for the same tensor. + + Args: + module_name: The module name to get the signature for. + entries_by_module: Pre-built index mapping module names to their entries + for efficient lookup (avoids O(n²) iteration over all entries). + """ + seen: dict[str, tuple[str, str, str | None, tuple[int, ...]]] = {} + for e in entries_by_module.get(module_name, []): + if e.tensor_type not in seen: + seen[e.tensor_type] = ( + e.tensor_type, + self._format_placement(e.placements), + e.mesh, + e.shape, + ) + return tuple(sorted(seen.values(), key=lambda x: x[0])) + + @staticmethod + def _finalize_group( + pattern: str, group: list[tuple[int, str]] + ) -> tuple[str, list[str]]: + """Convert a group of (idx, module_name) to (display_name, module_names).""" + if len(group) == 1: + return group[0][1], [group[0][1]] + indices = [i for i, _ in group] + display = pattern.replace("{idx}", f"[{min(indices)}-{max(indices)}]") + return display, [n for _, n in group] + + def _collapse_repeated_modules( + self, module_names: list[str], phase: str + ) -> list[tuple[str, list[str], bool]]: + """Collapse layers with identical signatures into ranges like [0-31].""" + if not self.collapse_layers: + return [(name, [name], False) for name in module_names] + + # Build index of entries by module name for efficient signature lookup (O(n) vs O(n²)) + entries_by_module: dict[str, list[ShardingEntry]] = defaultdict(list) + for e in self.entries: + if e.phase == phase: + entries_by_module[e.module_name].append(e) + + # Group modules by pattern + pattern_to_modules: dict[str, list[tuple[int, str]]] = defaultdict(list) + for module_name in module_names: + pattern, idx = self._extract_layer_pattern(module_name) + if idx is not None: + pattern_to_modules[pattern].append((idx, module_name)) + + # Collapse consecutive modules with identical signatures + pattern_to_collapsed: dict[str, list[tuple[str, list[str]]]] = {} + for pattern, idx_modules in pattern_to_modules.items(): + idx_modules.sort(key=lambda x: x[0]) + groups: list[tuple[str, list[str]]] = [] + current_group: list[tuple[int, str]] = [] + current_sig: tuple | None = None + + for idx, module_name in idx_modules: + sig = self._get_module_signature(module_name, entries_by_module) + if current_sig is None or sig == current_sig: + current_group.append((idx, module_name)) + current_sig = sig + else: + groups.append(self._finalize_group(pattern, current_group)) + current_group = [(idx, module_name)] + current_sig = sig + + if current_group: + groups.append(self._finalize_group(pattern, current_group)) + pattern_to_collapsed[pattern] = groups + + # Build result preserving original order + result: list[tuple[str, list[str], bool]] = [] + seen_patterns: set[str] = set() + output_modules: set[str] = set() + + for module_name in module_names: + if module_name in output_modules: + continue + pattern, idx = self._extract_layer_pattern(module_name) + if idx is None: + result.append((module_name, [module_name], False)) + output_modules.add(module_name) + continue + if pattern not in pattern_to_collapsed: + continue + for display_name, group_modules in pattern_to_collapsed[pattern]: + if module_name not in group_modules: + continue + group_key = f"{pattern}:{display_name}" + if group_key not in seen_patterns: + seen_patterns.add(group_key) + result.append((display_name, group_modules, len(group_modules) > 1)) + output_modules.update(group_modules) + break + + return result + + @staticmethod + def _format_mesh(device_mesh: "DeviceMesh | None") -> str | None: + """Format device mesh as '(dp=2, tp=4)'. + + Args: + device_mesh: The device mesh to format. + + Returns: + Formatted string like '(tp=8)' or '(dp=2, tp=4)', or None if no mesh. + """ + if device_mesh is None: + return None + try: + names = device_mesh.mesh_dim_names + shape = device_mesh.shape + if not names: + return f"(shape={shape})" + return "(" + ", ".join(f"{n}={s}" for n, s in zip(names, shape)) + ")" + except AttributeError: + # Fall back if mesh doesn't have expected attributes + return str(device_mesh) + + @staticmethod + def _format_placement(placements: tuple[str, ...]) -> str: + """Format placements tuple as a single string.""" + return placements[0] if len(placements) == 1 else ", ".join(placements) + + @staticmethod + def _format_shape(shape: tuple[int, ...]) -> str: + """Format shape tuple as a string like '(1, 8192, 4096)'.""" + return f"({', '.join(str(s) for s in shape)})" + + @staticmethod + def _tensor_sort_key(t: str) -> tuple[int, str]: + """Sort tensors: inputs/grad_outputs first, params middle, outputs last.""" + if t.startswith(("input", "grad_output")): + return (0, t) + if t.startswith(("output", "grad_input")): + return (2, t) + return (1, t) + + def _compute_global_max_lengths(self) -> tuple[int, int, int, int]: + """Compute max lengths for tensor_type, placement, mesh, shape globally.""" + max_type = max_place = max_mesh = max_shape = 0 + for e in self.entries: + max_type = max(max_type, len(e.tensor_type)) + max_place = max(max_place, len(self._format_placement(e.placements))) + mesh_str = "None" if e.mesh is None else e.mesh + max_mesh = max(max_mesh, len(mesh_str)) + max_shape = max(max_shape, len(self._format_shape(e.shape))) + return max_type, max_place, max_mesh, max_shape + + def _render_module_tensors( + self, + lines: list[str], + tensors: dict[str, list[ShardingEntry]], + max_lens: tuple[int, int, int, int], + indent: str = " ", + ) -> None: + """Render tensor entries as ASCII tree lines.""" + sorted_tensors = sorted(tensors.keys(), key=self._tensor_sort_key) + if not sorted_tensors: + return + + max_type_len, max_place_len, max_mesh_len, max_shape_len = max_lens + + for idx, tensor_type in enumerate(sorted_tensors): + entry = tensors[tensor_type][0] + placement = self._format_placement(entry.placements) + mesh = "None" if entry.mesh is None else entry.mesh + shape = self._format_shape(entry.shape) + + branch = "└─" if idx == len(sorted_tensors) - 1 else "├─" + line = ( + f"{indent}{branch} {tensor_type:<{max_type_len}}: " + f"{placement:<{max_place_len}} mesh={mesh:<{max_mesh_len}} " + f"shape={shape}" + ) + lines.append(line) + + def _build_module_tree( + self, phase: str + ) -> dict[str, dict[str, list[ShardingEntry]]]: + """Build dict: module_name -> tensor_type -> entries.""" + tree: dict[str, dict[str, list[ShardingEntry]]] = defaultdict( + lambda: defaultdict(list) + ) + for entry in self.entries: + if entry.phase == phase: + tree[entry.module_name][entry.tensor_type].append(entry) + return tree + + def _generate_ascii_diagram(self) -> str: + """Generate ASCII text representation of sharding info.""" + lines: list[str] = [] + max_lens = self._compute_global_max_lengths() + + lines.append(f"Sharding Info - Rank {self.rank}") + lines.append("=" * _LINE_WIDTH) + lines.append("") + + lines.append("SUMMARY") + lines.append("-" * 40) + lines.append(f"Total modules traced: {self.num_modules_traced}") + lines.append(f"Forward entries: {self.forward_entries}") + lines.append(f"Recompute entries: {self.recompute_entries}") + lines.append(f"Backward entries: {self.backward_entries}") + lines.append(f"Unique placements: {', '.join(sorted(self.placements_seen))}") + lines.append("") + + lines.append("LEGEND") + lines.append("-" * 40) + lines.append(" R = Replicated across all devices") + lines.append(" S(n) = Sharded along dimension n") + lines.append(" P(op) = Partial with reduction operation (e.g., P(sum))") + lines.append(" Local = Local tensor (not distributed)") + lines.append( + " [recompute] = Module recomputed during backward (activation checkpointing)" + ) + lines.append(" [0-N] = Layers 0 through N with identical sharding pattern") + lines.append("") + + forward_entries = [e for e in self.entries if e.phase == "forward"] + if forward_entries: + lines.append("=" * _LINE_WIDTH) + lines.append("FORWARD PASS") + lines.append("=" * _LINE_WIDTH) + lines.append("") + + module_tree = self._build_module_tree("forward") + + def get_first_call_order(module_name: str) -> int: + return min( + ( + e.call_order + for t in module_tree[module_name].values() + for e in t + ), + default=0, + ) + + sorted_modules = sorted(module_tree.keys(), key=get_first_call_order) + collapsed = self._collapse_repeated_modules(sorted_modules, "forward") + + for display_name, original_names, _ in collapsed: + lines.append(display_name) + self._render_module_tensors( + lines, module_tree[original_names[0]], max_lens + ) + lines.append("") + + backward_entries = [ + e for e in self.entries if e.phase in ("backward", "recompute") + ] + if backward_entries: + lines.append("=" * _LINE_WIDTH) + lines.append("BACKWARD PASS") + lines.append("=" * _LINE_WIDTH) + lines.append("") + + trees = { + "recompute": self._build_module_tree("recompute"), + "backward": self._build_module_tree("backward"), + } + + # Build (module_name, phase) list preserving interleaved execution order + module_phase_order: list[tuple[str, str]] = [] + seen: set[tuple[str, str]] = set() + for entry in sorted(backward_entries, key=lambda e: e.call_order): + key = (entry.module_name, entry.phase) + if key not in seen: + seen.add(key) + module_phase_order.append(key) + + # Build collapsed maps for each phase + def build_collapsed_map( + phase: str, + ) -> dict[str, tuple[str, list[str], bool]]: + names = [m for m, p in module_phase_order if p == phase] + collapsed = self._collapse_repeated_modules(names, phase) + result: dict[str, tuple[str, list[str], bool]] = {} + for display, orig_names, is_col in collapsed: + for name in orig_names: + result[name] = (display, orig_names, is_col) + return result + + collapsed_maps = { + "recompute": build_collapsed_map("recompute"), + "backward": build_collapsed_map("backward"), + } + output_groups: dict[str, set[str]] = {"recompute": set(), "backward": set()} + + for module_name, phase in module_phase_order: + if module_name not in collapsed_maps[phase]: + continue + display_name, original_names, _ = collapsed_maps[phase][module_name] + if display_name in output_groups[phase]: + continue + output_groups[phase].add(display_name) + prefix = "[recompute] " if phase == "recompute" else "" + lines.append(f"{prefix}{display_name}") + self._render_module_tensors( + lines, trees[phase][original_names[0]], max_lens + ) + lines.append("") + + return "\n".join(lines) + + @staticmethod + def _load_html_template() -> str: + """Load the HTML template from the package directory.""" + import pathlib + + template_path = pathlib.Path(__file__).parent / "sharding_info_template.html" + return template_path.read_text() + + def _get_module_tensors( + self, module_name: str, phase: str + ) -> list[dict[str, str | None]]: + """Get deduplicated, sorted tensor info for a module. + + Args: + module_name: The module name to get tensors for. + phase: The phase ("forward", "backward", or "recompute"). + + Returns: + List of tensor info dicts with keys: tensorType, placement, mesh, shape. + """ + seen: dict[str, dict[str, str | None]] = {} + for e in self.entries: + if e.module_name == module_name and e.phase == phase: + if e.tensor_type not in seen: + seen[e.tensor_type] = { + "tensorType": e.tensor_type, + "placement": self._format_placement(e.placements), + "mesh": e.mesh, + "shape": self._format_shape(e.shape), + } + # Sort: inputs/grad_outputs first, params middle, outputs/grad_inputs last + return sorted( + seen.values(), key=lambda t: self._tensor_sort_key(str(t["tensorType"])) + ) + + def _build_html_module_views( + self, + ) -> dict[str, dict[str, list[dict]]]: + """Pre-compute collapsed and uncollapsed module views for HTML. + + Returns a dict with structure: + { + "forward": {"collapsed": [...], "uncollapsed": [...]}, + "backward": {"collapsed": [...], "uncollapsed": [...]} + } + + Each module entry contains: + - displayName: The name to show + - originalNames: List of original module names (for search) + - isRecompute: Whether this is a recompute module + - tensors: List of tensor info dicts + """ + result: dict[str, dict[str, list[dict]]] = { + "forward": {"collapsed": [], "uncollapsed": []}, + "backward": {"collapsed": [], "uncollapsed": []}, + } + + # Build module trees: phase -> module_name -> tensor_type -> entries + trees: dict[str, dict[str, dict[str, list[ShardingEntry]]]] = { + "forward": self._build_module_tree("forward"), + "recompute": self._build_module_tree("recompute"), + "backward": self._build_module_tree("backward"), + } + + # Helper to get first call order for sorting + def get_first_call_order( + module_name: str, tree: dict[str, dict[str, list[ShardingEntry]]] + ) -> int: + return min( + (e.call_order for t in tree[module_name].values() for e in t), + default=0, + ) + + # Process forward pass + forward_tree = trees["forward"] + if forward_tree: + sorted_modules = sorted( + forward_tree.keys(), + key=lambda m: get_first_call_order(m, forward_tree), + ) + + # Uncollapsed view + for module_name in sorted_modules: + result["forward"]["uncollapsed"].append( + { + "displayName": module_name, + "originalNames": [module_name], + "isRecompute": False, + "tensors": self._get_module_tensors(module_name, "forward"), + } + ) + + # Collapsed view + collapsed = self._collapse_repeated_modules(sorted_modules, "forward") + for display_name, original_names, _ in collapsed: + result["forward"]["collapsed"].append( + { + "displayName": display_name, + "originalNames": original_names, + "isRecompute": False, + "tensors": self._get_module_tensors( + original_names[0], "forward" + ), + } + ) + + # Process backward pass (interleaved recompute + backward) + backward_entries = [ + e for e in self.entries if e.phase in ("backward", "recompute") + ] + if backward_entries: + # Build (module_name, phase) list preserving execution order + module_phase_order: list[tuple[str, str]] = [] + seen: set[tuple[str, str]] = set() + for entry in sorted(backward_entries, key=lambda e: e.call_order): + key = (entry.module_name, entry.phase) + if key not in seen: + seen.add(key) + module_phase_order.append(key) + + # Uncollapsed view - preserve execution order + for module_name, phase in module_phase_order: + result["backward"]["uncollapsed"].append( + { + "displayName": module_name, + "originalNames": [module_name], + "isRecompute": phase == "recompute", + "tensors": self._get_module_tensors(module_name, phase), + } + ) + + # Collapsed view - collapse each phase separately, maintain order + recompute_names = [m for m, p in module_phase_order if p == "recompute"] + backward_names = [m for m, p in module_phase_order if p == "backward"] + + collapsed_recompute = self._collapse_repeated_modules( + recompute_names, "recompute" + ) + collapsed_backward = self._collapse_repeated_modules( + backward_names, "backward" + ) + + # Build lookup maps + recompute_map: dict[str, tuple[str, list[str]]] = {} + for display, orig_names, _ in collapsed_recompute: + for name in orig_names: + recompute_map[name] = (display, orig_names) + + backward_map: dict[str, tuple[str, list[str]]] = {} + for display, orig_names, _ in collapsed_backward: + for name in orig_names: + backward_map[name] = (display, orig_names) + + # Output in execution order, deduplicating groups + output_groups: dict[str, set[str]] = {"recompute": set(), "backward": set()} + for module_name, phase in module_phase_order: + lookup = recompute_map if phase == "recompute" else backward_map + if module_name not in lookup: + continue + display_name, orig_names = lookup[module_name] + if display_name in output_groups[phase]: + continue + output_groups[phase].add(display_name) + result["backward"]["collapsed"].append( + { + "displayName": display_name, + "originalNames": orig_names, + "isRecompute": phase == "recompute", + "tensors": self._get_module_tensors(orig_names[0], phase), + } + ) + + return result + + def _generate_html(self) -> str: + """Generate interactive HTML visualization of sharding info.""" + from string import Template + + # Pre-compute both collapsed and uncollapsed views + # This moves all collapsing logic to Python, simplifying the JavaScript + module_views = self._build_html_module_views() + module_views_json = json.dumps(module_views, indent=None) + + template = Template(self._load_html_template()) + return template.substitute( + RANK=self.rank, + NUM_MODULES_TRACED=self.num_modules_traced, + FORWARD_ENTRIES=self.forward_entries, + RECOMPUTE_ENTRIES=self.recompute_entries, + BACKWARD_ENTRIES=self.backward_entries, + MODULE_VIEWS_JSON=module_views_json, + COLLAPSE_LAYERS_CHECKED="checked" if self.collapse_layers else "", + ) diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 4fe3f01a66..d08cc7e300 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -23,6 +23,7 @@ from torchtitan.config import Comm as CommConfig, Debug as DebugConfig, TORCH_DTYPE_MAP from torchtitan.distributed.parallel_dims import ParallelDims +from torchtitan.distributed.sharding_utils import ShardingDebugContext from torchtitan.tools.logging import logger from torchtitan.tools.utils import device_module, device_type @@ -228,13 +229,39 @@ def __call__(self) -> contextlib.AbstractContextManager[None]: pass -def get_train_context(enable_loss_parallel: bool) -> TrainContext: +def get_train_context( + enable_loss_parallel: bool, + model_parts: list[torch.nn.Module] | None = None, + debug_config: DebugConfig | None = None, + dump_folder: str = "./outputs", +) -> TrainContext: + # Track whether sharding info has been recorded for this training run. + # Using a list instead of a bool allows the nested closure to mutate + # the value. + step_recorded = [False] + @contextlib.contextmanager def context(): with contextlib.ExitStack() as stack: if enable_loss_parallel: stack.enter_context(torch.distributed.tensor.parallel.loss_parallel()) + # Enter sharding info context only on first step + if ( + debug_config is not None + and debug_config.log_sharding_info + and model_parts is not None + and not step_recorded[0] + ): + stack.enter_context( + ShardingDebugContext( + model_parts=model_parts, + dump_folder=dump_folder, + collapse_layers=debug_config.collapse_identical_layers, + ) + ) + step_recorded[0] = True + yield return context diff --git a/torchtitan/train.py b/torchtitan/train.py index 9378d742e3..7a62b8eee8 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -319,7 +319,12 @@ def __init__(self, job_config: JobConfig): parallel_dims.tp_enabled and not job_config.parallelism.disable_loss_parallel ) - self.train_context = dist_utils.get_train_context(loss_parallel_enabled) + self.train_context = dist_utils.get_train_context( + enable_loss_parallel=loss_parallel_enabled, + model_parts=self.model_parts, + debug_config=job_config.debug, + dump_folder=job_config.job.dump_folder, + ) self.maybe_enable_amp = dist_utils.maybe_enable_amp( parallel_dims, job_config.training.mixed_precision_param, From efefd75a81a3baaa756ba768707a9910537fdf82 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Thu, 5 Feb 2026 16:34:54 -0800 Subject: [PATCH 2/2] Update [ghstack-poisoned] --- torchtitan/config/job_config.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index fc7537c850..bdabaf662f 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -988,10 +988,18 @@ class Debug: """If True, we force each experts to get the same amount of tokens via round-robin. This option is for debugging usage only.""" log_sharding_info: bool = False - """If True, logs DTensor sharding/mesh info for module inputs, params, outputs during one fwd/bwd pass.""" + """ + If True, logs DTensor sharding/mesh info for module inputs, params, + outputs during one fwd/bwd pass. Only the first step will recorded and + this flag should only be used for debugging purpose and should not + be enabled during the real training. + """ - collapse_identical_layers: bool = True - """If True, collapse repeated layer modules with identical sharding patterns in the output.""" + collapse_identical_layers: bool = False + """ + If True, and log_sharding_info is True, collapse repeated layer modules with + with identical sharding patterns in the sharding log. + """ @dataclass