From e0fd16caa7da66e6f25c0c6a4a2f51d62195d132 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Fri, 16 Aug 2024 05:33:18 -0700 Subject: [PATCH] use bfloat16 in grad reduction (= reduce-scatter / all-reduce) ref: https://github.com/pytorch/pytorch/blob/762b1b4/torch/distributed/_composable/fsdp/_fsdp_api.py#L9 Signed-off-by: Masaki Kozuki --- thunder/benchmarks/benchmark_litgpt.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/thunder/benchmarks/benchmark_litgpt.py b/thunder/benchmarks/benchmark_litgpt.py index 41c5dabaa2..2f4ead75b1 100644 --- a/thunder/benchmarks/benchmark_litgpt.py +++ b/thunder/benchmarks/benchmark_litgpt.py @@ -319,14 +319,17 @@ def setup_distributed(self, model): transformer_block, mesh=mesh, reshard_after_forward=reshard_after_forward, - mp_policy=MixedPrecisionPolicy(), + mp_policy=MixedPrecisionPolicy( + param_dtype=torch.bfloat16, + reduce_dtype=torch.bfloat16, + ), ) fully_shard( model, mesh=mesh, reshard_after_forward=reshard_after_forward, - mp_policy=MixedPrecisionPolicy(), + mp_policy=MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16), ) model.to_empty(device=self.device) model.apply(model._init_weights)