From 4fa60ce690407ed9e074df4a82029b5ef807d871 Mon Sep 17 00:00:00 2001 From: Gaoyuan-Zhang Date: Thu, 1 May 2025 14:52:22 +0000 Subject: [PATCH 1/3] cp merge --- dolomite_engine/arguments.py | 2 + dolomite_engine/distributed.py | 7 +- .../hf_models/mixins/dense/base.py | 3 + dolomite_engine/model_wrapper/pretraining.py | 6 +- dolomite_engine/pretrain.py | 51 ++++++++---- dolomite_engine/train_utils.py | 2 +- dolomite_engine/utils/__init__.py | 5 +- dolomite_engine/utils/parallel.py | 81 ++++++++++++++++++- 8 files changed, 132 insertions(+), 25 deletions(-) diff --git a/dolomite_engine/arguments.py b/dolomite_engine/arguments.py index 95262d3cf..842838ce0 100644 --- a/dolomite_engine/arguments.py +++ b/dolomite_engine/arguments.py @@ -281,6 +281,8 @@ class DistributedArgs(BaseArgs): pipeline_parallel_schedule: str | None = None # whether to use async-TP use_async_tensor_parallel: bool = False + # world size for each CP group + context_parallel_world_size: int = 1 def model_post_init(self, __context: Any) -> None: # communication dtype diff --git a/dolomite_engine/distributed.py b/dolomite_engine/distributed.py index 91bfe90d2..850cf587a 100644 --- a/dolomite_engine/distributed.py +++ b/dolomite_engine/distributed.py @@ -111,8 +111,11 @@ def wrap_model_container_for_distributed_training( communication_dtype = None if communication_dtype is None else string_to_torch_dtype(communication_dtype) assert stage in [0, 2, 3] - - dp_mesh = ProcessGroupManager.get_data_parallel_mesh() + if ProcessGroupManager.get_context_parallel_world_size() > 1: + dp_mesh = ProcessGroupManager.get_data_parallel_context_parallel_mesh() + else: + dp_mesh = ProcessGroupManager.get_data_parallel_mesh() + block_classes = [ get_module_class_from_name(model_container[0], name) for name in block_names + teacher_block_names ] diff --git a/dolomite_engine/hf_models/mixins/dense/base.py b/dolomite_engine/hf_models/mixins/dense/base.py index eee638120..9e94c57a2 100644 --- a/dolomite_engine/hf_models/mixins/dense/base.py +++ b/dolomite_engine/hf_models/mixins/dense/base.py @@ -325,6 +325,7 @@ def _prepare_a_bunch_of_stuff( tuple[torch.Tensor], ]: if use_cache is None: + ## TODO: disable cache for cp without padding free transformer use_cache = False if self._use_padding_free_transformer else self.config.use_cache if input_ids is not None and inputs_embeds is not None: @@ -445,6 +446,8 @@ def _setup_positional_encoding(self) -> None: base=self.config.rope_theta, scale=self.config.rope_scaling["factor"], original_max_position_embeddings=self.config.rope_scaling["original_max_position_embeddings"], + beta_fast=self.config.rope_scaling["beta_fast"], + beta_slow=self.config.rope_scaling["beta_slow"], ) elif self.position_embedding_type == "nope": pass diff --git a/dolomite_engine/model_wrapper/pretraining.py b/dolomite_engine/model_wrapper/pretraining.py index 5f6f7e21c..83f9a85f5 100644 --- a/dolomite_engine/model_wrapper/pretraining.py +++ b/dolomite_engine/model_wrapper/pretraining.py @@ -118,9 +118,9 @@ def forward( else: assert aux_loss_from_pipeline_parallel == 0 - batch = self._prepare_model_inputs(batch) - labels = batch.pop("labels") - output: CausalLMOutputWithPast | PipelineParallelOutput = self.model(**batch, return_dict=True) + input_ids, labels = batch + + output = self.model(input_ids=input_ids, return_dict=True) if self.is_pipeline_parallel_enabled: # aux_loss is returned as a 0 dimensional tensor diff --git a/dolomite_engine/pretrain.py b/dolomite_engine/pretrain.py index 7c3d2fd53..70f6ec76f 100644 --- a/dolomite_engine/pretrain.py +++ b/dolomite_engine/pretrain.py @@ -32,6 +32,8 @@ is_torchao_available, log_rank_0, setup_tf32, + create_context_parallel_ctx, + get_cp_context, ) @@ -181,14 +183,25 @@ def train_step_without_pipeline_parallel( gradient_accumulation_steps = StepTracker.get_gradient_accumulation_steps() + world_mesh = ProcessGroupManager.get_mesh() + with no_sync(): for _ in range(gradient_accumulation_steps - 1): - batch = get_next_batch(train_dataloader) - with forward_context(): - loss_micro_step_dict = model(batch, lm_loss_multiplier=lm_loss_multiplier) - - # compute gradients - with backward_context(): + batch = model._prepare_model_inputs(get_next_batch(train_dataloader)) + input_ids = batch["input_ids"] + labels = batch["labels"] + optional_context_parallel_ctx = ( + create_context_parallel_ctx( + cp_mesh=world_mesh["cp"], + cp_buffers=[input_ids, labels] + [model.model.transformer.rope.cos_cached, model.model.transformer.rope.sin_cached], + cp_seq_dims=[1, 1, 0, 0], + cp_no_restore_buffers={input_ids, labels}, + cp_rotate_method="allgather", + ) + ) + + with forward_context(optional_context_parallel_ctx): + loss_micro_step_dict = model((input_ids, labels), lm_loss_multiplier=lm_loss_multiplier) loss_micro_step_scaled: torch.Tensor = loss_micro_step_dict["loss"] / gradient_accumulation_steps loss_micro_step_scaled.backward() @@ -198,12 +211,21 @@ def train_step_without_pipeline_parallel( if fsdp_algorithm == 2: model.set_requires_gradient_sync(True) - batch = get_next_batch(train_dataloader) - with forward_context(): - loss_micro_step_dict = model(batch, lm_loss_multiplier=lm_loss_multiplier) - - # compute gradients - with backward_context(): + batch = model._prepare_model_inputs(get_next_batch(train_dataloader)) + input_ids = batch["input_ids"] + labels = batch["labels"] + + optional_context_parallel_ctx = ( + create_context_parallel_ctx( + cp_mesh=world_mesh["cp"], + cp_buffers=[input_ids, labels] + [model.model.transformer.rope.cos_cached, model.model.transformer.rope.sin_cached], + cp_seq_dims=[1, 1, 0, 0], + cp_no_restore_buffers={input_ids, labels}, + cp_rotate_method="allgather", + ) + ) + with forward_context(optional_context_parallel_ctx): + loss_micro_step_dict = model((input_ids, labels), lm_loss_multiplier=lm_loss_multiplier) loss_micro_step_scaled: torch.Tensor = loss_micro_step_dict["loss"] / gradient_accumulation_steps loss_micro_step_scaled.backward() @@ -343,7 +365,7 @@ def train( / ProcessGroupManager.get_world_size() ) - forward_context = nullcontext + forward_context = get_cp_context(False, False) if ProcessGroupManager.get_context_parallel_world_size() > 1 else nullcontext backward_context = loss_parallel if ProcessGroupManager.is_tensor_parallel_enabled() else nullcontext torch_profiler = get_torch_profiler(args.logging_args.torch_profiler_trace_path) @@ -380,7 +402,7 @@ def train( forward_context=forward_context, backward_context=backward_context, sync_every_gradient_accumulation_step=args.distributed_args.sync_every_gradient_accumulation_step, - lm_loss_multiplier=1 / (micro_batch_size * sequence_length), + lm_loss_multiplier=1 / (micro_batch_size * sequence_length / args.distributed_args.context_parallel_world_size), ) metrics_tracker = metrics_tracker + loss_step_dict @@ -542,6 +564,7 @@ def main(mode: Mode = Mode.training) -> None: data_parallel_replication_world_size=args.distributed_args.zero_topology.data_parallel_replication_world_size, data_parallel_sharding_world_size=args.distributed_args.zero_topology.data_parallel_sharding_world_size, zero_stage=args.distributed_args.stage, + context_parallel_world_size=args.distributed_args.context_parallel_world_size, timeout_minutes=args.distributed_args.timeout_minutes, use_async_tensor_parallel=args.distributed_args.use_async_tensor_parallel, ) diff --git a/dolomite_engine/train_utils.py b/dolomite_engine/train_utils.py index f8d408037..9eba97e5d 100644 --- a/dolomite_engine/train_utils.py +++ b/dolomite_engine/train_utils.py @@ -17,7 +17,7 @@ def all_reduce_metrics_tracker(metrics_tracker: MetricsTrackingDict) -> MetricsT # tensor = torch.stack(tensor) / ProcessGroupManager.get_data_parallel_world_size() # tensor = tensor.cpu() # gloo op doesn't support averaging so we do sum and divide by world size above - torch.distributed.all_reduce(tensor, op=ReduceOp.AVG, group=ProcessGroupManager.get_data_parallel_group()) + torch.distributed.all_reduce(tensor, op=ReduceOp.AVG, group=ProcessGroupManager.get_mesh()["ddp", "fsdp", "cp"]._flatten(mesh_dim_name="dp_cp").get_group()) tensor = tensor.tolist() for i, key in enumerate(metrics_tracker): diff --git a/dolomite_engine/utils/__init__.py b/dolomite_engine/utils/__init__.py index 73dfe7f93..84d22dde4 100644 --- a/dolomite_engine/utils/__init__.py +++ b/dolomite_engine/utils/__init__.py @@ -19,7 +19,7 @@ is_zstandard_available, log_environment, ) -from .parallel import ProcessGroupManager, get_pipeline_stage_ids_on_current_rank, run_rank_n +from .parallel import ProcessGroupManager, get_pipeline_stage_ids_on_current_rank, run_rank_n, create_context_parallel_ctx, get_cp_context from .pydantic import BaseArgs from .safetensors import SafeTensorsWeightsManager from .step_tracker import StepTracker @@ -34,6 +34,7 @@ def init_distributed( data_parallel_size: int, data_parallel_replication_world_size: int, data_parallel_sharding_world_size: int, + context_parallel_world_size: int, zero_stage: int, timeout_minutes: int = None, use_async_tensor_parallel: bool = False, @@ -57,6 +58,7 @@ def init_distributed( data_parallel_size=data_parallel_size, data_parallel_replication_world_size=data_parallel_replication_world_size, data_parallel_sharding_world_size=data_parallel_sharding_world_size, + context_parallel_world_size=context_parallel_world_size, zero_stage=zero_stage, timeout_minutes=timeout_minutes, use_async_tensor_parallel=use_async_tensor_parallel, @@ -66,6 +68,7 @@ def init_distributed( log_rank_0(logging.INFO, f"total GPUs = {process_group_manager.get_world_size()}") log_rank_0(logging.INFO, f"tensor parallel size = {process_group_manager.get_tensor_parallel_world_size()}") log_rank_0(logging.INFO, f"data parallel size = {process_group_manager.get_data_parallel_world_size()}") + log_rank_0(logging.INFO, f"context parallel size = {context_parallel_world_size}") def setup_tf32(use_tf32: bool = True) -> None: diff --git a/dolomite_engine/utils/parallel.py b/dolomite_engine/utils/parallel.py index 9092d7914..60f39784b 100644 --- a/dolomite_engine/utils/parallel.py +++ b/dolomite_engine/utils/parallel.py @@ -1,7 +1,7 @@ import os -from contextlib import contextmanager +from contextlib import contextmanager, ExitStack from datetime import timedelta -from typing import Callable +from typing import Callable, List, Set, Optional, Generator import torch import torch.distributed @@ -39,6 +39,9 @@ _DATA_PARALLEL_REPLICATION_WORLD_SIZE: int | None = None _DATA_PARALLEL_SHARDING_WORLD_SIZE: int | None = None +# context parallel +_DATA_PARALLEL_CONTEXT_PARALLEL_MESH: DeviceMesh | None = None + class ProcessGroupManager: def __init__( @@ -48,6 +51,7 @@ def __init__( data_parallel_size: int | None = None, data_parallel_replication_world_size: int | None = None, data_parallel_sharding_world_size: int | None = None, + context_parallel_world_size: int | None = None, zero_stage: int = 3, timeout_minutes: int | None = None, use_async_tensor_parallel: bool = False, @@ -83,7 +87,7 @@ def __init__( else: assert data_parallel_sharding_world_size is not None - assert data_parallel_replication_world_size * data_parallel_sharding_world_size == data_parallel_size + assert data_parallel_replication_world_size * data_parallel_sharding_world_size * context_parallel_world_size == data_parallel_size global _MESH, _TENSOR_PARALLEL_FIRST_RANK, _DATA_PARALLEL_REPLICATION_WORLD_SIZE, _DATA_PARALLEL_SHARDING_WORLD_SIZE @@ -96,11 +100,14 @@ def __init__( pipeline_parallel_world_size, data_parallel_replication_world_size, data_parallel_sharding_world_size, + context_parallel_world_size, tensor_parallel_world_size, ), - mesh_dim_names=("pp", "ddp", "fsdp", "tp"), + mesh_dim_names=("pp", "ddp", "fsdp","cp", "tp"), ) + _MESH["fsdp", "cp"]._flatten(mesh_dim_name="fsdp_cp") + local_rank = int(os.getenv("LOCAL_RANK", 0)) torch.cuda.set_device(local_rank) @@ -293,6 +300,19 @@ def get_data_parallel_mesh() -> DeviceMesh: if _DATA_PARALLEL_MESH is None: _DATA_PARALLEL_MESH = ProcessGroupManager.get_mesh()["ddp", "fsdp"] return _DATA_PARALLEL_MESH + + # data parallel + context parallel + @staticmethod + def get_data_parallel_context_parallel_mesh() -> DeviceMesh: + global _DATA_PARALLEL_CONTEXT_PARALLEL_MESH + + if _DATA_PARALLEL_CONTEXT_PARALLEL_MESH is None: + _DATA_PARALLEL_CONTEXT_PARALLEL_MESH = ProcessGroupManager.get_mesh()["ddp", "fsdp_cp"] + return _DATA_PARALLEL_CONTEXT_PARALLEL_MESH + + @staticmethod + def get_context_parallel_world_size() -> int: + return ProcessGroupManager.get_mesh()["cp"].size() @staticmethod def get_data_parallel_group() -> ProcessGroup: @@ -417,3 +437,56 @@ def get_pipeline_stage_ids_on_current_rank(num_pipeline_stages: int) -> int: ) return tuple(pp_rank + i * pp_world_size for i in range(num_pipeline_stages_per_rank)) + + +def create_context_parallel_ctx( + cp_mesh: DeviceMesh, + cp_buffers: List[torch.Tensor], + cp_seq_dims: List[int], + cp_no_restore_buffers: Set[torch.Tensor], + cp_rotate_method: str, +): + try: + from torch.distributed.tensor.experimental import context_parallel + from torch.distributed.tensor.experimental._attention import set_rotate_method + except ImportError: + print( + f"PyTorch version {torch.__version__} does not include the experimental " + "Context Parallel API. Please update to a newer version." + ) + + set_rotate_method(cp_rotate_method) + return context_parallel( + cp_mesh, + buffers=cp_buffers, + buffer_seq_dims=cp_seq_dims, + no_restore_buffers=cp_no_restore_buffers, + ) + +def get_cp_context(enable_loss_parallel: bool, enable_compiled_autograd: bool): + @contextmanager + def context(cp_context: Optional[Generator[None, None, None]] = None): + with ExitStack() as stack: + if enable_loss_parallel: + stack.enter_context(torch.distributed.tensor.parallel.loss_parallel()) + + if enable_compiled_autograd: + stack.enter_context( + torch._dynamo.utils.maybe_enable_compiled_autograd(True) + ) + + if cp_context is not None: + from torch.nn.attention import sdpa_kernel, SDPBackend + + # currently we only support these two SDP backends. + # TODO (xilunwu): support cuDNN backend + stack.enter_context( + sdpa_kernel( + [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] + ) + ) + stack.enter_context(cp_context) + + yield + + return context \ No newline at end of file From 9bdff598a4774f567da52954619b59c01ad1d5c8 Mon Sep 17 00:00:00 2001 From: Gaoyuan-Zhang Date: Fri, 2 May 2025 15:47:08 +0000 Subject: [PATCH 2/3] forward context update --- dolomite_engine/pretrain.py | 88 ++++++++++++++++++++----------- dolomite_engine/utils/parallel.py | 3 ++ 2 files changed, 59 insertions(+), 32 deletions(-) diff --git a/dolomite_engine/pretrain.py b/dolomite_engine/pretrain.py index 70f6ec76f..e45794b57 100644 --- a/dolomite_engine/pretrain.py +++ b/dolomite_engine/pretrain.py @@ -187,23 +187,35 @@ def train_step_without_pipeline_parallel( with no_sync(): for _ in range(gradient_accumulation_steps - 1): - batch = model._prepare_model_inputs(get_next_batch(train_dataloader)) - input_ids = batch["input_ids"] - labels = batch["labels"] - optional_context_parallel_ctx = ( - create_context_parallel_ctx( - cp_mesh=world_mesh["cp"], - cp_buffers=[input_ids, labels] + [model.model.transformer.rope.cos_cached, model.model.transformer.rope.sin_cached], - cp_seq_dims=[1, 1, 0, 0], - cp_no_restore_buffers={input_ids, labels}, - cp_rotate_method="allgather", + if ProcessGroupManager.get_context_parallel_world_size() > 1: + batch = model._prepare_model_inputs(get_next_batch(train_dataloader)) + input_ids = batch["input_ids"] + labels = batch["labels"] + cp_context = get_cp_context(False, False) + optional_context_parallel_ctx = ( + create_context_parallel_ctx( + cp_mesh=world_mesh["cp"], + cp_buffers=[input_ids, labels] + [model.model.transformer.rope.cos_cached, model.model.transformer.rope.sin_cached], + cp_seq_dims=[1, 1, 0, 0], + cp_no_restore_buffers={input_ids, labels}, + cp_rotate_method="allgather", + ) ) - ) + + with cp_context(optional_context_parallel_ctx): + loss_micro_step_dict = model((input_ids, labels), lm_loss_multiplier=lm_loss_multiplier) + loss_micro_step_scaled: torch.Tensor = loss_micro_step_dict["loss"] / gradient_accumulation_steps + loss_micro_step_scaled.backward() - with forward_context(optional_context_parallel_ctx): - loss_micro_step_dict = model((input_ids, labels), lm_loss_multiplier=lm_loss_multiplier) - loss_micro_step_scaled: torch.Tensor = loss_micro_step_dict["loss"] / gradient_accumulation_steps - loss_micro_step_scaled.backward() + else: + batch = get_next_batch(train_dataloader) + with forward_context(): + loss_micro_step_dict = model(batch, lm_loss_multiplier=lm_loss_multiplier) + + # compute gradients + with backward_context(): + loss_micro_step_scaled: torch.Tensor = loss_micro_step_dict["loss"] / gradient_accumulation_steps + loss_micro_step_scaled.backward() with torch.inference_mode(): metrics_tracker = metrics_tracker + loss_micro_step_dict @@ -211,23 +223,35 @@ def train_step_without_pipeline_parallel( if fsdp_algorithm == 2: model.set_requires_gradient_sync(True) - batch = model._prepare_model_inputs(get_next_batch(train_dataloader)) - input_ids = batch["input_ids"] - labels = batch["labels"] - - optional_context_parallel_ctx = ( - create_context_parallel_ctx( - cp_mesh=world_mesh["cp"], - cp_buffers=[input_ids, labels] + [model.model.transformer.rope.cos_cached, model.model.transformer.rope.sin_cached], - cp_seq_dims=[1, 1, 0, 0], - cp_no_restore_buffers={input_ids, labels}, - cp_rotate_method="allgather", - ) + if ProcessGroupManager.get_context_parallel_world_size() > 1: + batch = model._prepare_model_inputs(get_next_batch(train_dataloader)) + input_ids = batch["input_ids"] + labels = batch["labels"] + cp_context = get_cp_context(False, False) + optional_context_parallel_ctx = ( + create_context_parallel_ctx( + cp_mesh=world_mesh["cp"], + cp_buffers=[input_ids, labels] + [model.model.transformer.rope.cos_cached, model.model.transformer.rope.sin_cached], + cp_seq_dims=[1, 1, 0, 0], + cp_no_restore_buffers={input_ids, labels}, + cp_rotate_method="allgather", ) - with forward_context(optional_context_parallel_ctx): - loss_micro_step_dict = model((input_ids, labels), lm_loss_multiplier=lm_loss_multiplier) - loss_micro_step_scaled: torch.Tensor = loss_micro_step_dict["loss"] / gradient_accumulation_steps - loss_micro_step_scaled.backward() +) + + with cp_context(optional_context_parallel_ctx): + loss_micro_step_dict = model((input_ids, labels), lm_loss_multiplier=lm_loss_multiplier) + loss_micro_step_scaled: torch.Tensor = loss_micro_step_dict["loss"] / gradient_accumulation_steps + loss_micro_step_scaled.backward() + + else: + batch = get_next_batch(train_dataloader) + with forward_context(): + loss_micro_step_dict = model(batch, lm_loss_multiplier=lm_loss_multiplier) + + # compute gradients + with backward_context(): + loss_micro_step_scaled: torch.Tensor = loss_micro_step_dict["loss"] / gradient_accumulation_steps + loss_micro_step_scaled.backward() with torch.inference_mode(): metrics_tracker = metrics_tracker + loss_micro_step_dict @@ -365,7 +389,7 @@ def train( / ProcessGroupManager.get_world_size() ) - forward_context = get_cp_context(False, False) if ProcessGroupManager.get_context_parallel_world_size() > 1 else nullcontext + forward_context = nullcontext backward_context = loss_parallel if ProcessGroupManager.is_tensor_parallel_enabled() else nullcontext torch_profiler = get_torch_profiler(args.logging_args.torch_profiler_trace_path) diff --git a/dolomite_engine/utils/parallel.py b/dolomite_engine/utils/parallel.py index 60f39784b..502aa253a 100644 --- a/dolomite_engine/utils/parallel.py +++ b/dolomite_engine/utils/parallel.py @@ -418,8 +418,11 @@ def func_rank_other(*args, **kwargs): def is_tracking_rank() -> bool: + + ## TODO verify cp local rank for logging return ( ProcessGroupManager.get_data_parallel_rank() == 0 + and ProcessGroupManager.get_mesh()["cp"].get_local_rank() == 0 and ProcessGroupManager.is_tensor_parallel_first_rank() and ProcessGroupManager.get_pipeline_parallel_rank() == ProcessGroupManager.get_pipeline_parallel_world_size() - 1 From d3f65e5102a0b3e3a3be045bb4a1a4340a5a3f0d Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 7 May 2025 13:07:59 -0400 Subject: [PATCH 3/3] reformat Signed-off-by: Mayank Mishra --- dolomite_engine/distributed.py | 4 +-- dolomite_engine/pretrain.py | 57 +++++++++++++++---------------- dolomite_engine/train_utils.py | 6 +++- dolomite_engine/utils/__init__.py | 8 ++++- dolomite_engine/utils/parallel.py | 34 +++++++++--------- 5 files changed, 60 insertions(+), 49 deletions(-) diff --git a/dolomite_engine/distributed.py b/dolomite_engine/distributed.py index 850cf587a..9dba6cddf 100644 --- a/dolomite_engine/distributed.py +++ b/dolomite_engine/distributed.py @@ -111,11 +111,11 @@ def wrap_model_container_for_distributed_training( communication_dtype = None if communication_dtype is None else string_to_torch_dtype(communication_dtype) assert stage in [0, 2, 3] - if ProcessGroupManager.get_context_parallel_world_size() > 1: + if ProcessGroupManager.is_context_parallel_enabled(): dp_mesh = ProcessGroupManager.get_data_parallel_context_parallel_mesh() else: dp_mesh = ProcessGroupManager.get_data_parallel_mesh() - + block_classes = [ get_module_class_from_name(model_container[0], name) for name in block_names + teacher_block_names ] diff --git a/dolomite_engine/pretrain.py b/dolomite_engine/pretrain.py index e45794b57..c917af0be 100644 --- a/dolomite_engine/pretrain.py +++ b/dolomite_engine/pretrain.py @@ -28,12 +28,12 @@ MetricsTrackingDict, ProcessGroupManager, StepTracker, + create_context_parallel_ctx, + get_cp_context, init_distributed, is_torchao_available, log_rank_0, setup_tf32, - create_context_parallel_ctx, - get_cp_context, ) @@ -187,31 +187,30 @@ def train_step_without_pipeline_parallel( with no_sync(): for _ in range(gradient_accumulation_steps - 1): - if ProcessGroupManager.get_context_parallel_world_size() > 1: + if ProcessGroupManager.is_context_parallel_enabled(): batch = model._prepare_model_inputs(get_next_batch(train_dataloader)) input_ids = batch["input_ids"] labels = batch["labels"] cp_context = get_cp_context(False, False) - optional_context_parallel_ctx = ( - create_context_parallel_ctx( - cp_mesh=world_mesh["cp"], - cp_buffers=[input_ids, labels] + [model.model.transformer.rope.cos_cached, model.model.transformer.rope.sin_cached], - cp_seq_dims=[1, 1, 0, 0], - cp_no_restore_buffers={input_ids, labels}, - cp_rotate_method="allgather", - ) - ) - + optional_context_parallel_ctx = create_context_parallel_ctx( + cp_mesh=world_mesh["cp"], + cp_buffers=[input_ids, labels] + + [model.model.transformer.rope.cos_cached, model.model.transformer.rope.sin_cached], + cp_seq_dims=[1, 1, 0, 0], + cp_no_restore_buffers={input_ids, labels}, + cp_rotate_method="allgather", + ) + with cp_context(optional_context_parallel_ctx): loss_micro_step_dict = model((input_ids, labels), lm_loss_multiplier=lm_loss_multiplier) loss_micro_step_scaled: torch.Tensor = loss_micro_step_dict["loss"] / gradient_accumulation_steps loss_micro_step_scaled.backward() - + else: batch = get_next_batch(train_dataloader) with forward_context(): loss_micro_step_dict = model(batch, lm_loss_multiplier=lm_loss_multiplier) - + # compute gradients with backward_context(): loss_micro_step_scaled: torch.Tensor = loss_micro_step_dict["loss"] / gradient_accumulation_steps @@ -223,31 +222,30 @@ def train_step_without_pipeline_parallel( if fsdp_algorithm == 2: model.set_requires_gradient_sync(True) - if ProcessGroupManager.get_context_parallel_world_size() > 1: + if ProcessGroupManager.is_context_parallel_enabled(): batch = model._prepare_model_inputs(get_next_batch(train_dataloader)) input_ids = batch["input_ids"] labels = batch["labels"] cp_context = get_cp_context(False, False) - optional_context_parallel_ctx = ( - create_context_parallel_ctx( - cp_mesh=world_mesh["cp"], - cp_buffers=[input_ids, labels] + [model.model.transformer.rope.cos_cached, model.model.transformer.rope.sin_cached], - cp_seq_dims=[1, 1, 0, 0], - cp_no_restore_buffers={input_ids, labels}, - cp_rotate_method="allgather", - ) -) - + optional_context_parallel_ctx = create_context_parallel_ctx( + cp_mesh=world_mesh["cp"], + cp_buffers=[input_ids, labels] + + [model.model.transformer.rope.cos_cached, model.model.transformer.rope.sin_cached], + cp_seq_dims=[1, 1, 0, 0], + cp_no_restore_buffers={input_ids, labels}, + cp_rotate_method="allgather", + ) + with cp_context(optional_context_parallel_ctx): loss_micro_step_dict = model((input_ids, labels), lm_loss_multiplier=lm_loss_multiplier) loss_micro_step_scaled: torch.Tensor = loss_micro_step_dict["loss"] / gradient_accumulation_steps loss_micro_step_scaled.backward() - + else: batch = get_next_batch(train_dataloader) with forward_context(): loss_micro_step_dict = model(batch, lm_loss_multiplier=lm_loss_multiplier) - + # compute gradients with backward_context(): loss_micro_step_scaled: torch.Tensor = loss_micro_step_dict["loss"] / gradient_accumulation_steps @@ -426,7 +424,8 @@ def train( forward_context=forward_context, backward_context=backward_context, sync_every_gradient_accumulation_step=args.distributed_args.sync_every_gradient_accumulation_step, - lm_loss_multiplier=1 / (micro_batch_size * sequence_length / args.distributed_args.context_parallel_world_size), + lm_loss_multiplier=1 + / (micro_batch_size * sequence_length / args.distributed_args.context_parallel_world_size), ) metrics_tracker = metrics_tracker + loss_step_dict diff --git a/dolomite_engine/train_utils.py b/dolomite_engine/train_utils.py index 9eba97e5d..af5fc4ad6 100644 --- a/dolomite_engine/train_utils.py +++ b/dolomite_engine/train_utils.py @@ -17,7 +17,11 @@ def all_reduce_metrics_tracker(metrics_tracker: MetricsTrackingDict) -> MetricsT # tensor = torch.stack(tensor) / ProcessGroupManager.get_data_parallel_world_size() # tensor = tensor.cpu() # gloo op doesn't support averaging so we do sum and divide by world size above - torch.distributed.all_reduce(tensor, op=ReduceOp.AVG, group=ProcessGroupManager.get_mesh()["ddp", "fsdp", "cp"]._flatten(mesh_dim_name="dp_cp").get_group()) + torch.distributed.all_reduce( + tensor, + op=ReduceOp.AVG, + group=ProcessGroupManager.get_mesh()["ddp", "fsdp", "cp"]._flatten(mesh_dim_name="dp_cp").get_group(), + ) tensor = tensor.tolist() for i, key in enumerate(metrics_tracker): diff --git a/dolomite_engine/utils/__init__.py b/dolomite_engine/utils/__init__.py index 84d22dde4..c78044f7e 100644 --- a/dolomite_engine/utils/__init__.py +++ b/dolomite_engine/utils/__init__.py @@ -19,7 +19,13 @@ is_zstandard_available, log_environment, ) -from .parallel import ProcessGroupManager, get_pipeline_stage_ids_on_current_rank, run_rank_n, create_context_parallel_ctx, get_cp_context +from .parallel import ( + ProcessGroupManager, + create_context_parallel_ctx, + get_cp_context, + get_pipeline_stage_ids_on_current_rank, + run_rank_n, +) from .pydantic import BaseArgs from .safetensors import SafeTensorsWeightsManager from .step_tracker import StepTracker diff --git a/dolomite_engine/utils/parallel.py b/dolomite_engine/utils/parallel.py index 502aa253a..98150f3f8 100644 --- a/dolomite_engine/utils/parallel.py +++ b/dolomite_engine/utils/parallel.py @@ -1,7 +1,7 @@ import os -from contextlib import contextmanager, ExitStack +from contextlib import ExitStack, contextmanager from datetime import timedelta -from typing import Callable, List, Set, Optional, Generator +from typing import Callable, Generator, List, Optional, Set import torch import torch.distributed @@ -87,7 +87,10 @@ def __init__( else: assert data_parallel_sharding_world_size is not None - assert data_parallel_replication_world_size * data_parallel_sharding_world_size * context_parallel_world_size == data_parallel_size + assert ( + data_parallel_replication_world_size * data_parallel_sharding_world_size * context_parallel_world_size + == data_parallel_size + ) global _MESH, _TENSOR_PARALLEL_FIRST_RANK, _DATA_PARALLEL_REPLICATION_WORLD_SIZE, _DATA_PARALLEL_SHARDING_WORLD_SIZE @@ -103,7 +106,7 @@ def __init__( context_parallel_world_size, tensor_parallel_world_size, ), - mesh_dim_names=("pp", "ddp", "fsdp","cp", "tp"), + mesh_dim_names=("pp", "ddp", "fsdp", "cp", "tp"), ) _MESH["fsdp", "cp"]._flatten(mesh_dim_name="fsdp_cp") @@ -300,7 +303,7 @@ def get_data_parallel_mesh() -> DeviceMesh: if _DATA_PARALLEL_MESH is None: _DATA_PARALLEL_MESH = ProcessGroupManager.get_mesh()["ddp", "fsdp"] return _DATA_PARALLEL_MESH - + # data parallel + context parallel @staticmethod def get_data_parallel_context_parallel_mesh() -> DeviceMesh: @@ -309,11 +312,15 @@ def get_data_parallel_context_parallel_mesh() -> DeviceMesh: if _DATA_PARALLEL_CONTEXT_PARALLEL_MESH is None: _DATA_PARALLEL_CONTEXT_PARALLEL_MESH = ProcessGroupManager.get_mesh()["ddp", "fsdp_cp"] return _DATA_PARALLEL_CONTEXT_PARALLEL_MESH - + @staticmethod def get_context_parallel_world_size() -> int: return ProcessGroupManager.get_mesh()["cp"].size() + @staticmethod + def is_context_parallel_enabled() -> bool: + return ProcessGroupManager.get_context_parallel_world_size() > 1 + @staticmethod def get_data_parallel_group() -> ProcessGroup: global _DATA_PARALLEL_GROUP @@ -466,6 +473,7 @@ def create_context_parallel_ctx( no_restore_buffers=cp_no_restore_buffers, ) + def get_cp_context(enable_loss_parallel: bool, enable_compiled_autograd: bool): @contextmanager def context(cp_context: Optional[Generator[None, None, None]] = None): @@ -474,22 +482,16 @@ def context(cp_context: Optional[Generator[None, None, None]] = None): stack.enter_context(torch.distributed.tensor.parallel.loss_parallel()) if enable_compiled_autograd: - stack.enter_context( - torch._dynamo.utils.maybe_enable_compiled_autograd(True) - ) + stack.enter_context(torch._dynamo.utils.maybe_enable_compiled_autograd(True)) if cp_context is not None: - from torch.nn.attention import sdpa_kernel, SDPBackend + from torch.nn.attention import SDPBackend, sdpa_kernel # currently we only support these two SDP backends. # TODO (xilunwu): support cuDNN backend - stack.enter_context( - sdpa_kernel( - [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] - ) - ) + stack.enter_context(sdpa_kernel([SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION])) stack.enter_context(cp_context) yield - return context \ No newline at end of file + return context