diff --git a/thunder/benchmarks/benchmark_litgpt.py b/thunder/benchmarks/benchmark_litgpt.py index 12e90038e0..b11e19d2da 100644 --- a/thunder/benchmarks/benchmark_litgpt.py +++ b/thunder/benchmarks/benchmark_litgpt.py @@ -351,6 +351,7 @@ def init_model(self): init_device = torch.device("meta") if self.distributed_mode in FSDP_MODES else self.device with init_device: model = GPT(self.config) + # TODO(crcrpar): Remove this guard once https://github.com/pytorch/ao/pull/713 is merged if ( self.distributed_mode == "fsdp2" and self._torchao_fp8_handler._enabled