Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement FSDPTransform.transform_state_dict_for_submodule #989

Merged
merged 11 commits into from
Aug 20, 2024
37 changes: 30 additions & 7 deletions thunder/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,15 +555,16 @@ def _shard_params(
sharded_params[param] = True


def _shard_param(
@torch.no_grad()
def _shard_tensor(
t-vi marked this conversation as resolved.
Show resolved Hide resolved
param: torch.Tensor,
rank: int,
world_size: int,
name: str,
*,
allow_padding_for_fsdp: bool = False,
dim: int | None = None,
) -> None:
) -> tuple[torch.Tensor, int | None]:

dim_to_shard = 0 if dim is None else dim
if allow_padding_for_fsdp:
Expand All @@ -576,11 +577,11 @@ def _shard_param(
if _thunder_fsdp_padding_size > 0:
padded_param = torch.empty(padded_param_shape, device=param.device, dtype=param.dtype)
padded_param[:orig_0dim_size].copy_(param)
param.data = padded_param.data.narrow(0, chunk_size * rank, chunk_size).clone()
param._thunder_fsdp_padding_size = _thunder_fsdp_padding_size
shard = padded_param.data.narrow(0, chunk_size * rank, chunk_size).clone()
return shard, _thunder_fsdp_padding_size
else:
param.data = param.data.narrow(0, chunk_size * rank, chunk_size).clone()
param._thunder_fsdp_padding_size = None
shard = param.data.narrow(0, chunk_size * rank, chunk_size).clone()
return shard, None
else:
utils.check(
param.shape[dim_to_shard] % world_size == 0,
Expand All @@ -593,7 +594,29 @@ def _shard_param(
# NOTE This could be a ShardTensor to indicate other parts of the code
# that it's sharded and should be treated differently
shard = param.data.narrow(dim_to_shard, chunk_size * rank, chunk_size).clone()
param.data = shard
return shard, None


def _shard_param(
crcrpar marked this conversation as resolved.
Show resolved Hide resolved
param: torch.Tensor,
rank: int,
world_size: int,
name: str,
*,
allow_padding_for_fsdp: bool = False,
dim: int | None = None,
) -> None:
shard, padding_size = _shard_tensor(
param,
rank,
world_size,
name,
allow_padding_for_fsdp=allow_padding_for_fsdp,
dim=dim,
)
param.data = shard
if allow_padding_for_fsdp:
crcrpar marked this conversation as resolved.
Show resolved Hide resolved
param._thunder_fsdp_padding_size = padding_size


@torch.no_grad()
Expand Down
22 changes: 20 additions & 2 deletions thunder/distributed/transforms/fsdp_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,14 @@
from thunder.core.transforms import VISIT_TYPE
from thunder.core.transforms import visitor_transform
from thunder.core.transform_common import Transform
from thunder.distributed import copy_default_process_group, FSDPType, FSDPBucketingStrategy, _shard_param, _materialize
from thunder.distributed import (
copy_default_process_group,
FSDPType,
FSDPBucketingStrategy,
_materialize,
_shard_param,
_shard_tensor,
)

if TYPE_CHECKING:
from typing import Any
Expand Down Expand Up @@ -111,6 +118,8 @@ def transform_module(
utils.check(self.process_group is not None, lambda: "The default process group is None")
global_rank = tdist.get_rank(group=self.process_group)
world_size = tdist.get_world_size(group=self.process_group)
self.global_rank = global_rank
self.world_size = world_size
if self.device is None:
local_rank = int(os.environ["LOCAL_RANK"])
self.device = torch.device("cuda", local_rank)
Expand Down Expand Up @@ -237,7 +246,16 @@ def transform_module(
def transform_state_dict_for_submodule(
self, model: thunder.ThunderModule, submodule_name: str, state_dict: dict
) -> dict:
raise NotImplementedError("cannot transform state dict yet")
prefix = ""
if submodule_name:
prefix = f"{submodule_name}."
new_state_dict = {}
for k, v in state_dict.items():
full_k = prefix + k
if full_k in self.sharded_params:
v, _ = _shard_tensor(v, self.global_rank, self.world_size, full_k, allow_padding_for_fsdp=True)
new_state_dict[k] = v
return new_state_dict

def transform_traces_pre_prologue(self, prologue_trace, computation_trace, epilogue_trace, **kwargs):
from thunder.distributed import prims as dist_prims
Expand Down
20 changes: 20 additions & 0 deletions thunder/tests/distributed/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,26 @@ def test_memory_consumption(self, model_device):
active_mem_fsdp_jit = torch.cuda.memory_stats()["active_bytes.all.current"]
self.assertAlmostEqual(active_mem_fsdp_jit, active_mem_jit_fsdp)

@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires 2 devices")
def test_load_original_state_dict(self):
device = torch.device("cuda", self.rank)
with device:
x = torch.randn((2, ToyModel.N_IN))
with torch.device("cuda"):
model1 = ToyModel()
model2 = ToyModel()

sd = {k: v.clone() for k, v in model1.state_dict().items()}

jm1 = fsdp(thunder.jit(model1), device=device)
jm2 = fsdp(thunder.jit(model2), device=device)
jm2.load_original_state_dict(sd)

y_1 = jm1(x)
y_2 = jm2(x)

torch.testing.assert_close(y_1, y_2)


common_utils.instantiate_parametrized_tests(FSDPTest)

Expand Down
Loading