Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reverse transform of state_dict per submodule #1011

Merged
merged 6 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 59 additions & 11 deletions thunder/core/module.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations
from contextlib import contextmanager
import itertools
from typing import Any
from collections.abc import Mapping
from typing import TYPE_CHECKING
import collections

import torch as pytorch
Expand All @@ -10,6 +10,22 @@
import thunder
from thunder.core.compile_data import get_compile_data

if TYPE_CHECKING:
from collections.abc import Mapping
from typing import Any
from thunder.core.transform_common import Transform


def _convert_state_dict_to_per_module(state_dict: dict[str, Any], module_names: set[str]) -> dict[str, dict[str, Any]]:
state_dict_per_module = collections.defaultdict(dict)
for k, v in state_dict.items():
prefix, sep, _ = k.rpartition(".")
# not great but should not happen too often / deep
while prefix not in module_names:
prefix, sep, _ = prefix.rpartition(".")
state_dict_per_module[prefix][k[len(prefix) + len(sep) :]] = v
return state_dict_per_module


class ThunderModule(pytorch.nn.Module):
"""A wrapper nn.Module subclass.
Expand Down Expand Up @@ -120,14 +136,7 @@ def _get_shared_names(self):

def load_original_state_dict(self, state_dict):
# this loads the state dict incrementally to not exhaust memory
module_names = {n for n, _ in self._model.named_modules()}
sd_per_module = collections.defaultdict(dict)
for k, v in state_dict.items():
prefix, sep, _ = k.rpartition(".")
# not great but should not happen too often / deep
while prefix not in module_names:
prefix, sep, _ = prefix.rpartition(".")
sd_per_module[prefix][k[len(prefix) + len(sep) :]] = v
sd_per_module = _convert_state_dict_to_per_module(state_dict, {n for n, _ in self._model.named_modules()})

shared_names = self._get_shared_names()
processed_names = set()
Expand Down Expand Up @@ -182,7 +191,7 @@ def state_dict(self, *, destination=None, prefix="", keep_vars=False):
Returns the state dict of the (transformed) Thunder module.

Args:
destination: if given, use this mutuable mapping as the dict container.
destination: if given, use this mutable mapping as the dict container.
prefix: a prefix for the keys.
keep_vars: do not detach

Expand Down Expand Up @@ -215,6 +224,45 @@ def state_dict(self, *, destination=None, prefix="", keep_vars=False):
destination[extra_state_key] = self.get_extra_state()
return destination

def original_state_dict(
self,
*,
destination: dict[str, Any] | None = None,
prefix: str = "",
keep_vars: bool = False,
) -> dict[str, Any]:
"""Returns the state dict of the transformed :class:`ThunderModule` with reverse transform applied.

For example, :func:`ThunderModule.state_dict` returns a state dict of sharded tensors if
a model is :func:`thunder.distributed.fsdp` applied while :func:`ThunderModule.original_state_dict`
returns a state dict of unsharded tensors.

Args:
destination: if given, use this mutable mapping as the dict container.
prefix: a prefix for the keys.
keep_vars: do not detach

"""
module_names = {name for name, _ in self._model.named_modules()}
state_dict = self.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
t-vi marked this conversation as resolved.
Show resolved Hide resolved
state_dict_per_submodule = _convert_state_dict_to_per_module(state_dict, module_names)

_state_dict_keys = list(state_dict.keys())
t-vi marked this conversation as resolved.
Show resolved Hide resolved
cur_idx: int = 0
transform: Transform
for submodule_name, submodule_state_dict in state_dict_per_submodule.items():
for transform in reversed(self._lc_transforms):
submodule_state_dict = transform.reverse_transform_state_dict_for_submodule(
self,
submodule_name,
submodule_state_dict,
)
state_dict_per_submodule[submodule_name] = submodule_state_dict
for v in submodule_state_dict.values():
state_dict[_state_dict_keys[cur_idx]] = v
cur_idx += 1
return state_dict

def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False):
"""Loads the state dict to a transformed module.

Expand Down
29 changes: 19 additions & 10 deletions thunder/core/transform_common.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from __future__ import annotations
import time
from typing import TYPE_CHECKING
from abc import ABC, abstractmethod
from collections import defaultdict
from abc import ABC
from collections.abc import Sequence
from collections import defaultdict
from itertools import filterfalse, chain
from itertools import filterfalse
from functools import partial

import thunder
Expand All @@ -15,11 +13,11 @@
from thunder.core.pytree import tree_flatten, tree_iter, tree_map, tree_unflatten
from thunder.core.symbol import BoundSymbol, BoundSymbolRHS, has_tags
from thunder.core.trace import from_trace, TraceProvenance, TraceCtx as Trace, tracectx
from thunder.core.utils import ProxyDict, producers, check, consumers
from thunder.core.utils import ProxyDict, producers, check

