Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
299 changes: 172 additions & 127 deletions torchtitan/distributed/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import copy

import math
import os
from typing import Callable
Expand All @@ -13,7 +12,6 @@
import torch.nn as nn
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.pipelining import PipelineStage

from torch.distributed.pipelining.schedules import (
_PipelineSchedule,
_PipelineScheduleRuntime,
Expand All @@ -24,7 +22,6 @@
ScheduleDualPipeV,
ScheduleZBVZeroBubble,
)

from torchtitan.components.loss import LossFunction, rescale_accumulated_loss
from torchtitan.config import JobConfig
from torchtitan.distributed import ParallelDims
Expand All @@ -34,6 +31,7 @@

__all__ = [
"pipeline_llm",
"get_pipeline_metadata",
"build_pipeline_schedule",
"generate_llm_fqn_per_model_part",
"pipeline_module_split",
Expand All @@ -51,6 +49,68 @@ def pipeline_llm(
) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]:
pp_mesh = parallel_dims.get_mesh("pp")

num_virtual_stages, num_layers, input_weight, output_weight = get_pipeline_metadata(
parallel_dims, job_config, model_args
)

module_names_per_stage = job_config.parallelism.module_fqns_per_model_part
if module_names_per_stage is None:
module_names_per_stage = generate_llm_fqn_per_model_part(
num_virtual_stages, num_layers, input_weight, output_weight
)
for i, stage_ms in enumerate(module_names_per_stage):
logger.debug(f"Stage {i}: {stage_ms}")

stages, model_parts = pipeline_module_split(
model,
pp_mesh,
job_config.parallelism.pipeline_parallel_schedule,
device,
module_names_per_stage,
)

# For PP with looped schedules, each item in model_parts is one stage-model-chunk.
# We need to iterate through model_parts to apply SPMD parallelisms, compilation,
# optimizer, and checkpointing
for i, m in enumerate(model_parts):
# apply SPMD-style PT-D techniques
m = parallelize_fn(m, parallel_dims, job_config)
model_parts[i] = m
# NOTE: this is to update the model in the stage
# in case the model is modified e.g. by torch.compile
stages[i].submod = m

pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn)

# This is used in the train loop to determine whether to pass in the input_ids and labels
has_first_stage = False
has_last_stage = False
for stage in stages:
if stage.is_first:
has_first_stage = True
if stage.is_last:
has_last_stage = True

return pp_schedule, model_parts, has_first_stage, has_last_stage


