Skip to content

Commit 270bd95

Browse files
committed
init adamw inner before diloco
1 parent a143cab commit 270bd95

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

src/zeroband/train.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -124,12 +124,6 @@ def train(config: Config):
124124
model = torch.compile(model)
125125
logger.debug("model compiled and fsdped")
126126

127-
if config.diloco is not None:
128-
if world_info.local_world_size == 1:
129-
raise ValueError("Diloco is not supported for local_world_size == 1 because of a pytorch bug")
130-
131-
diloco = Diloco(config.diloco, model, sharding_strategy, elastic_device_mesh.global_pg)
132-
133127
# Setup optimizers
134128
inner_optimizer = torch.optim.AdamW(
135129
model.parameters(),
@@ -138,6 +132,12 @@ def train(config: Config):
138132
betas=(config.optim.adam_betas1, config.optim.adam_betas2),
139133
)
140134

135+
if config.diloco is not None:
136+
if world_info.local_world_size == 1:
137+
raise ValueError("Diloco is not supported for local_world_size == 1 because of a pytorch bug")
138+
139+
diloco = Diloco(config.diloco, model, sharding_strategy, elastic_device_mesh.global_pg)
140+
141141
scheduler = get_cosine_schedule_with_warmup(
142142
inner_optimizer,
143143
num_warmup_steps=config.optim.warmup_steps,

0 commit comments

Comments
 (0)