diff --git a/dolomite_engine/arguments.py b/dolomite_engine/arguments.py index 210d3b5e0..c24bf4d22 100644 --- a/dolomite_engine/arguments.py +++ b/dolomite_engine/arguments.py @@ -283,6 +283,8 @@ class DistributedArgs(BaseArgs): pipeline_parallel_schedule: str | None = None # whether to use async-TP use_async_tensor_parallel: bool = False + # world size for each CP group + context_parallel_world_size: int = 1 def model_post_init(self, __context: Any) -> None: # communication dtype diff --git a/dolomite_engine/distributed.py b/dolomite_engine/distributed.py index a24b67169..6d7a6893b 100644 --- a/dolomite_engine/distributed.py +++ b/dolomite_engine/distributed.py @@ -191,8 +191,11 @@ def wrap_model_container_for_distributed_training( communication_dtype = None if communication_dtype is None else string_to_torch_dtype(communication_dtype) assert stage in [0, 2, 3] + if ProcessGroupManager.is_context_parallel_enabled(): + dp_mesh = ProcessGroupManager.get_data_parallel_context_parallel_mesh() + else: + dp_mesh = ProcessGroupManager.get_data_parallel_mesh() - dp_mesh = ProcessGroupManager.get_data_parallel_mesh() block_classes = [ get_module_class_from_name(model_container[0], name) for name in block_names + teacher_block_names ] diff --git a/dolomite_engine/hf_models/mixins/dense/base.py b/dolomite_engine/hf_models/mixins/dense/base.py index 1dae3c3a9..a77f027ac 100644 --- a/dolomite_engine/hf_models/mixins/dense/base.py +++ b/dolomite_engine/hf_models/mixins/dense/base.py @@ -326,6 +326,7 @@ def _prepare_a_bunch_of_stuff( tuple[torch.Tensor], ]: if use_cache is None: + ## TODO: disable cache for cp without padding free transformer use_cache = False if self.use_padding_free_transformer else self.config.use_cache if input_ids is not None and inputs_embeds is not None: @@ -446,6 +447,8 @@ def _setup_positional_encoding(self) -> None: base=self.config.rope_theta, scale=self.config.rope_scaling["factor"], original_max_position_embeddings=self.config.rope_scaling["original_max_position_embeddings"], + beta_fast=self.config.rope_scaling["beta_fast"], + beta_slow=self.config.rope_scaling["beta_slow"], ) elif self.position_embedding_type == "nope": pass diff --git a/dolomite_engine/model_wrapper/pretraining.py b/dolomite_engine/model_wrapper/pretraining.py index 4fde79f52..f1cf59c95 100644 --- a/dolomite_engine/model_wrapper/pretraining.py +++ b/dolomite_engine/model_wrapper/pretraining.py @@ -122,9 +122,9 @@ def forward( else: assert aux_loss_from_pipeline_parallel == 0 - batch = self._prepare_model_inputs(batch) - labels = batch.pop("labels") - output: CausalLMOutputWithPast | PipelineParallelOutput = self.model(**batch, return_dict=True) + input_ids, labels = batch + + output = self.model(input_ids=input_ids, return_dict=True) if self.is_pipeline_parallel_enabled: # aux_loss is returned as a 0 dimensional tensor diff --git a/dolomite_engine/pretrain.py b/dolomite_engine/pretrain.py index 13b300746..705dbc5a5 100644 --- a/dolomite_engine/pretrain.py +++ b/dolomite_engine/pretrain.py @@ -32,6 +32,8 @@ MetricsTrackingDict, ProcessGroupManager, StepTracker, + create_context_parallel_ctx, + get_cp_context, init_distributed, is_torchao_available, log_rank_0, @@ -185,16 +187,38 @@ def train_step_without_pipeline_parallel( gradient_accumulation_steps = StepTracker.get_gradient_accumulation_steps() + world_mesh = ProcessGroupManager.get_mesh() + with no_sync(): for _ in range(gradient_accumulation_steps - 1): - batch = get_next_batch(train_dataloader) - with forward_context(): - loss_micro_step_dict = model(batch, lm_loss_multiplier=lm_loss_multiplier) + if ProcessGroupManager.is_context_parallel_enabled(): + batch = model._prepare_model_inputs(get_next_batch(train_dataloader)) + input_ids = batch["input_ids"] + labels = batch["labels"] + cp_context = get_cp_context(False, False) + optional_context_parallel_ctx = create_context_parallel_ctx( + cp_mesh=world_mesh["cp"], + cp_buffers=[input_ids, labels] + + [model.model.transformer.rope.cos_cached, model.model.transformer.rope.sin_cached], + cp_seq_dims=[1, 1, 0, 0], + cp_no_restore_buffers={input_ids, labels}, + cp_rotate_method="allgather", + ) + + with cp_context(optional_context_parallel_ctx): + loss_micro_step_dict = model((input_ids, labels), lm_loss_multiplier=lm_loss_multiplier) + loss_micro_step_scaled: torch.Tensor = loss_micro_step_dict["loss"] / gradient_accumulation_steps + loss_micro_step_scaled.backward() + + else: + batch = get_next_batch(train_dataloader) + with forward_context(): + loss_micro_step_dict = model(batch, lm_loss_multiplier=lm_loss_multiplier) - # compute gradients - with backward_context(): - loss_micro_step_scaled: torch.Tensor = loss_micro_step_dict["loss"] / gradient_accumulation_steps - loss_micro_step_scaled.backward() + # compute gradients + with backward_context(): + loss_micro_step_scaled: torch.Tensor = loss_micro_step_dict["loss"] / gradient_accumulation_steps + loss_micro_step_scaled.backward() with torch.inference_mode(): metrics_tracker = metrics_tracker + loss_micro_step_dict @@ -202,14 +226,34 @@ def train_step_without_pipeline_parallel( if fsdp_algorithm == 2: model.set_requires_gradient_sync(True) - batch = get_next_batch(train_dataloader) - with forward_context(): - loss_micro_step_dict = model(batch, lm_loss_multiplier=lm_loss_multiplier) + if ProcessGroupManager.is_context_parallel_enabled(): + batch = model._prepare_model_inputs(get_next_batch(train_dataloader)) + input_ids = batch["input_ids"] + labels = batch["labels"] + cp_context = get_cp_context(False, False) + optional_context_parallel_ctx = create_context_parallel_ctx( + cp_mesh=world_mesh["cp"], + cp_buffers=[input_ids, labels] + + [model.model.transformer.rope.cos_cached, model.model.transformer.rope.sin_cached], + cp_seq_dims=[1, 1, 0, 0], + cp_no_restore_buffers={input_ids, labels}, + cp_rotate_method="allgather", + ) - # compute gradients - with backward_context(): - loss_micro_step_scaled: torch.Tensor = loss_micro_step_dict["loss"] / gradient_accumulation_steps - loss_micro_step_scaled.backward() + with cp_context(optional_context_parallel_ctx): + loss_micro_step_dict = model((input_ids, labels), lm_loss_multiplier=lm_loss_multiplier) + loss_micro_step_scaled: torch.Tensor = loss_micro_step_dict["loss"] / gradient_accumulation_steps + loss_micro_step_scaled.backward() + + else: + batch = get_next_batch(train_dataloader) + with forward_context(): + loss_micro_step_dict = model(batch, lm_loss_multiplier=lm_loss_multiplier) + + # compute gradients + with backward_context(): + loss_micro_step_scaled: torch.Tensor = loss_micro_step_dict["loss"] / gradient_accumulation_steps + loss_micro_step_scaled.backward() with torch.inference_mode(): metrics_tracker = metrics_tracker + loss_micro_step_dict @@ -393,7 +437,8 @@ def train( forward_context=forward_context, backward_context=backward_context, sync_every_gradient_accumulation_step=args.distributed_args.sync_every_gradient_accumulation_step, - lm_loss_multiplier=1 / (micro_batch_size * sequence_length), + lm_loss_multiplier=1 + / (micro_batch_size * sequence_length / args.distributed_args.context_parallel_world_size), ) metrics_tracker = metrics_tracker + loss_step_dict @@ -577,6 +622,7 @@ def main(mode: Mode = Mode.training) -> None: data_parallel_replication_world_size=args.distributed_args.zero_topology.data_parallel_replication_world_size, data_parallel_sharding_world_size=args.distributed_args.zero_topology.data_parallel_sharding_world_size, zero_stage=args.distributed_args.stage, + context_parallel_world_size=args.distributed_args.context_parallel_world_size, timeout_minutes=args.distributed_args.timeout_minutes, use_async_tensor_parallel=args.distributed_args.use_async_tensor_parallel, ) diff --git a/dolomite_engine/train_utils.py b/dolomite_engine/train_utils.py index a829a84dd..54f6ad35d 100644 --- a/dolomite_engine/train_utils.py +++ b/dolomite_engine/train_utils.py @@ -21,7 +21,11 @@ def all_reduce_metrics_tracker(metrics_tracker: MetricsTrackingDict) -> MetricsT # tensor = torch.stack(tensor) / ProcessGroupManager.get_data_parallel_world_size() # tensor = tensor.cpu() # gloo op doesn't support averaging so we do sum and divide by world size above - torch.distributed.all_reduce(tensor, op=ReduceOp.AVG, group=ProcessGroupManager.get_data_parallel_group()) + torch.distributed.all_reduce( + tensor, + op=ReduceOp.AVG, + group=ProcessGroupManager.get_mesh()["ddp", "fsdp", "cp"]._flatten(mesh_dim_name="dp_cp").get_group(), + ) tensor = tensor.tolist() for i, key in enumerate(metrics_tracker): diff --git a/dolomite_engine/utils/__init__.py b/dolomite_engine/utils/__init__.py index 1438c7e45..a348717ef 100644 --- a/dolomite_engine/utils/__init__.py +++ b/dolomite_engine/utils/__init__.py @@ -23,7 +23,13 @@ is_zstandard_available, log_environment, ) -from .parallel import ProcessGroupManager, get_pipeline_stage_ids_on_current_rank, run_rank_n +from .parallel import ( + ProcessGroupManager, + create_context_parallel_ctx, + get_cp_context, + get_pipeline_stage_ids_on_current_rank, + run_rank_n, +) from .pydantic import BaseArgs from .safetensors import SafeTensorsWeightsManager from .step_tracker import StepTracker @@ -37,6 +43,7 @@ def init_distributed( pipeline_parallel_world_size: int, data_parallel_replication_world_size: int, data_parallel_sharding_world_size: int, + context_parallel_world_size: int, zero_stage: int, timeout_minutes: int = None, use_async_tensor_parallel: bool = False, @@ -58,6 +65,7 @@ def init_distributed( pipeline_parallel_world_size=pipeline_parallel_world_size, data_parallel_replication_world_size=data_parallel_replication_world_size, data_parallel_sharding_world_size=data_parallel_sharding_world_size, + context_parallel_world_size=context_parallel_world_size, zero_stage=zero_stage, timeout_minutes=timeout_minutes, use_async_tensor_parallel=use_async_tensor_parallel, @@ -67,6 +75,7 @@ def init_distributed( log_rank_0(logging.INFO, f"total GPUs = {process_group_manager.get_world_size()}") log_rank_0(logging.INFO, f"tensor parallel size = {process_group_manager.get_tensor_parallel_world_size()}") log_rank_0(logging.INFO, f"data parallel size = {process_group_manager.get_data_parallel_world_size()}") + log_rank_0(logging.INFO, f"context parallel size = {context_parallel_world_size}") def setup_tf32(use_tf32: bool = True) -> None: diff --git a/dolomite_engine/utils/parallel.py b/dolomite_engine/utils/parallel.py index db0e89c64..dc9c4e465 100644 --- a/dolomite_engine/utils/parallel.py +++ b/dolomite_engine/utils/parallel.py @@ -3,9 +3,9 @@ # ************************************************** import os -from contextlib import contextmanager +from contextlib import ExitStack, contextmanager from datetime import timedelta -from typing import Callable +from typing import Callable, Generator, List, Optional, Set import torch import torch.distributed @@ -43,6 +43,9 @@ _DATA_PARALLEL_REPLICATION_WORLD_SIZE: int | None = None _DATA_PARALLEL_SHARDING_WORLD_SIZE: int | None = None +# context parallel +_DATA_PARALLEL_CONTEXT_PARALLEL_MESH: DeviceMesh | None = None + class ProcessGroupManager: def __init__( @@ -51,6 +54,7 @@ def __init__( pipeline_parallel_world_size: int = 1, data_parallel_replication_world_size: int | None = None, data_parallel_sharding_world_size: int | None = None, + context_parallel_world_size: int | None = None, zero_stage: int = 3, timeout_minutes: int | None = None, use_async_tensor_parallel: bool = False, @@ -84,7 +88,10 @@ def __init__( else: assert data_parallel_sharding_world_size is not None - assert data_parallel_replication_world_size * data_parallel_sharding_world_size == data_parallel_size + assert ( + data_parallel_replication_world_size * data_parallel_sharding_world_size * context_parallel_world_size + == data_parallel_size + ) global _MESH, _TENSOR_PARALLEL_FIRST_RANK, _DATA_PARALLEL_REPLICATION_WORLD_SIZE, _DATA_PARALLEL_SHARDING_WORLD_SIZE @@ -97,11 +104,14 @@ def __init__( pipeline_parallel_world_size, data_parallel_replication_world_size, data_parallel_sharding_world_size, + context_parallel_world_size, tensor_parallel_world_size, ), - mesh_dim_names=("pp", "ddp", "fsdp", "tp"), + mesh_dim_names=("pp", "ddp", "fsdp", "cp", "tp"), ) + _MESH["fsdp", "cp"]._flatten(mesh_dim_name="fsdp_cp") + local_rank = int(os.getenv("LOCAL_RANK", 0)) torch.cuda.set_device(local_rank) @@ -295,6 +305,23 @@ def get_data_parallel_mesh() -> DeviceMesh: _DATA_PARALLEL_MESH = ProcessGroupManager.get_mesh()["ddp", "fsdp"] return _DATA_PARALLEL_MESH + # data parallel + context parallel + @staticmethod + def get_data_parallel_context_parallel_mesh() -> DeviceMesh: + global _DATA_PARALLEL_CONTEXT_PARALLEL_MESH + + if _DATA_PARALLEL_CONTEXT_PARALLEL_MESH is None: + _DATA_PARALLEL_CONTEXT_PARALLEL_MESH = ProcessGroupManager.get_mesh()["ddp", "fsdp_cp"] + return _DATA_PARALLEL_CONTEXT_PARALLEL_MESH + + @staticmethod + def get_context_parallel_world_size() -> int: + return ProcessGroupManager.get_mesh()["cp"].size() + + @staticmethod + def is_context_parallel_enabled() -> bool: + return ProcessGroupManager.get_context_parallel_world_size() > 1 + @staticmethod def get_data_parallel_group() -> ProcessGroup: global _DATA_PARALLEL_GROUP @@ -399,8 +426,11 @@ def func_rank_other(*args, **kwargs): def is_tracking_rank() -> bool: + + ## TODO verify cp local rank for logging return ( ProcessGroupManager.get_data_parallel_rank() == 0 + and ProcessGroupManager.get_mesh()["cp"].get_local_rank() == 0 and ProcessGroupManager.is_tensor_parallel_first_rank() and ProcessGroupManager.get_pipeline_parallel_rank() == ProcessGroupManager.get_pipeline_parallel_world_size() - 1 @@ -418,3 +448,51 @@ def get_pipeline_stage_ids_on_current_rank(num_pipeline_stages: int) -> int: ) return tuple(pp_rank + i * pp_world_size for i in range(num_pipeline_stages_per_rank)) + + +def create_context_parallel_ctx( + cp_mesh: DeviceMesh, + cp_buffers: List[torch.Tensor], + cp_seq_dims: List[int], + cp_no_restore_buffers: Set[torch.Tensor], + cp_rotate_method: str, +): + try: + from torch.distributed.tensor.experimental import context_parallel + from torch.distributed.tensor.experimental._attention import set_rotate_method + except ImportError: + print( + f"PyTorch version {torch.__version__} does not include the experimental " + "Context Parallel API. Please update to a newer version." + ) + + set_rotate_method(cp_rotate_method) + return context_parallel( + cp_mesh, + buffers=cp_buffers, + buffer_seq_dims=cp_seq_dims, + no_restore_buffers=cp_no_restore_buffers, + ) + + +def get_cp_context(enable_loss_parallel: bool, enable_compiled_autograd: bool): + @contextmanager + def context(cp_context: Optional[Generator[None, None, None]] = None): + with ExitStack() as stack: + if enable_loss_parallel: + stack.enter_context(torch.distributed.tensor.parallel.loss_parallel()) + + if enable_compiled_autograd: + stack.enter_context(torch._dynamo.utils.maybe_enable_compiled_autograd(True)) + + if cp_context is not None: + from torch.nn.attention import SDPBackend, sdpa_kernel + + # currently we only support these two SDP backends. + # TODO (xilunwu): support cuDNN backend + stack.enter_context(sdpa_kernel([SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION])) + stack.enter_context(cp_context) + + yield + + return context