Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move FSDP module transformation to transform_module #986

Merged
merged 6 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@ def add_transform(
*,
transform: Transform,
disable_torch_autograd_support=False,
_legacy_copy_params=False,
) -> Callable:
from thunder.common import CompileData

Expand Down Expand Up @@ -433,7 +434,7 @@ def add_transform(
)
from thunder import ThunderModule

if isinstance(jfn, ThunderModule):
if _legacy_copy_params and isinstance(jfn, ThunderModule):
jfn._overrides_parameters = cfn._overrides_parameters
jfn._overrides_buffers = cfn._overrides_buffers
return jfn
Expand Down
165 changes: 11 additions & 154 deletions thunder/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,155 +422,6 @@ def f(tensor: TensorProxy) -> str:
return f


# When the user calls fsdp(jitted_module), this function does the following
# - It transforms the ThunderModule jitted_module, materializing and sharding the parameters as `overrides`
# in the ThunderModule.
# - While doing that, it leaves the original user module alone.
# - It then registers an early transform (callback that runs before prologue is executed) that transforms the
# prologue and compute trace.
#
# Note that for doing so, there are a few constraints / caveats:
# - We do not have prologues/compute traces when we transform the module.
# - We need to record the info from the module transformations because a later transform might modify the module further.


def fsdp_transform_module(
thunder_model: ThunderModule,
*,
device: torch.device | None = None,
broadcast_from: int | None = None,
sharding_strategy: FSDPType = FSDPType.ZERO2,
bucketing_strategy: FSDPBucketingStrategy = FSDPBucketingStrategy.NONE,
) -> ThunderModule:
from thunder import compile_data as get_compile_data
from thunder.core.transforms import add_transform
from thunder.core.module import ThunderModule
from thunder.distributed.transforms.fsdp_v2 import FSDPTraceTransform

process_group = copy_default_process_group()
utils.check(process_group is not None, lambda: "The default process group is None")
global_rank = tdist.get_rank(group=process_group)
world_size = tdist.get_world_size(group=process_group)
if device is None:
local_rank = int(os.environ["LOCAL_RANK"])
device = torch.device("cuda", local_rank)

cd = get_compile_data(thunder_model)
# TODO: promote use_fsdp and use_ddp to public members of CompileData
cd.use_fsdp = True
orig_module: torch.nn.Module = cd.fn
utils.check(
isinstance(orig_module, torch.nn.Module) and not isinstance(orig_module, ThunderModule),
lambda: f"CompileData.fn expected to be `nn.Module` but {type(orig_module)}",
)
orig_module.use_fsdp = True
orig_module.process_group_for_ddp = process_group
orig_module.bucketing_strategy = bucketing_strategy
orig_module.sharding_strategy = sharding_strategy

# modify module
sharded_params = {}
device_adjustments = {}
# We use `shared_params` dictionary to track the shared parameters.
# Key to this dictionary is the original parameter from the user's Module.
# Values are the copied and sharded parameter for the thunder module and meta-data related to sharding.
shared_params = WeakTensorKeyDictionary()

# NOTE: Shared Parameters in Trace
# Shared parameters in PyTorch eager are parameters of module which have different name but share the underlying tensor.
# For shared parameter, we replace all occurence shared parameter with it's corresponding `base` parameter.
# In our implementation `base` parameter is the parameter and corresponding name which we see the first time while
# iterating our parameters (see below). We track subsequent parameter which share the underlying Tensor with this `base` parameter
# in `shared_params_name` dictionary.
# Then while, transforming the trace - `see FSDPTraceTransform.transform_traces` - we replace all the proxy of shared parameter
# with the corresponding proxy of base parameter in the computation trace.

# This is used to track the shared parameters when the transform is applied.
# key - parameter name, value - `base` parameter name.
shared_params_name: dict[str, str] = {}
for module_name, _ in thunder_model._model.named_modules():
submodule = thunder_model.get_submodule(module_name)

# we use a copy to let the user's module alone (TODO: does this fully work?)
module_copy = copy.copy(submodule)
# TODO: we should probably populate the module copy with parameters that consider overrides

