Skip to content

Commit

Permalink
Correctly get parameters with unsharded grads and parameters to regis…
Browse files Browse the repository at this point in the history
…ter sharded and sync'ed grads (#1643)
  • Loading branch information
crcrpar authored Jan 14, 2025
1 parent a477a7c commit fa8b80c
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions thunder/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,10 @@ def _sync_grads(module: torch.nn.Module) -> None:
tdist.distributed_c10d.all_reduce(g, group=process_group)
cm.wait()
elif getattr(module, "use_fsdp", False):
from typing import cast
from thunder.core.module import ThunderModule

module = cast(ThunderModule, module)

def prep_shard(
g: torch.Tensor,
Expand All @@ -154,7 +158,16 @@ def prep_shard(

rank: int = tdist.distributed_c10d.get_rank(process_group)
world_size: int = tdist.distributed_c10d.get_world_size(process_group)
params_with_grad = tuple(filter(lambda p: hasattr(p, "_thunder_fsdp_unsharded_grad"), module.parameters()))

params_of_orig_module = list(module._model.get_parameter(name) for name, _ in module.named_parameters())
params_of_fsdp_module = list(module.get_parameter(name) for name, _ in module.named_parameters())

params_with_grad = []
params_to_attach_grad = []
for o, f in zip(params_of_orig_module, params_of_fsdp_module):
if hasattr(o, "_thunder_fsdp_unsharded_grad"):
params_with_grad.append(o)
params_to_attach_grad.append(f)
if not params_with_grad:
return
unsharded_grads = [p._thunder_fsdp_unsharded_grad for p in params_with_grad]
Expand All @@ -165,9 +178,9 @@ def prep_shard(
s, u, op=tdist.distributed_c10d.ReduceOp.AVG, group=process_group
)
cm.wait()
for p, g in zip(params_with_grad, sharded_grads):
for orig, p, g in zip(params_with_grad, params_to_attach_grad, sharded_grads):
p.grad = g
del p._thunder_fsdp_unsharded_grad
del orig._thunder_fsdp_unsharded_grad
else:
import warnings

Expand Down

0 comments on commit fa8b80c

Please sign in to comment.