diff --git a/torchtitan/distributed/pipeline_parallel.py b/torchtitan/distributed/pipeline_parallel.py index d9b6d29a09..a687b17c1a 100644 --- a/torchtitan/distributed/pipeline_parallel.py +++ b/torchtitan/distributed/pipeline_parallel.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import copy - import math import os from typing import Callable @@ -13,7 +12,6 @@ import torch.nn as nn from torch.distributed.device_mesh import DeviceMesh from torch.distributed.pipelining import PipelineStage - from torch.distributed.pipelining.schedules import ( _PipelineSchedule, _PipelineScheduleRuntime, @@ -24,7 +22,6 @@ ScheduleDualPipeV, ScheduleZBVZeroBubble, ) - from torchtitan.components.loss import LossFunction, rescale_accumulated_loss from torchtitan.config import JobConfig from torchtitan.distributed import ParallelDims @@ -34,6 +31,7 @@ __all__ = [ "pipeline_llm", + "get_pipeline_metadata", "build_pipeline_schedule", "generate_llm_fqn_per_model_part", "pipeline_module_split", @@ -51,6 +49,68 @@ def pipeline_llm( ) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]: pp_mesh = parallel_dims.get_mesh("pp") + num_virtual_stages, num_layers, input_weight, output_weight = get_pipeline_metadata( + parallel_dims, job_config, model_args + ) + + module_names_per_stage = job_config.parallelism.module_fqns_per_model_part + if module_names_per_stage is None: + module_names_per_stage = generate_llm_fqn_per_model_part( + num_virtual_stages, num_layers, input_weight, output_weight + ) + for i, stage_ms in enumerate(module_names_per_stage): + logger.debug(f"Stage {i}: {stage_ms}") + + stages, model_parts = pipeline_module_split( + model, + pp_mesh, + job_config.parallelism.pipeline_parallel_schedule, + device, + module_names_per_stage, + ) + + # For PP with looped schedules, each item in model_parts is one stage-model-chunk. + # We need to iterate through model_parts to apply SPMD parallelisms, compilation, + # optimizer, and checkpointing + for i, m in enumerate(model_parts): + # apply SPMD-style PT-D techniques + m = parallelize_fn(m, parallel_dims, job_config) + model_parts[i] = m + # NOTE: this is to update the model in the stage + # in case the model is modified e.g. by torch.compile + stages[i].submod = m + + pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn) + + # This is used in the train loop to determine whether to pass in the input_ids and labels + has_first_stage = False + has_last_stage = False + for stage in stages: + if stage.is_first: + has_first_stage = True + if stage.is_last: + has_last_stage = True + + return pp_schedule, model_parts, has_first_stage, has_last_stage + + +def get_pipeline_metadata( + parallel_dims: ParallelDims, + job_config: JobConfig, + model_args: BaseModelArgs, +) -> tuple[int, int, int, int]: + """ + Determine the number of virtual stages and the number of layers in the model. + + Args: + parallel_dims (ParallelDims): Parallel dimensions. + job_config (JobConfig): Job configuration. + model_args (BaseModelArgs): Model arguments. + + Returns: + tuple: A tuple containing the number of virtual stages, the number of layers in the model, + the input weight, and the output weight. + """ # Determine the number of virtual stages based on schedule type schedule_class = get_schedule_class( job_config.parallelism.pipeline_parallel_schedule @@ -113,46 +173,7 @@ def pipeline_llm( # For single-stage schedules, default is 1 virtual stage per rank stages_per_rank = 1 if is_single_stage_schedule else 2 num_virtual_stages = parallel_dims.pp * stages_per_rank - - module_names_per_stage = job_config.parallelism.module_fqns_per_model_part - if module_names_per_stage is None: - module_names_per_stage = generate_llm_fqn_per_model_part( - num_virtual_stages, num_layers, input_weight, output_weight - ) - for i, stage_ms in enumerate(module_names_per_stage): - logger.debug(f"Stage {i}: {stage_ms}") - - stages, model_parts = pipeline_module_split( - model, - pp_mesh, - job_config.parallelism.pipeline_parallel_schedule, - device, - module_names_per_stage, - ) - - # For PP with looped schedules, each item in model_parts is one stage-model-chunk. - # We need to iterate through model_parts to apply SPMD parallelisms, compilation, - # optimizer, and checkpointing - for i, m in enumerate(model_parts): - # apply SPMD-style PT-D techniques - m = parallelize_fn(m, parallel_dims, job_config) - model_parts[i] = m - # NOTE: this is to update the model in the stage - # in case the model is modified e.g. by torch.compile - stages[i].submod = m - - pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn) - - # This is used in the train loop to determine whether to pass in the input_ids and labels - has_first_stage = False - has_last_stage = False - for stage in stages: - if stage.is_first: - has_first_stage = True - if stage.is_last: - has_last_stage = True - - return pp_schedule, model_parts, has_first_stage, has_last_stage + return num_virtual_stages, num_layers, input_weight, output_weight def build_pipeline_schedule( @@ -344,6 +365,106 @@ def generate_llm_fqn_per_model_part( return module_names_per_stage +def split_module( + whole_model: nn.Module, + module_names: list[str], +) -> nn.Module: + """ + Splits a whole model into a module based on the specified module names. + + Args: + whole_model: The complete model to be split + module_names: List of module names to include in the split + + Returns: + The split module + + Example usage: + module_names = ["tok_embeddings", "layers.0", "layers.1", "norm", "output"] + split_module(whole_model, module_names) + """ + model = copy.deepcopy(whole_model) + # Create a set of modules to keep for faster lookup + modules_to_keep = set(module_names) + for module_name, module_value in model.named_children(): + # Handle layer-like structures (e.g., "layers.0", "layers.1") + if isinstance(module_value, (nn.ModuleDict, nn.ModuleList)): + layers_to_keep = { + name.split(".", 1)[1] + for name in modules_to_keep + if name.startswith(f"{module_name}.") + } + if layers_to_keep: + # Keep only specified layers + if isinstance(module_value, nn.ModuleDict): + for layer_name in list(module_value.keys()): + if layer_name not in layers_to_keep: + del module_value[layer_name] + elif isinstance(module_value, nn.ModuleList): + indices_to_keep = { + int(idx) for idx in layers_to_keep if idx.isdigit() + } + new_layers = nn.ModuleList( + [ + layer + for i, layer in enumerate(module_value) + if i in indices_to_keep + ] + ) + setattr(model, module_name, new_layers) + else: + # No layers from this structure needed, set to empty structure + if isinstance(module_value, nn.ModuleDict): + setattr(model, module_name, nn.ModuleDict()) + elif isinstance(module_value, nn.ModuleList): + setattr(model, module_name, nn.ModuleList()) + # Handle simple module attributes (e.g., "linear", "norm") + elif module_name not in modules_to_keep: + # Replace with None + setattr(model, module_name, None) + return model + + +def get_pp_rank_to_stage_indices_mapping( + pp_rank: int, + pp_degree, + pp_schedule: str, + num_stages: int, +) -> tuple[int, ...]: + """ + Returns a mapping from PP rank to stage indices for the given pipeline schedule. + + Args: + pp_rank: Pipeline parallel rank + pp_degree: Number of pipeline parallel ranks + pp_schedule: Name of pipeline parallelism schedule + num_stages: Number of pipeline stages + + Returns: + Mapping from PP rank to stage indices + """ + schedule_class = get_schedule_class(pp_schedule) + style = ( + "v" if schedule_class in (ScheduleZBVZeroBubble, ScheduleDualPipeV) else "loop" + ) + assert ( + num_stages % pp_degree == 0 + ), f"num_stages {num_stages} must be evenly divisible by pp_degree {pp_degree}" + stages_per_rank = num_stages // pp_degree + if style == "loop": + return tuple(pp_rank + s * pp_degree for s in range(stages_per_rank)) + elif style == "v": + assert ( + stages_per_rank == 2 + ), f"v schedules assume 2 stages per rank, got {stages_per_rank}" + stage_v_pairs = list( + zip(range(pp_degree), range(num_stages - 1, pp_degree - 1, -1)) + ) + return tuple(stage_v_pairs[pp_rank]) + else: + raise ValueError(f"Unknown style {style}") + + def pipeline_module_split( whole_model: nn.Module, pp_mesh: DeviceMesh, @@ -385,97 +506,21 @@ def pipeline_module_split( """ pp_rank = pp_mesh.get_local_rank() pp_degree = pp_mesh.size() - - def _build_stage_from_modules( - stage_idx: int, module_names: list[str], num_stages: int - ) -> tuple[PipelineStage, nn.Module]: - model = copy.deepcopy(whole_model) - - # Create a set of modules to keep for faster lookup - modules_to_keep = set(module_names) - for module_name, module_value in model.named_children(): - # Handle layer-like structures (e.g., "layers.0", "layers.1") - if isinstance(module_value, (nn.ModuleDict, nn.ModuleList)): - layers_to_keep = { - name.split(".", 1)[1] - for name in modules_to_keep - if name.startswith(f"{module_name}.") - } - if layers_to_keep: - # Keep only specified layers - if isinstance(module_value, nn.ModuleDict): - for layer_name in list(module_value.keys()): - if layer_name not in layers_to_keep: - del module_value[layer_name] - elif isinstance(module_value, nn.ModuleList): - indices_to_keep = { - int(idx) for idx in layers_to_keep if idx.isdigit() - } - new_layers = nn.ModuleList( - [ - layer - for i, layer in enumerate(module_value) - if i in indices_to_keep - ] - ) - setattr(model, module_name, new_layers) - else: - # No layers from this structure needed, set to empty structure - if isinstance(module_value, nn.ModuleDict): - setattr(model, module_name, nn.ModuleDict()) - elif isinstance(module_value, nn.ModuleList): - setattr(model, module_name, nn.ModuleList()) - # Handle simple module attributes (e.g., "linear", "norm") - elif module_name not in modules_to_keep: - # Replace with None - setattr(model, module_name, None) - - stage = PipelineStage( - model, - stage_idx, - num_stages, - device, - group=pp_mesh.get_group("pp"), - ) - return stage, model - num_stages = len(module_names_per_stage) stages = [] models = [] - - schedule_class = get_schedule_class(pp_schedule) - style = ( - "v" if schedule_class in (ScheduleZBVZeroBubble, ScheduleDualPipeV) else "loop" + pp_rank_to_stage_indices = get_pp_rank_to_stage_indices_mapping( + pp_rank, pp_degree, pp_schedule, num_stages ) - - def _get_stage_indices() -> tuple[int, ...]: - """ - Compute the stage ids for the stages that will run on this pp rank - for either a looped or V style schedule - """ - assert ( - num_stages % pp_degree == 0 - ), f"num_stages {num_stages} must be evenly divisible by pp_degree {pp_degree}" - stages_per_rank = num_stages // pp_degree - if style == "loop": - return tuple(pp_rank + s * pp_degree for s in range(stages_per_rank)) - elif style == "v": - assert ( - stages_per_rank == 2 - ), f"v schedules assume 2 stages per rank, got {stages_per_rank}" - stage_v_pairs = list( - zip(range(pp_degree), range(num_stages - 1, pp_degree - 1, -1)) - ) - return stage_v_pairs[pp_rank] - else: - raise ValueError(f"Unknown style {style}") - - for stage_idx in _get_stage_indices(): + for stage_idx in pp_rank_to_stage_indices: module_names = module_names_per_stage[stage_idx] - stage, model_chunk = _build_stage_from_modules( + model_chunk = split_module(whole_model, module_names) + stage = PipelineStage( + model_chunk, stage_idx, - module_names, num_stages, + device, + group=pp_mesh.get_group("pp"), ) logger.info( f"PP rank {pp_rank} is building stage_idx {stage_idx} " diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 2ba9c08422..5f9bf3108c 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -494,18 +494,25 @@ def _clip_grad_norm_with_ep( else: non_ep_params.append(p) non_ep_grads.append(p.grad) + + # Either list can be empty depending on the parallelization strategy: + # - In torchtitan with separate dense/sparse meshes, both lists are typically non-empty + # - In autoparallel, all params may live on a single sparse mesh with "ep" dimension, + # so non_ep_grads would be empty + # - In PP + EP setups, certain PP ranks may only own EP or non-EP layers ep_grads_total_norm = torch.nn.utils.get_total_norm( ep_grads, norm_type, error_if_nonfinite, foreach ) - # ep_grads may be an empty list, in which case get_total_norm returns tensor(0.), a non-DTensor - # This can occur in PP + EP setups where certain PP ranks only own non-EP layers, for instance. + # get_total_norm returns tensor(0.) for empty list, which is a non-DTensor if isinstance(ep_grads_total_norm, DTensor): ep_grads_total_norm = ep_grads_total_norm.full_tensor() - # pyrefly: ignore [missing-attribute] non_ep_grads_total_norm = torch.nn.utils.get_total_norm( non_ep_grads, norm_type, error_if_nonfinite, foreach - ).full_tensor() + ) + # get_total_norm returns tensor(0.) for empty list, which is a non-DTensor + if isinstance(non_ep_grads_total_norm, DTensor): + non_ep_grads_total_norm = non_ep_grads_total_norm.full_tensor() if math.isinf(norm_type): total_norm = torch.maximum(ep_grads_total_norm, non_ep_grads_total_norm) diff --git a/torchtitan/experiments/autoparallel/graph_pp_builder.py b/torchtitan/experiments/autoparallel/graph_pp_builder.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchtitan/experiments/autoparallel/local_map_deepseek_v3/model.py b/torchtitan/experiments/autoparallel/local_map_deepseek_v3/model.py index f4915fb708..a620066330 100644 --- a/torchtitan/experiments/autoparallel/local_map_deepseek_v3/model.py +++ b/torchtitan/experiments/autoparallel/local_map_deepseek_v3/model.py @@ -15,4 +15,9 @@ # Need to share same base class with torchtitan models class DeepSeekV3Model(_DeepSeekV3Model, ModelProtocol): def __init__(self, model_args: DeepSeekV3ModelArgs): - super().__init__(model_args) + # Call _DeepSeekV3Model.__init__ which calls nn.Module.__init__ + # Note: We don't call ModelProtocol.__init__ separately because: + # 1. nn.Module.__init__() is already called by _DeepSeekV3Model.__init__ + # 2. Calling ModelProtocol.__init__ after would reset all module state + # (nn.Module.__init__ clears _modules, _parameters, etc.) + _DeepSeekV3Model.__init__(self, model_args)