Skip to content

Commit d41de80

Browse files
committed
revert shm offload
1 parent e64eb2d commit d41de80

File tree

1 file changed

+17
-47
lines changed

1 file changed

+17
-47
lines changed

src/zeroband/diloco.py

Lines changed: 17 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,17 @@
1-
import shutil
2-
import os
31
from pydantic_config import BaseConfig
42
import torch
53
from torch import nn
6-
from zeroband.utils import get_module_signature
74
from zeroband.utils.world_info import get_world_info
85
from zeroband.utils.logging import get_logger
96
from torch.distributed.fsdp import ShardingStrategy
107
import torch.distributed as dist
11-
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
128

139

1410
class DilocoConfig(BaseConfig):
1511
outer_lr: float = 0.7
1612
inner_steps: int
1713

1814

19-
SHARED_MEMORY_PATH = "/dev/shm/zeroband"
20-
21-
2215
class Diloco:
2316
"""
2417
This class implements the diloco algorithm from https://arxiv.org/abs/2311.08105 and https://arxiv.org/abs/2407.07852.
@@ -64,26 +57,25 @@ def __init__(
6457

6558
self._init_offloaded_optimizer(model=model)
6659

67-
def _init_offloaded_optimizer(self, model: nn.Module):
68-
with FSDP.summon_full_params(model):
69-
self.param_list_cpu = self.get_offloaded_param(model)
70-
self.outer_optimizer = torch.optim.SGD(
71-
self.param_list_cpu, lr=self.config.outer_lr, momentum=0.9, nesterov=True
72-
)
73-
self._logger.debug("offload model to cpu")
60+
def _init_offloaded_optimizer(self, model):
61+
self.param_list_cpu = self.get_offloaded_param(model)
62+
self.outer_optimizer = torch.optim.SGD(
63+
self.param_list_cpu, lr=self.config.outer_lr, momentum=0.9, nesterov=True
64+
)
65+
self._logger.debug("offload model to cpu")
7466

7567
def sync_pseudo_gradient(self, model: nn.Module):
7668
"""
7769
Sync the pseudo gradient from the local process group to the global process group
7870
"""
7971
self._logger.debug("sync pseudo gradient")
80-
# TODO: This assumes all params require grad, which is used by the offload
8172
for param_offloaded, param in zip(self.param_list_cpu, model.parameters()):
8273
param_offloaded.grad = param_offloaded.data - param.data.to(param_offloaded.device)
8374

8475
# gloo does not support AVG
8576
param_offloaded.grad = param_offloaded.grad / self.global_pg.size()
8677
dist.all_reduce(param_offloaded.grad, op=dist.ReduceOp.SUM, group=self.global_pg)
78+
# todo async here
8779

8880
def sync_inner_model(self, model: nn.Module):
8981
"""
@@ -92,51 +84,29 @@ def sync_inner_model(self, model: nn.Module):
9284

9385
self._logger.debug("sync inner model")
9486
for param_offloaded, param in zip(self.param_list_cpu, model.parameters()):
95-
param.data.copy_(param_offloaded.data)
87+
param.data.copy_(param_offloaded.data) # todo: use copy_ here
9688

9789
def get_offloaded_param(self, model: nn.Module) -> list[nn.Parameter]:
9890
"""
9991
Offload the model parameters to cpu
10092
"""
101-
# The change here makes processes which are part of the same FSDP replica group (which are assumed to be on the same node with the same /dev/shm) use the same underlying storage for the offloaded_param.
102-
# All the processes use the same shared memory file to create a storage for each parameter. Only rank 0 will do the copy.
103-
# A barrier is added to ensure that after the function completes, the parameters are all offloaded. Otherwise, the non 0 ranks might access uninitialized memory.
10493
offloaded_params = []
105-
os.makedirs(f"{SHARED_MEMORY_PATH}/{self.world_info.global_unique_id}", exist_ok=True)
10694

107-
for param_name, param in model.named_parameters():
95+
for param in model.parameters():
10896
if param.requires_grad:
109-
storage = torch.UntypedStorage.from_file(
110-
f"{SHARED_MEMORY_PATH}/{self.world_info.global_unique_id}/{param_name}",
111-
True,
112-
param.data.untyped_storage().size(),
113-
)
114-
offloaded_param = torch.tensor(storage, dtype=param.dtype, device="cpu")
115-
offloaded_param.as_strided_(size=param.data.size(), stride=param.data.stride())
116-
if self.world_info.local_rank == 0:
117-
# TODO: Can we async or split the copy among gpus probs overkill?
118-
offloaded_param.copy_(param.data)
119-
offloaded_param.requires_grad = False # TODO: check if we need to set this to True
97+
offloaded_param = param.data.detach().clone().to("cpu")
98+
offloaded_param.requires_grad = True
12099
offloaded_params.append(offloaded_param)
121100

122-
dist.barrier()
123101
return offloaded_params
124102

125103
def step(self, model: nn.Module):
126104
"""
127105
Step the optimizer
128106
"""
129-
with FSDP.summon_full_params(model):
130-
self._logger.debug("Pre diloco step %s", get_module_signature(model))
131-
if self.world_info.rank == 0:
132-
self.sync_pseudo_gradient(model)
133-
if self.outer_optimizer is not None:
134-
self.outer_optimizer.step()
135-
self.outer_optimizer.zero_grad() # todo(sami): check if we can remove this
136-
137-
dist.barrier()
138-
self.sync_inner_model(model)
139-
self._logger.debug("Post meow diloco step %s", get_module_signature(model))
140-
141-
def __del__(self):
142-
shutil.rmtree(f"{SHARED_MEMORY_PATH}/{self.world_info.global_unique_id}", ignore_errors=True)
107+
self.sync_pseudo_gradient(model)
108+
if self.outer_optimizer is not None:
109+
self.outer_optimizer.step()
110+
self.outer_optimizer.zero_grad() # todo(sami): check if we can remove this
111+
112+
self.sync_inner_model(model)

0 commit comments

Comments
 (0)