diff --git a/src/zeroband/train.py b/src/zeroband/train.py index b99ff09b..8e97ebef 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -152,7 +152,7 @@ def train(config: Config): num = 1 if isinstance(config.train.ac_ckpt, bool) else config.train.ac_ckpt apply_ac_ckpt(model, num) - elastic_device_mesh = ElasticDeviceMesh("nccl") + elastic_device_mesh = ElasticDeviceMesh() mp_policy = MixedPrecisionPolicy( param_dtype=torch.bfloat16, reduce_dtype=torch.float32 if config.train.reduce_fp32 else None