Skip to content

Commit

Permalink
enable force_recompute_fp8_weight_in_bwd
Browse files Browse the repository at this point in the history
ref: pytorch/ao@e919558
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
  • Loading branch information
crcrpar committed Dec 17, 2024
1 parent f7f4f3b commit 9ad3327
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions thunder/benchmarks/benchmark_litgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def __post_init__(self) -> None:
cast_config_grad_output=CastConfig(ScalingType.DYNAMIC),
enable_fsdp_float8_all_gather=self.use_fp8_allgather and self.is_fsdp2,
enable_pre_and_post_forward=False,
force_recompute_fp8_weight_in_bwd=self.is_fsdp2,
)
self.precompute_scale = (
self.is_fsdp2 and self.use_fp8_allgather and self.use_torchao_fp8_precompute_float8_dynamic_scale_for_fsdp
Expand Down

0 comments on commit 9ad3327

Please sign in to comment.