Skip to content

Commit

Permalink
Move FSDP module transformation to transform_module (#986)
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi authored Aug 19, 2024
1 parent d425fe4 commit fbe9f2c
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 159 deletions.
3 changes: 2 additions & 1 deletion thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@ def add_transform(
*,
transform: Transform,
disable_torch_autograd_support=False,
_legacy_copy_params=False,
) -> Callable:
from thunder.common import CompileData

Expand Down Expand Up @@ -433,7 +434,7 @@ def add_transform(
)
from thunder import ThunderModule

if isinstance(jfn, ThunderModule):
if _legacy_copy_params and isinstance(jfn, ThunderModule):
jfn._overrides_parameters = cfn._overrides_parameters
jfn._overrides_buffers = cfn._overrides_buffers
return jfn
Expand Down
164 changes: 10 additions & 154 deletions thunder/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,155 +422,6 @@ def f(tensor: TensorProxy) -> str:
return f


# When the user calls fsdp(jitted_module), this function does the following
# - It transforms the ThunderModule jitted_module, materializing and sharding the parameters as `overrides`
# in the ThunderModule.
# - While doing that, it leaves the original user module alone.
# - It then registers an early transform (callback that runs before prologue is executed) that transforms the
# prologue and compute trace.
#
# Note that for doing so, there are a few constraints / caveats:
# - We do not have prologues/compute traces when we transform the module.
# - We need to record the info from the module transformations because a later transform might modify the module further.


def fsdp_transform_module(
thunder_model: ThunderModule,
*,
device: torch.device | None = None,
broadcast_from: int | None = None,
sharding_strategy: FSDPType = FSDPType.ZERO2,
bucketing_strategy: FSDPBucketingStrategy = FSDPBucketingStrategy.NONE,
) -> ThunderModule:
from thunder import compile_data as get_compile_data
from thunder.core.transforms import add_transform
from thunder.core.module import ThunderModule
from thunder.distributed.transforms.fsdp_v2 import FSDPTraceTransform

process_group = copy_default_process_group()
utils.check(process_group is not None, lambda: "The default process group is None")
global_rank = tdist.get_rank(group=process_group)
world_size = tdist.get_world_size(group=process_group)
if device is None:
local_rank = int(os.environ["LOCAL_RANK"])
device = torch.device("cuda", local_rank)

cd = get_compile_data(thunder_model)
# TODO: promote use_fsdp and use_ddp to public members of CompileData
cd.use_fsdp = True
orig_module: torch.nn.Module = cd.fn
utils.check(
isinstance(orig_module, torch.nn.Module) and not isinstance(orig_module, ThunderModule),
lambda: f"CompileData.fn expected to be `nn.Module` but {type(orig_module)}",
)
orig_module.use_fsdp = True
orig_module.process_group_for_ddp = process_group
orig_module.bucketing_strategy = bucketing_strategy
orig_module.sharding_strategy = sharding_strategy

# modify module
sharded_params = {}
device_adjustments = {}
# We use `shared_params` dictionary to track the shared parameters.
# Key to this dictionary is the original parameter from the user's Module.
# Values are the copied and sharded parameter for the thunder module and meta-data related to sharding.
shared_params = WeakTensorKeyDictionary()

# NOTE: Shared Parameters in Trace
# Shared parameters in PyTorch eager are parameters of module which have different name but share the underlying tensor.
# For shared parameter, we replace all occurence shared parameter with it's corresponding `base` parameter.
# In our implementation `base` parameter is the parameter and corresponding name which we see the first time while
# iterating our parameters (see below). We track subsequent parameter which share the underlying Tensor with this `base` parameter
# in `shared_params_name` dictionary.
# Then while, transforming the trace - `see FSDPTraceTransform.transform_traces` - we replace all the proxy of shared parameter
# with the corresponding proxy of base parameter in the computation trace.

# This is used to track the shared parameters when the transform is applied.
# key - parameter name, value - `base` parameter name.
shared_params_name: dict[str, str] = {}
for module_name, _ in thunder_model._model.named_modules():
submodule = thunder_model.get_submodule(module_name)

# we use a copy to let the user's module alone (TODO: does this fully work?)
module_copy = copy.copy(submodule)
# TODO: we should probably populate the module copy with parameters that consider overrides

# Materialize meta-parameters on-device if necessary.
# This is done before sharding in case the materialization logic depends on the tensor shape.
# The tradeoff is that all of a module's direct parameters need to fit in device.
# Each module only initializes its own parameters and not those of its children (recurse=False)
if any(t.is_meta for t in chain(module_copy.parameters(recurse=False), module_copy.buffers(recurse=False))):
# TODO: we could also support calling a "param_init_fn" argument like PyTorch
_materialize(module_copy, device)
for n, p in module_copy.named_parameters(recurse=False, prefix=module_name):
thunder_model._overrides_parameters[n] = p
device_adjustments[n] = device
for n, b in module_copy.named_buffers(recurse=False, prefix=module_name):
thunder_model._overrides_buffers[n] = b
device_adjustments[n] = device
else:
# Move leftover params and buffers to device. This is at least required to broadcast.
# Cannot `submodule.to(device)` because we don't want it to recurse
for n, p in module_copy.named_parameters(recurse=False, prefix=module_name):
if p.device != device:
thunder_model._overrides_parameters[n] = torch.nn.Parameter(
p.to(device=device), requires_grad=p.requires_grad
)
device_adjustments[n] = device
for n, b in module_copy.named_buffers(recurse=False, prefix=module_name):
if b.device != device:
thunder_model._overrides_buffers[n] = b.to(device=device)
device_adjustments[n] = device

# Broadcast parameters if requested
if broadcast_from is not None:
for pn, _ in submodule.named_parameters(recurse=False, prefix=module_name):
tdist.broadcast(
thunder_model.get_parameter(pn), src=broadcast_from, group=process_group, async_op=False
)
for pn, _ in submodule.named_buffers(recurse=False, prefix=module_name):
tdist.broadcast(thunder_model.get_buffer(pn), src=broadcast_from, group=process_group, async_op=False)

for pn, p in submodule.named_parameters(recurse=False, prefix=module_name):
# If there are shared params in the original user Module, we reuse the sharded copy created from the original parameter below.
# This way we re-create parameter sharing in thunder's copy of the Module.
if p in shared_params:
# Shared param names : current param - base param
shared_params_name[pn] = shared_params[p]["param_name"]
# Re-use the previous copy of this parameter.
thunder_model._overrides_parameters[pn] = shared_params[p]["param_copy"]
sharded_params[pn] = shared_params[p]["param_shard_meta"]
continue

# if we don't have an override or it is just the original, do create a copy
if thunder_model._overrides_parameters.get(pn, p) is p:
thunder_model._overrides_parameters[pn] = copy.copy(p)
# we collect shapes and devices because we do not know if other transforms also change it...
old_shape = thunder_model._overrides_parameters[pn].shape
_shard_param(
thunder_model._overrides_parameters[pn], global_rank, world_size, pn, allow_padding_for_fsdp=True
)
new_shape = thunder_model._overrides_parameters[pn].shape
sharded_params[pn] = (old_shape, new_shape, thunder_model._overrides_parameters[pn].device)

# Track the original param and it's corresponding copied shard and metadata.
shared_params[p] = {
"param_copy": thunder_model._overrides_parameters[pn],
"param_shard_meta": sharded_params[pn],
"param_name": pn,
}

transform_from_trace_to_fsdp_trace = FSDPTraceTransform(
sharded_params=sharded_params,
process_group=process_group,
shared_params_name=shared_params_name,
)
# add prologue + compute transform
thunder_model = add_transform(thunder_model, transform=transform_from_trace_to_fsdp_trace)

return thunder_model


def fsdp(
model: torch.nn.Module,
*,
Expand Down Expand Up @@ -621,12 +472,17 @@ def fsdp(
)

if isinstance(model, thunder.ThunderModule):
return fsdp_transform_module(
from thunder.core.transforms import add_transform
from thunder.distributed.transforms.fsdp_v2 import FSDPTransform

return add_transform(
model,
device=device,
broadcast_from=broadcast_from,
sharding_strategy=sharding_strategy,
bucketing_strategy=bucketing_strategy,
transform=FSDPTransform(
device=device,
broadcast_from=broadcast_from,
sharding_strategy=sharding_strategy,
bucketing_strategy=bucketing_strategy,
),
)

process_group = copy_default_process_group()
Expand Down
1 change: 1 addition & 0 deletions thunder/distributed/tensor_parallel/column_wise.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ def forward(self, tokens: torch.Tensor) -> torch.Tensor:
chunked_param_name_to_layer_type=chunked_param_name_to_layer_type,
process_group=process_group,
),
_legacy_copy_params=True,
)

return colwise_thunder_module
1 change: 1 addition & 0 deletions thunder/distributed/tensor_parallel/row_wise.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def forward(self, tokens: torch.Tensor) -> torch.Tensor:
chunked_param_name_to_layer_type=chunked_param_name_to_layer_type,
process_group=process_group,
),
_legacy_copy_params=True,
)

return rowwise_thunder_module
Loading

0 comments on commit fbe9f2c

Please sign in to comment.