diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index e9e701442..6a5826f27 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -667,20 +667,26 @@ def _ft_load(self) -> None: step = self._find_load_step(folder=self._ft_folder()) if step == -1: return - - begin = time.monotonic() - logger.info(f"Loading the FT checkpoint at step {step}.") - checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder()) - self.dcp_load( - self.ft_states, - checkpoint_id=checkpoint_id, - # FT checkpoints are always DCP because FT checkpoint currently only save/load dataloader. - from_hf=False, - ) - GarbageCollection.collect("GC collection for checkpoint loading.") - logger.info( - f"Finished loading the ft checkpoint in {time.monotonic() - begin:.2f} seconds." - ) + try: + begin = time.monotonic() + logger.info(f"Loading the FT checkpoint at step {step}.") + checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder()) + logger.info(f"Calling dcp_load for {checkpoint_id}") + self.dcp_load( + self.ft_states, + checkpoint_id=checkpoint_id, + # FT checkpoints are always DCP because FT checkpoint currently only save/load dataloader. + from_hf=False, + ) + GarbageCollection.collect("GC collection for checkpoint loading.") + logger.info( + f"Finished loading the ft checkpoint in {time.monotonic() - begin:.2f} seconds." + ) + except Exception as e: + # The checkpoint is corrupt. We'll replay all data here. + # TODO: We can try to load checkpoint from previous steps. + logger.error("Failed to load the FT checkpoint.") + return def _flattened_model_states_sd( self, state_dict: dict[str, Any] | None = None diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 0eb06a0d4..14a5218a4 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -34,6 +34,12 @@ class Profiling: profile_freq: int = 10 """How often to collect profile traces, in iterations""" + profiler_active: int = 1 + """The steps profiler is active for""" + + profiler_warmup: int = 3 + """The number of warmup steps before the active step in each profiling cycle""" + enable_memory_snapshot: bool = False """Whether to dump memory snapshot""" diff --git a/torchtitan/tools/profiling.py b/torchtitan/tools/profiling.py index 0e851d335..97c9272e2 100644 --- a/torchtitan/tools/profiling.py +++ b/torchtitan/tools/profiling.py @@ -14,9 +14,6 @@ from torchtitan.config import Profiling as ProfilingConfig from torchtitan.tools.logging import logger -# the number of warmup steps before the active step in each profiling cycle -WARMUP = 3 - # how much memory allocation/free ops to record in memory snapshots MEMORY_SNAPSHOT_MAX_ENTRIES = 100000 @@ -58,7 +55,10 @@ def trace_handler(prof): if not os.path.exists(trace_dir): os.makedirs(trace_dir, exist_ok=True) - warmup, active = WARMUP, 1 + warmup, active = ( + profiling_config.profiler_warmup, + profiling_config.profiler_active, + ) wait = profile_freq - (active + warmup) assert ( wait >= 0 diff --git a/torchtitan/train.py b/torchtitan/train.py index 0afcac8dc..7f2ceceb7 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -6,6 +6,7 @@ import importlib import os +import signal import time from datetime import timedelta from typing import Any, Generator, Iterable, Optional @@ -588,6 +589,17 @@ def train(self): ), ), ): + if torch_profiler: + + def sigabrt_handler(signal, frame): + logger.info("SIGABRT received. Stopping profiler") + for _ in range(config.profiling.profiler_active): + # Step the profiler enough times to trigger a trace + torch_profiler.step() + torch_profiler.stop() + + signal.signal(signal.SIGABRT, sigabrt_handler) + data_iterator = self.batch_generator(self.dataloader) while self.should_continue_training(): self.step += 1