diff --git a/tests/integration_tests/ft.py b/tests/integration_tests/ft.py index 73807d7f40..e6c9a8e9c3 100644 --- a/tests/integration_tests/ft.py +++ b/tests/integration_tests/ft.py @@ -58,7 +58,9 @@ def run_single_test(test_flavor: OverrideDefinitions, full_path: str, output_dir cmd = ( f'TORCH_TRACE="{output_dir}/{test_name}/compile_trace" ' + f"CUDA_VISIBLE_DEVICES={ranks} " + + "TRAIN_FILE=torchtitan.experiments.ft.train " + f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} ./run_train.sh " + + "--model.name=llama3_ft " + "--fault_tolerance.enable " + f"--fault_tolerance.replica_id={replica_id} --fault_tolerance.group_size={test_flavor.ngpu}" ) diff --git a/docs/torchft.md b/torchtitan/experiments/ft/torchft.md similarity index 81% rename from docs/torchft.md rename to torchtitan/experiments/ft/torchft.md index afcb50eee6..a32f215041 100644 --- a/docs/torchft.md +++ b/torchtitan/experiments/ft/torchft.md @@ -36,12 +36,12 @@ RUST_BACKTRACE=1 torchft_lighthouse --min_replicas 1 --quorum_tick_ms 100 --join 2. Launch the first TorchTitan instance: ```bash -NGPU=4 CUDA_VISIBLE_DEVICES=0,1,2,3 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --fault_tolerance.enable --fault_tolerance.replica_id=0 --fault_tolerance.group_size=2 --parallelism.data_parallel_shard_degree=4 +NGPU=4 CUDA_VISIBLE_DEVICES=0,1,2,3 TRAIN_FILE=torchtitan.experiments.ft.train CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --fault_tolerance.enable --fault_tolerance.replica_id=0 --fault_tolerance.group_size=2 --parallelism.data_parallel_shard_degree=4 ``` 3. Launch the second TorchTitan instance: ```bash -NGPU=4 CUDA_VISIBLE_DEVICES=4,5,6,7 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --fault_tolerance.enable --fault_tolerance.replica_id=1 --fault_tolerance.group_size=2 --parallelism.data_parallel_shard_degree=4 +NGPU=4 CUDA_VISIBLE_DEVICES=4,5,6,7 TRAIN_FILE=torchtitan.experiments.ft.train CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --fault_tolerance.enable --fault_tolerance.replica_id=1 --fault_tolerance.group_size=2 --parallelism.data_parallel_shard_degree=4 ``` ### Explanation @@ -68,12 +68,12 @@ The `--training.global_batch_size` parameter refers to global batch size that wi #### Replica Group 0 ```bash -CONFIG_FILE=./torchtitan/models/llama3_ft/train_configs/debug_model.toml CUDA_VISIBLE_DEVICES=0,1,2,3 NGPU=4 ./run_train.sh --parallelism.data_parallel_shard_degree=4 --fault_tolerance.enable --fault_tolerance.group_size=2 --fault_tolerance.replica_id=0 +TRAIN_FILE=torchtitan.experiments.ft.train CONFIG_FILE=./torchtitan/models/llama3_ft/train_configs/debug_model.toml CUDA_VISIBLE_DEVICES=0,1,2,3 NGPU=4 ./run_train.sh --parallelism.data_parallel_shard_degree=4 --fault_tolerance.enable --fault_tolerance.group_size=2 --fault_tolerance.replica_id=0 ``` #### Replica Group 1 ```bash -CONFIG_FILE=./torchtitan/models/llama3_ft/train_configs/debug_model.toml CUDA_VISIBLE_DEVICES=4,5,6,7 NGPU=4 ./run_train.sh --parallelism.data_parallel_shard_degree=4 --fault_tolerance.enable --fault_tolerance.group_size=2 --fault_tolerance.replica_id=1 +TRAIN_FILE=torchtitan.experiments.ft.train CONFIG_FILE=./torchtitan/models/llama3_ft/train_configs/debug_model.toml CUDA_VISIBLE_DEVICES=4,5,6,7 NGPU=4 ./run_train.sh --parallelism.data_parallel_shard_degree=4 --fault_tolerance.enable --fault_tolerance.group_size=2 --fault_tolerance.replica_id=1 ``` ## Fault Tolerance Configuration Options diff --git a/torchtitan/experiments/ft/train.py b/torchtitan/experiments/ft/train.py index 891f6c5554..39ecd6fd8c 100644 --- a/torchtitan/experiments/ft/train.py +++ b/torchtitan/experiments/ft/train.py @@ -4,13 +4,335 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import dataclasses +import importlib +import json import os +import time +from datetime import timedelta +from typing import cast, Iterator +import torch +import torch.distributed.checkpoint.stateful + +import torchtitan.protocols.train_spec as train_spec_module +from torch.distributed.elastic.multiprocessing.errors import record +from torchtitan.components.checkpoint import CheckpointManager +from torchtitan.components.dataloader import DataloaderExhaustedError +from torchtitan.components.ft import FTManager, maybe_semi_sync_training +from torchtitan.components.loss import IGNORE_INDEX +from torchtitan.components.metrics import ( + build_metrics_processor, + ensure_pp_loss_visible, +) +from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims, utils as dist_utils +from torchtitan.protocols import ModelProtocol +from torchtitan.protocols.model_converter import build_model_converters +from torchtitan.tools import utils +from torchtitan.tools.logging import logger +from torchtitan.tools.profiling import ( + maybe_enable_memory_snapshot, + maybe_enable_profiling, +) from torchtitan.train import main, Trainer class FTTrainer(Trainer): + ft_manager: FTManager | None = None + + # Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html + @record + def __init__(self, job_config: JobConfig): + torch._C._log_api_usage_once("torchtitan.train") + + self.job_config = job_config + + logger.info(f"Starting job: {job_config.job.description}") + + if job_config.experimental.custom_import: + importlib.import_module(job_config.experimental.custom_import) + + device_module, device_type = utils.device_module, utils.device_type + # pyrefly: ignore [read-only] + self.device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}") + # Device has to be set before creating TorchFT manager. + device_module.set_device(self.device) + + # init distributed and build meshes + self.parallel_dims = parallel_dims = self.init_distributed() + + # Logging needs to happen after distributed initialized + job_config.maybe_log() + + if parallel_dims.dp_enabled: + batch_mesh = parallel_dims.get_mesh("batch") + batch_degree, batch_rank = batch_mesh.size(), batch_mesh.get_local_rank() + else: + batch_degree, batch_rank = 1, 0 + + self.ft_manager = FTManager(job_config.fault_tolerance) + batch_degree, batch_rank = self.ft_manager.get_dp_info(batch_degree, batch_rank) + + # take control of garbage collection to avoid stragglers + self.gc_handler = utils.GarbageCollection( + gc_freq=job_config.training.gc_freq, debug=job_config.training.gc_debug + ) + + # Set random seed, and maybe enable deterministic mode + # (mainly for debugging, expect perf loss). + dist_utils.set_determinism( + parallel_dims, + self.device, + job_config.debug, + distinct_seed_mesh_dims=["pp"], + ) + self.train_spec = train_spec_module.get_train_spec(job_config.model.name) + + # build tokenizer and dataloader + self.tokenizer = ( + self.train_spec.build_tokenizer_fn(job_config) + if self.train_spec.build_tokenizer_fn is not None + else None + ) + + self.dataloader = self.train_spec.build_dataloader_fn( + dp_world_size=batch_degree, + dp_rank=batch_rank, + tokenizer=self.tokenizer, + job_config=job_config, + ) + + # build model (using meta init) + model_args = self.train_spec.model_args[job_config.model.flavor] + # set the model args from training job configs + model_args.update_from_config(job_config) + self.model_args = model_args + + logger.info( + f"Building {job_config.model.name} {job_config.model.flavor}" + f"with {json.dumps(dataclasses.asdict(model_args), indent=2, ensure_ascii=False)}" + ) + with ( + torch.device("meta"), + utils.set_default_dtype(TORCH_DTYPE_MAP[job_config.training.dtype]), + ): + model = self.train_spec.model_cls(model_args) + + # Build the collection of model converters. No-op if `model.converters` empty + model_converters = build_model_converters(job_config, parallel_dims) + model_converters.convert(model) + + # metrics logging + build_metrics_processor_fn = ( + build_metrics_processor + if self.train_spec.build_metrics_processor_fn is None + else self.train_spec.build_metrics_processor_fn + ) + self.metrics_processor = build_metrics_processor_fn( + job_config, parallel_dims, model_args + ) + color = self.metrics_processor.color + + # calculate model size and flops per token + ( + model_param_count, + self.metrics_processor.num_flops_per_token, + ) = model_args.get_nparams_and_flops(model, job_config.training.seq_len) + + logger.info( + f"{color.blue}Model {job_config.model.name} {job_config.model.flavor} " + f"{color.red}size: {model_param_count:,} total parameters{color.reset}" + ) + + # move sharded model to CPU/GPU and initialize weights via DTensor + buffer_device: torch.device | None + if job_config.checkpoint.create_seed_checkpoint: + init_device = "cpu" + buffer_device = None + elif job_config.training.enable_cpu_offload: + init_device = "cpu" + buffer_device = torch.device(device_type) + else: + init_device = device_type + buffer_device = None + + self.loss_fn = self.train_spec.build_loss_fn( + job_config, parallel_dims=parallel_dims, ft_manager=self.ft_manager + ) + + # verify batch sizes + global_batch_size = job_config.training.global_batch_size + if global_batch_size < 0: + # This global batch size results in 1 gradient accumulation + # step. + global_batch_size = job_config.training.local_batch_size * batch_degree + assert global_batch_size > 0 + assert ( + global_batch_size % (job_config.training.local_batch_size * batch_degree) + == 0 + ), ( + f"global batch size must be multiple of local batch size times " + f"data-parallel degree ({global_batch_size} " + f"% ({job_config.training.local_batch_size} * {batch_degree}) != 0)" + ) + + # calculate gradient accumulation steps + self.gradient_accumulation_steps = global_batch_size // ( + job_config.training.local_batch_size * batch_degree + ) + assert self.gradient_accumulation_steps > 0 + + # apply parallelisms and initialization + if parallel_dims.pp_enabled: + if not self.train_spec.pipelining_fn: + raise RuntimeError( + f"Pipeline Parallel is enabled but {job_config.model.name} " + f"does not support pipelining" + ) + + # apply both PT-D Pipeline Parallel and SPMD-style PT-D techniques + ( + self.pp_schedule, + self.model_parts, + self.pp_has_first_stage, + self.pp_has_last_stage, + ) = self.train_spec.pipelining_fn( + model, + parallel_dims, + job_config, + self.device, + model_args, + self.train_spec.parallelize_fn, + self.loss_fn, + ) + # when PP is enabled, `model` obj is no longer used after this point, + # model_parts is used instead + del model + + for m in self.model_parts: + m.to_empty(device=init_device) + with torch.no_grad(): + cast(ModelProtocol, m).init_weights(buffer_device=buffer_device) + m.train() + + # confirm that user will be able to view loss metrics on the console + ensure_pp_loss_visible(parallel_dims, job_config, color) + else: + # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel + model = self.train_spec.parallelize_fn(model, parallel_dims, job_config) + + model.to_empty(device=init_device) + with torch.no_grad(): + cast(ModelProtocol, model).init_weights(buffer_device=buffer_device) + model.train() + + self.model_parts = [model] + + self.ft_manager.maybe_set_all_reduce_hook(self.model_parts) + + # initialize device memory monitor and get peak flops for MFU calculation + device_memory_monitor = self.metrics_processor.device_memory_monitor + gpu_peak_flops = utils.get_peak_flops(device_memory_monitor.device_name) + logger.info(f"Peak FLOPS used for computing MFU: {gpu_peak_flops:.3e}") + device_mem_stats = device_memory_monitor.get_peak_stats() + logger.info( + f"{device_type.upper()} memory usage for model: " + f"{device_mem_stats.max_reserved_gib:.2f}GiB" + f"({device_mem_stats.max_reserved_pct:.2f}%)" + ) + + # build optimizer after applying parallelisms to the model + self.optimizers = self.train_spec.build_optimizers_fn( + self.model_parts, job_config.optimizer, parallel_dims, self.ft_manager + ) + self.lr_schedulers = self.train_spec.build_lr_schedulers_fn( + self.optimizers, job_config.lr_scheduler, job_config.training.steps + ) + # Post optimizer step model converters hook. + # e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2 + # where it issues a single all-reduce for all parameters at once for better performance + self.optimizers.register_step_post_hook( + lambda *args, **kwargs: model_converters.post_optimizer_hook( + self.model_parts + ) + ) + self.metrics_processor.optimizers = self.optimizers + self.metrics_processor.model_parts = self.model_parts + + # Initialize trainer states that will be saved in checkpoint. + # These attributes must be initialized before checkpoint loading. + self.step = 0 + self.ntokens_seen = 0 + + self.checkpointer = CheckpointManager( + dataloader=self.dataloader, + model_parts=self.model_parts, + optimizers=self.optimizers, + lr_schedulers=self.lr_schedulers, + states={"train_state": self}, + checkpoint_config=job_config.checkpoint, + sd_adapter=( + self.train_spec.state_dict_adapter( + model_args, job_config.model.hf_assets_path + ) + if self.train_spec.state_dict_adapter + else None + ), + base_folder=job_config.job.dump_folder, + ft_manager=self.ft_manager, + ) + + loss_parallel_enabled = ( + parallel_dims.tp_enabled + and not job_config.parallelism.disable_loss_parallel + ) + self.train_context = dist_utils.get_train_context(loss_parallel_enabled) + self.maybe_enable_amp = dist_utils.maybe_enable_amp( + parallel_dims, + job_config.training.mixed_precision_param, + device_type, + ) + + # Build validator if validation is configured + if job_config.validation.enable: + assert self.train_spec.build_validator_fn is not None + + pp_schedule, pp_has_first_stage, pp_has_last_stage = ( + ( + self.pp_schedule, + self.pp_has_first_stage, + self.pp_has_last_stage, + ) + if parallel_dims.pp_enabled + else (None, None, None) + ) + + self.validator = self.train_spec.build_validator_fn( + job_config=job_config, + dp_world_size=batch_degree, + dp_rank=batch_rank, + tokenizer=self.tokenizer, + parallel_dims=parallel_dims, + loss_fn=self.loss_fn, + validation_context=self.train_context, + maybe_enable_amp=self.maybe_enable_amp, + metrics_processor=self.metrics_processor, + pp_schedule=pp_schedule, + pp_has_first_stage=pp_has_first_stage, + pp_has_last_stage=pp_has_last_stage, + ) + + logger.info( + "Trainer is initialized with " + f"local batch size {job_config.training.local_batch_size}, " + f"global batch size {global_batch_size}, " + f"gradient accumulation steps {self.gradient_accumulation_steps}, " + f"sequence length {job_config.training.seq_len}, " + f"total steps {job_config.training.steps} " + f"(warmup {job_config.lr_scheduler.warmup_steps})" + ) + def init_distributed(self) -> ParallelDims: job_config = self.job_config @@ -46,6 +368,196 @@ def init_distributed(self) -> ParallelDims: world_size=world_size, ) + def train_step( + self, data_iterator: Iterator[tuple[dict[str, torch.Tensor], torch.Tensor]] + ): + self.optimizers.zero_grad() + # Save the current step learning rate for logging + lr = self.lr_schedulers.schedulers[0].get_last_lr()[0] + + # Keep these variables local to shorten the code as these are + # the major variables that are used in the training loop. + parallel_dims = self.parallel_dims + + # Collect all microbatches on CPU and count total valid tokens + microbatches = [] + local_valid_tokens = torch.tensor(0, dtype=torch.int64) + for _microbatch in range(self.gradient_accumulation_steps): + input_dict, labels = next(data_iterator) + local_valid_tokens += (labels != IGNORE_INDEX).sum() + microbatches.append((input_dict, labels)) + + # All-reduce to get global token count across DP ranks + # Move to GPU for distributed communication + local_valid_tokens = local_valid_tokens.to(self.device) + if parallel_dims.dp_enabled: + batch_mesh = parallel_dims.get_mesh("batch") + global_valid_tokens = dist_utils.dist_sum(local_valid_tokens, batch_mesh) + else: + global_valid_tokens = local_valid_tokens.float() + + # Process each microbatch: move to GPU, forward/backward, then free + accumulated_losses = [] + for input_dict, labels in microbatches: + # Move tensors to GPU + for k, v in input_dict.items(): + if isinstance(v, torch.Tensor): + input_dict[k] = v.to(self.device) + labels = labels.to(self.device) + + loss = self.forward_backward_step( + input_dict=input_dict, + labels=labels, + # pyrefly: ignore [bad-argument-type] + global_valid_tokens=global_valid_tokens, + ) + accumulated_losses.append(loss.detach()) + + grad_norm = dist_utils.clip_grad_norm_( + [p for m in self.model_parts for p in m.parameters()], + self.job_config.training.max_norm, + foreach=True, + pp_mesh=parallel_dims.get_optional_mesh("pp"), + ep_enabled=parallel_dims.ep_enabled, + ) + self.checkpointer.maybe_wait_for_staging() + self.optimizers.step() + self.lr_schedulers.step() + + # Reduce the data collected over gradient accumulation steps. + loss = torch.sum(torch.stack(accumulated_losses)) + + # log metrics + if not self.metrics_processor.should_log(self.step): + return + + if parallel_dims.dp_cp_enabled: + loss = loss.detach() + ft_pg = self.ft_manager.loss_sync_pg + loss_mesh = parallel_dims.get_optional_mesh("loss") + + # For global_avg_loss, we want the average loss across all ranks: + # loss = local_loss_sum / global_valid_tokens + # global_avg_loss = sum(local_loss_sum) / global_valid_tokens + # = sum(loss) + # + # For global_max_loss, we want the max of local average losses across ranks: + # local_avg_loss = local_loss_sum / local_valid_tokens + # = (loss * global_valid_tokens) / local_valid_tokens + # global_max_loss = max(local_avg_loss) + local_avg_loss = loss * global_valid_tokens / local_valid_tokens + global_avg_loss, global_max_loss, global_ntokens_seen = ( + dist_utils.dist_sum(loss, loss_mesh, ft_pg), + dist_utils.dist_max(local_avg_loss, loss_mesh, ft_pg), + dist_utils.dist_sum( + torch.tensor( + self.ntokens_seen, dtype=torch.int64, device=self.device + ), + loss_mesh, + ft_pg, + ), + ) + else: + global_avg_loss = global_max_loss = loss.detach().item() + global_ntokens_seen = self.ntokens_seen + + extra_metrics = { + "n_tokens_seen": global_ntokens_seen, + "lr": lr, + } + self.metrics_processor.log( + self.step, + global_avg_loss, + global_max_loss, + grad_norm.item(), + extra_metrics=extra_metrics, + ) + + @record + def train(self): + job_config = self.job_config + + self.checkpointer.load(step=job_config.checkpoint.load_step) + logger.info(f"Training starts at step {self.step + 1}") + + leaf_folder = ( + "" + if not self.ft_manager.enabled + else f"replica_{self.ft_manager.replica_id}" + ) + with ( + maybe_enable_profiling( + job_config.profiling, + global_step=self.step, + base_folder=job_config.job.dump_folder, + leaf_folder=leaf_folder, + ) as torch_profiler, + maybe_enable_memory_snapshot( + job_config.profiling, + global_step=self.step, + base_folder=job_config.job.dump_folder, + leaf_folder=leaf_folder, + ) as memory_profiler, + maybe_semi_sync_training( + job_config.fault_tolerance, + ft_manager=self.ft_manager, + model=self.model_parts[0], + n_layers=( + self.model_args.n_layers + if hasattr(self.model_args, "n_layers") + else 0 + ), + optimizer=self.optimizers, + fragment_fn=( + self.train_spec.fragment_fn + if hasattr(self.train_spec, "fragment_fn") + else None + ), + ), + ): + data_iterator = self.batch_generator(self.dataloader) + while self.should_continue_training(): + self.step += 1 + self.gc_handler.run(self.step) + try: + self.train_step(data_iterator) + except DataloaderExhaustedError: + logger.warning("Ran out of data; last step was canceled.") + break + + self.checkpointer.save( + self.step, last_step=(self.step == job_config.training.steps) + ) + + # Run validation if validator is available + if ( + self.job_config.validation.enable + and self.validator.should_validate(self.step) + ): + self.validator.validate(self.model_parts, self.step) + + # signal the profiler that the next profiling step has started + if torch_profiler: + torch_profiler.step() + if memory_profiler: + memory_profiler.step() + + # reduce timeout after first train step for faster signal + # (assuming lazy init and compilation are finished) + if self.step == 1: + dist_utils.set_pg_timeouts( + timeout=timedelta( + seconds=job_config.comm.train_timeout_seconds + ), + parallel_dims=self.parallel_dims, + ) + + if torch.distributed.get_rank() == 0: + logger.info("Sleeping 2 seconds for other ranks to complete") + time.sleep(2) + + logger.info("Training completed") + if __name__ == "__main__": main(FTTrainer) diff --git a/torchtitan/models/flux/train.py b/torchtitan/models/flux/train.py index be67c33a75..cde7f88cc2 100644 --- a/torchtitan/models/flux/train.py +++ b/torchtitan/models/flux/train.py @@ -246,19 +246,17 @@ def train_step( if parallel_dims.dp_cp_enabled: loss = loss.detach() - ft_pg = self.ft_manager.loss_sync_pg loss_mesh = parallel_dims.get_optional_mesh("loss") # NOTE: the loss returned by train global_avg_loss, global_max_loss, global_ntokens_seen = ( - dist_utils.dist_sum(loss, loss_mesh, ft_pg), - dist_utils.dist_max(loss, loss_mesh, ft_pg), + dist_utils.dist_sum(loss, loss_mesh), + dist_utils.dist_max(loss, loss_mesh), dist_utils.dist_sum( torch.tensor( self.ntokens_seen, dtype=torch.int64, device=self.device ), loss_mesh, - ft_pg, ), ) else: diff --git a/torchtitan/train.py b/torchtitan/train.py index 9378d742e3..9a1bcfde4d 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -19,7 +19,6 @@ import torchtitan.protocols.train_spec as train_spec_module from torchtitan.components.checkpoint import CheckpointManager from torchtitan.components.dataloader import DataloaderExhaustedError -from torchtitan.components.ft import FTManager, maybe_semi_sync_training from torchtitan.components.loss import IGNORE_INDEX from torchtitan.components.metrics import ( build_metrics_processor, @@ -59,7 +58,6 @@ class Trainer(torch.distributed.checkpoint.stateful.Stateful): # non-swappable training components checkpointer: CheckpointManager - ft_manager: FTManager # runtime utilities device: torch.device @@ -103,9 +101,6 @@ def __init__(self, job_config: JobConfig): else: batch_degree, batch_rank = 1, 0 - self.ft_manager = FTManager(job_config.fault_tolerance) - batch_degree, batch_rank = self.ft_manager.get_dp_info(batch_degree, batch_rank) - # take control of garbage collection to avoid stragglers self.gc_handler = utils.GarbageCollection( gc_freq=job_config.training.gc_freq, debug=job_config.training.gc_debug @@ -190,7 +185,7 @@ def __init__(self, job_config: JobConfig): buffer_device = None self.loss_fn = self.train_spec.build_loss_fn( - job_config, parallel_dims=parallel_dims, ft_manager=self.ft_manager + job_config, parallel_dims=parallel_dims ) # verify batch sizes @@ -261,8 +256,6 @@ def __init__(self, job_config: JobConfig): self.model_parts = [model] - self.ft_manager.maybe_set_all_reduce_hook(self.model_parts) - # initialize device memory monitor and get peak flops for MFU calculation device_memory_monitor = self.metrics_processor.device_memory_monitor gpu_peak_flops = utils.get_peak_flops(device_memory_monitor.device_name) @@ -276,7 +269,7 @@ def __init__(self, job_config: JobConfig): # build optimizer after applying parallelisms to the model self.optimizers = self.train_spec.build_optimizers_fn( - self.model_parts, job_config.optimizer, parallel_dims, self.ft_manager + self.model_parts, job_config.optimizer, parallel_dims ) self.lr_schedulers = self.train_spec.build_lr_schedulers_fn( self.optimizers, job_config.lr_scheduler, job_config.training.steps @@ -312,7 +305,6 @@ def __init__(self, job_config: JobConfig): else None ), base_folder=job_config.job.dump_folder, - ft_manager=self.ft_manager, ) loss_parallel_enabled = ( @@ -613,7 +605,6 @@ def train_step( if parallel_dims.dp_cp_enabled: loss = loss.detach() - ft_pg = self.ft_manager.loss_sync_pg loss_mesh = parallel_dims.get_optional_mesh("loss") # For global_avg_loss, we want the average loss across all ranks: @@ -627,14 +618,13 @@ def train_step( # global_max_loss = max(local_avg_loss) local_avg_loss = loss * global_valid_tokens / local_valid_tokens global_avg_loss, global_max_loss, global_ntokens_seen = ( - dist_utils.dist_sum(loss, loss_mesh, ft_pg), - dist_utils.dist_max(local_avg_loss, loss_mesh, ft_pg), + dist_utils.dist_sum(loss, loss_mesh), + dist_utils.dist_max(local_avg_loss, loss_mesh), dist_utils.dist_sum( torch.tensor( self.ntokens_seen, dtype=torch.int64, device=self.device ), loss_mesh, - ft_pg, ), ) else: @@ -660,40 +650,17 @@ def train(self): self.checkpointer.load(step=job_config.checkpoint.load_step) logger.info(f"Training starts at step {self.step + 1}") - leaf_folder = ( - "" - if not self.ft_manager.enabled - else f"replica_{self.ft_manager.replica_id}" - ) with ( maybe_enable_profiling( job_config.profiling, global_step=self.step, base_folder=job_config.job.dump_folder, - leaf_folder=leaf_folder, ) as torch_profiler, maybe_enable_memory_snapshot( job_config.profiling, global_step=self.step, base_folder=job_config.job.dump_folder, - leaf_folder=leaf_folder, ) as memory_profiler, - maybe_semi_sync_training( - job_config.fault_tolerance, - ft_manager=self.ft_manager, - model=self.model_parts[0], - n_layers=( - self.model_args.n_layers - if hasattr(self.model_args, "n_layers") - else 0 - ), - optimizer=self.optimizers, - fragment_fn=( - self.train_spec.fragment_fn - if hasattr(self.train_spec, "fragment_fn") - else None - ), - ), ): data_iterator = self.batch_generator(self.dataloader) while self.should_continue_training():