def get_pipeline_metadata(
parallel_dims: ParallelDims,
job_config: JobConfig,
model_args: BaseModelArgs,
) -> tuple[int, int, int, int]:
"""
Determine the number of virtual stages and the number of layers in the model.

Args:
parallel_dims (ParallelDims): Parallel dimensions.
job_config (JobConfig): Job configuration.
model_args (BaseModelArgs): Model arguments.

Returns:
tuple: A tuple containing the number of virtual stages, the number of layers in the model,
the input weight, and the output weight.
"""
# Determine the number of virtual stages based on schedule type
schedule_class = get_schedule_class(
job_config.parallelism.pipeline_parallel_schedule
Expand Down Expand Up @@ -113,46 +173,7 @@ def pipeline_llm(
# For single-stage schedules, default is 1 virtual stage per rank
stages_per_rank = 1 if is_single_stage_schedule else 2
num_virtual_stages = parallel_dims.pp * stages_per_rank

module_names_per_stage = job_config.parallelism.module_fqns_per_model_part
if module_names_per_stage is None:
module_names_per_stage = generate_llm_fqn_per_model_part(
num_virtual_stages, num_layers, input_weight, output_weight
)
for i, stage_ms in enumerate(module_names_per_stage):
logger.debug(f"Stage {i}: {stage_ms}")

stages, model_parts = pipeline_module_split(
model,
pp_mesh,
job_config.parallelism.pipeline_parallel_schedule,
device,
module_names_per_stage,
)

# For PP with looped schedules, each item in model_parts is one stage-model-chunk.
# We need to iterate through model_parts to apply SPMD parallelisms, compilation,
# optimizer, and checkpointing
for i, m in enumerate(model_parts):
# apply SPMD-style PT-D techniques
m = parallelize_fn(m, parallel_dims, job_config)
model_parts[i] = m
# NOTE: this is to update the model in the stage
# in case the model is modified e.g. by torch.compile
stages[i].submod = m

pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn)

# This is used in the train loop to determine whether to pass in the input_ids and labels
has_first_stage = False
has_last_stage = False
for stage in stages:
if stage.is_first:
has_first_stage = True
if stage.is_last:
has_last_stage = True

return pp_schedule, model_parts, has_first_stage, has_last_stage
return num_virtual_stages, num_layers, input_weight, output_weight


def build_pipeline_schedule(
Expand Down Expand Up @@ -344,6 +365,106 @@ def generate_llm_fqn_per_model_part(
return module_names_per_stage


def split_module(
whole_model: nn.Module,
module_names: list[str],
) -> nn.Module:
"""
Splits a whole model into a module based on the specified module names.

Args:
whole_model: The complete model to be split
module_names: List of module names to include in the split

Returns:
The split module

Example usage:
module_names = ["tok_embeddings", "layers.0", "layers.1", "norm", "output"]
split_module(whole_model, module_names)
"""
model = copy.deepcopy(whole_model)
# Create a set of modules to keep for faster lookup
modules_to_keep = set(module_names)
for module_name, module_value in model.named_children():
# Handle layer-like structures (e.g., "layers.0", "layers.1")
if isinstance(module_value, (nn.ModuleDict, nn.ModuleList)):
layers_to_keep = {
name.split(".", 1)[1]
for name in modules_to_keep
if name.startswith(f"{module_name}.")
}
if layers_to_keep:
# Keep only specified layers
if isinstance(module_value, nn.ModuleDict):
for layer_name in list(module_value.keys()):
if layer_name not in layers_to_keep:
del module_value[layer_name]
elif isinstance(module_value, nn.ModuleList):
indices_to_keep = {
int(idx) for idx in layers_to_keep if idx.isdigit()
}
new_layers = nn.ModuleList(
[
layer
for i, layer in enumerate(module_value)
if i in indices_to_keep
]
)
setattr(model, module_name, new_layers)
else:
# No layers from this structure needed, set to empty structure
if isinstance(module_value, nn.ModuleDict):
setattr(model, module_name, nn.ModuleDict())
elif isinstance(module_value, nn.ModuleList):
setattr(model, module_name, nn.ModuleList())
# Handle simple module attributes (e.g., "linear", "norm")
elif module_name not in modules_to_keep:
# Replace with None
setattr(model, module_name, None)
return model


def get_pp_rank_to_stage_indices_mapping(
pp_rank: int,
pp_degree,
pp_schedule: str,
num_stages: int,
) -> tuple[int, ...]:
"""
Returns a mapping from PP rank to stage indices for the given pipeline schedule.

Args:
pp_rank: Pipeline parallel rank
pp_degree: Number of pipeline parallel ranks
pp_schedule: Name of pipeline parallelism schedule
num_stages: Number of pipeline stages

Returns:
Mapping from PP rank to stage indices
"""
schedule_class = get_schedule_class(pp_schedule)
style = (
"v" if schedule_class in (ScheduleZBVZeroBubble, ScheduleDualPipeV) else "loop"
)
assert (
num_stages % pp_degree == 0
), f"num_stages {num_stages} must be evenly divisible by pp_degree {pp_degree}"
stages_per_rank = num_stages // pp_degree
if style == "loop":
return tuple(pp_rank + s * pp_degree for s in range(stages_per_rank))
elif style == "v":
assert (
stages_per_rank == 2
), f"v schedules assume 2 stages per rank, got {stages_per_rank}"
stage_v_pairs = list(
zip(range(pp_degree), range(num_stages - 1, pp_degree - 1, -1))
)
return tuple(stage_v_pairs[pp_rank])
else:
raise ValueError(f"Unknown style {style}")


def pipeline_module_split(
whole_model: nn.Module,
pp_mesh: DeviceMesh,
Expand Down Expand Up @@ -385,97 +506,21 @@ def pipeline_module_split(
"""
pp_rank = pp_mesh.get_local_rank()
pp_degree = pp_mesh.size()

def _build_stage_from_modules(
stage_idx: int, module_names: list[str], num_stages: int
) -> tuple[PipelineStage, nn.Module]:
model = copy.deepcopy(whole_model)

# Create a set of modules to keep for faster lookup
modules_to_keep = set(module_names)
for module_name, module_value in model.named_children():
# Handle layer-like structures (e.g., "layers.0", "layers.1")
if isinstance(module_value, (nn.ModuleDict, nn.ModuleList)):
layers_to_keep = {
name.split(".", 1)[1]
for name in modules_to_keep
if name.startswith(f"{module_name}.")
}
if layers_to_keep:
# Keep only specified layers
if isinstance(module_value, nn.ModuleDict):
for layer_name in list(module_value.keys()):
if layer_name not in layers_to_keep:
del module_value[layer_name]
elif isinstance(module_value, nn.ModuleList):
indices_to_keep = {
int(idx) for idx in layers_to_keep if idx.isdigit()
}
new_layers = nn.ModuleList(
[
layer
for i, layer in enumerate(module_value)
if i in indices_to_keep
]
)
setattr(model, module_name, new_layers)
else:
# No layers from this structure needed, set to empty structure
if isinstance(module_value, nn.ModuleDict):
setattr(model, module_name, nn.ModuleDict())
elif isinstance(module_value, nn.ModuleList):
setattr(model, module_name, nn.ModuleList())
# Handle simple module attributes (e.g., "linear", "norm")
elif module_name not in modules_to_keep:
# Replace with None
setattr(model, module_name, None)

stage = PipelineStage(
model,
stage_idx,
num_stages,
device,
group=pp_mesh.get_group("pp"),
)
return stage, model

num_stages = len(module_names_per_stage)
stages = []
models = []

schedule_class = get_schedule_class(pp_schedule)
style = (
"v" if schedule_class in (ScheduleZBVZeroBubble, ScheduleDualPipeV) else "loop"
pp_rank_to_stage_indices = get_pp_rank_to_stage_indices_mapping(
pp_rank, pp_degree, pp_schedule, num_stages
)

def _get_stage_indices() -> tuple[int, ...]:
"""
Compute the stage ids for the stages that will run on this pp rank
for either a looped or V style schedule
"""
assert (
num_stages % pp_degree == 0
), f"num_stages {num_stages} must be evenly divisible by pp_degree {pp_degree}"
stages_per_rank = num_stages // pp_degree
if style == "loop":
return tuple(pp_rank + s * pp_degree for s in range(stages_per_rank))
elif style == "v":
assert (
stages_per_rank == 2
), f"v schedules assume 2 stages per rank, got {stages_per_rank}"
stage_v_pairs = list(
zip(range(pp_degree), range(num_stages - 1, pp_degree - 1, -1))
)
return stage_v_pairs[pp_rank]
else:
raise ValueError(f"Unknown style {style}")

for stage_idx in _get_stage_indices():
for stage_idx in pp_rank_to_stage_indices:
module_names = module_names_per_stage[stage_idx]
stage, model_chunk = _build_stage_from_modules(
model_chunk = split_module(whole_model, module_names)
stage = PipelineStage(
model_chunk,
stage_idx,
module_names,
num_stages,
device,
group=pp_mesh.get_group("pp"),
)
logger.info(
f"PP rank {pp_rank} is building stage_idx {stage_idx} "
Expand Down
15 changes: 11 additions & 4 deletions torchtitan/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,18 +494,25 @@ def _clip_grad_norm_with_ep(
else:
non_ep_params.append(p)
non_ep_grads.append(p.grad)

# Either list can be empty depending on the parallelization strategy:
# - In torchtitan with separate dense/sparse meshes, both lists are typically non-empty
# - In autoparallel, all params may live on a single sparse mesh with "ep" dimension,
# so non_ep_grads would be empty
# - In PP + EP setups, certain PP ranks may only own EP or non-EP layers
ep_grads_total_norm = torch.nn.utils.get_total_norm(
ep_grads, norm_type, error_if_nonfinite, foreach
)
# ep_grads may be an empty list, in which case get_total_norm returns tensor(0.), a non-DTensor
# This can occur in PP + EP setups where certain PP ranks only own non-EP layers, for instance.
# get_total_norm returns tensor(0.) for empty list, which is a non-DTensor
if isinstance(ep_grads_total_norm, DTensor):
ep_grads_total_norm = ep_grads_total_norm.full_tensor()

# pyrefly: ignore [missing-attribute]
non_ep_grads_total_norm = torch.nn.utils.get_total_norm(
non_ep_grads, norm_type, error_if_nonfinite, foreach
).full_tensor()
)
# get_total_norm returns tensor(0.) for empty list, which is a non-DTensor
if isinstance(non_ep_grads_total_norm, DTensor):
non_ep_grads_total_norm = non_ep_grads_total_norm.full_tensor()

if math.isinf(norm_type):
total_norm = torch.maximum(ep_grads_total_norm, non_ep_grads_total_norm)
Expand Down
Empty file.
Loading
Loading