Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 129 additions & 2 deletions tests/unit_tests/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def fake_save(self, state_dict: dict, checkpoint_id: str, storage_writer=None):
sd_to_save[key] = val
torch.save(sd_to_save, os.path.join(checkpoint_id, "state_dict.pt"))

def fake_load(self, states: dict, checkpoint_id=None):
def fake_load(self, states: dict, checkpoint_id=None, **kwargs):
path = os.path.join(checkpoint_id, "state_dict.pt")
loaded = torch.load(path, weights_only="False")
for key, val in loaded.items():
Expand Down Expand Up @@ -750,7 +750,7 @@ def fake_save(state_dict: dict, checkpoint_id: str, storage_writer=None):
self.assertNotIn("optimizer", state_dict)
return

def fake_load(state_dict: dict, checkpoint_id=None):
def fake_load(state_dict: dict, checkpoint_id=None, **kwargs):
self.assertIn("bias", state_dict)
self.assertIn("weight", state_dict)
# No model prefix
Expand All @@ -777,6 +777,133 @@ def fake_load(state_dict: dict, checkpoint_id=None):
manager.save(curr_step=2, last_step=True)
manager.load(step=1)

@mock.patch("torch.distributed.get_rank", return_value=0)
@mock.patch("torchtitan.components.checkpoint.dcp.save")
def test_partial_state_dict_save(self, mock_save, mock_rank):
"""Test that ModelWrapper.state_dict_to_save() filters keys correctly
when the model has a state_dict_to_save method."""

class PartialSaveModel(nn.Module):
def __init__(self):
super().__init__()
self.base_weight = nn.Parameter(torch.randn(2, 2))
self.adapter_weight = nn.Parameter(torch.randn(2, 2))

def state_dict_to_save(self):
# Only save the adapter weight
return {"adapter_weight": self.adapter_weight}

partial_model = PartialSaveModel()
mock_save.side_effect = self.fake_save

cfg = self.job_config.checkpoint
cfg.keep_latest_k = 0

manager = CheckpointManager(
dataloader=self.data_loader,
model_parts=[partial_model],
optimizers=self.optimizers,
lr_schedulers=self.lr_schedulers,
states=self.states,
checkpoint_config=cfg,
sd_adapter=None,
base_folder=self.job_config.job.dump_folder,
ft_manager=self.ft_manager,
)

manager.save(curr_step=1)
self.assertEqual(mock_save.call_count, 1)
checkpoint_path = os.path.join(self.test_folder, "step-1", "state_dict.pt")
saved_data = torch.load(checkpoint_path, weights_only=False)

# Only adapter_weight should be saved
self.assertIn("adapter_weight", saved_data)
self.assertNotIn("base_weight", saved_data)
manager.close()

@mock.patch("torch.distributed.get_rank", return_value=0)
@mock.patch("torchtitan.components.checkpoint.dcp.save")
@mock.patch("torchtitan.components.checkpoint.dcp.load")
def test_additional_load_paths(self, mock_load, mock_save, mock_rank):
"""Test that additional_load_paths loads from extra checkpoint directories."""
mock_save.side_effect = self.fake_save
mock_load.side_effect = self.fake_load

# Create an additional checkpoint directory with a saved state
additional_dir = os.path.join(self.base_temp_dir, "additional_ckpt")
os.makedirs(additional_dir, exist_ok=True)
torch.save(
{"weight": torch.ones(2, 2), "bias": torch.ones(2)},
os.path.join(additional_dir, "state_dict.pt"),
)

cfg = self.job_config.checkpoint
cfg.keep_latest_k = 0
cfg.additional_load_paths = [additional_dir]

manager = CheckpointManager(
dataloader=self.data_loader,
model_parts=self.model_parts,
optimizers=self.optimizers,
lr_schedulers=self.lr_schedulers,
states=self.states,
checkpoint_config=cfg,
sd_adapter=None,
base_folder=self.job_config.job.dump_folder,
ft_manager=self.ft_manager,
)

# Save and then load - additional_load_paths should be loaded after main ckpt
manager.save(curr_step=1)
manager.load(step=1)

# dcp.load should be called twice: once for main checkpoint, once for additional
self.assertEqual(mock_load.call_count, 2)
# Verify the second call used the additional path
_, kwargs2 = mock_load.call_args_list[1]
self.assertEqual(kwargs2.get("checkpoint_id"), additional_dir)
manager.close()

@mock.patch("torch.distributed.get_rank", return_value=0)
@mock.patch("torchtitan.components.checkpoint.dcp.load")
def test_additional_load_paths_invalid_path_raises(self, mock_load, mock_rank):
"""Test that an invalid additional_load_paths raises ValueError."""
cfg = self.job_config.checkpoint
cfg.keep_latest_k = 0
cfg.additional_load_paths = ["/nonexistent/path"]