if TYPE_CHECKING:
from thunder.core.proxies import ProxyInterface
from thunder.core.symbol import Symbol, VariableInterface
from typing import Any
from thunder.core.module import ThunderModule


#
Expand Down Expand Up @@ -346,13 +344,16 @@ def transform_traces_pre_prologue(
# default to noop
return prologue_trace, computation_trace, epilogue_trace

def transform_module(self, model: thunder.ThunderModule):
def transform_module(self, model: ThunderModule) -> None:
"""Transforms the ThunderModule. This is executed once on application of the transform"""
pass

def transform_state_dict_for_submodule(
self, model: thunder.ThunderModule, submodule_name: str, state_dict: dict
) -> dict:
self,
model: ThunderModule,
submodule_name: str,
state_dict: dict[str, Any],
) -> dict[str, Any]:
"""
Implement this to transform the state dict (mostly parameters and buffers) of a module, e.g. when loading
from a state dict of the original model.
Expand All @@ -370,6 +371,14 @@ def transform_trace_post_optimization(self, computation_trace: Trace, **kwargs):
"""
return computation_trace

def reverse_transform_state_dict_for_submodule(
self,
model: ThunderModule,
submodule_name: str,
state_dict: dict[str, Any],
) -> dict[str, Any]:
return state_dict


def order_proxies(bsyms: Sequence[BoundSymbol]) -> dict[str, int]:
"""computes a canonical ordering of proxies in the bound symbols based on the order of appearance
Expand Down
19 changes: 14 additions & 5 deletions thunder/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from itertools import chain
from contextlib import contextmanager
from contextvars import ContextVar, Token
import copy
from enum import auto, Enum
from typing import TYPE_CHECKING, Any
from collections.abc import Generator
Expand Down Expand Up @@ -423,13 +422,14 @@ def f(tensor: TensorProxy) -> str:


def fsdp(
model: torch.nn.Module,
model: torch.nn.Module | ThunderModule,
*,
device: torch.device | None = None,
broadcast_from: int | None = None,
sharding_strategy: FSDPType = FSDPType.ZERO2,
bucketing_strategy: FSDPBucketingStrategy = FSDPBucketingStrategy.NONE,
) -> torch.nn.Module:
move_state_dict_to_cpu: bool | None = None,
) -> torch.nn.Module | ThunderModule:
"""Convert ``model`` into Fully Sharded Data Parallel.

This splits ``model``'s parameters in their first dimension into ``world_size`` chunks
Expand Down Expand Up @@ -458,20 +458,22 @@ def fsdp(
from a checkpoint in a single rank.
sharding_strategy:
bucketing_strategy:
move_state_dict_to_cpu: Move all-gather'ed parameters of :func:`~thunder.core.module.ThunderModule.original_state_dict` to CPU
as each all-gather is finished.

Returns:
:class:`torch.nn.Module`

"""
import thunder
from thunder.core.module import ThunderModule

utils.check(isinstance(sharding_strategy, FSDPType), lambda: f"FSDPType.ZERO2 and FSDPType.ZERO3 are supported.")
utils.check(
tdist.is_available(),
lambda: "fsdp requires torch distributed to be available (but it's not)",
)

if isinstance(model, thunder.ThunderModule):
if isinstance(model, ThunderModule):
from thunder.core.transforms import add_transform
from thunder.distributed.transforms.fsdp_v2 import FSDPTransform
from thunder.transforms import MaterializationTransform
Expand All @@ -488,11 +490,18 @@ def fsdp(
sharding_strategy=sharding_strategy,
bucketing_strategy=bucketing_strategy,
release_original_parameters=True,
move_state_dict_to_cpu=False if move_state_dict_to_cpu is None else move_state_dict_to_cpu,
),
MaterializationTransform(device, init=MaterializationTransform.init_from_original_module_init()),
],
)

if move_state_dict_to_cpu is not None:
import warnings

warnings.warn(
"`move_state_dict_to_cpu` is only effective when `model` is `ThunderModule`, i.e., `thunder.jit(model)`"
)
process_group = copy_default_process_group()
utils.check(process_group is not None, lambda: "The default process group is None")
model.use_fsdp = True
Expand Down
54 changes: 40 additions & 14 deletions thunder/distributed/transforms/fsdp_v2.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Transform for `fsdp(jit(model))` to convert a model to use fsdp."""

from __future__ import annotations
import copy
from dataclasses import dataclass
from dataclasses import field
from itertools import chain
Expand All @@ -27,14 +26,13 @@
copy_default_process_group,
FSDPType,
FSDPBucketingStrategy,
_materialize,
_shard_param,
_shard_tensor,
)

if TYPE_CHECKING:
from typing import Any
from torch.distributed import ProcessGroup
from thunder.core.module import ThunderModule
from thunder.core.symbol import BoundSymbol
from thunder.core.trace import VariableInterface

Expand Down Expand Up @@ -71,7 +69,7 @@ def __call__(self, bsym: BoundSymbol) -> VISIT_TYPE:
return VISIT_TYPE.REPLACE


# When the user calls fsdp(jitted_module), or applies this Transform direcly, it does the following
# When the user calls fsdp(jitted_module), or applies this Transform directly, it does the following
crcrpar marked this conversation as resolved.
Show resolved Hide resolved
# - It transforms the ThunderModule jitted_module, sharding the parameters as `overrides`
# in the ThunderModule.
# - While doing that, it leaves the original user module alone, except when
Expand All @@ -86,21 +84,16 @@ def __call__(self, bsym: BoundSymbol) -> VISIT_TYPE:
#
# The thunder.distributed.fsdp function calls FSDPTransform followed by MaterializationTransform, the latter does
# the materialization of submodules previously on the meta device.


class FSDPTransform(Transform):
sharded_params: dict[str, Any]
process_group: ProcessGroup
shared_params_name: dict[str, str]

def __init__(
self,
device: torch.device | None = None,
broadcast_from: int | None = None,
sharding_strategy: FSDPType = FSDPType.ZERO2,
bucketing_strategy: FSDPBucketingStrategy = FSDPBucketingStrategy.NONE,
release_original_parameters: bool = False,
):
move_state_dict_to_cpu: bool = False,
) -> None:
self.device = device
self.broadcast_from = broadcast_from
self.sharding_strategy = sharding_strategy
Expand All @@ -109,13 +102,13 @@ def __init__(
self.sharded_params: dict[str, Any] = {}
self.process_group: ProcessGroup | None = None
self.shared_params_name: dict[str, str] = {}
self.move_state_dict_to_cpu = move_state_dict_to_cpu

def transform_module(
self,
thunder_model: ThunderModule,
):
from thunder import compile_data as get_compile_data
from thunder.core.transforms import add_transform
from thunder.core.module import ThunderModule

self.process_group = copy_default_process_group()
Expand Down Expand Up @@ -250,8 +243,11 @@ def transform_module(
p_orig._thunder_device = self.device

def transform_state_dict_for_submodule(
self, model: thunder.ThunderModule, submodule_name: str, state_dict: dict
) -> dict:
self,
model: ThunderModule,
submodule_name: str,
state_dict: dict[str, Any],
) -> dict[str, Any]:
prefix = ""
if submodule_name:
prefix = f"{submodule_name}."
Expand All @@ -263,6 +259,36 @@ def transform_state_dict_for_submodule(
new_state_dict[k] = v
return new_state_dict

def reverse_transform_state_dict_for_submodule(
self,
model: ThunderModule,
submodule_name: str,
state_dict: dict[str, Any],
) -> dict[str, Any]:
from thunder.executors.torchex import _all_gather_prim_impl

for name, tensor in state_dict.items():
t-vi marked this conversation as resolved.
Show resolved Hide resolved
fqn: str
if submodule_name:
fqn = f"{submodule_name}.{name}"
else:
fqn = name

if fqn not in self.sharded_params:
continue

old_shape, *_ = self.sharded_params[fqn]
unsharded_tensor = _all_gather_prim_impl(
tensor,
group=self.process_group,
do_async=False,
dim=None,
).narrow(0, 0, old_shape[0])
if self.move_state_dict_to_cpu:
unsharded_tensor = unsharded_tensor.cpu()
state_dict[name] = unsharded_tensor
return state_dict

def transform_traces_pre_prologue(self, prologue_trace, computation_trace, epilogue_trace, **kwargs):
from thunder.distributed import prims as dist_prims

Expand Down
20 changes: 20 additions & 0 deletions thunder/tests/distributed/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,26 @@ def test_load_original_state_dict(self):

torch.testing.assert_close(y_1, y_2)

@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires 2 devices")
def test_original_state_dict(self):
crcrpar marked this conversation as resolved.
Show resolved Hide resolved
device = torch.device("cuda", self.rank)

for move_state_dict_to_cpu in (False, True):
with torch.device("cuda"):
model = ToyModel()

jitted = fsdp(thunder.jit(model), device=device, move_state_dict_to_cpu=move_state_dict_to_cpu)
state_dict = jitted.state_dict()
original_state_dict = jitted.original_state_dict()

for key, sharded_param in state_dict.items():
unsharded = original_state_dict[key]
self.assertEqual(len(sharded_param) * self.world_size, len(unsharded))
if move_state_dict_to_cpu:
self.assertEqual(unsharded.device, torch.device("cpu"))
else:
self.assertEqual(unsharded.device, device)


common_utils.instantiate_parametrized_tests(FSDPTest)

Expand Down
Loading