diff --git a/thunder/core/module.py b/thunder/core/module.py index b67d7d3795..9dd9292dab 100644 --- a/thunder/core/module.py +++ b/thunder/core/module.py @@ -112,7 +112,7 @@ def _get_shared_names(self): self.named_parameters(remove_duplicate=False), self.named_buffers(remove_duplicate=False) ): parameters_to_names.setdefault(v, set()).add(name) - shared_names = {} + shared_names: dict[str, set[str]] = {} for s in parameters_to_names.values(): for n in s: shared_names[n] = s diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index 585da752df..e5131adb6c 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -404,7 +404,7 @@ def visitor_transform(trace_from: Trace, visit: Callable, *, provenance: None | def add_transform( cfn: Callable, *, - transform: Transform, + transform: Transform | list[Transform], disable_torch_autograd_support=False, _legacy_copy_params=False, ) -> Callable: @@ -414,15 +414,20 @@ def add_transform( utils.check(cd is not None, lambda: f"Can only transform compiled thunder functions") utils.check(isinstance(cd, CompileData), lambda: f"Found an unknown compile data attribute {cd}") - utils.check_type(transform, Transform) + if isinstance(transform, Transform): + transform = [transform] + else: + utils.check( + all(isinstance(t, Transform) for t in transform), + lambda: "transform must be an instance of Transform or a list of Transform instances.", + ) assert cd.using_jit from thunder import jit # todo: move _lc_transforms to compile_data - transforms = cfn._lc_transforms[:] - transforms.append(transform) + transforms = cfn._lc_transforms + transform jfn = jit( cd.fn, langctx=cd.langctx, diff --git a/thunder/distributed/__init__.py b/thunder/distributed/__init__.py index f7c4b7f1b4..96bff58464 100644 --- a/thunder/distributed/__init__.py +++ b/thunder/distributed/__init__.py @@ -474,16 +474,23 @@ def fsdp( if isinstance(model, thunder.ThunderModule): from thunder.core.transforms import add_transform from thunder.distributed.transforms.fsdp_v2 import FSDPTransform + from thunder.transforms import MaterializationTransform + if device is None: + local_rank = int(os.environ["LOCAL_RANK"]) + device = torch.device("cuda", local_rank) return add_transform( model, - transform=FSDPTransform( - device=device, - broadcast_from=broadcast_from, - sharding_strategy=sharding_strategy, - bucketing_strategy=bucketing_strategy, - release_original_parameters=True, - ), + transform=[ + FSDPTransform( + device=device, + broadcast_from=broadcast_from, + sharding_strategy=sharding_strategy, + bucketing_strategy=bucketing_strategy, + release_original_parameters=True, + ), + MaterializationTransform(device, init=MaterializationTransform.init_from_original_module_init()), + ], ) process_group = copy_default_process_group() diff --git a/thunder/distributed/transforms/fsdp_v2.py b/thunder/distributed/transforms/fsdp_v2.py index dd8441d754..1afe796a92 100644 --- a/thunder/distributed/transforms/fsdp_v2.py +++ b/thunder/distributed/transforms/fsdp_v2.py @@ -1,4 +1,4 @@ -"""Transform for `fsdp(jit(model))` to convert a trace into fsdp.""" +"""Transform for `fsdp(jit(model))` to convert a model to use fsdp.""" from __future__ import annotations import copy @@ -10,7 +10,6 @@ import torch import torch.distributed as tdist -from torch.utils.weak import WeakTensorKeyDictionary from thunder.core import devices from thunder.core import prims @@ -72,16 +71,21 @@ def __call__(self, bsym: BoundSymbol) -> VISIT_TYPE: return VISIT_TYPE.REPLACE -# When the user calls fsdp(jitted_module), this function does the following -# - It transforms the ThunderModule jitted_module, materializing and sharding the parameters as `overrides` +# When the user calls fsdp(jitted_module), or applies this Transform direcly, it does the following +# - It transforms the ThunderModule jitted_module, sharding the parameters as `overrides` # in the ThunderModule. -# - While doing that, it leaves the original user module alone. -# - It then registers an transform (callback that runs before prologue is executed) that transforms the -# prologue and compute trace. +# - While doing that, it leaves the original user module alone, except when +# releasing the original tensors is requested (for memory consumption). +# - When needed, a submodule state dict from the unsharded submodule can be transformed into a one of the sharded +# submodule. This is used by MaterializationTransform and thunder_model.load_original_state_dict. +# - The prologue and compute trace are transformed, inserting communication and reflecting the weight shape changes. # # Note that for doing so, there are a few constraints / caveats: # - We do not have prologues/compute traces when we transform the module. # - We need to record the info from the module transformations because a later transform might modify the module further. +# +# The thunder.distributed.fsdp function calls FSDPTransform followed by MaterializationTransform, the latter does +# the materialization of submodules previously on the meta device. class FSDPTransform(Transform): @@ -140,10 +144,6 @@ def transform_module( # modify module self.sharded_params = {} device_adjustments = {} - # We use `shared_params` dictionary to track the shared parameters. - # Key to this dictionary is the original parameter from the user's Module. - # Values are the copied and sharded parameter for the thunder module and meta-data related to sharding. - shared_params = WeakTensorKeyDictionary() # NOTE: Shared Parameters in Trace # Shared parameters in PyTorch eager are parameters of module which have different name but share the underlying tensor. @@ -156,92 +156,98 @@ def transform_module( # This is used to track the shared parameters when the transform is applied. # key - parameter name, value - `base` parameter name. - self.shared_params_name: dict[str, str] = {} - for module_name, _ in thunder_model._model.named_modules(): - submodule = thunder_model.get_submodule(module_name) - - # we use a copy to let the user's module alone (TODO: does this fully work?) - module_copy = copy.copy(submodule) - # TODO: we should probably populate the module copy with parameters that consider overrides - - # Materialize meta-parameters on-device if necessary. - # This is done before sharding in case the materialization logic depends on the tensor shape. - # The tradeoff is that all of a module's direct parameters need to fit in device. - # Each module only initializes its own parameters and not those of its children (recurse=False) - if any(t.is_meta for t in chain(module_copy.parameters(recurse=False), module_copy.buffers(recurse=False))): - # TODO: we could also support calling a "param_init_fn" argument like PyTorch - _materialize(module_copy, self.device) - for n, p in module_copy.named_parameters(recurse=False, prefix=module_name): - thunder_model._overrides_parameters[n] = p - device_adjustments[n] = self.device - for n, b in module_copy.named_buffers(recurse=False, prefix=module_name): - thunder_model._overrides_buffers[n] = b - device_adjustments[n] = self.device - else: - # Move leftover params and buffers to device. This is at least required to broadcast. - # Cannot `submodule.to(device)` because we don't want it to recurse - for n, p in module_copy.named_parameters(recurse=False, prefix=module_name): - if p.device != self.device: - thunder_model._overrides_parameters[n] = torch.nn.Parameter( - p.to(device=self.device), requires_grad=p.requires_grad - ) - device_adjustments[n] = self.device - for n, b in module_copy.named_buffers(recurse=False, prefix=module_name): - if b.device != self.device: - thunder_model._overrides_buffers[n] = b.to(device=self.device) - device_adjustments[n] = self.device - - # Broadcast parameters if requested - if self.broadcast_from is not None: - for pn, _ in submodule.named_parameters(recurse=False, prefix=module_name): + # Note that .named_parameters / .named_buffers used below only return a duplicated parameter once. + + shared_names = thunder_model._get_shared_names() + self.shared_params_name = {} + + # For materialized parameters and buffers, we move them to the target device as necessary + # for un-materialized parameters and buffers, we set the ._thunder_device + + is_fully_materialized = True + for n, p in thunder_model.named_parameters(): + for n2 in shared_names[n]: + if n2 != n: + self.shared_params_name[n2] = n + try: + orig_p = thunder_model._model.get_parameter(n) + except AttributeError: + orig_p = None + if p.is_meta: + is_fully_materialized = False + p._thunder_device = self.device + if orig_p is not None: + orig_p._thunder_device = self.device + # TODO: check if device_adjustments are still needed + for n2 in shared_names[n]: + device_adjustments[n2] = self.device + elif p.device != self.device: + with torch.no_grad(): + new_p = torch.nn.Parameter(p.to(device=self.device), requires_grad=p.requires_grad) + for n2 in shared_names[n]: + thunder_model._overrides_parameters[n2] = new_p + device_adjustments[n2] = self.device + + for n, b in thunder_model.named_buffers(): + try: + orig_b = thunder_model._model.get_buffer(n) + except AttributeError: + orig_b = None + if b.is_meta: + is_fully_materialized = False + b._thunder_device = self.device + if orig_b is not None: + orig_b._thunder_device = self.device + # TODO: check if this is still needed + device_adjustments[n] = self.device + elif b.device != self.device: + new_b = b.to(device=self.device) + for n2 in shared_names[n]: + thunder_model._overrides_buffers[n2] = new_p + device_adjustments[n2] = self.device + + # Broadcast parameters if requested + # (todos shared with thunder/distributed/_init__.py) + # TODO Make these broadcast asyncs + # TODO Perform up to two broadcasts at a time + # See issue "Update ddp to use async broadcasts" + # TODO "Bucket" small tensors together before broadcasting + if self.broadcast_from is not None: + if not is_fully_materialized: + # Note: we could move broadcasting into its own transform coming + # after materialization (in thunder.distributed.fsdp) to + # support this, if it is useful. + raise RuntimeError("cannot broadcast from non-materialized model") + with torch.no_grad(): + for pn, p in chain(thunder_model.named_parameters(), thunder_model.named_buffers()): tdist.broadcast( - thunder_model.get_parameter(pn), + p, src=self.broadcast_from, group=self.process_group, async_op=False, ) - for pn, _ in submodule.named_buffers(recurse=False, prefix=module_name): - tdist.broadcast( - thunder_model.get_buffer(pn), src=self.broadcast_from, group=self.process_group, async_op=False - ) - for pn, p in submodule.named_parameters(recurse=False, prefix=module_name): - # If there are shared params in the original user Module, we reuse the sharded copy created from the original parameter below. - # This way we re-create parameter sharing in thunder's copy of the Module. - if p in shared_params: - # Shared param names : current param - base param - self.shared_params_name[pn] = shared_params[p]["param_name"] - # Re-use the previous copy of this parameter. - thunder_model._overrides_parameters[pn] = shared_params[p]["param_copy"] - self.sharded_params[pn] = shared_params[p]["param_shard_meta"] - continue - - # if we don't have an override or it is just the original, do create a copy - if thunder_model._overrides_parameters.get(pn, p) is p: - thunder_model._overrides_parameters[pn] = copy.copy(p) - # we collect shapes and devices because we do not know if other transforms also change it... - old_shape = thunder_model._overrides_parameters[pn].shape - _shard_param( - thunder_model._overrides_parameters[pn], global_rank, world_size, pn, allow_padding_for_fsdp=True - ) - new_shape = thunder_model._overrides_parameters[pn].shape - self.sharded_params[pn] = (old_shape, new_shape, thunder_model._overrides_parameters[pn].device) - if self.release_original_parameters: - base_pn = pn.rsplit(".", 1)[-1] - p_orig = getattr(submodule, base_pn) - if p_orig.device.type != "meta": - p_meta = torch.nn.Parameter(p.to(device="meta"), requires_grad=p.requires_grad) - p_meta._thunder_device = p_orig.device + # do the actual sharding. Note that meta tensors will give sharded meta tensors. + for pn, p in list(thunder_model.named_parameters()): + # we collect shapes and devices because we do not know if other transforms also change it. + old_shape = p.shape + p_new, _ = _shard_tensor(p, global_rank, world_size, pn, allow_padding_for_fsdp=True) + p_new = torch.nn.Parameter(p_new.clone(), requires_grad=p.requires_grad) + new_shape = p_new.shape + for n2 in shared_names[pn]: + thunder_model._overrides_parameters[n2] = p_new + self.sharded_params[n2] = (old_shape, new_shape, getattr(p, "_thunder_device", p.device)) + if self.release_original_parameters: + p_orig = thunder_model._model.get_parameter(pn) + if p_orig.device.type != "meta": + p_meta = torch.nn.Parameter(p_orig.to(device="meta"), requires_grad=p_orig.requires_grad) + p_meta._thunder_device = p_orig.device + for n2 in shared_names[pn]: + submodule_name, _, base_pn = n2.rpartition(".") + submodule = thunder_model._model.get_submodule(submodule_name) submodule.register_parameter(base_pn, p_meta) - else: - p_orig._thunder_device = self.device - - # Track the original param and it's corresponding copied shard and metadata. - shared_params[p] = { - "param_copy": thunder_model._overrides_parameters[pn], - "param_shard_meta": self.sharded_params[pn], - "param_name": pn, - } + else: + p_orig._thunder_device = self.device def transform_state_dict_for_submodule( self, model: thunder.ThunderModule, submodule_name: str, state_dict: dict diff --git a/thunder/transforms/materialization.py b/thunder/transforms/materialization.py index e2f9f76e2b..a4cf3badbf 100644 --- a/thunder/transforms/materialization.py +++ b/thunder/transforms/materialization.py @@ -47,6 +47,7 @@ def transform_module(self, model: ThunderModule): p._thunder_device = self.device shared_names = model._get_shared_names() + self.have_materialized: set[str] = set() # note: the iterations below are without duplicates for n, p in list(model.named_parameters()): @@ -55,16 +56,18 @@ def transform_module(self, model: ThunderModule): torch.empty_like(p, device=getattr(p, "_thunder_device", self.device)), requires_grad=p.requires_grad, ) - for nn in shared_names[n]: - model._overrides_parameters[nn] = p + for n2 in shared_names[n]: + model._overrides_parameters[n2] = p + self.have_materialized.add(n2) for n, b in list(model.named_buffers()): if b.device.type == "meta": b = torch.empty_like( b, device=getattr(b, "_thunder_device", self.device), requires_grad=b.requires_grad ) - for nn in shared_names[n]: - model._overrides_parameters[nn] = b + for n2 in shared_names[n]: + model._overrides_parameters[n2] = b + self.have_materialized.add(n2) self.init(self, model) @@ -101,16 +104,28 @@ def module_init_from_original_module_init(transform: MaterializationTransform, t prefix = module_name if not module_name else f"{module_name}." submodule = tm.get_submodule(module_name) - # we use a copy to let the user's module alone - module_copy = copy.copy(submodule) - # Materialize meta-parameters on-device if necessary. # This is done before sharding in case the materialization logic depends on the tensor shape. # The tradeoff is that all of a module's direct parameters need to fit in device. # Each module only initializes its own parameters and not those of its children (recurse=False) if any( - t.is_meta for t in chain(module_copy.parameters(recurse=False), module_copy.buffers(recurse=False)) + chain( + ( + transform.have_materialized + for n, _ in submodule.named_parameters(recurse=False, prefix=module_name) + ), + ( + transform.have_materialized + for n, _ in submodule.named_buffers(recurse=False, prefix=module_name) + ), + ) ): + # we use a copy to let the user's module alone + module_copy = copy.copy(submodule) + module_copy._parameters = module_copy._parameters.copy() + module_copy._buffers = module_copy._buffers.copy() + module_copy._modules = module_copy._modules.__class__() + # we need to initialize the module unless all parameters are duplicatess need_init = not all( shared_names[n] & processed_names