Skip to content

Commit 21a222b

Browse files
authored
fsdp(jit(model)) + parameter sharing - dont duplicate allgather (#602)
1 parent 71466cc commit 21a222b

File tree

3 files changed

+49
-3
lines changed

3 files changed

+49
-3
lines changed

thunder/distributed/__init__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,19 @@ def fsdp_transform_module(
453453
# Key to this dictionary is the original parameter from the user's Module.
454454
# Values are the copied and sharded parameter for the thunder module and meta-data related to sharding.
455455
shared_params = WeakTensorKeyDictionary()
456+
457+
# NOTE: Shared Parameters in Trace
458+
# Shared parameters in PyTorch eager are parameters of module which have different name but share the underlying tensor.
459+
# For shared parameter, we replace all occurence shared parameter with it's corresponding `base` parameter.
460+
# In our implementation `base` parameter is the parameter and corresponding name which we see the first time while
461+
# iterating our parameters (see below). We track subsequent parameter which share the underlying Tensor with this `base` parameter
462+
# in `shared_params_name` dictionary.
463+
# Then while, transforming the trace - `see FSDPTraceTransform.transform_traces` - we replace all the proxy of shared parameter
464+
# with the corresponding proxy of base parameter in the computation trace.
465+
466+
# This is used to track the shared parameters when the transform is applied.
467+
# key - parameter name, value - `base` parameter name.
468+
shared_params_name: dict[str, str] = {}
456469
for module_name, _ in thunder_model._model.named_modules():
457470
submodule = thunder_model.get_submodule(module_name)
458471

@@ -500,6 +513,8 @@ def fsdp_transform_module(
500513
# If there are shared params in the original user Module, we reuse the sharded copy created from the original parameter below.
501514
# This way we re-create parameter sharing in thunder's copy of the Module.
502515
if p in shared_params:
516+
# Shared param names : current param - base param
517+
shared_params_name[pn] = shared_params[p]["param_name"]
503518
# Re-use the previous copy of this parameter.
504519
thunder_model._overrides_parameters[pn] = shared_params[p]["param_copy"]
505520
sharded_params[pn] = shared_params[p]["param_shard_meta"]
@@ -520,11 +535,13 @@ def fsdp_transform_module(
520535
shared_params[p] = {
521536
"param_copy": thunder_model._overrides_parameters[pn],
522537
"param_shard_meta": sharded_params[pn],
538+
"param_name": pn,
523539
}
524540

525541
early_transform_from_trace_to_fsdp_trace = FSDPTraceTransform(
526542
sharded_params=sharded_params,
527543
process_group=process_group,
544+
shared_params_name=shared_params_name,
528545
)
529546
# add prologue + compute transform
530547
thunder_model = add_transform(thunder_model, early_transform=early_transform_from_trace_to_fsdp_trace)

thunder/distributed/transforms/fsdp_v2.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
class FSDPTraceTransform(EarlyTransform):
2828
sharded_params: dict[str, Any]
2929
process_group: ProcessGroup
30+
shared_params_name: dict[str, str]
3031

3132
def transform_traces(self, prologue_trace, computation_trace, epilogue_trace, **kwargs):
3233
from thunder.distributed import prims as dist_prims
@@ -49,13 +50,15 @@ def transform_traces(self, prologue_trace, computation_trace, epilogue_trace, **
4950
computation_trace.push_scope([])
5051

5152
synchronized_parameters = []
53+
param_name_to_comp_trc_proxy = {} # Track param_name to it's corresponding proxy in computation_trc.
5254
# todo: deal with epilogue output
5355
for pro_out_p, comp_inp_p in zip(prologue_trace.output, computation_trace.args):
5456
bsym = prologue_producers[pro_out_p]
5557
if bsym.sym == prims.unpack_parameter:
5658
param_thunder_module, param_name = bsym.args
5759
assert param_thunder_module is thunder_module_proxy
5860
if param_name in self.sharded_params:
61+
param_name_to_comp_trc_proxy[param_name] = comp_inp_p
5962
old_shape, new_shape, new_torch_device = self.sharded_params[param_name]
6063
thunder_device = devices.to_device(new_torch_device)
6164
thunder_device_str = str(thunder_device)
@@ -91,6 +94,15 @@ def transform_traces(self, prologue_trace, computation_trace, epilogue_trace, **
9194

9295
proxies_to_replace = {id(bsym.args[0]): bsym.output for bsym in new_scope}
9396

97+
# See NOTE: Shared Parameters in Trace
98+
for param_name, base_param in self.shared_params_name.items():
99+
param_proxy = param_name_to_comp_trc_proxy[param_name]
100+
base_param_proxy = param_name_to_comp_trc_proxy[base_param]
101+
allgather_base_param_proxy = proxies_to_replace[id(base_param_proxy)]
102+
# Update `proxies_to_replace` so we replace all usage of `param_proxy`
103+
# with the output of `AllGather` on `base_param_proxy`.
104+
proxies_to_replace[id(param_proxy)] = allgather_base_param_proxy
105+
94106
new_computation_trace = from_trace(computation_trace)
95107
for idx, bsym in enumerate(computation_trace.bound_symbols):
96108
if bsym.sym != prims.unpack_trivial:

thunder/tests/distributed/test_ddp.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -960,7 +960,7 @@ def __init__(self) -> None:
960960
def forward(self, x):
961961
return self.fc1(x) + self.fc2(x)
962962

963-
def _test_model_output_and_gradients(model, x):
963+
def _test_model_output_and_gradients(model, x, duplicate_all_gather):
964964
output = model(x)
965965
with device:
966966
grad_output = torch.ones_like(output)
@@ -985,6 +985,23 @@ def _test_model_output_and_gradients(model, x):
985985
expected_grad = 2 * (grad_output.T @ x)
986986
torch.testing.assert_close(actual_grad_gathered, expected_grad)
987987

988+
forward_exec_trace = thunder.last_traces(model)[-1]
989+
gathered_params = set()
990+
for bsym in forward_exec_trace.bound_symbols:
991+
if bsym.sym.id in (
992+
thunder.distributed.prims.PrimIDs.ALL_GATHER,
993+
thunder.executors.torchex.all_gather_prim_impl.id,
994+
):
995+
gathered_params.add(bsym.args[0].name)
996+
997+
# Check trace to see we don't have duplicate AllGather for shared parameters.
998+
if duplicate_all_gather:
999+
# Both params are gathered.
1000+
assert "t_fc1_weight" in gathered_params and "t_fc2_weight" in gathered_params
1001+
else:
1002+
# Either of the param was gathered but not both.
1003+
assert ("t_fc1_weight" in gathered_params) ^ ("t_fc2_weight" in gathered_params)
1004+
9881005
with device:
9891006
jit_fsdp_model = Model()
9901007
fsdp_jit_model = Model()
@@ -995,14 +1012,14 @@ def _test_model_output_and_gradients(model, x):
9951012

9961013
jit_fsdp_model = thunder.jit(thunder.distributed.fsdp(jit_fsdp_model), executors=["torch"])
9971014

998-
_test_model_output_and_gradients(jit_fsdp_model, x)
1015+
_test_model_output_and_gradients(jit_fsdp_model, x, duplicate_all_gather=True)
9991016

10001017
# Check `fsdp(jit(model))` works
10011018
fsdp_jit_model.fc1.weight = fsdp_jit_model.fc2.weight
10021019

10031020
fsdp_jit_model = thunder.distributed.fsdp(thunder.jit(fsdp_jit_model, executors=["torch"]))
10041021

1005-
_test_model_output_and_gradients(fsdp_jit_model, x)
1022+
_test_model_output_and_gradients(fsdp_jit_model, x, duplicate_all_gather=False)
10061023

10071024

10081025
common_utils.instantiate_parametrized_tests(CompileDDPTest)

0 commit comments

Comments
 (0)