Skip to content

Commit

Permalink
add reduce dtype fp32
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Oct 7, 2024
1 parent c7605c8 commit 9f10e05
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/zeroband/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ class TrainConfig(BaseConfig):
torch_compile: bool = True
sharding_strategy: str = "SHARD_GRAD_OP"
ac_ckpt: bool | int = False

reduce_fp32: bool = False # should be True if SXM. Keep to false as default for backward compatibility

log_model_hash: bool = False

memory_monitor: bool = False
Expand Down Expand Up @@ -151,7 +154,9 @@ def train(config: Config):

elastic_device_mesh = ElasticDeviceMesh("nccl")

mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16)
mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16, reduce_dtype=torch.float32 if config.train.reduce_fp32 else None
)

for layer_id, transformer_block in model.layers.items():
reshard_after_forward = int(layer_id) < len(model.layers) - 1
Expand Down

0 comments on commit 9f10e05

Please sign in to comment.