From b4a04c017f1a108e90086d290c81ad40206b9f18 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Sat, 17 Aug 2024 09:50:06 +0200 Subject: [PATCH 1/9] Move FSDP module transformation to transform_module --- thunder/core/transforms.py | 3 +- thunder/distributed/__init__.py | 165 ++--------------- .../tensor_parallel/column_wise.py | 1 + .../distributed/tensor_parallel/row_wise.py | 1 + thunder/distributed/transforms/fsdp_v2.py | 170 +++++++++++++++++- 5 files changed, 181 insertions(+), 159 deletions(-) diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index 4f6c07f687..6fc0099120 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -405,6 +405,7 @@ def add_transform( *, transform: Transform, disable_torch_autograd_support=False, + _legacy_copy_params=False, ) -> Callable: from thunder.common import CompileData @@ -433,7 +434,7 @@ def add_transform( ) from thunder import ThunderModule - if isinstance(jfn, ThunderModule): + if _legacy_copy_params and isinstance(jfn, ThunderModule): jfn._overrides_parameters = cfn._overrides_parameters jfn._overrides_buffers = cfn._overrides_buffers return jfn diff --git a/thunder/distributed/__init__.py b/thunder/distributed/__init__.py index d56d911dc7..d72b5d96be 100644 --- a/thunder/distributed/__init__.py +++ b/thunder/distributed/__init__.py @@ -422,155 +422,6 @@ def f(tensor: TensorProxy) -> str: return f -# When the user calls fsdp(jitted_module), this function does the following -# - It transforms the ThunderModule jitted_module, materializing and sharding the parameters as `overrides` -# in the ThunderModule. -# - While doing that, it leaves the original user module alone. -# - It then registers an early transform (callback that runs before prologue is executed) that transforms the -# prologue and compute trace. -# -# Note that for doing so, there are a few constraints / caveats: -# - We do not have prologues/compute traces when we transform the module. -# - We need to record the info from the module transformations because a later transform might modify the module further. - - -def fsdp_transform_module( - thunder_model: ThunderModule, - *, - device: torch.device | None = None, - broadcast_from: int | None = None, - sharding_strategy: FSDPType = FSDPType.ZERO2, - bucketing_strategy: FSDPBucketingStrategy = FSDPBucketingStrategy.NONE, -) -> ThunderModule: - from thunder import compile_data as get_compile_data - from thunder.core.transforms import add_transform - from thunder.core.module import ThunderModule - from thunder.distributed.transforms.fsdp_v2 import FSDPTraceTransform - - process_group = copy_default_process_group() - utils.check(process_group is not None, lambda: "The default process group is None") - global_rank = tdist.get_rank(group=process_group) - world_size = tdist.get_world_size(group=process_group) - if device is None: - local_rank = int(os.environ["LOCAL_RANK"]) - device = torch.device("cuda", local_rank) - - cd = get_compile_data(thunder_model) - # TODO: promote use_fsdp and use_ddp to public members of CompileData - cd.use_fsdp = True - orig_module: torch.nn.Module = cd.fn - utils.check( - isinstance(orig_module, torch.nn.Module) and not isinstance(orig_module, ThunderModule), - lambda: f"CompileData.fn expected to be `nn.Module` but {type(orig_module)}", - ) - orig_module.use_fsdp = True - orig_module.process_group_for_ddp = process_group - orig_module.bucketing_strategy = bucketing_strategy - orig_module.sharding_strategy = sharding_strategy - - # modify module - sharded_params = {} - device_adjustments = {} - # We use `shared_params` dictionary to track the shared parameters. - # Key to this dictionary is the original parameter from the user's Module. - # Values are the copied and sharded parameter for the thunder module and meta-data related to sharding. - shared_params = WeakTensorKeyDictionary() - - # NOTE: Shared Parameters in Trace - # Shared parameters in PyTorch eager are parameters of module which have different name but share the underlying tensor. - # For shared parameter, we replace all occurence shared parameter with it's corresponding `base` parameter. - # In our implementation `base` parameter is the parameter and corresponding name which we see the first time while - # iterating our parameters (see below). We track subsequent parameter which share the underlying Tensor with this `base` parameter - # in `shared_params_name` dictionary. - # Then while, transforming the trace - `see FSDPTraceTransform.transform_traces` - we replace all the proxy of shared parameter - # with the corresponding proxy of base parameter in the computation trace. - - # This is used to track the shared parameters when the transform is applied. - # key - parameter name, value - `base` parameter name. - shared_params_name: dict[str, str] = {} - for module_name, _ in thunder_model._model.named_modules(): - submodule = thunder_model.get_submodule(module_name) - - # we use a copy to let the user's module alone (TODO: does this fully work?) - module_copy = copy.copy(submodule) - # TODO: we should probably populate the module copy with parameters that consider overrides - - # Materialize meta-parameters on-device if necessary. - # This is done before sharding in case the materialization logic depends on the tensor shape. - # The tradeoff is that all of a module's direct parameters need to fit in device. - # Each module only initializes its own parameters and not those of its children (recurse=False) - if any(t.is_meta for t in chain(module_copy.parameters(recurse=False), module_copy.buffers(recurse=False))): - # TODO: we could also support calling a "param_init_fn" argument like PyTorch - _materialize(module_copy, device) - for n, p in module_copy.named_parameters(recurse=False, prefix=module_name): - thunder_model._overrides_parameters[n] = p - device_adjustments[n] = device - for n, b in module_copy.named_buffers(recurse=False, prefix=module_name): - thunder_model._overrides_buffers[n] = b - device_adjustments[n] = device - else: - # Move leftover params and buffers to device. This is at least required to broadcast. - # Cannot `submodule.to(device)` because we don't want it to recurse - for n, p in module_copy.named_parameters(recurse=False, prefix=module_name): - if p.device != device: - thunder_model._overrides_parameters[n] = torch.nn.Parameter( - p.to(device=device), requires_grad=p.requires_grad - ) - device_adjustments[n] = device - for n, b in module_copy.named_buffers(recurse=False, prefix=module_name): - if b.device != device: - thunder_model._overrides_buffers[n] = b.to(device=device) - device_adjustments[n] = device - - # Broadcast parameters if requested - if broadcast_from is not None: - for pn, _ in submodule.named_parameters(recurse=False, prefix=module_name): - tdist.broadcast( - thunder_model.get_parameter(pn), src=broadcast_from, group=process_group, async_op=False - ) - for pn, _ in submodule.named_buffers(recurse=False, prefix=module_name): - tdist.broadcast(thunder_model.get_buffer(pn), src=broadcast_from, group=process_group, async_op=False) - - for pn, p in submodule.named_parameters(recurse=False, prefix=module_name): - # If there are shared params in the original user Module, we reuse the sharded copy created from the original parameter below. - # This way we re-create parameter sharing in thunder's copy of the Module. - if p in shared_params: - # Shared param names : current param - base param - shared_params_name[pn] = shared_params[p]["param_name"] - # Re-use the previous copy of this parameter. - thunder_model._overrides_parameters[pn] = shared_params[p]["param_copy"] - sharded_params[pn] = shared_params[p]["param_shard_meta"] - continue - - # if we don't have an override or it is just the original, do create a copy - if thunder_model._overrides_parameters.get(pn, p) is p: - thunder_model._overrides_parameters[pn] = copy.copy(p) - # we collect shapes and devices because we do not know if other transforms also change it... - old_shape = thunder_model._overrides_parameters[pn].shape - _shard_param( - thunder_model._overrides_parameters[pn], global_rank, world_size, pn, allow_padding_for_fsdp=True - ) - new_shape = thunder_model._overrides_parameters[pn].shape - sharded_params[pn] = (old_shape, new_shape, thunder_model._overrides_parameters[pn].device) - - # Track the original param and it's corresponding copied shard and metadata. - shared_params[p] = { - "param_copy": thunder_model._overrides_parameters[pn], - "param_shard_meta": sharded_params[pn], - "param_name": pn, - } - - transform_from_trace_to_fsdp_trace = FSDPTraceTransform( - sharded_params=sharded_params, - process_group=process_group, - shared_params_name=shared_params_name, - ) - # add prologue + compute transform - thunder_model = add_transform(thunder_model, transform=transform_from_trace_to_fsdp_trace) - - return thunder_model - - def fsdp( model: torch.nn.Module, *, @@ -621,13 +472,19 @@ def fsdp( ) if isinstance(model, thunder.ThunderModule): - return fsdp_transform_module( + from thunder.core.transforms import add_transform + from thunder.distributed.transforms.fsdp_v2 import FSDPTransform + + new_model = add_transform( model, - device=device, - broadcast_from=broadcast_from, - sharding_strategy=sharding_strategy, - bucketing_strategy=bucketing_strategy, + transform=FSDPTransform( + device=device, + broadcast_from=broadcast_from, + sharding_strategy=sharding_strategy, + bucketing_strategy=bucketing_strategy, + ), ) + return new_model process_group = copy_default_process_group() utils.check(process_group is not None, lambda: "The default process group is None") diff --git a/thunder/distributed/tensor_parallel/column_wise.py b/thunder/distributed/tensor_parallel/column_wise.py index 647c7f830a..204bb2a7e4 100644 --- a/thunder/distributed/tensor_parallel/column_wise.py +++ b/thunder/distributed/tensor_parallel/column_wise.py @@ -284,6 +284,7 @@ def forward(self, tokens: torch.Tensor) -> torch.Tensor: chunked_param_name_to_layer_type=chunked_param_name_to_layer_type, process_group=process_group, ), + _legacy_copy_params=True, ) return colwise_thunder_module diff --git a/thunder/distributed/tensor_parallel/row_wise.py b/thunder/distributed/tensor_parallel/row_wise.py index af84b3d632..1612e8fec1 100644 --- a/thunder/distributed/tensor_parallel/row_wise.py +++ b/thunder/distributed/tensor_parallel/row_wise.py @@ -291,6 +291,7 @@ def forward(self, tokens: torch.Tensor) -> torch.Tensor: chunked_param_name_to_layer_type=chunked_param_name_to_layer_type, process_group=process_group, ), + _legacy_copy_params=True, ) return rowwise_thunder_module diff --git a/thunder/distributed/transforms/fsdp_v2.py b/thunder/distributed/transforms/fsdp_v2.py index 07739ee7c6..a187c39e2e 100644 --- a/thunder/distributed/transforms/fsdp_v2.py +++ b/thunder/distributed/transforms/fsdp_v2.py @@ -1,10 +1,17 @@ -"""Early transform for `fsdp(jit(model))` to convert a trace into fsdp.""" +"""Transform for `fsdp(jit(model))` to convert a trace into fsdp.""" from __future__ import annotations +import copy from dataclasses import dataclass from dataclasses import field +from itertools import chain +import os from typing import TYPE_CHECKING +import torch +import torch.distributed as tdist +from torch.utils.weak import WeakTensorKeyDictionary + from thunder.core import devices from thunder.core import prims from thunder.core import utils @@ -17,6 +24,7 @@ from thunder.core.transforms import VISIT_TYPE from thunder.core.transforms import visitor_transform from thunder.core.transform_common import Transform +from thunder.distributed import copy_default_process_group, FSDPType, FSDPBucketingStrategy, _shard_param if TYPE_CHECKING: from typing import Any @@ -26,7 +34,7 @@ __all__ = [ - "FSDPTraceTransform", + "FSDPTransform", ] @@ -57,12 +65,166 @@ def __call__(self, bsym: BoundSymbol) -> VISIT_TYPE: return VISIT_TYPE.REPLACE -@dataclass(frozen=True) -class FSDPTraceTransform(Transform): +# When the user calls fsdp(jitted_module), this function does the following +# - It transforms the ThunderModule jitted_module, materializing and sharding the parameters as `overrides` +# in the ThunderModule. +# - While doing that, it leaves the original user module alone. +# - It then registers an early transform (callback that runs before prologue is executed) that transforms the +# prologue and compute trace. +# +# Note that for doing so, there are a few constraints / caveats: +# - We do not have prologues/compute traces when we transform the module. +# - We need to record the info from the module transformations because a later transform might modify the module further. + + +class FSDPTransform(Transform): sharded_params: dict[str, Any] process_group: ProcessGroup shared_params_name: dict[str, str] + def __init__( + self, + device: torch.device | None = None, + broadcast_from: int | None = None, + sharding_strategy: FSDPType = FSDPType.ZERO2, + bucketing_strategy: FSDPBucketingStrategy = FSDPBucketingStrategy.NONE, + ): + self.device = device + self.broadcast_from = broadcast_from + self.sharding_strategy = sharding_strategy + self.bucketing_strategy = bucketing_strategy + + def transform_module( + self, + thunder_model: ThunderModule, + ): + from thunder import compile_data as get_compile_data + from thunder.core.transforms import add_transform + from thunder.core.module import ThunderModule + + self.process_group = copy_default_process_group() + utils.check(self.process_group is not None, lambda: "The default process group is None") + global_rank = tdist.get_rank(group=self.process_group) + world_size = tdist.get_world_size(group=self.process_group) + if self.device is None: + local_rank = int(os.environ["LOCAL_RANK"]) + self.device = torch.device("cuda", local_rank) + + cd = get_compile_data(thunder_model) + # TODO: promote use_fsdp and use_ddp to public members of CompileData + cd.use_fsdp = True + orig_module: torch.nn.Module = cd.fn + utils.check( + isinstance(orig_module, torch.nn.Module) and not isinstance(orig_module, ThunderModule), + lambda: f"CompileData.fn expected to be `nn.Module` but {type(orig_module)}", + ) + orig_module.use_fsdp = True + orig_module.process_group_for_ddp = self.process_group + orig_module.bucketing_strategy = self.bucketing_strategy + orig_module.sharding_strategy = self.sharding_strategy + + # modify module + self.sharded_params = {} + device_adjustments = {} + # We use `shared_params` dictionary to track the shared parameters. + # Key to this dictionary is the original parameter from the user's Module. + # Values are the copied and sharded parameter for the thunder module and meta-data related to sharding. + shared_params = WeakTensorKeyDictionary() + + # NOTE: Shared Parameters in Trace + # Shared parameters in PyTorch eager are parameters of module which have different name but share the underlying tensor. + # For shared parameter, we replace all occurence shared parameter with it's corresponding `base` parameter. + # In our implementation `base` parameter is the parameter and corresponding name which we see the first time while + # iterating our parameters (see below). We track subsequent parameter which share the underlying Tensor with this `base` parameter + # in `shared_params_name` dictionary. + # Then while, transforming the trace - `see FSDPTraceTransform.transform_traces` - we replace all the proxy of shared parameter + # with the corresponding proxy of base parameter in the computation trace. + + # This is used to track the shared parameters when the transform is applied. + # key - parameter name, value - `base` parameter name. + self.shared_params_name: dict[str, str] = {} + for module_name, _ in thunder_model._model.named_modules(): + submodule = thunder_model.get_submodule(module_name) + + # we use a copy to let the user's module alone (TODO: does this fully work?) + module_copy = copy.copy(submodule) + # TODO: we should probably populate the module copy with parameters that consider overrides + + # Materialize meta-parameters on-device if necessary. + # This is done before sharding in case the materialization logic depends on the tensor shape. + # The tradeoff is that all of a module's direct parameters need to fit in device. + # Each module only initializes its own parameters and not those of its children (recurse=False) + if any(t.is_meta for t in chain(module_copy.parameters(recurse=False), module_copy.buffers(recurse=False))): + # TODO: we could also support calling a "param_init_fn" argument like PyTorch + _materialize(module_copy, self.device) + for n, p in module_copy.named_parameters(recurse=False, prefix=module_name): + thunder_model._overrides_parameters[n] = p + device_adjustments[n] = self.device + for n, b in module_copy.named_buffers(recurse=False, prefix=module_name): + thunder_model._overrides_buffers[n] = b + device_adjustments[n] = self.device + else: + # Move leftover params and buffers to device. This is at least required to broadcast. + # Cannot `submodule.to(device)` because we don't want it to recurse + for n, p in module_copy.named_parameters(recurse=False, prefix=module_name): + if p.device != self.device: + thunder_model._overrides_parameters[n] = torch.nn.Parameter( + p.to(device=self.device), requires_grad=p.requires_grad + ) + device_adjustments[n] = self.device + for n, b in module_copy.named_buffers(recurse=False, prefix=module_name): + if b.device != self.device: + thunder_model._overrides_buffers[n] = b.to(device=self.device) + device_adjustments[n] = self.device + + # Broadcast parameters if requested + if self.broadcast_from is not None: + for pn, _ in submodule.named_parameters(recurse=False, prefix=module_name): + tdist.broadcast( + thunder_model.get_parameter(pn), + src=self.broadcast_from, + group=self.process_group, + async_op=False, + ) + for pn, _ in submodule.named_buffers(recurse=False, prefix=module_name): + tdist.broadcast( + thunder_model.get_buffer(pn), src=self.broadcast_from, group=self.process_group, async_op=False + ) + + for pn, p in submodule.named_parameters(recurse=False, prefix=module_name): + # If there are shared params in the original user Module, we reuse the sharded copy created from the original parameter below. + # This way we re-create parameter sharing in thunder's copy of the Module. + if p in shared_params: + # Shared param names : current param - base param + self.shared_params_name[pn] = shared_params[p]["param_name"] + # Re-use the previous copy of this parameter. + thunder_model._overrides_parameters[pn] = shared_params[p]["param_copy"] + self.sharded_params[pn] = shared_params[p]["param_shard_meta"] + continue + + # if we don't have an override or it is just the original, do create a copy + if thunder_model._overrides_parameters.get(pn, p) is p: + thunder_model._overrides_parameters[pn] = copy.copy(p) + # we collect shapes and devices because we do not know if other transforms also change it... + old_shape = thunder_model._overrides_parameters[pn].shape + _shard_param( + thunder_model._overrides_parameters[pn], global_rank, world_size, pn, allow_padding_for_fsdp=True + ) + new_shape = thunder_model._overrides_parameters[pn].shape + self.sharded_params[pn] = (old_shape, new_shape, thunder_model._overrides_parameters[pn].device) + + # Track the original param and it's corresponding copied shard and metadata. + shared_params[p] = { + "param_copy": thunder_model._overrides_parameters[pn], + "param_shard_meta": self.sharded_params[pn], + "param_name": pn, + } + + def transform_state_dict_for_submodule( + self, model: thunder.ThunderModule, submodule_name: str, state_dict: dict + ) -> dict: + raise NotImplementedError("cannot transform state dict yet") + def transform_traces_pre_prologue(self, prologue_trace, computation_trace, epilogue_trace, **kwargs): from thunder.distributed import prims as dist_prims From 597ece257d8d1daae1b36f636069c88b256f34ae Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Sun, 18 Aug 2024 18:18:11 +0200 Subject: [PATCH 2/9] import _materialize --- thunder/distributed/transforms/fsdp_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thunder/distributed/transforms/fsdp_v2.py b/thunder/distributed/transforms/fsdp_v2.py index a187c39e2e..548cbd9e57 100644 --- a/thunder/distributed/transforms/fsdp_v2.py +++ b/thunder/distributed/transforms/fsdp_v2.py @@ -24,7 +24,7 @@ from thunder.core.transforms import VISIT_TYPE from thunder.core.transforms import visitor_transform from thunder.core.transform_common import Transform -from thunder.distributed import copy_default_process_group, FSDPType, FSDPBucketingStrategy, _shard_param +from thunder.distributed import copy_default_process_group, FSDPType, FSDPBucketingStrategy, _shard_param, _materialize if TYPE_CHECKING: from typing import Any From 5da2880e656982efc2008b6c92e6bbd4b576fed3 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Sun, 18 Aug 2024 22:24:47 +0200 Subject: [PATCH 3/9] Update thunder/distributed/transforms/fsdp_v2.py --- thunder/distributed/transforms/fsdp_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thunder/distributed/transforms/fsdp_v2.py b/thunder/distributed/transforms/fsdp_v2.py index 548cbd9e57..dc8b99e59b 100644 --- a/thunder/distributed/transforms/fsdp_v2.py +++ b/thunder/distributed/transforms/fsdp_v2.py @@ -69,7 +69,7 @@ def __call__(self, bsym: BoundSymbol) -> VISIT_TYPE: # - It transforms the ThunderModule jitted_module, materializing and sharding the parameters as `overrides` # in the ThunderModule. # - While doing that, it leaves the original user module alone. -# - It then registers an early transform (callback that runs before prologue is executed) that transforms the +# - It then registers an transform (callback that runs before prologue is executed) that transforms the # prologue and compute trace. # # Note that for doing so, there are a few constraints / caveats: From 96d026e89a13af94ba7c5db079a15cc8d78d8cb9 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Sun, 18 Aug 2024 22:26:50 +0200 Subject: [PATCH 4/9] Update thunder/distributed/transforms/fsdp_v2.py Co-authored-by: Masaki Kozuki --- thunder/distributed/transforms/fsdp_v2.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/thunder/distributed/transforms/fsdp_v2.py b/thunder/distributed/transforms/fsdp_v2.py index dc8b99e59b..33f4a184e2 100644 --- a/thunder/distributed/transforms/fsdp_v2.py +++ b/thunder/distributed/transforms/fsdp_v2.py @@ -93,6 +93,9 @@ def __init__( self.broadcast_from = broadcast_from self.sharding_strategy = sharding_strategy self.bucketing_strategy = bucketing_strategy + self.sharded_params: dict[str, Any] = {} + self.process_group: ProcessGroup | None = None + self.shared_params_name: dict[str, str] = {} def transform_module( self, From f05768ec3b3fb15e84b00404835764341a7e847c Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Sun, 18 Aug 2024 18:21:00 +0200 Subject: [PATCH 5/9] free orig parameters after sharding --- thunder/distributed/__init__.py | 1 + thunder/distributed/transforms/fsdp_v2.py | 11 +++++++++ thunder/tests/distributed/test_fsdp.py | 27 +++++++++++++++++++++++ 3 files changed, 39 insertions(+) diff --git a/thunder/distributed/__init__.py b/thunder/distributed/__init__.py index d72b5d96be..57cd533786 100644 --- a/thunder/distributed/__init__.py +++ b/thunder/distributed/__init__.py @@ -482,6 +482,7 @@ def fsdp( broadcast_from=broadcast_from, sharding_strategy=sharding_strategy, bucketing_strategy=bucketing_strategy, + release_original_parameters=True, ), ) return new_model diff --git a/thunder/distributed/transforms/fsdp_v2.py b/thunder/distributed/transforms/fsdp_v2.py index 33f4a184e2..fec2b23370 100644 --- a/thunder/distributed/transforms/fsdp_v2.py +++ b/thunder/distributed/transforms/fsdp_v2.py @@ -88,11 +88,13 @@ def __init__( broadcast_from: int | None = None, sharding_strategy: FSDPType = FSDPType.ZERO2, bucketing_strategy: FSDPBucketingStrategy = FSDPBucketingStrategy.NONE, + release_original_parameters: bool = False, ): self.device = device self.broadcast_from = broadcast_from self.sharding_strategy = sharding_strategy self.bucketing_strategy = bucketing_strategy + self.release_original_parameters = release_original_parameters self.sharded_params: dict[str, Any] = {} self.process_group: ProcessGroup | None = None self.shared_params_name: dict[str, str] = {} @@ -215,6 +217,15 @@ def transform_module( ) new_shape = thunder_model._overrides_parameters[pn].shape self.sharded_params[pn] = (old_shape, new_shape, thunder_model._overrides_parameters[pn].device) + if self.release_original_parameters: + base_pn = pn.rsplit(".", 1)[-1] + p_orig = getattr(submodule, base_pn) + if p_orig.device.type != "meta": + p_meta = torch.nn.Parameter(p.to(device="meta"), requires_grad=p.requires_grad) + p_meta._thunder_device = p_orig.device + setattr(submodule, base_pn, p_meta) + else: + p_orig._thunder_device = self.device # Track the original param and it's corresponding copied shard and metadata. shared_params[p] = { diff --git a/thunder/tests/distributed/test_fsdp.py b/thunder/tests/distributed/test_fsdp.py index 8b3d12698d..f694ad9d12 100644 --- a/thunder/tests/distributed/test_fsdp.py +++ b/thunder/tests/distributed/test_fsdp.py @@ -638,6 +638,33 @@ def _test_model_output_and_gradients(model, x, duplicate_all_gather): _test_model_output_and_gradients(fsdp_jit_model, x, duplicate_all_gather=False) + @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires 2 devices") + @common_utils.parametrize("model_device", ["cuda", "meta"]) + def test_memory_consumption(self, model_device): + import gc + + device = torch.device("cuda", self.rank) + with device: + x_1 = torch.randn((2, ToyModel.N_IN)) + with torch.device(model_device): + model = ToyModel() + jit_fsdp_model = thunder.jit(fsdp(model, device=device)) + y_1 = jit_fsdp_model(x_1) + active_mem_jit_fsdp = torch.cuda.memory_stats()["active_bytes.all.current"] + + del x_1, y_1, jit_fsdp_model, model + gc.collect() + torch.cuda.empty_cache() + + with device: + x_2 = torch.randn((2, ToyModel.N_IN)) + with torch.device(model_device): + model = ToyModel() + fsdp_jit_model = fsdp(thunder.jit(model), device=device) + y_2 = fsdp_jit_model(x_2) + active_mem_fsdp_jit = torch.cuda.memory_stats()["active_bytes.all.current"] + self.assertAlmostEqual(active_mem_fsdp_jit, active_mem_jit_fsdp) + common_utils.instantiate_parametrized_tests(FSDPTest) From 2357aada2f35af047a5e7944bf84c74bc1138606 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Sun, 18 Aug 2024 22:32:57 +0200 Subject: [PATCH 6/9] Update thunder/distributed/transforms/fsdp_v2.py --- thunder/distributed/transforms/fsdp_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thunder/distributed/transforms/fsdp_v2.py b/thunder/distributed/transforms/fsdp_v2.py index fec2b23370..d1f2d414be 100644 --- a/thunder/distributed/transforms/fsdp_v2.py +++ b/thunder/distributed/transforms/fsdp_v2.py @@ -223,7 +223,7 @@ def transform_module( if p_orig.device.type != "meta": p_meta = torch.nn.Parameter(p.to(device="meta"), requires_grad=p.requires_grad) p_meta._thunder_device = p_orig.device - setattr(submodule, base_pn, p_meta) + submodule.register_parameter(base_pn, p_meta) else: p_orig._thunder_device = self.device From 22a356a6352785df9fb846317c5c966de4cdaeb2 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Sun, 18 Aug 2024 22:15:42 +0200 Subject: [PATCH 7/9] implement FSDPTransform.transform_state_dict_for_submodule --- thunder/distributed/__init__.py | 36 ++++++++++++++++++----- thunder/distributed/transforms/fsdp_v2.py | 20 +++++++++++-- thunder/tests/distributed/test_fsdp.py | 20 +++++++++++++ 3 files changed, 67 insertions(+), 9 deletions(-) diff --git a/thunder/distributed/__init__.py b/thunder/distributed/__init__.py index 57cd533786..c767a305e6 100644 --- a/thunder/distributed/__init__.py +++ b/thunder/distributed/__init__.py @@ -556,7 +556,7 @@ def _shard_params( sharded_params[param] = True -def _shard_param( +def _shard_tensor( param: torch.Tensor, rank: int, world_size: int, @@ -564,7 +564,7 @@ def _shard_param( *, allow_padding_for_fsdp: bool = False, dim: int | None = None, -) -> None: +) -> tuple[torch.Tensor, int | None]: dim_to_shard = 0 if dim is None else dim if allow_padding_for_fsdp: @@ -577,11 +577,11 @@ def _shard_param( if _thunder_fsdp_padding_size > 0: padded_param = torch.empty(padded_param_shape, device=param.device, dtype=param.dtype) padded_param[:orig_0dim_size].copy_(param) - param.data = padded_param.data.narrow(0, chunk_size * rank, chunk_size).clone() - param._thunder_fsdp_padding_size = _thunder_fsdp_padding_size + shard = padded_param.data.narrow(0, chunk_size * rank, chunk_size).clone() + return shard, _thunder_fsdp_padding_size else: - param.data = param.data.narrow(0, chunk_size * rank, chunk_size).clone() - param._thunder_fsdp_padding_size = None + shard = param.data.narrow(0, chunk_size * rank, chunk_size).clone() + return shard, None else: utils.check( param.shape[dim_to_shard] % world_size == 0, @@ -594,7 +594,29 @@ def _shard_param( # NOTE This could be a ShardTensor to indicate other parts of the code # that it's sharded and should be treated differently shard = param.data.narrow(dim_to_shard, chunk_size * rank, chunk_size).clone() - param.data = shard + return shard, None + + +def _shard_param( + param: torch.Tensor, + rank: int, + world_size: int, + name: str, + *, + allow_padding_for_fsdp: bool = False, + dim: int | None = None, +) -> None: + shard, padding_size = _shard_tensor( + param, + rank, + world_size, + name, + allow_padding_for_fsdp=allow_padding_for_fsdp, + dim=dim, + ) + param.data = shard + if allow_padding_for_fsdp: + param._thunder_fsdp_padding_size = padding_size @torch.no_grad() diff --git a/thunder/distributed/transforms/fsdp_v2.py b/thunder/distributed/transforms/fsdp_v2.py index d1f2d414be..4f4a64b9fc 100644 --- a/thunder/distributed/transforms/fsdp_v2.py +++ b/thunder/distributed/transforms/fsdp_v2.py @@ -24,7 +24,14 @@ from thunder.core.transforms import VISIT_TYPE from thunder.core.transforms import visitor_transform from thunder.core.transform_common import Transform -from thunder.distributed import copy_default_process_group, FSDPType, FSDPBucketingStrategy, _shard_param, _materialize +from thunder.distributed import ( + copy_default_process_group, + FSDPType, + FSDPBucketingStrategy, + _materialize, + _shard_param, + _shard_tensor, +) if TYPE_CHECKING: from typing import Any @@ -111,6 +118,8 @@ def transform_module( utils.check(self.process_group is not None, lambda: "The default process group is None") global_rank = tdist.get_rank(group=self.process_group) world_size = tdist.get_world_size(group=self.process_group) + self.global_rank = global_rank + self.world_size = world_size if self.device is None: local_rank = int(os.environ["LOCAL_RANK"]) self.device = torch.device("cuda", local_rank) @@ -237,7 +246,14 @@ def transform_module( def transform_state_dict_for_submodule( self, model: thunder.ThunderModule, submodule_name: str, state_dict: dict ) -> dict: - raise NotImplementedError("cannot transform state dict yet") + prefix = submodule_name + ("." if submodule_name else "") + new_state_dict = {} + for k, v in state_dict.items(): + full_k = prefix + k + if full_k in self.sharded_params: + v, _ = _shard_tensor(v, self.global_rank, self.world_size, full_k, allow_padding_for_fsdp=True) + new_state_dict[k] = v + return new_state_dict def transform_traces_pre_prologue(self, prologue_trace, computation_trace, epilogue_trace, **kwargs): from thunder.distributed import prims as dist_prims diff --git a/thunder/tests/distributed/test_fsdp.py b/thunder/tests/distributed/test_fsdp.py index f694ad9d12..eb9a35e801 100644 --- a/thunder/tests/distributed/test_fsdp.py +++ b/thunder/tests/distributed/test_fsdp.py @@ -665,6 +665,26 @@ def test_memory_consumption(self, model_device): active_mem_fsdp_jit = torch.cuda.memory_stats()["active_bytes.all.current"] self.assertAlmostEqual(active_mem_fsdp_jit, active_mem_jit_fsdp) + @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires 2 devices") + def test_load_original_state_dict(self): + device = torch.device("cuda", self.rank) + with device: + x = torch.randn((2, ToyModel.N_IN)) + with torch.device("cuda"): + model1 = ToyModel() + model2 = ToyModel() + + sd = {k: v.clone() for k, v in model1.state_dict().items()} + + jm1 = fsdp(thunder.jit(model1), device=device) + jm2 = fsdp(thunder.jit(model2), device=device) + jm2.load_original_state_dict(sd) + + y_1 = jm1(x) + y_2 = jm2(x) + + torch.testing.assert_close(y_1, y_2) + common_utils.instantiate_parametrized_tests(FSDPTest) From 55033177495f93aa65c23a879883139ec69c0039 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Mon, 19 Aug 2024 14:26:17 +0200 Subject: [PATCH 8/9] Update thunder/distributed/transforms/fsdp_v2.py Co-authored-by: Masaki Kozuki --- thunder/distributed/transforms/fsdp_v2.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/thunder/distributed/transforms/fsdp_v2.py b/thunder/distributed/transforms/fsdp_v2.py index 4f4a64b9fc..dd8441d754 100644 --- a/thunder/distributed/transforms/fsdp_v2.py +++ b/thunder/distributed/transforms/fsdp_v2.py @@ -246,7 +246,9 @@ def transform_module( def transform_state_dict_for_submodule( self, model: thunder.ThunderModule, submodule_name: str, state_dict: dict ) -> dict: - prefix = submodule_name + ("." if submodule_name else "") + prefix = "" + if submodule_name: + prefix = f"{submodule_name}." new_state_dict = {} for k, v in state_dict.items(): full_k = prefix + k From be3c9316b3ceb48565fef6373fa680e86369e73c Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Mon, 19 Aug 2024 14:28:05 +0200 Subject: [PATCH 9/9] Update thunder/distributed/__init__.py --- thunder/distributed/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/thunder/distributed/__init__.py b/thunder/distributed/__init__.py index c767a305e6..0bafad550e 100644 --- a/thunder/distributed/__init__.py +++ b/thunder/distributed/__init__.py @@ -556,6 +556,7 @@ def _shard_params( sharded_params[param] = True +@torch.no_grad() def _shard_tensor( param: torch.Tensor, rank: int,