From 13779ce4105e89f08205bc67f4b444acfb28eec7 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 9 Feb 2026 13:03:09 -0800 Subject: [PATCH 1/2] Update [ghstack-poisoned] --- tests/unit_tests/test_checkpoint.py | 131 +++++++++++++++++++++- torchtitan/components/checkpoint.py | 142 ++++++++++++++++++------ torchtitan/config/job_config.py | 8 ++ torchtitan/protocols/model.py | 9 ++ torchtitan/protocols/model_converter.py | 5 +- 5 files changed, 258 insertions(+), 37 deletions(-) diff --git a/tests/unit_tests/test_checkpoint.py b/tests/unit_tests/test_checkpoint.py index b5dc3de8a2..c117ad2368 100644 --- a/tests/unit_tests/test_checkpoint.py +++ b/tests/unit_tests/test_checkpoint.py @@ -167,7 +167,7 @@ def fake_save(self, state_dict: dict, checkpoint_id: str, storage_writer=None): sd_to_save[key] = val torch.save(sd_to_save, os.path.join(checkpoint_id, "state_dict.pt")) - def fake_load(self, states: dict, checkpoint_id=None): + def fake_load(self, states: dict, checkpoint_id=None, **kwargs): path = os.path.join(checkpoint_id, "state_dict.pt") loaded = torch.load(path, weights_only="False") for key, val in loaded.items(): @@ -750,7 +750,7 @@ def fake_save(state_dict: dict, checkpoint_id: str, storage_writer=None): self.assertNotIn("optimizer", state_dict) return - def fake_load(state_dict: dict, checkpoint_id=None): + def fake_load(state_dict: dict, checkpoint_id=None, **kwargs): self.assertIn("bias", state_dict) self.assertIn("weight", state_dict) # No model prefix @@ -777,6 +777,133 @@ def fake_load(state_dict: dict, checkpoint_id=None): manager.save(curr_step=2, last_step=True) manager.load(step=1) + @mock.patch("torch.distributed.get_rank", return_value=0) + @mock.patch("torchtitan.components.checkpoint.dcp.save") + def test_partial_state_dict_save(self, mock_save, mock_rank): + """Test that ModelWrapper.state_dict_to_save() filters keys correctly + when the model has a state_dict_to_save method.""" + + class PartialSaveModel(nn.Module): + def __init__(self): + super().__init__() + self.base_weight = nn.Parameter(torch.randn(2, 2)) + self.adapter_weight = nn.Parameter(torch.randn(2, 2)) + + def state_dict_to_save(self): + # Only save the adapter weight + return {"adapter_weight": self.adapter_weight} + + partial_model = PartialSaveModel() + mock_save.side_effect = self.fake_save + + cfg = self.job_config.checkpoint + cfg.keep_latest_k = 0 + + manager = CheckpointManager( + dataloader=self.data_loader, + model_parts=[partial_model], + optimizers=self.optimizers, + lr_schedulers=self.lr_schedulers, + states=self.states, + checkpoint_config=cfg, + sd_adapter=None, + base_folder=self.job_config.job.dump_folder, + ft_manager=self.ft_manager, + ) + + manager.save(curr_step=1) + self.assertEqual(mock_save.call_count, 1) + checkpoint_path = os.path.join(self.test_folder, "step-1", "state_dict.pt") + saved_data = torch.load(checkpoint_path, weights_only=False) + + # Only adapter_weight should be saved + self.assertIn("adapter_weight", saved_data) + self.assertNotIn("base_weight", saved_data) + manager.close() + + @mock.patch("torch.distributed.get_rank", return_value=0) + @mock.patch("torchtitan.components.checkpoint.dcp.save") + @mock.patch("torchtitan.components.checkpoint.dcp.load") + def test_additional_load_paths(self, mock_load, mock_save, mock_rank): + """Test that additional_load_paths loads from extra checkpoint directories.""" + mock_save.side_effect = self.fake_save + mock_load.side_effect = self.fake_load + + # Create an additional checkpoint directory with a saved state + additional_dir = os.path.join(self.base_temp_dir, "additional_ckpt") + os.makedirs(additional_dir, exist_ok=True) + torch.save( + {"weight": torch.ones(2, 2), "bias": torch.ones(2)}, + os.path.join(additional_dir, "state_dict.pt"), + ) + + cfg = self.job_config.checkpoint + cfg.keep_latest_k = 0 + cfg.additional_load_paths = [additional_dir] + + manager = CheckpointManager( + dataloader=self.data_loader, + model_parts=self.model_parts, + optimizers=self.optimizers, + lr_schedulers=self.lr_schedulers, + states=self.states, + checkpoint_config=cfg, + sd_adapter=None, + base_folder=self.job_config.job.dump_folder, + ft_manager=self.ft_manager, + ) + + # Save and then load - additional_load_paths should be loaded after main ckpt + manager.save(curr_step=1) + manager.load(step=1) + + # dcp.load should be called twice: once for main checkpoint, once for additional + self.assertEqual(mock_load.call_count, 2) + # Verify the second call used the additional path + _, kwargs2 = mock_load.call_args_list[1] + self.assertEqual(kwargs2.get("checkpoint_id"), additional_dir) + manager.close() + + @mock.patch("torch.distributed.get_rank", return_value=0) + @mock.patch("torchtitan.components.checkpoint.dcp.load") + def test_additional_load_paths_invalid_path_raises(self, mock_load, mock_rank): + """Test that an invalid additional_load_paths raises ValueError.""" + cfg = self.job_config.checkpoint + cfg.keep_latest_k = 0 + cfg.additional_load_paths = ["/nonexistent/path"] + + manager = CheckpointManager( + dataloader=self.data_loader, + model_parts=self.model_parts, + optimizers=self.optimizers, + lr_schedulers=self.lr_schedulers, + states=self.states, + checkpoint_config=cfg, + sd_adapter=None, + base_folder=self.job_config.job.dump_folder, + ft_manager=self.ft_manager, + ) + + # Even without a main checkpoint, loading should try additional paths and fail + with self.assertRaises(ValueError): + manager.load(step=-1) + manager.close() + + def test_model_wrapper_default_behavior(self): + """Test that ModelWrapper works correctly with plain nn.Module (no state_dict_to_save).""" + from torchtitan.components.checkpoint import ModelWrapper + + model = nn.Linear(3, 3) + wrapper = ModelWrapper(model) + + # For a plain nn.Module without state_dict_to_save, all keys should be included + sd_save = wrapper.state_dict_to_save() + sd_load = wrapper.state_dict_to_load() + sd_full = wrapper._get_state_dict() + + self.assertEqual(set(sd_save.keys()), set(sd_full.keys())) + self.assertEqual(set(sd_load.keys()), set(sd_full.keys())) + if __name__ == "__main__": unittest.main() diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 385535a566..51edafcbad 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -23,6 +23,7 @@ from torch.distributed.checkpoint._consolidate_hf_safetensors import ( consolidate_safetensors_files_on_every_rank, ) +from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner from torch.distributed.checkpoint.staging import DefaultStager, StagingOptions from torch.distributed.checkpoint.state_dict import ( get_model_state_dict, @@ -61,7 +62,6 @@ class AsyncMode(str, enum.Enum): class ModelWrapper(Stateful): def __init__(self, model: nn.Module | list[nn.Module]) -> None: self.model = [model] if isinstance(model, nn.Module) else model - self.cache_state_dict = self._get_state_dict() def _get_state_dict(self) -> dict[str, Any]: state_dict = { @@ -69,8 +69,34 @@ def _get_state_dict(self) -> dict[str, Any]: } return state_dict + def state_dict_to_save(self) -> dict[str, Any]: + full_sd = self._get_state_dict() + keys_to_save: set[str] | None = None + for part in self.model: + if hasattr(part, "state_dict_to_save"): + if keys_to_save is None: + keys_to_save = set() + # pyrefly: ignore [not-callable] + keys_to_save.update(part.state_dict_to_save().keys()) + if keys_to_save is None: + return full_sd + return {k: v for k, v in full_sd.items() if k in keys_to_save} + + def state_dict_to_load(self) -> dict[str, Any]: + full_sd = self._get_state_dict() + keys_to_load: set[str] | None = None + for part in self.model: + if hasattr(part, "state_dict_to_load"): + if keys_to_load is None: + keys_to_load = set() + # pyrefly: ignore [not-callable] + keys_to_load.update(part.state_dict_to_load().keys()) + if keys_to_load is None: + return full_sd + return {k: v for k, v in full_sd.items() if k in keys_to_load} + def state_dict(self) -> dict[str, Any]: - return self.cache_state_dict + return self.state_dict_to_save() def load_state_dict(self, state_dict: dict[str, Any]) -> None: func = functools.partial( @@ -79,9 +105,6 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: options=StateDictOptions(strict=False), ) list(map(func, self.model)) - # `set_model_state_dict()` does change the keys of the input state_dict, - # we will need to reinitialize the cache_state_dict. - self.cache_state_dict = self._get_state_dict() class Terminate: @@ -279,6 +302,7 @@ def load_state_dict(state_dict): self.sd_adapter = sd_adapter self.export_dtype = TORCH_DTYPE_MAP[checkpoint_config.export_dtype] self.exclude_from_loading = checkpoint_config.exclude_from_loading + self.additional_load_paths = checkpoint_config.additional_load_paths self.interval = checkpoint_config.interval self.enable_first_step_checkpoint = ( checkpoint_config.enable_first_step_checkpoint @@ -438,41 +462,54 @@ def dcp_save( def dcp_load( self, state_dict: dict[str, Any], - checkpoint_id: str, + checkpoint_id: str | list[str], from_hf: bool, from_quantized: bool, ) -> None: - """Load the checkpoint with dcp. + """Load the checkpoint(s) with dcp. Args: state_dict (dict): The state dict to load. - checkpoint_id (str): The checkpoint id to load. + checkpoint_id (str | list[str]): The checkpoint id(s) to load. + When a list is provided, each checkpoint is loaded sequentially + using the same ``from_hf``/``from_quantized`` semantics. from_hf (bool): Whether to load from HuggingFace checkpoint with its own model definition and safetensors format. + from_quantized (bool): Whether the HuggingFace checkpoint is quantized. """ + checkpoint_ids = ( + [checkpoint_id] if isinstance(checkpoint_id, str) else checkpoint_id + ) + planner = ( + DefaultLoadPlanner(allow_partial_load=True) + if len(checkpoint_ids) > 1 + else DefaultLoadPlanner() + ) - if from_hf: - assert ( - self.sd_adapter is not None - ), "trying to load checkpoint in HF safetensors format, but sd_adapter is not provided." - hf_state_dict = self.sd_adapter.to_hf(state_dict) - hf_storage_reader = self.sd_adapter.get_hf_storage_reader( - checkpoint_id, from_quantized - ) - - dcp.load( - hf_state_dict, - storage_reader=hf_storage_reader, - ) + for cid in checkpoint_ids: + if from_hf: + assert ( + self.sd_adapter is not None + ), "trying to load checkpoint in HF safetensors format, but sd_adapter is not provided." + hf_state_dict = self.sd_adapter.to_hf(state_dict) + hf_storage_reader = self.sd_adapter.get_hf_storage_reader( + cid, from_quantized + ) - state_dict = self.sd_adapter.from_hf(hf_state_dict) - self.states[MODEL].load_state_dict(state_dict) - else: - dcp.load(state_dict, checkpoint_id=checkpoint_id) + dcp.load( + hf_state_dict, + storage_reader=hf_storage_reader, + planner=planner, + ) - # TODO: Since we flatten the model states in state_dict, we need to - # manually call load_state_dict() for the model. Need to fix this. - if MODEL in self.states: + state_dict = self.sd_adapter.from_hf(hf_state_dict) self.states[MODEL].load_state_dict(state_dict) + else: + dcp.load(state_dict, checkpoint_id=cid, planner=planner) + + # TODO: Since we flatten the model states in state_dict, we need to + # manually call load_state_dict() for the model. Need to fix this. + if MODEL in self.states: + self.states[MODEL].load_state_dict(state_dict) @torch.no_grad() def save(self, curr_step: int, last_step: bool = False) -> None: @@ -575,6 +612,12 @@ def load(self, step: int = -1) -> bool: if not self.enable: return False + for path in self.additional_load_paths: + if not os.path.isdir(path): + raise ValueError( + f"checkpoint.additional_load_paths contains invalid path: {path}" + ) + model_only = False from_hf = False from_quantized = False @@ -618,6 +661,17 @@ def load(self, step: int = -1) -> bool: f"loading HF safetensors from --model.hf_assets_path: {hf_assets_path}" ) else: + if self.additional_load_paths: + additional_states = self.states[MODEL]._get_state_dict() + self.dcp_load( + additional_states, + checkpoint_id=self.additional_load_paths, + from_hf=False, + from_quantized=False, + ) + GarbageCollection.collect( + "GC collection for additional checkpoint loading." + ) return False else: if self.initial_load_path: @@ -632,6 +686,17 @@ def load(self, step: int = -1) -> bool: ) step = self._find_load_step() if step == -1 else step if step == -1: + if self.additional_load_paths: + additional_states = self.states[MODEL]._get_state_dict() + self.dcp_load( + additional_states, + checkpoint_id=self.additional_load_paths, + from_hf=False, + from_quantized=False, + ) + GarbageCollection.collect( + "GC collection for additional checkpoint loading." + ) return False model_only = step == 0 checkpoint_id = self._create_checkpoint_id(step) @@ -650,6 +715,14 @@ def load(self, step: int = -1) -> bool: from_hf=from_hf, from_quantized=from_quantized, ) + if self.additional_load_paths: + additional_states = self.states[MODEL]._get_state_dict() + self.dcp_load( + additional_states, + checkpoint_id=self.additional_load_paths, + from_hf=False, + from_quantized=False, + ) GarbageCollection.collect("GC collection for checkpoint loading.") logger.info( f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds." @@ -735,7 +808,7 @@ def _ft_load(self) -> None: ) def _flattened_model_states_sd( - self, state_dict: dict[str, Any] | None = None + self, state_dict: dict[str, Any] | None = None, for_load: bool = False ) -> dict[str, Any]: """Flatten the model states into a single dictionary. @@ -744,7 +817,10 @@ def _flattened_model_states_sd( states = state_dict if state_dict is not None else self.states sd = {k: v for k, v in states.items() if k != MODEL} if MODEL in states: - sd.update(states[MODEL].state_dict()) + if for_load: + sd.update(states[MODEL].state_dict_to_load()) + else: + sd.update(states[MODEL].state_dict_to_save()) return sd def _states_to_load(self, model_only: bool) -> dict[str, Any]: @@ -761,7 +837,7 @@ def _states_to_load(self, model_only: bool) -> dict[str, Any]: """ # For the first step, we will only load the model. if model_only: - return self.states[MODEL].state_dict() + return self.states[MODEL].state_dict_to_load() for exclude_key in self.exclude_from_loading: if exclude_key not in self.states: @@ -771,7 +847,7 @@ def _states_to_load(self, model_only: bool) -> dict[str, Any]: k: v for k, v in self.states.items() if k not in self.exclude_from_loading } - states_to_load = self._flattened_model_states_sd(states_to_load) + states_to_load = self._flattened_model_states_sd(states_to_load, for_load=True) if self.enable_ft_dataloader_checkpoints: states_to_load.pop(DATALOADER) @@ -785,7 +861,7 @@ def _save_last_step(self, curr_step: int) -> None: # is not the same as the export dtype at the end of the training. if self.last_save_model_only: - states = self.states[MODEL].state_dict() + states = self.states[MODEL].state_dict_to_save() if self.export_dtype != torch.float32: states = {k: v.to(self.export_dtype) for k, v in states.items()} diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index cdf41df293..9d217e2c30 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -623,6 +623,14 @@ class Checkpoint: This will load the model only, excluding the specified keys. """ + additional_load_paths: list[str] = field(default_factory=list) + """ + Additional checkpoint paths to load from after the primary checkpoint. + Useful for loading state dicts from multiple sources, e.g., base model + weights from one checkpoint and LoRA adapter weights from another. + Each path should contain a valid DCP checkpoint directory. + """ + enable_first_step_checkpoint: bool = False """ Enable the checkpoint save at first step. This will save a checkpoint immediately diff --git a/torchtitan/protocols/model.py b/torchtitan/protocols/model.py index 99e4c34dc0..09d817f34e 100644 --- a/torchtitan/protocols/model.py +++ b/torchtitan/protocols/model.py @@ -6,6 +6,7 @@ from abc import abstractmethod from dataclasses import dataclass +from typing import Any import torch import torch.nn as nn @@ -71,3 +72,11 @@ def get_attention_masks( raise NotImplementedError( "This model does not support attention masking/Flex Attention." ) + + def state_dict_to_save(self) -> dict[str, Any]: + """Return the state dict subset to save. Override to save partial state (e.g. LoRA).""" + return self.state_dict() + + def state_dict_to_load(self) -> dict[str, Any]: + """Return the state dict buffer for loading. Override to load partial state.""" + return self.state_dict() diff --git a/torchtitan/protocols/model_converter.py b/torchtitan/protocols/model_converter.py index dbfc3a99c3..28879206c6 100644 --- a/torchtitan/protocols/model_converter.py +++ b/torchtitan/protocols/model_converter.py @@ -9,6 +9,7 @@ from torchtitan.config import JobConfig from torchtitan.distributed import ParallelDims +from torchtitan.protocols.model import ModelProtocol from torchtitan.tools.logging import logger @@ -24,7 +25,7 @@ class ModelConverter(Protocol): def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): ... - def convert(self, model: nn.Module): + def convert(self, model: ModelProtocol): """Inplace conversion of the model.""" ... @@ -66,7 +67,7 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): ] self.print_after_conversion = job_config.model.print_after_conversion - def convert(self, model: nn.Module): + def convert(self, model: ModelProtocol): for mh in self.converters: mh.convert(model) if self.print_after_conversion: From b3829ce6ceee446ee5dca16abbfbf44389eee458 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 9 Feb 2026 13:19:59 -0800 Subject: [PATCH 2/2] Update [ghstack-poisoned] --- torchtitan/components/checkpoint.py | 32 +---------------------------- 1 file changed, 1 insertion(+), 31 deletions(-) diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 51edafcbad..a84950250a 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -661,17 +661,6 @@ def load(self, step: int = -1) -> bool: f"loading HF safetensors from --model.hf_assets_path: {hf_assets_path}" ) else: - if self.additional_load_paths: - additional_states = self.states[MODEL]._get_state_dict() - self.dcp_load( - additional_states, - checkpoint_id=self.additional_load_paths, - from_hf=False, - from_quantized=False, - ) - GarbageCollection.collect( - "GC collection for additional checkpoint loading." - ) return False else: if self.initial_load_path: @@ -686,17 +675,6 @@ def load(self, step: int = -1) -> bool: ) step = self._find_load_step() if step == -1 else step if step == -1: - if self.additional_load_paths: - additional_states = self.states[MODEL]._get_state_dict() - self.dcp_load( - additional_states, - checkpoint_id=self.additional_load_paths, - from_hf=False, - from_quantized=False, - ) - GarbageCollection.collect( - "GC collection for additional checkpoint loading." - ) return False model_only = step == 0 checkpoint_id = self._create_checkpoint_id(step) @@ -711,18 +689,10 @@ def load(self, step: int = -1) -> bool: states = self._states_to_load(model_only) self.dcp_load( states, - checkpoint_id=checkpoint_id, + checkpoint_id=[checkpoint_id] + self.additional_load_paths, from_hf=from_hf, from_quantized=from_quantized, ) - if self.additional_load_paths: - additional_states = self.states[MODEL]._get_state_dict() - self.dcp_load( - additional_states, - checkpoint_id=self.additional_load_paths, - from_hf=False, - from_quantized=False, - ) GarbageCollection.collect("GC collection for checkpoint loading.") logger.info( f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds."