Skip to content

Commit

Permalink
update FSDPTransform to delegate materialization to the Materializati…
Browse files Browse the repository at this point in the history
…onTransform (#995)
  • Loading branch information
t-vi authored Aug 20, 2024
1 parent 8c5fc73 commit 820a7d2
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 110 deletions.
2 changes: 1 addition & 1 deletion thunder/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def _get_shared_names(self):
self.named_parameters(remove_duplicate=False), self.named_buffers(remove_duplicate=False)
):
parameters_to_names.setdefault(v, set()).add(name)
shared_names = {}
shared_names: dict[str, set[str]] = {}
for s in parameters_to_names.values():
for n in s:
shared_names[n] = s
Expand Down
13 changes: 9 additions & 4 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ def visitor_transform(trace_from: Trace, visit: Callable, *, provenance: None |
def add_transform(
cfn: Callable,
*,
transform: Transform,
transform: Transform | list[Transform],
disable_torch_autograd_support=False,
_legacy_copy_params=False,
) -> Callable:
Expand All @@ -414,15 +414,20 @@ def add_transform(

utils.check(cd is not None, lambda: f"Can only transform compiled thunder functions")
utils.check(isinstance(cd, CompileData), lambda: f"Found an unknown compile data attribute {cd}")
utils.check_type(transform, Transform)
if isinstance(transform, Transform):
transform = [transform]
else:
utils.check(
all(isinstance(t, Transform) for t in transform),
lambda: "transform must be an instance of Transform or a list of Transform instances.",
)

assert cd.using_jit

from thunder import jit

# todo: move _lc_transforms to compile_data
transforms = cfn._lc_transforms[:]
transforms.append(transform)
transforms = cfn._lc_transforms + transform
jfn = jit(
cd.fn,
langctx=cd.langctx,
Expand Down
21 changes: 14 additions & 7 deletions thunder/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,16 +474,23 @@ def fsdp(
if isinstance(model, thunder.ThunderModule):
from thunder.core.transforms import add_transform
from thunder.distributed.transforms.fsdp_v2 import FSDPTransform
from thunder.transforms import MaterializationTransform

if device is None:
local_rank = int(os.environ["LOCAL_RANK"])
device = torch.device("cuda", local_rank)
return add_transform(
model,
transform=FSDPTransform(
device=device,
broadcast_from=broadcast_from,
sharding_strategy=sharding_strategy,
bucketing_strategy=bucketing_strategy,
release_original_parameters=True,
),
transform=[
FSDPTransform(
device=device,
broadcast_from=broadcast_from,
sharding_strategy=sharding_strategy,
bucketing_strategy=bucketing_strategy,
release_original_parameters=True,
),
MaterializationTransform(device, init=MaterializationTransform.init_from_original_module_init()),
],
)

process_group = copy_default_process_group()
Expand Down
186 changes: 96 additions & 90 deletions thunder/distributed/transforms/fsdp_v2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Transform for `fsdp(jit(model))` to convert a trace into fsdp."""
"""Transform for `fsdp(jit(model))` to convert a model to use fsdp."""

from __future__ import annotations
import copy
Expand All @@ -10,7 +10,6 @@

import torch
import torch.distributed as tdist
from torch.utils.weak import WeakTensorKeyDictionary

from thunder.core import devices
from thunder.core import prims
Expand Down Expand Up @@ -72,16 +71,21 @@ def __call__(self, bsym: BoundSymbol) -> VISIT_TYPE:
return VISIT_TYPE.REPLACE


# When the user calls fsdp(jitted_module), this function does the following
# - It transforms the ThunderModule jitted_module, materializing and sharding the parameters as `overrides`
# When the user calls fsdp(jitted_module), or applies this Transform direcly, it does the following
# - It transforms the ThunderModule jitted_module, sharding the parameters as `overrides`
# in the ThunderModule.
# - While doing that, it leaves the original user module alone.
# - It then registers an transform (callback that runs before prologue is executed) that transforms the
# prologue and compute trace.
# - While doing that, it leaves the original user module alone, except when
# releasing the original tensors is requested (for memory consumption).
# - When needed, a submodule state dict from the unsharded submodule can be transformed into a one of the sharded
# submodule. This is used by MaterializationTransform and thunder_model.load_original_state_dict.
# - The prologue and compute trace are transformed, inserting communication and reflecting the weight shape changes.
#
# Note that for doing so, there are a few constraints / caveats:
# - We do not have prologues/compute traces when we transform the module.
# - We need to record the info from the module transformations because a later transform might modify the module further.
#
# The thunder.distributed.fsdp function calls FSDPTransform followed by MaterializationTransform, the latter does
# the materialization of submodules previously on the meta device.


class FSDPTransform(Transform):
Expand Down Expand Up @@ -140,10 +144,6 @@ def transform_module(
# modify module
self.sharded_params = {}
device_adjustments = {}
# We use `shared_params` dictionary to track the shared parameters.
# Key to this dictionary is the original parameter from the user's Module.
# Values are the copied and sharded parameter for the thunder module and meta-data related to sharding.
shared_params = WeakTensorKeyDictionary()

# NOTE: Shared Parameters in Trace
# Shared parameters in PyTorch eager are parameters of module which have different name but share the underlying tensor.
Expand All @@ -156,92 +156,98 @@ def transform_module(

# This is used to track the shared parameters when the transform is applied.
# key - parameter name, value - `base` parameter name.
self.shared_params_name: dict[str, str] = {}
for module_name, _ in thunder_model._model.named_modules():
submodule = thunder_model.get_submodule(module_name)

# we use a copy to let the user's module alone (TODO: does this fully work?)
module_copy = copy.copy(submodule)
# TODO: we should probably populate the module copy with parameters that consider overrides

# Materialize meta-parameters on-device if necessary.
# This is done before sharding in case the materialization logic depends on the tensor shape.
# The tradeoff is that all of a module's direct parameters need to fit in device.
# Each module only initializes its own parameters and not those of its children (recurse=False)
if any(t.is_meta for t in chain(module_copy.parameters(recurse=False), module_copy.buffers(recurse=False))):
# TODO: we could also support calling a "param_init_fn" argument like PyTorch
_materialize(module_copy, self.device)
for n, p in module_copy.named_parameters(recurse=False, prefix=module_name):
thunder_model._overrides_parameters[n] = p
device_adjustments[n] = self.device
for n, b in module_copy.named_buffers(recurse=False, prefix=module_name):
thunder_model._overrides_buffers[n] = b
device_adjustments[n] = self.device
else:
# Move leftover params and buffers to device. This is at least required to broadcast.
# Cannot `submodule.to(device)` because we don't want it to recurse
for n, p in module_copy.named_parameters(recurse=False, prefix=module_name):
if p.device != self.device:
thunder_model._overrides_parameters[n] = torch.nn.Parameter(
p.to(device=self.device), requires_grad=p.requires_grad
)
device_adjustments[n] = self.device
for n, b in module_copy.named_buffers(recurse=False, prefix=module_name):
if b.device != self.device:
thunder_model._overrides_buffers[n] = b.to(device=self.device)
device_adjustments[n] = self.device

# Broadcast parameters if requested
if self.broadcast_from is not None:
for pn, _ in submodule.named_parameters(recurse=False, prefix=module_name):
# Note that .named_parameters / .named_buffers used below only return a duplicated parameter once.

shared_names = thunder_model._get_shared_names()
self.shared_params_name = {}

# For materialized parameters and buffers, we move them to the target device as necessary
# for un-materialized parameters and buffers, we set the ._thunder_device

is_fully_materialized = True
for n, p in thunder_model.named_parameters():
for n2 in shared_names[n]:
if n2 != n:
self.shared_params_name[n2] = n
try:
orig_p = thunder_model._model.get_parameter(n)
except AttributeError:
orig_p = None
if p.is_meta:
is_fully_materialized = False
p._thunder_device = self.device
if orig_p is not None:
orig_p._thunder_device = self.device
# TODO: check if device_adjustments are still needed
for n2 in shared_names[n]:
device_adjustments[n2] = self.device
elif p.device != self.device:
with torch.no_grad():
new_p = torch.nn.Parameter(p.to(device=self.device), requires_grad=p.requires_grad)
for n2 in shared_names[n]:
thunder_model._overrides_parameters[n2] = new_p
device_adjustments[n2] = self.device

for n, b in thunder_model.named_buffers():
try:
orig_b = thunder_model._model.get_buffer(n)
except AttributeError:
orig_b = None
if b.is_meta:
is_fully_materialized = False
b._thunder_device = self.device
if orig_b is not None:
orig_b._thunder_device = self.device
# TODO: check if this is still needed
device_adjustments[n] = self.device
elif b.device != self.device:
new_b = b.to(device=self.device)
for n2 in shared_names[n]:
thunder_model._overrides_buffers[n2] = new_p
device_adjustments[n2] = self.device

# Broadcast parameters if requested
# (todos shared with thunder/distributed/_init__.py)
# TODO Make these broadcast asyncs
# TODO Perform up to two broadcasts at a time
# See issue "Update ddp to use async broadcasts"
# TODO "Bucket" small tensors together before broadcasting
if self.broadcast_from is not None:
if not is_fully_materialized:
# Note: we could move broadcasting into its own transform coming
# after materialization (in thunder.distributed.fsdp) to
# support this, if it is useful.
raise RuntimeError("cannot broadcast from non-materialized model")
with torch.no_grad():
for pn, p in chain(thunder_model.named_parameters(), thunder_model.named_buffers()):
tdist.broadcast(
thunder_model.get_parameter(pn),
p,
src=self.broadcast_from,
group=self.process_group,
async_op=False,
)
for pn, _ in submodule.named_buffers(recurse=False, prefix=module_name):
tdist.broadcast(
thunder_model.get_buffer(pn), src=self.broadcast_from, group=self.process_group, async_op=False
)

for pn, p in submodule.named_parameters(recurse=False, prefix=module_name):
# If there are shared params in the original user Module, we reuse the sharded copy created from the original parameter below.
# This way we re-create parameter sharing in thunder's copy of the Module.
if p in shared_params:
# Shared param names : current param - base param
self.shared_params_name[pn] = shared_params[p]["param_name"]
# Re-use the previous copy of this parameter.
thunder_model._overrides_parameters[pn] = shared_params[p]["param_copy"]
self.sharded_params[pn] = shared_params[p]["param_shard_meta"]
continue

# if we don't have an override or it is just the original, do create a copy
if thunder_model._overrides_parameters.get(pn, p) is p:
thunder_model._overrides_parameters[pn] = copy.copy(p)
# we collect shapes and devices because we do not know if other transforms also change it...
old_shape = thunder_model._overrides_parameters[pn].shape
_shard_param(
thunder_model._overrides_parameters[pn], global_rank, world_size, pn, allow_padding_for_fsdp=True
)
new_shape = thunder_model._overrides_parameters[pn].shape
self.sharded_params[pn] = (old_shape, new_shape, thunder_model._overrides_parameters[pn].device)
if self.release_original_parameters:
base_pn = pn.rsplit(".", 1)[-1]
p_orig = getattr(submodule, base_pn)
if p_orig.device.type != "meta":
p_meta = torch.nn.Parameter(p.to(device="meta"), requires_grad=p.requires_grad)
p_meta._thunder_device = p_orig.device
# do the actual sharding. Note that meta tensors will give sharded meta tensors.
for pn, p in list(thunder_model.named_parameters()):
# we collect shapes and devices because we do not know if other transforms also change it.
old_shape = p.shape
p_new, _ = _shard_tensor(p, global_rank, world_size, pn, allow_padding_for_fsdp=True)
p_new = torch.nn.Parameter(p_new.clone(), requires_grad=p.requires_grad)
new_shape = p_new.shape
for n2 in shared_names[pn]:
thunder_model._overrides_parameters[n2] = p_new
self.sharded_params[n2] = (old_shape, new_shape, getattr(p, "_thunder_device", p.device))
if self.release_original_parameters:
p_orig = thunder_model._model.get_parameter(pn)
if p_orig.device.type != "meta":
p_meta = torch.nn.Parameter(p_orig.to(device="meta"), requires_grad=p_orig.requires_grad)
p_meta._thunder_device = p_orig.device
for n2 in shared_names[pn]:
submodule_name, _, base_pn = n2.rpartition(".")
submodule = thunder_model._model.get_submodule(submodule_name)
submodule.register_parameter(base_pn, p_meta)
else:
p_orig._thunder_device = self.device

# Track the original param and it's corresponding copied shard and metadata.
shared_params[p] = {
"param_copy": thunder_model._overrides_parameters[pn],
"param_shard_meta": self.sharded_params[pn],
"param_name": pn,
}
else:
p_orig._thunder_device = self.device

def transform_state_dict_for_submodule(
self, model: thunder.ThunderModule, submodule_name: str, state_dict: dict
Expand Down
31 changes: 23 additions & 8 deletions thunder/transforms/materialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def transform_module(self, model: ThunderModule):
p._thunder_device = self.device

shared_names = model._get_shared_names()
self.have_materialized: set[str] = set()

# note: the iterations below are without duplicates
for n, p in list(model.named_parameters()):
Expand All @@ -55,16 +56,18 @@ def transform_module(self, model: ThunderModule):
torch.empty_like(p, device=getattr(p, "_thunder_device", self.device)),
requires_grad=p.requires_grad,
)
for nn in shared_names[n]:
model._overrides_parameters[nn] = p
for n2 in shared_names[n]:
model._overrides_parameters[n2] = p
self.have_materialized.add(n2)

for n, b in list(model.named_buffers()):
if b.device.type == "meta":
b = torch.empty_like(
b, device=getattr(b, "_thunder_device", self.device), requires_grad=b.requires_grad
)
for nn in shared_names[n]:
model._overrides_parameters[nn] = b
for n2 in shared_names[n]:
model._overrides_parameters[n2] = b
self.have_materialized.add(n2)

self.init(self, model)

Expand Down Expand Up @@ -101,16 +104,28 @@ def module_init_from_original_module_init(transform: MaterializationTransform, t
prefix = module_name if not module_name else f"{module_name}."
submodule = tm.get_submodule(module_name)

# we use a copy to let the user's module alone
module_copy = copy.copy(submodule)

# Materialize meta-parameters on-device if necessary.
# This is done before sharding in case the materialization logic depends on the tensor shape.
# The tradeoff is that all of a module's direct parameters need to fit in device.
# Each module only initializes its own parameters and not those of its children (recurse=False)
if any(
t.is_meta for t in chain(module_copy.parameters(recurse=False), module_copy.buffers(recurse=False))
chain(
(
transform.have_materialized
for n, _ in submodule.named_parameters(recurse=False, prefix=module_name)
),
(
transform.have_materialized
for n, _ in submodule.named_buffers(recurse=False, prefix=module_name)
),
)
):
# we use a copy to let the user's module alone
module_copy = copy.copy(submodule)
module_copy._parameters = module_copy._parameters.copy()
module_copy._buffers = module_copy._buffers.copy()
module_copy._modules = module_copy._modules.__class__()

# we need to initialize the module unless all parameters are duplicatess
need_init = not all(
shared_names[n] & processed_names
Expand Down

0 comments on commit 820a7d2

Please sign in to comment.