From 858bf1882c3bb99024efe7476df654265d931689 Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Wed, 9 Oct 2024 09:31:39 +0800 Subject: [PATCH] have flat backing for param and grad --- src/zeroband/diloco.py | 45 +++++++++++++++++++++--------------------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/src/zeroband/diloco.py b/src/zeroband/diloco.py index 2dae8a55..5080d0d2 100644 --- a/src/zeroband/diloco.py +++ b/src/zeroband/diloco.py @@ -117,36 +117,37 @@ def get_offloaded_param(self, model: nn.Module) -> list[nn.Parameter]: Offload the model parameters to cpu """ numels = sum(param.numel() for param in model.parameters() if param.requires_grad) + self.offloaded_data_flat_tensor = torch.empty((numels,), device="cpu", dtype=torch.float32) self.offloaded_grad_flat_tensor = torch.zeros((numels,), device="cpu", dtype=torch.float32) current_offset = 0 offloaded_params = [] for param in model.parameters(): - if param.requires_grad: - # so here we copy the DTensor from gpu to cpu. The trick is that we need to recreate the DTensor with the correct - # cpu devise mesh, otherwise we have a cpu DTensor with a cuda device mesh which will fail to do any communication - offloaded_param = nn.Parameter( - DTensor.from_local( - param.data.to_local().detach().to("cpu"), - device_mesh=self.elastic_device_mesh.cpu_local_mesh, - placements=param.data.placements, - ) - ) - - grad_tensor = self.offloaded_grad_flat_tensor.as_strided( - offloaded_param.to_local().size(), - offloaded_param.to_local().stride(), - current_offset, - ) - current_offset += grad_tensor.numel() - offloaded_param.grad = DTensor.from_local( - grad_tensor, + if not param.requires_grad: + continue + # so here we copy the DTensor from gpu to cpu. The trick is that we need to recreate the DTensor with the correct + # cpu devise mesh, otherwise we have a cpu DTensor with a cuda device mesh which will fail to do any communication + target = param.data.to_local().detach() + data_tensor = self.offloaded_data_flat_tensor.as_strided(target.size(), target.stride(), current_offset) + grad_tensor = self.offloaded_grad_flat_tensor.as_strided(target.size(), target.stride(), current_offset) + current_offset += data_tensor.numel() + + offloaded_param = nn.Parameter( + DTensor.from_local( + data_tensor, device_mesh=self.elastic_device_mesh.cpu_local_mesh, placements=param.data.placements, ) - # here we pre-allocate the grad DTensor on cpu. - offloaded_param.requires_grad = True - offloaded_params.append(offloaded_param) + ) + + offloaded_param.grad = DTensor.from_local( + grad_tensor, + device_mesh=self.elastic_device_mesh.cpu_local_mesh, + placements=param.data.placements, + ) + # here we pre-allocate the grad DTensor on cpu. + offloaded_param.requires_grad = True + offloaded_params.append(offloaded_param) return offloaded_params