Skip to content

Commit

Permalink
free orig parameters after sharding
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi committed Aug 19, 2024
1 parent 98f4692 commit f05768e
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 0 deletions.
1 change: 1 addition & 0 deletions thunder/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,7 @@ def fsdp(
broadcast_from=broadcast_from,
sharding_strategy=sharding_strategy,
bucketing_strategy=bucketing_strategy,
release_original_parameters=True,
),
)
return new_model
Expand Down
11 changes: 11 additions & 0 deletions thunder/distributed/transforms/fsdp_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,13 @@ def __init__(
broadcast_from: int | None = None,
sharding_strategy: FSDPType = FSDPType.ZERO2,
bucketing_strategy: FSDPBucketingStrategy = FSDPBucketingStrategy.NONE,
release_original_parameters: bool = False,
):
self.device = device
self.broadcast_from = broadcast_from
self.sharding_strategy = sharding_strategy
self.bucketing_strategy = bucketing_strategy
self.release_original_parameters = release_original_parameters
self.sharded_params: dict[str, Any] = {}
self.process_group: ProcessGroup | None = None
self.shared_params_name: dict[str, str] = {}
Expand Down Expand Up @@ -215,6 +217,15 @@ def transform_module(
)
new_shape = thunder_model._overrides_parameters[pn].shape
self.sharded_params[pn] = (old_shape, new_shape, thunder_model._overrides_parameters[pn].device)
if self.release_original_parameters:
base_pn = pn.rsplit(".", 1)[-1]
p_orig = getattr(submodule, base_pn)
if p_orig.device.type != "meta":
p_meta = torch.nn.Parameter(p.to(device="meta"), requires_grad=p.requires_grad)
p_meta._thunder_device = p_orig.device
setattr(submodule, base_pn, p_meta)
else:
p_orig._thunder_device = self.device

# Track the original param and it's corresponding copied shard and metadata.
shared_params[p] = {
Expand Down
27 changes: 27 additions & 0 deletions thunder/tests/distributed/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,33 @@ def _test_model_output_and_gradients(model, x, duplicate_all_gather):

_test_model_output_and_gradients(fsdp_jit_model, x, duplicate_all_gather=False)

@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires 2 devices")
@common_utils.parametrize("model_device", ["cuda", "meta"])
def test_memory_consumption(self, model_device):
import gc

device = torch.device("cuda", self.rank)
with device:
x_1 = torch.randn((2, ToyModel.N_IN))
with torch.device(model_device):
model = ToyModel()
jit_fsdp_model = thunder.jit(fsdp(model, device=device))
y_1 = jit_fsdp_model(x_1)
active_mem_jit_fsdp = torch.cuda.memory_stats()["active_bytes.all.current"]

del x_1, y_1, jit_fsdp_model, model
gc.collect()
torch.cuda.empty_cache()

with device:
x_2 = torch.randn((2, ToyModel.N_IN))
with torch.device(model_device):
model = ToyModel()
fsdp_jit_model = fsdp(thunder.jit(model), device=device)
y_2 = fsdp_jit_model(x_2)
active_mem_fsdp_jit = torch.cuda.memory_stats()["active_bytes.all.current"]
self.assertAlmostEqual(active_mem_fsdp_jit, active_mem_jit_fsdp)


common_utils.instantiate_parametrized_tests(FSDPTest)

Expand Down

0 comments on commit f05768e

Please sign in to comment.