You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The training script relies on FSDP's MixedPrecisionPolicy to take care of dtypes.
But when data-parallelism is not used (for example when running in a single node with TP 8) then this does not happen and training runs in float32.
This is a bit unintuitive especially when comparing against runs with DP enabled.
If I'm not mistaken, the default training script does not even call torch.set_float32_matmul_precision() so it's currently missing out on speedups.
Do you agree that this should be changed? Thanks!
The text was updated successfully, but these errors were encountered:
The training script relies on FSDP's
MixedPrecisionPolicy
to take care of dtypes.But when data-parallelism is not used (for example when running in a single node with TP 8) then this does not happen and training runs in float32.
This is a bit unintuitive especially when comparing against runs with DP enabled.
If I'm not mistaken, the default training script does not even call
torch.set_float32_matmul_precision()
so it's currently missing out on speedups.Do you agree that this should be changed? Thanks!
The text was updated successfully, but these errors were encountered: