Skip to content

Commit

Permalink
[benchmark_litgpt] apply fully_shard to wte and ln_f (#1156)
Browse files Browse the repository at this point in the history
  • Loading branch information
crcrpar authored Sep 26, 2024
1 parent 090af7c commit 7cfdf9c
Showing 1 changed file with 16 additions and 15 deletions.
31 changes: 16 additions & 15 deletions thunder/benchmarks/benchmark_litgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,7 @@ def setup_distributed(self, model):
)
elif self.distributed_mode == "fsdp2":
# reference: https://github.com/pytorch/torchtitan/blob/6e7a183/docs/fsdp.md
from functools import partial
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy

if self.bucketing_mode != "none":
Expand All @@ -467,25 +468,25 @@ def setup_distributed(self, model):

reshard_after_forward: bool = self.shard_mode == "zero3"

_apply_fully_shard = partial(
fully_shard,
mesh=mesh,
reshard_after_forward=reshard_after_forward,
mp_policy=MixedPrecisionPolicy(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
),
)

# for transformer_block in model.transformer.modules():
for transformer_block in model.modules():
if isinstance(transformer_block, Block):
fully_shard(
transformer_block,
mesh=mesh,
reshard_after_forward=reshard_after_forward,
mp_policy=MixedPrecisionPolicy(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
),
)
_apply_fully_shard(transformer_block)

fully_shard(
model,
mesh=mesh,
reshard_after_forward=reshard_after_forward,
mp_policy=MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16),
)
_apply_fully_shard(model.lm_head)
_apply_fully_shard(model.transformer["wte"])
_apply_fully_shard(model.transformer["ln_f"])
_apply_fully_shard(model)
model.to_empty(device=self.device)
model.apply(model._init_weights)

Expand Down

0 comments on commit 7cfdf9c

Please sign in to comment.