# Materialize meta-parameters on-device if necessary.
# This is done before sharding in case the materialization logic depends on the tensor shape.
# The tradeoff is that all of a module's direct parameters need to fit in device.
# Each module only initializes its own parameters and not those of its children (recurse=False)
if any(t.is_meta for t in chain(module_copy.parameters(recurse=False), module_copy.buffers(recurse=False))):
# TODO: we could also support calling a "param_init_fn" argument like PyTorch
_materialize(module_copy, device)
for n, p in module_copy.named_parameters(recurse=False, prefix=module_name):
thunder_model._overrides_parameters[n] = p
device_adjustments[n] = device
for n, b in module_copy.named_buffers(recurse=False, prefix=module_name):
thunder_model._overrides_buffers[n] = b
device_adjustments[n] = device
else:
# Move leftover params and buffers to device. This is at least required to broadcast.
# Cannot `submodule.to(device)` because we don't want it to recurse
for n, p in module_copy.named_parameters(recurse=False, prefix=module_name):
if p.device != device:
thunder_model._overrides_parameters[n] = torch.nn.Parameter(
p.to(device=device), requires_grad=p.requires_grad
)
device_adjustments[n] = device
for n, b in module_copy.named_buffers(recurse=False, prefix=module_name):
if b.device != device:
thunder_model._overrides_buffers[n] = b.to(device=device)
device_adjustments[n] = device

# Broadcast parameters if requested
if broadcast_from is not None:
for pn, _ in submodule.named_parameters(recurse=False, prefix=module_name):
tdist.broadcast(
thunder_model.get_parameter(pn), src=broadcast_from, group=process_group, async_op=False
)
for pn, _ in submodule.named_buffers(recurse=False, prefix=module_name):
tdist.broadcast(thunder_model.get_buffer(pn), src=broadcast_from, group=process_group, async_op=False)

for pn, p in submodule.named_parameters(recurse=False, prefix=module_name):
# If there are shared params in the original user Module, we reuse the sharded copy created from the original parameter below.
# This way we re-create parameter sharing in thunder's copy of the Module.
if p in shared_params:
# Shared param names : current param - base param
shared_params_name[pn] = shared_params[p]["param_name"]
# Re-use the previous copy of this parameter.
thunder_model._overrides_parameters[pn] = shared_params[p]["param_copy"]
sharded_params[pn] = shared_params[p]["param_shard_meta"]
continue

# if we don't have an override or it is just the original, do create a copy
if thunder_model._overrides_parameters.get(pn, p) is p:
thunder_model._overrides_parameters[pn] = copy.copy(p)
# we collect shapes and devices because we do not know if other transforms also change it...
old_shape = thunder_model._overrides_parameters[pn].shape
_shard_param(
thunder_model._overrides_parameters[pn], global_rank, world_size, pn, allow_padding_for_fsdp=True
)
new_shape = thunder_model._overrides_parameters[pn].shape
sharded_params[pn] = (old_shape, new_shape, thunder_model._overrides_parameters[pn].device)

# Track the original param and it's corresponding copied shard and metadata.
shared_params[p] = {
"param_copy": thunder_model._overrides_parameters[pn],
"param_shard_meta": sharded_params[pn],
"param_name": pn,
}

transform_from_trace_to_fsdp_trace = FSDPTraceTransform(
sharded_params=sharded_params,
process_group=process_group,
shared_params_name=shared_params_name,
)
# add prologue + compute transform
thunder_model = add_transform(thunder_model, transform=transform_from_trace_to_fsdp_trace)

return thunder_model


def fsdp(
model: torch.nn.Module,
*,
Expand Down Expand Up @@ -621,13 +472,19 @@ def fsdp(
)

if isinstance(model, thunder.ThunderModule):
return fsdp_transform_module(
from thunder.core.transforms import add_transform
from thunder.distributed.transforms.fsdp_v2 import FSDPTransform

new_model = add_transform(
t-vi marked this conversation as resolved.
Show resolved Hide resolved
model,
device=device,
broadcast_from=broadcast_from,
sharding_strategy=sharding_strategy,
bucketing_strategy=bucketing_strategy,
transform=FSDPTransform(
device=device,
broadcast_from=broadcast_from,
sharding_strategy=sharding_strategy,
bucketing_strategy=bucketing_strategy,
),
)
return new_model
t-vi marked this conversation as resolved.
Show resolved Hide resolved

process_group = copy_default_process_group()
utils.check(process_group is not None, lambda: "The default process group is None")
Expand Down
1 change: 1 addition & 0 deletions thunder/distributed/tensor_parallel/column_wise.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ def forward(self, tokens: torch.Tensor) -> torch.Tensor:
chunked_param_name_to_layer_type=chunked_param_name_to_layer_type,
process_group=process_group,
),
_legacy_copy_params=True,
crcrpar marked this conversation as resolved.
Show resolved Hide resolved
)

return colwise_thunder_module
1 change: 1 addition & 0 deletions thunder/distributed/tensor_parallel/row_wise.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def forward(self, tokens: torch.Tensor) -> torch.Tensor:
chunked_param_name_to_layer_type=chunked_param_name_to_layer_type,
process_group=process_group,
),
_legacy_copy_params=True,
crcrpar marked this conversation as resolved.
Show resolved Hide resolved
)

return rowwise_thunder_module
Loading
Loading