Skip to content

Commit c92e8a8

Browse files
authored
Reverse transform of state_dict per submodule (#1011)
1 parent eed91fb commit c92e8a8

File tree

5 files changed

+152
-40
lines changed

5 files changed

+152
-40
lines changed

thunder/core/module.py

Lines changed: 57 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1+
from __future__ import annotations
12
from contextlib import contextmanager
23
import itertools
3-
from typing import Any
4-
from collections.abc import Mapping
4+
from typing import TYPE_CHECKING
55
import collections
66

77
import torch as pytorch
@@ -10,6 +10,22 @@
1010
import thunder
1111
from thunder.core.compile_data import get_compile_data
1212

13+
if TYPE_CHECKING:
14+
from collections.abc import Mapping
15+
from typing import Any
16+
from thunder.core.transform_common import Transform
17+
18+
19+
def _convert_state_dict_to_per_module(state_dict: dict[str, Any], module_names: set[str]) -> dict[str, dict[str, Any]]:
20+
state_dict_per_module = collections.defaultdict(dict)
21+
for k, v in state_dict.items():
22+
prefix, sep, _ = k.rpartition(".")
23+
# not great but should not happen too often / deep
24+
while prefix not in module_names:
25+
prefix, sep, _ = prefix.rpartition(".")
26+
state_dict_per_module[prefix][k[len(prefix) + len(sep) :]] = v
27+
return state_dict_per_module
28+
1329

1430
class ThunderModule(pytorch.nn.Module):
1531
"""A wrapper nn.Module subclass.
@@ -120,14 +136,7 @@ def _get_shared_names(self):
120136

121137
def load_original_state_dict(self, state_dict):
122138
# this loads the state dict incrementally to not exhaust memory
123-
module_names = {n for n, _ in self._model.named_modules()}
124-
sd_per_module = collections.defaultdict(dict)
125-
for k, v in state_dict.items():
126-
prefix, sep, _ = k.rpartition(".")
127-
# not great but should not happen too often / deep
128-
while prefix not in module_names:
129-
prefix, sep, _ = prefix.rpartition(".")
130-
sd_per_module[prefix][k[len(prefix) + len(sep) :]] = v
139+
sd_per_module = _convert_state_dict_to_per_module(state_dict, {n for n, _ in self._model.named_modules()})
131140

132141
shared_names = self._get_shared_names()
133142
processed_names = set()
@@ -182,7 +191,7 @@ def state_dict(self, *, destination=None, prefix="", keep_vars=False):
182191
Returns the state dict of the (transformed) Thunder module.
183192
184193
Args:
185-
destination: if given, use this mutuable mapping as the dict container.
194+
destination: if given, use this mutable mapping as the dict container.
186195
prefix: a prefix for the keys.
187196
keep_vars: do not detach
188197
@@ -215,6 +224,43 @@ def state_dict(self, *, destination=None, prefix="", keep_vars=False):
215224
destination[extra_state_key] = self.get_extra_state()
216225
return destination
217226

227+
def original_state_dict(
228+
self,
229+
*,
230+
destination: dict[str, Any] | None = None,
231+
prefix: str = "",
232+
keep_vars: bool = False,
233+
) -> dict[str, Any]:
234+
"""Returns the state dict of the transformed :class:`ThunderModule` with reverse transform applied.
235+
236+
For example, :func:`ThunderModule.state_dict` returns a state dict of sharded tensors if
237+
a model is :func:`thunder.distributed.fsdp` applied while :func:`ThunderModule.original_state_dict`
238+
returns a state dict of unsharded tensors.
239+
240+
Args:
241+
destination: if given, use this mutable mapping as the dict container.
242+
prefix: a prefix for the keys.
243+
keep_vars: do not detach
244+
245+
"""
246+
module_names = {name for name, _ in self._model.named_modules()}
247+
state_dict_per_submodule = _convert_state_dict_to_per_module(self.state_dict(), module_names)
248+
249+
if destination is None:
250+
destination = collections.OrderedDict()
251+
destination._metadata = collections.OrderedDict()
252+
253+
transform: Transform
254+
for submodule_name, submodule_state_dict in state_dict_per_submodule.items():
255+
for transform in reversed(self._lc_transforms):
256+
submodule_state_dict = transform.reverse_transform_state_dict_for_submodule(
257+
self,
258+
submodule_name,
259+
submodule_state_dict,
260+
)
261+
destination.update({f"{prefix}{submodule_name}.{k}": v for k, v in submodule_state_dict.items()})
262+
return destination
263+
218264
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False):
219265
"""Loads the state dict to a transformed module.
220266

thunder/core/transform_common.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
from __future__ import annotations
22
import time
33
from typing import TYPE_CHECKING
4-
from abc import ABC, abstractmethod
5-
from collections import defaultdict
4+
from abc import ABC
65
from collections.abc import Sequence
7-
from collections import defaultdict
8-
from itertools import filterfalse, chain
6+
from itertools import filterfalse
97
from functools import partial
108

119
import thunder
@@ -15,11 +13,11 @@
1513
from thunder.core.pytree import tree_flatten, tree_iter, tree_map, tree_unflatten
1614
from thunder.core.symbol import BoundSymbol, BoundSymbolRHS, has_tags
1715
from thunder.core.trace import from_trace, TraceProvenance, TraceCtx as Trace, tracectx
18-
from thunder.core.utils import ProxyDict, producers, check, consumers
16+
from thunder.core.utils import ProxyDict, producers, check
1917

2018
if TYPE_CHECKING:
21-
from thunder.core.proxies import ProxyInterface
22-
from thunder.core.symbol import Symbol, VariableInterface
19+
from typing import Any
20+
from thunder.core.module import ThunderModule
2321

2422

2523
#
@@ -346,13 +344,16 @@ def transform_traces_pre_prologue(
346344
# default to noop
347345
return prologue_trace, computation_trace, epilogue_trace
348346

349-
def transform_module(self, model: thunder.ThunderModule):
347+
def transform_module(self, model: ThunderModule) -> None:
350348
"""Transforms the ThunderModule. This is executed once on application of the transform"""
351349
pass
352350

353351
def transform_state_dict_for_submodule(
354-
self, model: thunder.ThunderModule, submodule_name: str, state_dict: dict
355-
) -> dict:
352+
self,
353+
model: ThunderModule,
354+
submodule_name: str,
355+
state_dict: dict[str, Any],
356+
) -> dict[str, Any]:
356357
"""
357358
Implement this to transform the state dict (mostly parameters and buffers) of a module, e.g. when loading
358359
from a state dict of the original model.
@@ -370,6 +371,14 @@ def transform_trace_post_optimization(self, computation_trace: Trace, **kwargs):
370371
"""
371372
return computation_trace
372373

374+
def reverse_transform_state_dict_for_submodule(
375+
self,
376+
model: ThunderModule,
377+
submodule_name: str,
378+
state_dict: dict[str, Any],
379+
) -> dict[str, Any]:
380+
return state_dict
381+
373382

374383
def order_proxies(bsyms: Sequence[BoundSymbol]) -> dict[str, int]:
375384
"""computes a canonical ordering of proxies in the bound symbols based on the order of appearance

thunder/distributed/__init__.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from itertools import chain
55
from contextlib import contextmanager
66
from contextvars import ContextVar, Token
7-
import copy
87
from enum import auto, Enum
98
from typing import TYPE_CHECKING, Any
109
from collections.abc import Generator
@@ -423,13 +422,14 @@ def f(tensor: TensorProxy) -> str:
423422

424423

425424
def fsdp(
426-
model: torch.nn.Module,
425+
model: torch.nn.Module | ThunderModule,
427426
*,
428427
device: torch.device | None = None,
429428
broadcast_from: int | None = None,
430429
sharding_strategy: FSDPType = FSDPType.ZERO2,
431430
bucketing_strategy: FSDPBucketingStrategy = FSDPBucketingStrategy.NONE,
432-
) -> torch.nn.Module:
431+
move_state_dict_to_cpu: bool | None = None,
432+
) -> torch.nn.Module | ThunderModule:
433433
"""Convert ``model`` into Fully Sharded Data Parallel.
434434
435435
This splits ``model``'s parameters in their first dimension into ``world_size`` chunks
@@ -458,20 +458,22 @@ def fsdp(
458458
from a checkpoint in a single rank.
459459
sharding_strategy:
460460
bucketing_strategy:
461+
move_state_dict_to_cpu: Move all-gather'ed parameters of :func:`~thunder.core.module.ThunderModule.original_state_dict` to CPU
462+
as each all-gather is finished.
461463
462464
Returns:
463465
:class:`torch.nn.Module`
464466
465467
"""
466-
import thunder
468+
from thunder.core.module import ThunderModule
467469

468470
utils.check(isinstance(sharding_strategy, FSDPType), lambda: f"FSDPType.ZERO2 and FSDPType.ZERO3 are supported.")
469471
utils.check(
470472
tdist.is_available(),
471473
lambda: "fsdp requires torch distributed to be available (but it's not)",
472474
)
473475

474-
if isinstance(model, thunder.ThunderModule):
476+
if isinstance(model, ThunderModule):
475477
from thunder.core.transforms import add_transform
476478
from thunder.distributed.transforms.fsdp_v2 import FSDPTransform
477479
from thunder.transforms import MaterializationTransform
@@ -488,11 +490,18 @@ def fsdp(
488490
sharding_strategy=sharding_strategy,
489491
bucketing_strategy=bucketing_strategy,
490492
release_original_parameters=True,
493+
move_state_dict_to_cpu=False if move_state_dict_to_cpu is None else move_state_dict_to_cpu,
491494
),
492495
MaterializationTransform(device, init=MaterializationTransform.init_from_original_module_init()),
493496
],
494497
)
495498

499+
if move_state_dict_to_cpu is not None:
500+
import warnings
501+
502+
warnings.warn(
503+
"`move_state_dict_to_cpu` is only effective when `model` is `ThunderModule`, i.e., `thunder.jit(model)`"
504+
)
496505
process_group = copy_default_process_group()
497506
utils.check(process_group is not None, lambda: "The default process group is None")
498507
model.use_fsdp = True

thunder/distributed/transforms/fsdp_v2.py

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Transform for `fsdp(jit(model))` to convert a model to use fsdp."""
22

33
from __future__ import annotations
4-
import copy
54
from dataclasses import dataclass
65
from dataclasses import field
76
from itertools import chain
@@ -27,14 +26,13 @@
2726
copy_default_process_group,
2827
FSDPType,
2928
FSDPBucketingStrategy,
30-
_materialize,
31-
_shard_param,
3229
_shard_tensor,
3330
)
3431

3532
if TYPE_CHECKING:
3633
from typing import Any
3734
from torch.distributed import ProcessGroup
35+
from thunder.core.module import ThunderModule
3836
from thunder.core.symbol import BoundSymbol
3937
from thunder.core.trace import VariableInterface
4038

@@ -71,7 +69,7 @@ def __call__(self, bsym: BoundSymbol) -> VISIT_TYPE:
7169
return VISIT_TYPE.REPLACE
7270

7371

74-
# When the user calls fsdp(jitted_module), or applies this Transform direcly, it does the following
72+
# When the user calls fsdp(jitted_module), or applies this Transform directly, it does the following
7573
# - It transforms the ThunderModule jitted_module, sharding the parameters as `overrides`
7674
# in the ThunderModule.
7775
# - While doing that, it leaves the original user module alone, except when
@@ -86,21 +84,16 @@ def __call__(self, bsym: BoundSymbol) -> VISIT_TYPE:
8684
#
8785
# The thunder.distributed.fsdp function calls FSDPTransform followed by MaterializationTransform, the latter does
8886
# the materialization of submodules previously on the meta device.
89-
90-
9187
class FSDPTransform(Transform):
92-
sharded_params: dict[str, Any]
93-
process_group: ProcessGroup
94-
shared_params_name: dict[str, str]
95-
9688
def __init__(
9789
self,
9890
device: torch.device | None = None,
9991
broadcast_from: int | None = None,
10092
sharding_strategy: FSDPType = FSDPType.ZERO2,
10193
bucketing_strategy: FSDPBucketingStrategy = FSDPBucketingStrategy.NONE,
10294
release_original_parameters: bool = False,
103-
):
95+
move_state_dict_to_cpu: bool = False,
96+
) -> None:
10497
self.device = device
10598
self.broadcast_from = broadcast_from
10699
self.sharding_strategy = sharding_strategy
@@ -109,13 +102,13 @@ def __init__(
109102
self.sharded_params: dict[str, Any] = {}
110103
self.process_group: ProcessGroup | None = None
111104
self.shared_params_name: dict[str, str] = {}
105+
self.move_state_dict_to_cpu = move_state_dict_to_cpu
112106

113107
def transform_module(
114108
self,
115109
thunder_model: ThunderModule,
116110
):
117111
from thunder import compile_data as get_compile_data
118-
from thunder.core.transforms import add_transform
119112
from thunder.core.module import ThunderModule
120113

121114
self.process_group = copy_default_process_group()
@@ -250,8 +243,11 @@ def transform_module(
250243
p_orig._thunder_device = self.device
251244

252245
def transform_state_dict_for_submodule(
253-
self, model: thunder.ThunderModule, submodule_name: str, state_dict: dict
254-
) -> dict:
246+
self,
247+
model: ThunderModule,
248+
submodule_name: str,
249+
state_dict: dict[str, Any],
250+
) -> dict[str, Any]:
255251
prefix = ""
256252
if submodule_name:
257253
prefix = f"{submodule_name}."
@@ -263,6 +259,36 @@ def transform_state_dict_for_submodule(
263259
new_state_dict[k] = v
264260
return new_state_dict
265261

262+
def reverse_transform_state_dict_for_submodule(
263+
self,
264+
model: ThunderModule,
265+
submodule_name: str,
266+
state_dict: dict[str, Any],
267+
) -> dict[str, Any]:
268+
from thunder.executors.torchex import _all_gather_prim_impl
269+
270+
for name, tensor in state_dict.items():
271+
fqn: str
272+
if submodule_name:
273+
fqn = f"{submodule_name}.{name}"
274+
else:
275+
fqn = name
276+
277+
if fqn not in self.sharded_params:
278+
continue
279+
280+
old_shape, *_ = self.sharded_params[fqn]
281+
unsharded_tensor = _all_gather_prim_impl(
282+
tensor,
283+
group=self.process_group,
284+
do_async=False,
285+
dim=None,
286+
).narrow(0, 0, old_shape[0])
287+
if self.move_state_dict_to_cpu:
288+
unsharded_tensor = unsharded_tensor.cpu()
289+
state_dict[name] = unsharded_tensor
290+
return state_dict
291+
266292
def transform_traces_pre_prologue(self, prologue_trace, computation_trace, epilogue_trace, **kwargs):
267293
from thunder.distributed import prims as dist_prims
268294

thunder/tests/distributed/test_fsdp.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -685,6 +685,28 @@ def test_load_original_state_dict(self):
685685

686686
torch.testing.assert_close(y_1, y_2)
687687

688+
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires 2 devices")
689+
def test_original_state_dict(self):
690+
device = torch.device("cuda", self.rank)
691+
692+
for move_state_dict_to_cpu in (False, True):
693+
with torch.device("cuda"):
694+
model = ToyModel()
695+
696+
init_state_dict = model.state_dict()
697+
jitted = fsdp(thunder.jit(model), device=device, move_state_dict_to_cpu=move_state_dict_to_cpu)
698+
699+
sharded_state_dict = jitted.state_dict()
700+
original_state_dict = jitted.original_state_dict()
701+
for key, unsharded in original_state_dict.items():
702+
self.assertTrue(key in init_state_dict and key in sharded_state_dict)
703+
self.assertEqual(len(init_state_dict[key]), len(unsharded))
704+
self.assertGreater(len(unsharded), len(sharded_state_dict[key]))
705+
if move_state_dict_to_cpu:
706+
self.assertEqual(unsharded.device, torch.device("cpu"))
707+
else:
708+
self.assertEqual(unsharded.device, device)
709+
688710

689711
common_utils.instantiate_parametrized_tests(FSDPTest)
690712

0 commit comments

Comments
 (0)