diff --git a/thunder/benchmarks/benchmark_litgpt.py b/thunder/benchmarks/benchmark_litgpt.py index 49974c276a..637c881a2c 100644 --- a/thunder/benchmarks/benchmark_litgpt.py +++ b/thunder/benchmarks/benchmark_litgpt.py @@ -453,6 +453,7 @@ def setup_distributed(self, model): ) elif self.distributed_mode == "fsdp2": # reference: https://github.com/pytorch/torchtitan/blob/6e7a183/docs/fsdp.md + from functools import partial from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy if self.bucketing_mode != "none": @@ -467,25 +468,25 @@ def setup_distributed(self, model): reshard_after_forward: bool = self.shard_mode == "zero3" + _apply_fully_shard = partial( + fully_shard, + mesh=mesh, + reshard_after_forward=reshard_after_forward, + mp_policy=MixedPrecisionPolicy( + param_dtype=torch.bfloat16, + reduce_dtype=torch.bfloat16, + ), + ) + # for transformer_block in model.transformer.modules(): for transformer_block in model.modules(): if isinstance(transformer_block, Block): - fully_shard( - transformer_block, - mesh=mesh, - reshard_after_forward=reshard_after_forward, - mp_policy=MixedPrecisionPolicy( - param_dtype=torch.bfloat16, - reduce_dtype=torch.bfloat16, - ), - ) + _apply_fully_shard(transformer_block) - fully_shard( - model, - mesh=mesh, - reshard_after_forward=reshard_after_forward, - mp_policy=MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16), - ) + _apply_fully_shard(model.lm_head) + _apply_fully_shard(model.transformer["wte"]) + _apply_fully_shard(model.transformer["ln_f"]) + _apply_fully_shard(model) model.to_empty(device=self.device) model.apply(model._init_weights)