Skip to content

Commit 486a4d1

Browse files
committed
merge main
2 parents d264f75 + c15f74d commit 486a4d1

File tree

5 files changed

+39
-9
lines changed

5 files changed

+39
-9
lines changed

src/zeroband/comms.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,11 @@ def _resolve_world(self):
273273

274274
def maybe_reinit_global_pg(self):
275275
"""Reinitialize the global_pg if there are joiners or dead nodes."""
276+
277+
if self.world_info.global_world_size == 1:
278+
# no op if we only have one node
279+
return
280+
276281
time_start = time.perf_counter()
277282
self._logger.debug("Resolving world")
278283

@@ -334,6 +339,12 @@ def maybe_reinit_global_pg(self):
334339

335340
self.live_recovery.init_background_loop()
336341

342+
def get_global_pg(self, maybe_reinit: bool = False) -> dist.ProcessGroup:
343+
"""Get the global process group. If maybe_reinit is True, reinitialize the global process group if needed."""
344+
if maybe_reinit:
345+
self.maybe_reinit_global_pg()
346+
return self.global_pg
347+
337348

338349
class LiveRecoveryModel(BaseModel):
339350
dest_rank: int

src/zeroband/diloco.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from pydantic_config import BaseConfig
33
import torch
44
from torch import nn
5+
from zeroband.comms import ElasticDeviceMesh
56
from zeroband.utils.world_info import get_world_info
67
from zeroband.utils.logging import get_logger
78
from torch.distributed.fsdp import ShardingStrategy
@@ -44,11 +45,11 @@ def __init__(
4445
config: DilocoConfig,
4546
model: nn.Module,
4647
fsdp_sharding_strategy: ShardingStrategy,
47-
global_pg: dist.ProcessGroup,
48+
elastic_device_mesh: ElasticDeviceMesh,
4849
):
4950
self.config = config
5051
self.fsdp_sharding_strategy = fsdp_sharding_strategy
51-
self.global_pg = global_pg
52+
self.elastic_device_mesh = elastic_device_mesh
5253

5354
self._logger = get_logger()
5455
self.world_info = get_world_info()
@@ -70,14 +71,15 @@ def sync_pseudo_gradient(self, model: nn.Module):
7071
Sync the pseudo gradient from the local process group to the global process group
7172
"""
7273
self._logger.debug("sync pseudo gradient")
74+
global_pg = self.elastic_device_mesh.get_global_pg(maybe_reinit=True)
7375
for param_offloaded, param in zip(self.param_list_cpu, model.parameters()):
7476
if param.shape[0] == 0:
7577
continue
7678
param_offloaded.grad = param_offloaded.data - param.data.to(param_offloaded.device)
7779

7880
# gloo does not support AVG
79-
param_offloaded.grad = param_offloaded.grad / self.global_pg.size()
80-
dist.all_reduce(param_offloaded.grad, op=dist.ReduceOp.SUM, group=self.global_pg)
81+
param_offloaded.grad = param_offloaded.grad / global_pg.size()
82+
dist.all_reduce(param_offloaded.grad, op=dist.ReduceOp.SUM, group=global_pg)
8183
# todo async here
8284

8385
def sync_inner_model(self, model: nn.Module):

src/zeroband/train.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def train(config: Config):
175175
)
176176

177177
if config.diloco is not None:
178-
diloco = Diloco(config.diloco, model, sharding_strategy, elastic_device_mesh.global_pg)
178+
diloco = Diloco(config.diloco, model, sharding_strategy, elastic_device_mesh)
179179

180180
scheduler = get_cosine_schedule_with_warmup(
181181
inner_optimizer,
@@ -364,6 +364,9 @@ def train(config: Config):
364364
metric_logger.finish()
365365

366366
ckpt_manager.wait_async_save_process()
367+
368+
del elastic_device_mesh # allow to clean up for smoother tests transition
369+
367370
logger.info("Training finished, exiting ...")
368371

369372

tests/test_dist/test_comms.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ def foo(**kwargs):
2121
dist.all_reduce(a, op=dist.ReduceOp.SUM, group=edm.global_pg)
2222
assert torch.allclose(a, torch.tensor([0, sum_ints, 2 * sum_ints]))
2323

24+
del edm
25+
2426
processes = []
2527
for rank in range(world_size):
2628
processes.append(
@@ -64,6 +66,8 @@ def foo(**kwargs):
6466
sum_ints = global_world_size * (global_world_size + 1) // 2
6567
assert torch.allclose(a, torch.tensor([0, sum_ints, 2 * sum_ints]) + rank * global_world_size)
6668

69+
del edm
70+
6771
global_ports = [i for i in range(21970, 21970 + world_size)]
6872
master_ports = [i for i in range(31000, 31000 + global_world_size)]
6973
processes = []
@@ -96,8 +100,8 @@ def foo(**kwargs):
96100
pytest.fail(f"Process {p.pid} failed with exit code {p.exitcode}")
97101

98102

99-
@pytest.mark.parametrize("world_size", [1, 2, 8])
100-
@pytest.mark.parametrize("global_world_size", [2, 8])
103+
@pytest.mark.parametrize("world_size", [1, 2])
104+
@pytest.mark.parametrize("global_world_size", [2, 4])
101105
def test_elastic_device_mesh_on_off_ramp(world_size: int, global_world_size: int, mock_env):
102106
ready_event = mp.Event()
103107

@@ -136,6 +140,8 @@ def foo(**kwargs):
136140

137141
dist.barrier(edm.global_pg)
138142

143+
del edm
144+
139145
def bar(**kwargs):
140146
with mock_env(**kwargs):
141147
test_value = int(kwargs["TEST_VALUE"])
@@ -163,6 +169,8 @@ def bar(**kwargs):
163169

164170
dist.barrier(edm.global_pg)
165171

172+
del edm
173+
166174
global_ports = [i for i in range(21970, 21970 + world_size)]
167175
master_ports = [i for i in range(31000, 31000 + global_world_size + 1)]
168176
processes = []

tests/test_dist/test_diloco.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,13 @@ def test_diloco_all_reduce(world_size, random_available_port, dist_environment):
2121
if it is done correclty.
2222
"""
2323

24+
class FakeElasticDeviceMesh:
25+
def __init__(self):
26+
self.global_pg = dist.new_group(backend="gloo")
27+
28+
def get_global_pg(self, maybe_reinit: bool = False) -> dist.ProcessGroup:
29+
return self.global_pg
30+
2431
def all_reduce(rank: int, world_size: int):
2532
with dist_environment(random_available_port, rank=rank, world_size=world_size, global_unique_id=str(rank)):
2633
diloco_config = DilocoConfig(inner_steps=10)
@@ -31,8 +38,7 @@ def all_reduce(rank: int, world_size: int):
3138
for param in model.parameters():
3239
param.data = (rank + 1) * torch.ones_like(param.data).to("cuda")
3340

34-
global_pg = dist.new_group(backend="gloo")
35-
diloco = Diloco(diloco_config, model, ShardingStrategy.FULL_SHARD, global_pg)
41+
diloco = Diloco(diloco_config, model, ShardingStrategy.FULL_SHARD, FakeElasticDeviceMesh())
3642

3743
# simulate inner model updates
3844
for param in model.parameters():

0 commit comments

Comments
 (0)