Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dolomite_engine/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,8 @@ class DistributedArgs(BaseArgs):
pipeline_parallel_schedule: str | None = None
# whether to use async-TP
use_async_tensor_parallel: bool = False
# world size for each CP group
context_parallel_world_size: int = 1

def model_post_init(self, __context: Any) -> None:
# communication dtype
Expand Down
5 changes: 4 additions & 1 deletion dolomite_engine/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,11 @@ def wrap_model_container_for_distributed_training(
communication_dtype = None if communication_dtype is None else string_to_torch_dtype(communication_dtype)

assert stage in [0, 2, 3]
if ProcessGroupManager.is_context_parallel_enabled():
dp_mesh = ProcessGroupManager.get_data_parallel_context_parallel_mesh()
else:
dp_mesh = ProcessGroupManager.get_data_parallel_mesh()

dp_mesh = ProcessGroupManager.get_data_parallel_mesh()
block_classes = [
get_module_class_from_name(model_container[0], name) for name in block_names + teacher_block_names
]
Expand Down
3 changes: 3 additions & 0 deletions dolomite_engine/hf_models/mixins/dense/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ def _prepare_a_bunch_of_stuff(
tuple[torch.Tensor],
]:
if use_cache is None:
## TODO: disable cache for cp without padding free transformer
use_cache = False if self.use_padding_free_transformer else self.config.use_cache

if input_ids is not None and inputs_embeds is not None:
Expand Down Expand Up @@ -446,6 +447,8 @@ def _setup_positional_encoding(self) -> None:
base=self.config.rope_theta,
scale=self.config.rope_scaling["factor"],
original_max_position_embeddings=self.config.rope_scaling["original_max_position_embeddings"],
beta_fast=self.config.rope_scaling["beta_fast"],
beta_slow=self.config.rope_scaling["beta_slow"],
)
elif self.position_embedding_type == "nope":
pass
Expand Down
6 changes: 3 additions & 3 deletions dolomite_engine/model_wrapper/pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,9 @@ def forward(
else:
assert aux_loss_from_pipeline_parallel == 0

batch = self._prepare_model_inputs(batch)
labels = batch.pop("labels")
output: CausalLMOutputWithPast | PipelineParallelOutput = self.model(**batch, return_dict=True)
input_ids, labels = batch

output = self.model(input_ids=input_ids, return_dict=True)

if self.is_pipeline_parallel_enabled:
# aux_loss is returned as a 0 dimensional tensor
Expand Down
76 changes: 61 additions & 15 deletions dolomite_engine/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
MetricsTrackingDict,
ProcessGroupManager,
StepTracker,
create_context_parallel_ctx,
get_cp_context,
init_distributed,
is_torchao_available,
log_rank_0,
Expand Down Expand Up @@ -185,31 +187,73 @@ def train_step_without_pipeline_parallel(

gradient_accumulation_steps = StepTracker.get_gradient_accumulation_steps()

world_mesh = ProcessGroupManager.get_mesh()

with no_sync():
for _ in range(gradient_accumulation_steps - 1):
batch = get_next_batch(train_dataloader)
with forward_context():
loss_micro_step_dict = model(batch, lm_loss_multiplier=lm_loss_multiplier)
if ProcessGroupManager.is_context_parallel_enabled():
batch = model._prepare_model_inputs(get_next_batch(train_dataloader))
input_ids = batch["input_ids"]
labels = batch["labels"]
cp_context = get_cp_context(False, False)
optional_context_parallel_ctx = create_context_parallel_ctx(
cp_mesh=world_mesh["cp"],
cp_buffers=[input_ids, labels]
+ [model.model.transformer.rope.cos_cached, model.model.transformer.rope.sin_cached],
cp_seq_dims=[1, 1, 0, 0],
cp_no_restore_buffers={input_ids, labels},
cp_rotate_method="allgather",
)

with cp_context(optional_context_parallel_ctx):
loss_micro_step_dict = model((input_ids, labels), lm_loss_multiplier=lm_loss_multiplier)
loss_micro_step_scaled: torch.Tensor = loss_micro_step_dict["loss"] / gradient_accumulation_steps
loss_micro_step_scaled.backward()

else:
batch = get_next_batch(train_dataloader)
with forward_context():
loss_micro_step_dict = model(batch, lm_loss_multiplier=lm_loss_multiplier)

# compute gradients
with backward_context():
loss_micro_step_scaled: torch.Tensor = loss_micro_step_dict["loss"] / gradient_accumulation_steps
loss_micro_step_scaled.backward()
# compute gradients
with backward_context():
loss_micro_step_scaled: torch.Tensor = loss_micro_step_dict["loss"] / gradient_accumulation_steps
loss_micro_step_scaled.backward()

with torch.inference_mode():
metrics_tracker = metrics_tracker + loss_micro_step_dict

if fsdp_algorithm == 2:
model.set_requires_gradient_sync(True)

batch = get_next_batch(train_dataloader)
with forward_context():
loss_micro_step_dict = model(batch, lm_loss_multiplier=lm_loss_multiplier)
if ProcessGroupManager.is_context_parallel_enabled():
batch = model._prepare_model_inputs(get_next_batch(train_dataloader))
input_ids = batch["input_ids"]
labels = batch["labels"]
cp_context = get_cp_context(False, False)
optional_context_parallel_ctx = create_context_parallel_ctx(
cp_mesh=world_mesh["cp"],
cp_buffers=[input_ids, labels]
+ [model.model.transformer.rope.cos_cached, model.model.transformer.rope.sin_cached],
cp_seq_dims=[1, 1, 0, 0],
cp_no_restore_buffers={input_ids, labels},
cp_rotate_method="allgather",
)

# compute gradients
with backward_context():
loss_micro_step_scaled: torch.Tensor = loss_micro_step_dict["loss"] / gradient_accumulation_steps
loss_micro_step_scaled.backward()
with cp_context(optional_context_parallel_ctx):
loss_micro_step_dict = model((input_ids, labels), lm_loss_multiplier=lm_loss_multiplier)
loss_micro_step_scaled: torch.Tensor = loss_micro_step_dict["loss"] / gradient_accumulation_steps
loss_micro_step_scaled.backward()

else:
batch = get_next_batch(train_dataloader)
with forward_context():
loss_micro_step_dict = model(batch, lm_loss_multiplier=lm_loss_multiplier)

# compute gradients
with backward_context():
loss_micro_step_scaled: torch.Tensor = loss_micro_step_dict["loss"] / gradient_accumulation_steps
loss_micro_step_scaled.backward()

with torch.inference_mode():
metrics_tracker = metrics_tracker + loss_micro_step_dict
Expand Down Expand Up @@ -393,7 +437,8 @@ def train(
forward_context=forward_context,
backward_context=backward_context,
sync_every_gradient_accumulation_step=args.distributed_args.sync_every_gradient_accumulation_step,
lm_loss_multiplier=1 / (micro_batch_size * sequence_length),
lm_loss_multiplier=1
/ (micro_batch_size * sequence_length / args.distributed_args.context_parallel_world_size),
)

metrics_tracker = metrics_tracker + loss_step_dict
Expand Down Expand Up @@ -577,6 +622,7 @@ def main(mode: Mode = Mode.training) -> None:
data_parallel_replication_world_size=args.distributed_args.zero_topology.data_parallel_replication_world_size,
data_parallel_sharding_world_size=args.distributed_args.zero_topology.data_parallel_sharding_world_size,
zero_stage=args.distributed_args.stage,
context_parallel_world_size=args.distributed_args.context_parallel_world_size,
timeout_minutes=args.distributed_args.timeout_minutes,
use_async_tensor_parallel=args.distributed_args.use_async_tensor_parallel,
)
Expand Down
6 changes: 5 additions & 1 deletion dolomite_engine/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@ def all_reduce_metrics_tracker(metrics_tracker: MetricsTrackingDict) -> MetricsT
# tensor = torch.stack(tensor) / ProcessGroupManager.get_data_parallel_world_size()
# tensor = tensor.cpu()
# gloo op doesn't support averaging so we do sum and divide by world size above
torch.distributed.all_reduce(tensor, op=ReduceOp.AVG, group=ProcessGroupManager.get_data_parallel_group())
torch.distributed.all_reduce(
tensor,
op=ReduceOp.AVG,
group=ProcessGroupManager.get_mesh()["ddp", "fsdp", "cp"]._flatten(mesh_dim_name="dp_cp").get_group(),
)
tensor = tensor.tolist()

for i, key in enumerate(metrics_tracker):
Expand Down
11 changes: 10 additions & 1 deletion dolomite_engine/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,13 @@
is_zstandard_available,
log_environment,
)
from .parallel import ProcessGroupManager, get_pipeline_stage_ids_on_current_rank, run_rank_n
from .parallel import (
ProcessGroupManager,
create_context_parallel_ctx,
get_cp_context,
get_pipeline_stage_ids_on_current_rank,
run_rank_n,
)
from .pydantic import BaseArgs
from .safetensors import SafeTensorsWeightsManager
from .step_tracker import StepTracker
Expand All @@ -37,6 +43,7 @@ def init_distributed(
pipeline_parallel_world_size: int,
data_parallel_replication_world_size: int,
data_parallel_sharding_world_size: int,
context_parallel_world_size: int,
zero_stage: int,
timeout_minutes: int = None,
use_async_tensor_parallel: bool = False,
Expand All @@ -58,6 +65,7 @@ def init_distributed(
pipeline_parallel_world_size=pipeline_parallel_world_size,
data_parallel_replication_world_size=data_parallel_replication_world_size,
data_parallel_sharding_world_size=data_parallel_sharding_world_size,
context_parallel_world_size=context_parallel_world_size,
zero_stage=zero_stage,
timeout_minutes=timeout_minutes,
use_async_tensor_parallel=use_async_tensor_parallel,
Expand All @@ -67,6 +75,7 @@ def init_distributed(
log_rank_0(logging.INFO, f"total GPUs = {process_group_manager.get_world_size()}")
log_rank_0(logging.INFO, f"tensor parallel size = {process_group_manager.get_tensor_parallel_world_size()}")
log_rank_0(logging.INFO, f"data parallel size = {process_group_manager.get_data_parallel_world_size()}")
log_rank_0(logging.INFO, f"context parallel size = {context_parallel_world_size}")


def setup_tf32(use_tf32: bool = True) -> None:
Expand Down
86 changes: 82 additions & 4 deletions dolomite_engine/utils/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
# **************************************************

import os
from contextlib import contextmanager
from contextlib import ExitStack, contextmanager
from datetime import timedelta
from typing import Callable
from typing import Callable, Generator, List, Optional, Set

import torch
import torch.distributed
Expand Down Expand Up @@ -43,6 +43,9 @@
_DATA_PARALLEL_REPLICATION_WORLD_SIZE: int | None = None
_DATA_PARALLEL_SHARDING_WORLD_SIZE: int | None = None

# context parallel
_DATA_PARALLEL_CONTEXT_PARALLEL_MESH: DeviceMesh | None = None


class ProcessGroupManager:
def __init__(
Expand All @@ -51,6 +54,7 @@ def __init__(
pipeline_parallel_world_size: int = 1,
data_parallel_replication_world_size: int | None = None,
data_parallel_sharding_world_size: int | None = None,
context_parallel_world_size: int | None = None,
zero_stage: int = 3,
timeout_minutes: int | None = None,
use_async_tensor_parallel: bool = False,
Expand Down Expand Up @@ -84,7 +88,10 @@ def __init__(
else:
assert data_parallel_sharding_world_size is not None

assert data_parallel_replication_world_size * data_parallel_sharding_world_size == data_parallel_size
assert (
data_parallel_replication_world_size * data_parallel_sharding_world_size * context_parallel_world_size
== data_parallel_size
)

global _MESH, _TENSOR_PARALLEL_FIRST_RANK, _DATA_PARALLEL_REPLICATION_WORLD_SIZE, _DATA_PARALLEL_SHARDING_WORLD_SIZE

Expand All @@ -97,11 +104,14 @@ def __init__(
pipeline_parallel_world_size,
data_parallel_replication_world_size,
data_parallel_sharding_world_size,
context_parallel_world_size,
tensor_parallel_world_size,
),
mesh_dim_names=("pp", "ddp", "fsdp", "tp"),
mesh_dim_names=("pp", "ddp", "fsdp", "cp", "tp"),
)

_MESH["fsdp", "cp"]._flatten(mesh_dim_name="fsdp_cp")

local_rank = int(os.getenv("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)

Expand Down Expand Up @@ -295,6 +305,23 @@ def get_data_parallel_mesh() -> DeviceMesh:
_DATA_PARALLEL_MESH = ProcessGroupManager.get_mesh()["ddp", "fsdp"]
return _DATA_PARALLEL_MESH

# data parallel + context parallel
@staticmethod
def get_data_parallel_context_parallel_mesh() -> DeviceMesh:
global _DATA_PARALLEL_CONTEXT_PARALLEL_MESH

if _DATA_PARALLEL_CONTEXT_PARALLEL_MESH is None:
_DATA_PARALLEL_CONTEXT_PARALLEL_MESH = ProcessGroupManager.get_mesh()["ddp", "fsdp_cp"]
return _DATA_PARALLEL_CONTEXT_PARALLEL_MESH

@staticmethod
def get_context_parallel_world_size() -> int:
return ProcessGroupManager.get_mesh()["cp"].size()

@staticmethod
def is_context_parallel_enabled() -> bool:
return ProcessGroupManager.get_context_parallel_world_size() > 1

@staticmethod
def get_data_parallel_group() -> ProcessGroup:
global _DATA_PARALLEL_GROUP
Expand Down Expand Up @@ -399,8 +426,11 @@ def func_rank_other(*args, **kwargs):


def is_tracking_rank() -> bool:

## TODO verify cp local rank for logging
return (
ProcessGroupManager.get_data_parallel_rank() == 0
and ProcessGroupManager.get_mesh()["cp"].get_local_rank() == 0
and ProcessGroupManager.is_tensor_parallel_first_rank()
and ProcessGroupManager.get_pipeline_parallel_rank()
== ProcessGroupManager.get_pipeline_parallel_world_size() - 1
Expand All @@ -418,3 +448,51 @@ def get_pipeline_stage_ids_on_current_rank(num_pipeline_stages: int) -> int:
)

return tuple(pp_rank + i * pp_world_size for i in range(num_pipeline_stages_per_rank))


def create_context_parallel_ctx(
cp_mesh: DeviceMesh,
cp_buffers: List[torch.Tensor],
cp_seq_dims: List[int],
cp_no_restore_buffers: Set[torch.Tensor],
cp_rotate_method: str,
):
try:
from torch.distributed.tensor.experimental import context_parallel
from torch.distributed.tensor.experimental._attention import set_rotate_method
except ImportError:
print(
f"PyTorch version {torch.__version__} does not include the experimental "
"Context Parallel API. Please update to a newer version."
)

set_rotate_method(cp_rotate_method)
return context_parallel(
cp_mesh,
buffers=cp_buffers,
buffer_seq_dims=cp_seq_dims,
no_restore_buffers=cp_no_restore_buffers,
)


def get_cp_context(enable_loss_parallel: bool, enable_compiled_autograd: bool):
@contextmanager
def context(cp_context: Optional[Generator[None, None, None]] = None):
with ExitStack() as stack:
if enable_loss_parallel:
stack.enter_context(torch.distributed.tensor.parallel.loss_parallel())

if enable_compiled_autograd:
stack.enter_context(torch._dynamo.utils.maybe_enable_compiled_autograd(True))

if cp_context is not None:
from torch.nn.attention import SDPBackend, sdpa_kernel

# currently we only support these two SDP backends.
# TODO (xilunwu): support cuDNN backend
stack.enter_context(sdpa_kernel([SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]))
stack.enter_context(cp_context)

yield

return context