Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update FSDPTransform to delegate materialization to the MaterializationTransform #995

Merged
merged 21 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
b4a04c0
Move FSDP module transformation to transform_module
t-vi Aug 17, 2024
597ece2
import _materialize
t-vi Aug 18, 2024
5da2880
Update thunder/distributed/transforms/fsdp_v2.py
t-vi Aug 18, 2024
96d026e
Update thunder/distributed/transforms/fsdp_v2.py
t-vi Aug 18, 2024
98f4692
Merge remote-tracking branch 'origin/main' into tom/fsdp-composable-s…
t-vi Aug 19, 2024
f05768e
free orig parameters after sharding
t-vi Aug 18, 2024
2357aad
Update thunder/distributed/transforms/fsdp_v2.py
t-vi Aug 18, 2024
22a356a
implement FSDPTransform.transform_state_dict_for_submodule
t-vi Aug 18, 2024
5503317
Update thunder/distributed/transforms/fsdp_v2.py
t-vi Aug 19, 2024
be3c931
Update thunder/distributed/__init__.py
t-vi Aug 19, 2024
fd2efea
update FSDPTransform to delegate materialization to the Materializati…
t-vi Aug 19, 2024
96fe840
Apply suggestions from code review
t-vi Aug 19, 2024
4ae9d49
Update thunder/distributed/transforms/fsdp_v2.py
t-vi Aug 19, 2024
7af59d3
don't use nn as a local variable name, thank you Masaki
t-vi Aug 19, 2024
7e821ef
Merge branch 'main' into tom/fsdp-composable-step3
t-vi Aug 19, 2024
a6f9a0e
be more thorough in copying the module before to_empty
t-vi Aug 19, 2024
6564633
Merge branch 'tom/fsdp-composable-step3' into tom/fsdp-composable-step4
t-vi Aug 19, 2024
8e74d23
better keep track of materialization
t-vi Aug 20, 2024
d650f4c
Merge branch 'main' into tom/fsdp-composable-step4
t-vi Aug 20, 2024
d86334a
no-grad for broadcast
t-vi Aug 20, 2024
9f6aee5
Apply suggestions from code review
t-vi Aug 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion thunder/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def _get_shared_names(self):
self.named_parameters(remove_duplicate=False), self.named_buffers(remove_duplicate=False)
):
parameters_to_names.setdefault(v, set()).add(name)
shared_names = {}
shared_names: dict[str, set[str]] = {}
for s in parameters_to_names.values():
for n in s:
shared_names[n] = s
Expand Down
13 changes: 9 additions & 4 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def visitor_transform(trace_from: Trace, visit: Callable, *, provenance: None |
def add_transform(
cfn: Callable,
*,
transform: Transform,
transform: Transform | list[Transform],
lantiga marked this conversation as resolved.
Show resolved Hide resolved
disable_torch_autograd_support=False,
_legacy_copy_params=False,
) -> Callable:
Expand All @@ -413,15 +413,20 @@ def add_transform(

utils.check(cd is not None, lambda: f"Can only transform compiled thunder functions")
utils.check(isinstance(cd, CompileData), lambda: f"Found an unknown compile data attribute {cd}")
utils.check_type(transform, Transform)
if isinstance(transform, Transform):
transform = [transform]
else:
utils.check(
all(isinstance(t, Transform) for t in transform),
lambda: "transform must be an instance of Transform or a list of Transform instances.",
)

assert cd.using_jit

from thunder import jit

# todo: move _lc_transforms to compile_data
transforms = cfn._lc_transforms[:]
transforms.append(transform)
transforms = cfn._lc_transforms + transform
jfn = jit(
cd.fn,
langctx=cd.langctx,
Expand Down
21 changes: 14 additions & 7 deletions thunder/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,16 +474,23 @@ def fsdp(
if isinstance(model, thunder.ThunderModule):
from thunder.core.transforms import add_transform
from thunder.distributed.transforms.fsdp_v2 import FSDPTransform
from thunder.transforms import MaterializationTransform

if device is None:
local_rank = int(os.environ["LOCAL_RANK"])
device = torch.device("cuda", local_rank)
return add_transform(
model,
transform=FSDPTransform(
device=device,
broadcast_from=broadcast_from,
sharding_strategy=sharding_strategy,
bucketing_strategy=bucketing_strategy,
release_original_parameters=True,
),
transform=[
FSDPTransform(
device=device,
broadcast_from=broadcast_from,
sharding_strategy=sharding_strategy,
bucketing_strategy=bucketing_strategy,
release_original_parameters=True,
),
MaterializationTransform(device, init=MaterializationTransform.init_from_original_module_init()),
],
)

process_group = copy_default_process_group()
Expand Down
190 changes: 95 additions & 95 deletions thunder/distributed/transforms/fsdp_v2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Transform for `fsdp(jit(model))` to convert a trace into fsdp."""
"""Transform for `fsdp(jit(model))` to convert a model to use fsdp."""

from __future__ import annotations
import copy
Expand All @@ -10,7 +10,6 @@

import torch
import torch.distributed as tdist
from torch.utils.weak import WeakTensorKeyDictionary

from thunder.core import devices
from thunder.core import prims
Expand Down Expand Up @@ -72,16 +71,21 @@ def __call__(self, bsym: BoundSymbol) -> VISIT_TYPE:
return VISIT_TYPE.REPLACE


# When the user calls fsdp(jitted_module), this function does the following
# - It transforms the ThunderModule jitted_module, materializing and sharding the parameters as `overrides`
# When the user calls fsdp(jitted_module), or applies this Transform direcly, it does the following
# - It transforms the ThunderModule jitted_module, sharding the parameters as `overrides`
# in the ThunderModule.
# - While doing that, it leaves the original user module alone.
# - It then registers an transform (callback that runs before prologue is executed) that transforms the
# prologue and compute trace.
# - While doing that, it leaves the original user module alone, except when
# releasing the original tensors is requested (for memory consumption).
# - When needed, a submodule state dict from the unsharded submodule can be transformed into a one of the sharded
# submodule. This is used by MaterializationTransform and thunder_model.load_original_state_dict.
# - The prologue and compute trace are transformed, inserting communication and reflecting the weight shape changes.
#
# Note that for doing so, there are a few constraints / caveats:
# - We do not have prologues/compute traces when we transform the module.
# - We need to record the info from the module transformations because a later transform might modify the module further.
#
# The thunder.distributed.fsdp function calls FSDPTransform followed by MaterializationTransform, the latter does
# the materialization of submodules previously on the meta device.


class FSDPTransform(Transform):
Expand Down Expand Up @@ -140,10 +144,6 @@ def transform_module(
# modify module
self.sharded_params = {}
device_adjustments = {}
# We use `shared_params` dictionary to track the shared parameters.
# Key to this dictionary is the original parameter from the user's Module.
# Values are the copied and sharded parameter for the thunder module and meta-data related to sharding.
shared_params = WeakTensorKeyDictionary()

# NOTE: Shared Parameters in Trace
# Shared parameters in PyTorch eager are parameters of module which have different name but share the underlying tensor.
Expand All @@ -156,92 +156,92 @@ def transform_module(

# This is used to track the shared parameters when the transform is applied.
# key - parameter name, value - `base` parameter name.
self.shared_params_name: dict[str, str] = {}
for module_name, _ in thunder_model._model.named_modules():
submodule = thunder_model.get_submodule(module_name)

# we use a copy to let the user's module alone (TODO: does this fully work?)
module_copy = copy.copy(submodule)
# TODO: we should probably populate the module copy with parameters that consider overrides

# Materialize meta-parameters on-device if necessary.
# This is done before sharding in case the materialization logic depends on the tensor shape.
# The tradeoff is that all of a module's direct parameters need to fit in device.
# Each module only initializes its own parameters and not those of its children (recurse=False)
if any(t.is_meta for t in chain(module_copy.parameters(recurse=False), module_copy.buffers(recurse=False))):
# TODO: we could also support calling a "param_init_fn" argument like PyTorch
_materialize(module_copy, self.device)
for n, p in module_copy.named_parameters(recurse=False, prefix=module_name):
thunder_model._overrides_parameters[n] = p
device_adjustments[n] = self.device
for n, b in module_copy.named_buffers(recurse=False, prefix=module_name):
thunder_model._overrides_buffers[n] = b
device_adjustments[n] = self.device
else:
# Move leftover params and buffers to device. This is at least required to broadcast.
# Cannot `submodule.to(device)` because we don't want it to recurse
for n, p in module_copy.named_parameters(recurse=False, prefix=module_name):
if p.device != self.device:
thunder_model._overrides_parameters[n] = torch.nn.Parameter(
p.to(device=self.device), requires_grad=p.requires_grad
)
device_adjustments[n] = self.device
for n, b in module_copy.named_buffers(recurse=False, prefix=module_name):
if b.device != self.device:
thunder_model._overrides_buffers[n] = b.to(device=self.device)
device_adjustments[n] = self.device

# Broadcast parameters if requested
if self.broadcast_from is not None:
for pn, _ in submodule.named_parameters(recurse=False, prefix=module_name):
tdist.broadcast(
thunder_model.get_parameter(pn),
src=self.broadcast_from,
group=self.process_group,
async_op=False,
)
for pn, _ in submodule.named_buffers(recurse=False, prefix=module_name):
tdist.broadcast(
thunder_model.get_buffer(pn), src=self.broadcast_from, group=self.process_group, async_op=False
)

for pn, p in submodule.named_parameters(recurse=False, prefix=module_name):
# If there are shared params in the original user Module, we reuse the sharded copy created from the original parameter below.
# This way we re-create parameter sharing in thunder's copy of the Module.
if p in shared_params:
# Shared param names : current param - base param
self.shared_params_name[pn] = shared_params[p]["param_name"]
# Re-use the previous copy of this parameter.
thunder_model._overrides_parameters[pn] = shared_params[p]["param_copy"]
self.sharded_params[pn] = shared_params[p]["param_shard_meta"]
continue

# if we don't have an override or it is just the original, do create a copy
if thunder_model._overrides_parameters.get(pn, p) is p:
thunder_model._overrides_parameters[pn] = copy.copy(p)
# we collect shapes and devices because we do not know if other transforms also change it...
old_shape = thunder_model._overrides_parameters[pn].shape
_shard_param(
thunder_model._overrides_parameters[pn], global_rank, world_size, pn, allow_padding_for_fsdp=True
# Note that .named_parameters / .named_buffers used below only return a duplicated parameter once.

shared_names = thunder_model._get_shared_names()
self.shared_params_name = {}

# For materialized parameters and buffers, we move them to the target device as necessary
# for un-materialized parameters and buffers, we set the ._thunder_device

is_fully_materialized = True
for n, p in thunder_model.named_parameters():
for n2 in shared_names[n]:
if n2 != n:
self.shared_params_name[n2] = n
try:
orig_p = thunder_model._model.get_parameter(n)
except AttributeError:
orig_p = None
if p.is_meta:
is_fully_materialized = False
p._thunder_device = self.device
if orig_p is not None:
orig_p._thunder_device = self.device
# TODO: check if devic_adjustments are still needed
for n2 in shared_names[n]:
t-vi marked this conversation as resolved.
Show resolved Hide resolved
device_adjustments[n2] = self.device
elif p.device != self.device:
with torch.no_grad():
new_p = torch.nn.Parameter(p.to(device=self.device), requires_grad=p.requires_grad)
for n2 in shared_names[n]:
thunder_model._overrides_parameters[n2] = new_p
device_adjustments[n2] = self.device

for n, b in thunder_model.named_buffers():
try:
orig_b = thunder_model._model.get_buffer(n)
except AttributeError:
orig_b = None
if b.is_meta:
is_fully_materialized = False
b._thunder_device = self.device
if orig_b is not None:
orig_b._thunder_device = self.device
# TODO: check if this is still needed
device_adjustments[n] = self.device
elif b.device != self.device:
new_b = b.to(device=self.device)
for n2 in shared_names[n]:
thunder_model._overrides_buffers[n2] = new_p
device_adjustments[n2] = self.device

# Broadcast parameters if requested
if self.broadcast_from is not None:
if not is_fully_materialized:
# Note: we could move broadcasting into its own transform coming
# after materialization (in thunder.distributed.fsdp) to
# support this, if it is useful.
raise RuntimeError("cannot broadcast from non-materialized model")
for pn, p in chain(thunder_model.named_parameters(), thunder_model.named_buffers()):
tdist.broadcast(
p,
src=self.broadcast_from,
group=self.process_group,
async_op=False,
)
t-vi marked this conversation as resolved.
Show resolved Hide resolved
new_shape = thunder_model._overrides_parameters[pn].shape
self.sharded_params[pn] = (old_shape, new_shape, thunder_model._overrides_parameters[pn].device)
if self.release_original_parameters:
base_pn = pn.rsplit(".", 1)[-1]
p_orig = getattr(submodule, base_pn)
if p_orig.device.type != "meta":
p_meta = torch.nn.Parameter(p.to(device="meta"), requires_grad=p.requires_grad)
p_meta._thunder_device = p_orig.device

# do the actual sharding. Note that meta tensors will give sharded meta tensors.
for pn, p in list(thunder_model.named_parameters()):
# we collect shapes and devices because we do not know if other transforms also change it.
old_shape = p.shape
p_new, _ = _shard_tensor(p, global_rank, world_size, pn, allow_padding_for_fsdp=True)
p_new = torch.nn.Parameter(p_new.clone(), requires_grad=p.requires_grad)
t-vi marked this conversation as resolved.
Show resolved Hide resolved
new_shape = p_new.shape
for n2 in shared_names[pn]:
thunder_model._overrides_parameters[n2] = p_new
self.sharded_params[n2] = (old_shape, new_shape, getattr(p, "_thunder_device", p.device))
if self.release_original_parameters:
p_orig = thunder_model._model.get_parameter(pn)
if p_orig.device.type != "meta":
p_meta = torch.nn.Parameter(p_orig.to(device="meta"), requires_grad=p_orig.requires_grad)
p_meta._thunder_device = p_orig.device
for n2 in shared_names[pn]:
submodule_name, _, base_pn = n2.rpartition(".")
submodule = thunder_model._model.get_submodule(submodule_name)
submodule.register_parameter(base_pn, p_meta)
else:
p_orig._thunder_device = self.device

# Track the original param and it's corresponding copied shard and metadata.
shared_params[p] = {
"param_copy": thunder_model._overrides_parameters[pn],
"param_shard_meta": self.sharded_params[pn],
"param_name": pn,
}
else:
p_orig._thunder_device = self.device

def transform_state_dict_for_submodule(
self, model: thunder.ThunderModule, submodule_name: str, state_dict: dict
Expand Down
31 changes: 23 additions & 8 deletions thunder/transforms/materialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def transform_module(self, model: ThunderModule):
p._thunder_device = self.device

shared_names = model._get_shared_names()
self.have_materialized: set[str] = set()

# note: the iterations below are without duplicates
for n, p in list(model.named_parameters()):
Expand All @@ -55,16 +56,18 @@ def transform_module(self, model: ThunderModule):
torch.empty_like(p, device=getattr(p, "_thunder_device", self.device)),
requires_grad=p.requires_grad,
)
for nn in shared_names[n]:
model._overrides_parameters[nn] = p
for n2 in shared_names[n]:
model._overrides_parameters[n2] = p
self.have_materialized.add(n2)

for n, b in list(model.named_buffers()):
if b.device.type == "meta":
b = torch.empty_like(
b, device=getattr(b, "_thunder_device", self.device), requires_grad=b.requires_grad
)
for nn in shared_names[n]:
model._overrides_parameters[nn] = b
for n2 in shared_names[n]:
model._overrides_parameters[n2] = b
self.have_materialized.add(n2)

self.init(self, model)

Expand Down Expand Up @@ -101,16 +104,28 @@ def module_init_from_original_module_init(transform: MaterializationTransform, t
prefix = module_name if not module_name else f"{module_name}."
submodule = tm.get_submodule(module_name)

# we use a copy to let the user's module alone
module_copy = copy.copy(submodule)

# Materialize meta-parameters on-device if necessary.
# This is done before sharding in case the materialization logic depends on the tensor shape.
# The tradeoff is that all of a module's direct parameters need to fit in device.
# Each module only initializes its own parameters and not those of its children (recurse=False)
if any(
t.is_meta for t in chain(module_copy.parameters(recurse=False), module_copy.buffers(recurse=False))
chain(
(
transform.have_materialized
for n, _ in submodule.named_parameters(recurse=False, prefix=module_name)
),
(
transform.have_materialized
for n, _ in submodule.named_buffers(recurse=False, prefix=module_name)
),
)
):
# we use a copy to let the user's module alone
module_copy = copy.copy(submodule)
module_copy._parameters = module_copy._parameters.copy()
module_copy._buffers = module_copy._buffers.copy()
module_copy._modules = module_copy._modules.__class__()

# we need to initialize the module unless all parameters are duplicatess
need_init = not all(
shared_names[n] & processed_names
Expand Down
Loading