From fa8b80c4a4837345d88c9d7c63ef2c56774f420b Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Tue, 14 Jan 2025 16:28:31 +0900 Subject: [PATCH] Correctly get parameters with unsharded grads and parameters to register sharded and sync'ed grads (#1643) --- thunder/distributed/__init__.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/thunder/distributed/__init__.py b/thunder/distributed/__init__.py index a4c25235b..f1f3f77fc 100644 --- a/thunder/distributed/__init__.py +++ b/thunder/distributed/__init__.py @@ -143,6 +143,10 @@ def _sync_grads(module: torch.nn.Module) -> None: tdist.distributed_c10d.all_reduce(g, group=process_group) cm.wait() elif getattr(module, "use_fsdp", False): + from typing import cast + from thunder.core.module import ThunderModule + + module = cast(ThunderModule, module) def prep_shard( g: torch.Tensor, @@ -154,7 +158,16 @@ def prep_shard( rank: int = tdist.distributed_c10d.get_rank(process_group) world_size: int = tdist.distributed_c10d.get_world_size(process_group) - params_with_grad = tuple(filter(lambda p: hasattr(p, "_thunder_fsdp_unsharded_grad"), module.parameters())) + + params_of_orig_module = list(module._model.get_parameter(name) for name, _ in module.named_parameters()) + params_of_fsdp_module = list(module.get_parameter(name) for name, _ in module.named_parameters()) + + params_with_grad = [] + params_to_attach_grad = [] + for o, f in zip(params_of_orig_module, params_of_fsdp_module): + if hasattr(o, "_thunder_fsdp_unsharded_grad"): + params_with_grad.append(o) + params_to_attach_grad.append(f) if not params_with_grad: return unsharded_grads = [p._thunder_fsdp_unsharded_grad for p in params_with_grad] @@ -165,9 +178,9 @@ def prep_shard( s, u, op=tdist.distributed_c10d.ReduceOp.AVG, group=process_group ) cm.wait() - for p, g in zip(params_with_grad, sharded_grads): + for orig, p, g in zip(params_with_grad, params_to_attach_grad, sharded_grads): p.grad = g - del p._thunder_fsdp_unsharded_grad + del orig._thunder_fsdp_unsharded_grad else: import warnings