From 9ad3327a6023f52bd25cac3c6f109bc5242c2de4 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Mon, 9 Dec 2024 02:28:55 -0800 Subject: [PATCH] enable `force_recompute_fp8_weight_in_bwd` ref: https://github.com/pytorch/ao/commit/e9195580013eab3195f982598d63ce27ad8bdc93 Signed-off-by: Masaki Kozuki --- thunder/benchmarks/benchmark_litgpt.py | 1 + 1 file changed, 1 insertion(+) diff --git a/thunder/benchmarks/benchmark_litgpt.py b/thunder/benchmarks/benchmark_litgpt.py index 4d6a7feb6b..d1559729d1 100644 --- a/thunder/benchmarks/benchmark_litgpt.py +++ b/thunder/benchmarks/benchmark_litgpt.py @@ -170,6 +170,7 @@ def __post_init__(self) -> None: cast_config_grad_output=CastConfig(ScalingType.DYNAMIC), enable_fsdp_float8_all_gather=self.use_fp8_allgather and self.is_fsdp2, enable_pre_and_post_forward=False, + force_recompute_fp8_weight_in_bwd=self.is_fsdp2, ) self.precompute_scale = ( self.is_fsdp2 and self.use_fp8_allgather and self.use_torchao_fp8_precompute_float8_dynamic_scale_for_fsdp