Skip to content

Commit

Permalink
correctly get parameters
Browse files Browse the repository at this point in the history
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
  • Loading branch information
crcrpar committed Jan 14, 2025
1 parent a477a7c commit f8fbf6c
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions thunder/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,10 @@ def _sync_grads(module: torch.nn.Module) -> None:
tdist.distributed_c10d.all_reduce(g, group=process_group)
cm.wait()
elif getattr(module, "use_fsdp", False):
from typing import cast
from thunder.core.module import ThunderModule

module = cast(ThunderModule, module)

def prep_shard(
g: torch.Tensor,
Expand All @@ -154,7 +158,16 @@ def prep_shard(

rank: int = tdist.distributed_c10d.get_rank(process_group)
world_size: int = tdist.distributed_c10d.get_world_size(process_group)
params_with_grad = tuple(filter(lambda p: hasattr(p, "_thunder_fsdp_unsharded_grad"), module.parameters()))

params_of_orig_module = list(module._model.get_parameter(name) for name, _ in module.named_parameters())
params_of_fsdp_module = list(module.get_parameter(name) for name, _ in module.named_parameters())

params_with_grad = []
params_to_attach_grad = []
for o, f in zip(params_of_orig_module, params_of_fsdp_module):
if hasattr(o, "_thunder_fsdp_unsharded_grad"):
params_with_grad.append(o)
params_to_attach_grad.append(f)
if not params_with_grad:
return
unsharded_grads = [p._thunder_fsdp_unsharded_grad for p in params_with_grad]
Expand All @@ -165,9 +178,9 @@ def prep_shard(
s, u, op=tdist.distributed_c10d.ReduceOp.AVG, group=process_group
)
cm.wait()
for p, g in zip(params_with_grad, sharded_grads):
for orig, p, g in zip(params_with_grad, params_to_attach_grad, sharded_grads):
p.grad = g
del p._thunder_fsdp_unsharded_grad
del orig._thunder_fsdp_unsharded_grad
else:
import warnings

Expand Down

0 comments on commit f8fbf6c

Please sign in to comment.