manager = CheckpointManager(
dataloader=self.data_loader,
model_parts=self.model_parts,
optimizers=self.optimizers,
lr_schedulers=self.lr_schedulers,
states=self.states,
checkpoint_config=cfg,
sd_adapter=None,
base_folder=self.job_config.job.dump_folder,
ft_manager=self.ft_manager,
)

# Even without a main checkpoint, loading should try additional paths and fail
with self.assertRaises(ValueError):
manager.load(step=-1)
manager.close()

def test_model_wrapper_default_behavior(self):
"""Test that ModelWrapper works correctly with plain nn.Module (no state_dict_to_save)."""
from torchtitan.components.checkpoint import ModelWrapper

model = nn.Linear(3, 3)
wrapper = ModelWrapper(model)

# For a plain nn.Module without state_dict_to_save, all keys should be included
sd_save = wrapper.state_dict_to_save()
sd_load = wrapper.state_dict_to_load()
sd_full = wrapper._get_state_dict()

self.assertEqual(set(sd_save.keys()), set(sd_full.keys()))
self.assertEqual(set(sd_load.keys()), set(sd_full.keys()))


if __name__ == "__main__":
unittest.main()
114 changes: 80 additions & 34 deletions torchtitan/components/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from torch.distributed.checkpoint._consolidate_hf_safetensors import (
consolidate_safetensors_files_on_every_rank,
)
from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner
from torch.distributed.checkpoint.staging import DefaultStager, StagingOptions
from torch.distributed.checkpoint.state_dict import (
get_model_state_dict,
Expand Down Expand Up @@ -61,16 +62,41 @@ class AsyncMode(str, enum.Enum):
class ModelWrapper(Stateful):
def __init__(self, model: nn.Module | list[nn.Module]) -> None:
self.model = [model] if isinstance(model, nn.Module) else model
self.cache_state_dict = self._get_state_dict()

def _get_state_dict(self) -> dict[str, Any]:
state_dict = {
k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()
}
return state_dict

def state_dict_to_save(self) -> dict[str, Any]:
full_sd = self._get_state_dict()
keys_to_save: set[str] | None = None
for part in self.model:
if hasattr(part, "state_dict_to_save"):
if keys_to_save is None:
keys_to_save = set()
# pyrefly: ignore [not-callable]
keys_to_save.update(part.state_dict_to_save().keys())
if keys_to_save is None:
return full_sd
return {k: v for k, v in full_sd.items() if k in keys_to_save}

def state_dict_to_load(self) -> dict[str, Any]:
full_sd = self._get_state_dict()
keys_to_load: set[str] | None = None
for part in self.model:
if hasattr(part, "state_dict_to_load"):
if keys_to_load is None:
keys_to_load = set()
# pyrefly: ignore [not-callable]
keys_to_load.update(part.state_dict_to_load().keys())
if keys_to_load is None:
return full_sd
return {k: v for k, v in full_sd.items() if k in keys_to_load}

def state_dict(self) -> dict[str, Any]:
return self.cache_state_dict
return self.state_dict_to_save()

def load_state_dict(self, state_dict: dict[str, Any]) -> None:
func = functools.partial(
Expand All @@ -79,9 +105,6 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
options=StateDictOptions(strict=False),
)
list(map(func, self.model))
# `set_model_state_dict()` does change the keys of the input state_dict,
# we will need to reinitialize the cache_state_dict.
self.cache_state_dict = self._get_state_dict()


class Terminate:
Expand Down Expand Up @@ -279,6 +302,7 @@ def load_state_dict(state_dict):
self.sd_adapter = sd_adapter
self.export_dtype = TORCH_DTYPE_MAP[checkpoint_config.export_dtype]
self.exclude_from_loading = checkpoint_config.exclude_from_loading
self.additional_load_paths = checkpoint_config.additional_load_paths
self.interval = checkpoint_config.interval
self.enable_first_step_checkpoint = (
checkpoint_config.enable_first_step_checkpoint
Expand Down Expand Up @@ -438,41 +462,54 @@ def dcp_save(
def dcp_load(
self,
state_dict: dict[str, Any],
checkpoint_id: str,
checkpoint_id: str | list[str],
from_hf: bool,
from_quantized: bool,
) -> None:
"""Load the checkpoint with dcp.
"""Load the checkpoint(s) with dcp.
Args:
state_dict (dict): The state dict to load.
checkpoint_id (str): The checkpoint id to load.
checkpoint_id (str | list[str]): The checkpoint id(s) to load.
When a list is provided, each checkpoint is loaded sequentially
using the same ``from_hf``/``from_quantized`` semantics.
from_hf (bool): Whether to load from HuggingFace checkpoint with
its own model definition and safetensors format.
from_quantized (bool): Whether the HuggingFace checkpoint is quantized.
"""
checkpoint_ids = (
[checkpoint_id] if isinstance(checkpoint_id, str) else checkpoint_id
)
planner = (
DefaultLoadPlanner(allow_partial_load=True)
if len(checkpoint_ids) > 1
else DefaultLoadPlanner()
)

if from_hf:
assert (
self.sd_adapter is not None
), "trying to load checkpoint in HF safetensors format, but sd_adapter is not provided."
hf_state_dict = self.sd_adapter.to_hf(state_dict)
hf_storage_reader = self.sd_adapter.get_hf_storage_reader(
checkpoint_id, from_quantized
)

dcp.load(
hf_state_dict,
storage_reader=hf_storage_reader,
)
for cid in checkpoint_ids:
if from_hf:
assert (
self.sd_adapter is not None
), "trying to load checkpoint in HF safetensors format, but sd_adapter is not provided."
hf_state_dict = self.sd_adapter.to_hf(state_dict)
hf_storage_reader = self.sd_adapter.get_hf_storage_reader(
cid, from_quantized
)

state_dict = self.sd_adapter.from_hf(hf_state_dict)
self.states[MODEL].load_state_dict(state_dict)
else:
dcp.load(state_dict, checkpoint_id=checkpoint_id)
dcp.load(
hf_state_dict,
storage_reader=hf_storage_reader,
planner=planner,
)

# TODO: Since we flatten the model states in state_dict, we need to
# manually call load_state_dict() for the model. Need to fix this.
if MODEL in self.states:
state_dict = self.sd_adapter.from_hf(hf_state_dict)
self.states[MODEL].load_state_dict(state_dict)
else:
dcp.load(state_dict, checkpoint_id=cid, planner=planner)

# TODO: Since we flatten the model states in state_dict, we need to
# manually call load_state_dict() for the model. Need to fix this.
if MODEL in self.states:
self.states[MODEL].load_state_dict(state_dict)

@torch.no_grad()
def save(self, curr_step: int, last_step: bool = False) -> None:
Expand Down Expand Up @@ -575,6 +612,12 @@ def load(self, step: int = -1) -> bool:
if not self.enable:
return False

for path in self.additional_load_paths:
if not os.path.isdir(path):
raise ValueError(
f"checkpoint.additional_load_paths contains invalid path: {path}"
)

model_only = False
from_hf = False
from_quantized = False
Expand Down Expand Up @@ -646,7 +689,7 @@ def load(self, step: int = -1) -> bool:
states = self._states_to_load(model_only)
self.dcp_load(
states,
checkpoint_id=checkpoint_id,
checkpoint_id=[checkpoint_id] + self.additional_load_paths,
from_hf=from_hf,
from_quantized=from_quantized,
)
Expand Down Expand Up @@ -735,7 +778,7 @@ def _ft_load(self) -> None:
)

def _flattened_model_states_sd(
self, state_dict: dict[str, Any] | None = None
self, state_dict: dict[str, Any] | None = None, for_load: bool = False
) -> dict[str, Any]:
"""Flatten the model states into a single dictionary.

Expand All @@ -744,7 +787,10 @@ def _flattened_model_states_sd(
states = state_dict if state_dict is not None else self.states
sd = {k: v for k, v in states.items() if k != MODEL}
if MODEL in states:
sd.update(states[MODEL].state_dict())
if for_load:
sd.update(states[MODEL].state_dict_to_load())
else:
sd.update(states[MODEL].state_dict_to_save())
return sd

def _states_to_load(self, model_only: bool) -> dict[str, Any]:
Expand All @@ -761,7 +807,7 @@ def _states_to_load(self, model_only: bool) -> dict[str, Any]:
"""
# For the first step, we will only load the model.
if model_only:
return self.states[MODEL].state_dict()
return self.states[MODEL].state_dict_to_load()

for exclude_key in self.exclude_from_loading:
if exclude_key not in self.states:
Expand All @@ -771,7 +817,7 @@ def _states_to_load(self, model_only: bool) -> dict[str, Any]:
k: v for k, v in self.states.items() if k not in self.exclude_from_loading
}

states_to_load = self._flattened_model_states_sd(states_to_load)
states_to_load = self._flattened_model_states_sd(states_to_load, for_load=True)

if self.enable_ft_dataloader_checkpoints:
states_to_load.pop(DATALOADER)
Expand All @@ -785,7 +831,7 @@ def _save_last_step(self, curr_step: int) -> None:
# is not the same as the export dtype at the end of the training.

if self.last_save_model_only:
states = self.states[MODEL].state_dict()
states = self.states[MODEL].state_dict_to_save()

if self.export_dtype != torch.float32:
states = {k: v.to(self.export_dtype) for k, v in states.items()}
Expand Down
8 changes: 8 additions & 0 deletions torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,14 @@ class Checkpoint:
This will load the model only, excluding the specified keys.
"""

additional_load_paths: list[str] = field(default_factory=list)
"""
Additional checkpoint paths to load from after the primary checkpoint.
Useful for loading state dicts from multiple sources, e.g., base model
weights from one checkpoint and LoRA adapter weights from another.
Each path should contain a valid DCP checkpoint directory.
"""

enable_first_step_checkpoint: bool = False
"""
Enable the checkpoint save at first step. This will save a checkpoint immediately
Expand Down
Loading
Loading