Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 20 additions & 14 deletions torchtitan/components/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,20 +667,26 @@ def _ft_load(self) -> None:
step = self._find_load_step(folder=self._ft_folder())
if step == -1:
return

begin = time.monotonic()
logger.info(f"Loading the FT checkpoint at step {step}.")
checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder())
self.dcp_load(
self.ft_states,
checkpoint_id=checkpoint_id,
# FT checkpoints are always DCP because FT checkpoint currently only save/load dataloader.
from_hf=False,
)
GarbageCollection.collect("GC collection for checkpoint loading.")
logger.info(
f"Finished loading the ft checkpoint in {time.monotonic() - begin:.2f} seconds."
)
try:
begin = time.monotonic()
logger.info(f"Loading the FT checkpoint at step {step}.")
checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder())
logger.info(f"Calling dcp_load for {checkpoint_id}")
self.dcp_load(
self.ft_states,
checkpoint_id=checkpoint_id,
# FT checkpoints are always DCP because FT checkpoint currently only save/load dataloader.
from_hf=False,
)
GarbageCollection.collect("GC collection for checkpoint loading.")
logger.info(
f"Finished loading the ft checkpoint in {time.monotonic() - begin:.2f} seconds."
)
except Exception as e:
# The checkpoint is corrupt. We'll replay all data here.
# TODO: We can try to load checkpoint from previous steps.
logger.error("Failed to load the FT checkpoint.")
return

def _flattened_model_states_sd(
self, state_dict: dict[str, Any] | None = None
Expand Down
6 changes: 6 additions & 0 deletions torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ class Profiling:
profile_freq: int = 10
"""How often to collect profile traces, in iterations"""

profiler_active: int = 1
"""The steps profiler is active for"""

profiler_warmup: int = 3
"""The number of warmup steps before the active step in each profiling cycle"""

enable_memory_snapshot: bool = False
"""Whether to dump memory snapshot"""

Expand Down
8 changes: 4 additions & 4 deletions torchtitan/tools/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@
from torchtitan.config import Profiling as ProfilingConfig
from torchtitan.tools.logging import logger

# the number of warmup steps before the active step in each profiling cycle
WARMUP = 3

# how much memory allocation/free ops to record in memory snapshots
MEMORY_SNAPSHOT_MAX_ENTRIES = 100000

Expand Down Expand Up @@ -58,7 +55,10 @@ def trace_handler(prof):
if not os.path.exists(trace_dir):
os.makedirs(trace_dir, exist_ok=True)

warmup, active = WARMUP, 1
warmup, active = (
profiling_config.profiler_warmup,
profiling_config.profiler_active,
)
wait = profile_freq - (active + warmup)
assert (
wait >= 0
Expand Down
12 changes: 12 additions & 0 deletions torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import importlib
import os
import signal
import time
from datetime import timedelta
from typing import Any, Generator, Iterable, Optional
Expand Down Expand Up @@ -588,6 +589,17 @@ def train(self):
),
),
):
if torch_profiler:

def sigabrt_handler(signal, frame):
logger.info("SIGABRT received. Stopping profiler")
for _ in range(config.profiling.profiler_active):
# Step the profiler enough times to trigger a trace
torch_profiler.step()
torch_profiler.stop()

signal.signal(signal.SIGABRT, sigabrt_handler)

data_iterator = self.batch_generator(self.dataloader)
while self.should_continue_training():
self.step += 1
Expand Down
Loading