From 424b23cbbb9460c57036f72c378378ad7d7cd25b Mon Sep 17 00:00:00 2001 From: Tushar Jain Date: Thu, 25 Sep 2025 11:22:19 -0700 Subject: [PATCH 1/2] improve profiler --- torchtitan/config/job_config.py | 6 ++++++ torchtitan/tools/profiling.py | 8 ++++---- torchtitan/train.py | 12 ++++++++++++ 3 files changed, 22 insertions(+), 4 deletions(-) diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 0eb06a0d48..14a5218a4b 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -34,6 +34,12 @@ class Profiling: profile_freq: int = 10 """How often to collect profile traces, in iterations""" + profiler_active: int = 1 + """The steps profiler is active for""" + + profiler_warmup: int = 3 + """The number of warmup steps before the active step in each profiling cycle""" + enable_memory_snapshot: bool = False """Whether to dump memory snapshot""" diff --git a/torchtitan/tools/profiling.py b/torchtitan/tools/profiling.py index 0e851d335a..97c9272e2a 100644 --- a/torchtitan/tools/profiling.py +++ b/torchtitan/tools/profiling.py @@ -14,9 +14,6 @@ from torchtitan.config import Profiling as ProfilingConfig from torchtitan.tools.logging import logger -# the number of warmup steps before the active step in each profiling cycle -WARMUP = 3 - # how much memory allocation/free ops to record in memory snapshots MEMORY_SNAPSHOT_MAX_ENTRIES = 100000 @@ -58,7 +55,10 @@ def trace_handler(prof): if not os.path.exists(trace_dir): os.makedirs(trace_dir, exist_ok=True) - warmup, active = WARMUP, 1 + warmup, active = ( + profiling_config.profiler_warmup, + profiling_config.profiler_active, + ) wait = profile_freq - (active + warmup) assert ( wait >= 0 diff --git a/torchtitan/train.py b/torchtitan/train.py index 0afcac8dcf..7f2ceceb74 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -6,6 +6,7 @@ import importlib import os +import signal import time from datetime import timedelta from typing import Any, Generator, Iterable, Optional @@ -588,6 +589,17 @@ def train(self): ), ), ): + if torch_profiler: + + def sigabrt_handler(signal, frame): + logger.info("SIGABRT received. Stopping profiler") + for _ in range(config.profiling.profiler_active): + # Step the profiler enough times to trigger a trace + torch_profiler.step() + torch_profiler.stop() + + signal.signal(signal.SIGABRT, sigabrt_handler) + data_iterator = self.batch_generator(self.dataloader) while self.should_continue_training(): self.step += 1 From e2a1a9890f6d21f46e2d2c04ec6cf01d6ef91adc Mon Sep 17 00:00:00 2001 From: Tushar Jain Date: Thu, 25 Sep 2025 12:45:19 -0700 Subject: [PATCH 2/2] handle unable to load ft checkpoint Summary: - not being able to load ft checkpoint crashes the trainer - avoid loading the ft checkpoint for now to continue training --- torchtitan/components/checkpoint.py | 34 +++++++++++++++++------------ 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index e9e7014425..6a5826f278 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -667,20 +667,26 @@ def _ft_load(self) -> None: step = self._find_load_step(folder=self._ft_folder()) if step == -1: return - - begin = time.monotonic() - logger.info(f"Loading the FT checkpoint at step {step}.") - checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder()) - self.dcp_load( - self.ft_states, - checkpoint_id=checkpoint_id, - # FT checkpoints are always DCP because FT checkpoint currently only save/load dataloader. - from_hf=False, - ) - GarbageCollection.collect("GC collection for checkpoint loading.") - logger.info( - f"Finished loading the ft checkpoint in {time.monotonic() - begin:.2f} seconds." - ) + try: + begin = time.monotonic() + logger.info(f"Loading the FT checkpoint at step {step}.") + checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder()) + logger.info(f"Calling dcp_load for {checkpoint_id}") + self.dcp_load( + self.ft_states, + checkpoint_id=checkpoint_id, + # FT checkpoints are always DCP because FT checkpoint currently only save/load dataloader. + from_hf=False, + ) + GarbageCollection.collect("GC collection for checkpoint loading.") + logger.info( + f"Finished loading the ft checkpoint in {time.monotonic() - begin:.2f} seconds." + ) + except Exception as e: + # The checkpoint is corrupt. We'll replay all data here. + # TODO: We can try to load checkpoint from previous steps. + logger.error("Failed to load the FT checkpoint.") + return def _flattened_model_states_sd( self, state_dict: dict[str, Any] | None = None