Skip to content

Commit

Permalink
use bfloat16 in grad reduction (= reduce-scatter / all-reduce)
Browse files Browse the repository at this point in the history
  • Loading branch information
crcrpar committed Aug 16, 2024
1 parent 43b6345 commit e0fd16c
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions thunder/benchmarks/benchmark_litgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,14 +319,17 @@ def setup_distributed(self, model):
transformer_block,
mesh=mesh,
reshard_after_forward=reshard_after_forward,
mp_policy=MixedPrecisionPolicy(),
mp_policy=MixedPrecisionPolicy(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
),
)

fully_shard(
model,
mesh=mesh,
reshard_after_forward=reshard_after_forward,
mp_policy=MixedPrecisionPolicy(),
mp_policy=MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16),
)
model.to_empty(device=self.device)
model.apply(model._init_weights)
Expand Down

0 comments on commit e0fd16c

Please sign in to comment.