Skip to content

Commit

Permalink
implement FSDPTransform.transform_state_dict_for_submodule (#989)
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi authored Aug 20, 2024
1 parent 5cc3011 commit 42e79ba
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 9 deletions.
37 changes: 30 additions & 7 deletions thunder/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,15 +555,16 @@ def _shard_params(
sharded_params[param] = True


def _shard_param(
@torch.no_grad()
def _shard_tensor(
param: torch.Tensor,
rank: int,
world_size: int,
name: str,
*,
allow_padding_for_fsdp: bool = False,
dim: int | None = None,
) -> None:
) -> tuple[torch.Tensor, int | None]:

dim_to_shard = 0 if dim is None else dim
if allow_padding_for_fsdp:
Expand All @@ -576,11 +577,11 @@ def _shard_param(
if _thunder_fsdp_padding_size > 0:
padded_param = torch.empty(padded_param_shape, device=param.device, dtype=param.dtype)
padded_param[:orig_0dim_size].copy_(param)
param.data = padded_param.data.narrow(0, chunk_size * rank, chunk_size).clone()
param._thunder_fsdp_padding_size = _thunder_fsdp_padding_size
shard = padded_param.data.narrow(0, chunk_size * rank, chunk_size).clone()
return shard, _thunder_fsdp_padding_size
else:
param.data = param.data.narrow(0, chunk_size * rank, chunk_size).clone()
param._thunder_fsdp_padding_size = None
shard = param.data.narrow(0, chunk_size * rank, chunk_size).clone()
return shard, None
else:
utils.check(
param.shape[dim_to_shard] % world_size == 0,
Expand All @@ -593,7 +594,29 @@ def _shard_param(
# NOTE This could be a ShardTensor to indicate other parts of the code
# that it's sharded and should be treated differently
shard = param.data.narrow(dim_to_shard, chunk_size * rank, chunk_size).clone()
param.data = shard
return shard, None


def _shard_param(
param: torch.Tensor,
rank: int,
world_size: int,
name: str,
*,
allow_padding_for_fsdp: bool = False,
dim: int | None = None,
) -> None:
shard, padding_size = _shard_tensor(
param,
rank,
world_size,
name,
allow_padding_for_fsdp=allow_padding_for_fsdp,
dim=dim,
)
param.data = shard
if allow_padding_for_fsdp:
param._thunder_fsdp_padding_size = padding_size


@torch.no_grad()
Expand Down
22 changes: 20 additions & 2 deletions thunder/distributed/transforms/fsdp_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,14 @@
from thunder.core.transforms import VISIT_TYPE
from thunder.core.transforms import visitor_transform
from thunder.core.transform_common import Transform
from thunder.distributed import copy_default_process_group, FSDPType, FSDPBucketingStrategy, _shard_param, _materialize
from thunder.distributed import (
copy_default_process_group,
FSDPType,
FSDPBucketingStrategy,
_materialize,
_shard_param,
_shard_tensor,
)

if TYPE_CHECKING:
from typing import Any
Expand Down Expand Up @@ -111,6 +118,8 @@ def transform_module(
utils.check(self.process_group is not None, lambda: "The default process group is None")
global_rank = tdist.get_rank(group=self.process_group)
world_size = tdist.get_world_size(group=self.process_group)
self.global_rank = global_rank
self.world_size = world_size
if self.device is None:
local_rank = int(os.environ["LOCAL_RANK"])
self.device = torch.device("cuda", local_rank)
Expand Down Expand Up @@ -237,7 +246,16 @@ def transform_module(
def transform_state_dict_for_submodule(
self, model: thunder.ThunderModule, submodule_name: str, state_dict: dict
) -> dict:
raise NotImplementedError("cannot transform state dict yet")
prefix = ""
if submodule_name:
prefix = f"{submodule_name}."
new_state_dict = {}
for k, v in state_dict.items():
full_k = prefix + k
if full_k in self.sharded_params:
v, _ = _shard_tensor(v, self.global_rank, self.world_size, full_k, allow_padding_for_fsdp=True)
new_state_dict[k] = v
return new_state_dict

def transform_traces_pre_prologue(self, prologue_trace, computation_trace, epilogue_trace, **kwargs):
from thunder.distributed import prims as dist_prims
Expand Down
20 changes: 20 additions & 0 deletions thunder/tests/distributed/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,26 @@ def test_memory_consumption(self, model_device):
active_mem_fsdp_jit = torch.cuda.memory_stats()["active_bytes.all.current"]
self.assertAlmostEqual(active_mem_fsdp_jit, active_mem_jit_fsdp)

@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires 2 devices")
def test_load_original_state_dict(self):
device = torch.device("cuda", self.rank)
with device:
x = torch.randn((2, ToyModel.N_IN))
with torch.device("cuda"):
model1 = ToyModel()
model2 = ToyModel()

sd = {k: v.clone() for k, v in model1.state_dict().items()}

jm1 = fsdp(thunder.jit(model1), device=device)
jm2 = fsdp(thunder.jit(model2), device=device)
jm2.load_original_state_dict(sd)

y_1 = jm1(x)
y_2 = jm2(x)

torch.testing.assert_close(y_1, y_2)


common_utils.instantiate_parametrized_tests(FSDPTest)

Expand Down

0 comments on commit 42e79ba

Please sign in to comment.