Skip to content

Commit

Permalink
have flat backing for param and grad
Browse files Browse the repository at this point in the history
  • Loading branch information
Jackmin801 committed Oct 9, 2024
1 parent 42698ef commit 858bf18
Showing 1 changed file with 23 additions and 22 deletions.
45 changes: 23 additions & 22 deletions src/zeroband/diloco.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,36 +117,37 @@ def get_offloaded_param(self, model: nn.Module) -> list[nn.Parameter]:
Offload the model parameters to cpu
"""
numels = sum(param.numel() for param in model.parameters() if param.requires_grad)
self.offloaded_data_flat_tensor = torch.empty((numels,), device="cpu", dtype=torch.float32)
self.offloaded_grad_flat_tensor = torch.zeros((numels,), device="cpu", dtype=torch.float32)
current_offset = 0
offloaded_params = []

for param in model.parameters():
if param.requires_grad:
# so here we copy the DTensor from gpu to cpu. The trick is that we need to recreate the DTensor with the correct
# cpu devise mesh, otherwise we have a cpu DTensor with a cuda device mesh which will fail to do any communication
offloaded_param = nn.Parameter(
DTensor.from_local(
param.data.to_local().detach().to("cpu"),
device_mesh=self.elastic_device_mesh.cpu_local_mesh,
placements=param.data.placements,
)
)

grad_tensor = self.offloaded_grad_flat_tensor.as_strided(
offloaded_param.to_local().size(),
offloaded_param.to_local().stride(),
current_offset,
)
current_offset += grad_tensor.numel()
offloaded_param.grad = DTensor.from_local(
grad_tensor,
if not param.requires_grad:
continue
# so here we copy the DTensor from gpu to cpu. The trick is that we need to recreate the DTensor with the correct
# cpu devise mesh, otherwise we have a cpu DTensor with a cuda device mesh which will fail to do any communication
target = param.data.to_local().detach()
data_tensor = self.offloaded_data_flat_tensor.as_strided(target.size(), target.stride(), current_offset)
grad_tensor = self.offloaded_grad_flat_tensor.as_strided(target.size(), target.stride(), current_offset)
current_offset += data_tensor.numel()

offloaded_param = nn.Parameter(
DTensor.from_local(
data_tensor,
device_mesh=self.elastic_device_mesh.cpu_local_mesh,
placements=param.data.placements,
)
# here we pre-allocate the grad DTensor on cpu.
offloaded_param.requires_grad = True
offloaded_params.append(offloaded_param)
)

offloaded_param.grad = DTensor.from_local(
grad_tensor,
device_mesh=self.elastic_device_mesh.cpu_local_mesh,
placements=param.data.placements,
)
# here we pre-allocate the grad DTensor on cpu.
offloaded_param.requires_grad = True
offloaded_params.append(offloaded_param)

return offloaded_params

Expand Down

0 comments on commit 858bf18

Please sign in to comment.