diff --git a/thunder/distributed/transforms/fsdp_v2.py b/thunder/distributed/transforms/fsdp_v2.py index fec2b23370..d1f2d414be 100644 --- a/thunder/distributed/transforms/fsdp_v2.py +++ b/thunder/distributed/transforms/fsdp_v2.py @@ -223,7 +223,7 @@ def transform_module( if p_orig.device.type != "meta": p_meta = torch.nn.Parameter(p.to(device="meta"), requires_grad=p.requires_grad) p_meta._thunder_device = p_orig.device - setattr(submodule, base_pn, p_meta) + submodule.register_parameter(base_pn, p_meta) else: p_orig._thunder_device = self.device