Skip to content

Commit 1ce792f

Browse files
committed
modify bprop to _Function
1 parent 7e667f3 commit 1ce792f

File tree

4 files changed

+235
-273
lines changed

4 files changed

+235
-273
lines changed

mindone/diffusers/hooks/context_parallel.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -232,33 +232,32 @@ def post_construct(self, module, output):
232232
return output[0] if is_tensor else tuple(output)
233233

234234

235-
class AllGatherFunction(ms.nn.Cell):
236-
def __init__(self, dim, group):
237-
super().__init__()
238-
self.dim = dim
239-
self.group = group
240-
self.world_size = mint.distributed.get_world_size(group)
241-
self.rank = mint.distributed.get_rank(group)
235+
class AllGatherFunction(ms.common._Function):
236+
@staticmethod
237+
def forward(ctx, tensor, dim, group):
238+
ctx.dim = dim
239+
ctx.group = group
240+
ctx.world_size = mint.distributed.get_world_size(group)
241+
ctx.rank = mint.distributed.get_rank(group)
242242

243-
def construct(self, tensor):
244-
# return funcol.all_gather_tensor(tensor, dim, group=group)
245243
# mint.distributed.all_gather_into_tensor only support dim=0
246-
tensor_t = tensor.transpose(self.dim, 0) if self.dim != 0 else tensor
244+
tensor_t = tensor.transpose(dim, 0) if dim != 0 else tensor
247245

248246
out_shape = list(tensor_t.shape)
249-
out_shape[0] *= self.world_size
247+
out_shape[0] *= ctx.world_size
250248
output = mint.zeros(out_shape, dtype=tensor_t.dtype)
251249

252-
mint.distributed.all_gather_into_tensor(output, tensor_t.contiguous(), group=self.group)
250+
mint.distributed.all_gather_into_tensor(output, tensor_t.contiguous(), group=group)
253251

254-
if self.dim != 0:
255-
output = output.transpose(0, self.dim)
252+
if dim != 0:
253+
output = output.transpose(0, dim)
256254

257255
return output
258256

259-
def bprop(self, tensor, out, dout):
260-
grad_chunks = mint.chunk(dout, self.world_size, dim=self.dim)
261-
return (grad_chunks[self.rank],)
257+
@staticmethod
258+
def backward(ctx, grad_output):
259+
grad_chunks = mint.chunk(grad_output, ctx.world_size, dim=ctx.dim)
260+
return grad_chunks[ctx.rank], None, None
262261

263262

264263
class EquipartitionSharder:
@@ -278,7 +277,7 @@ def shard(cls, tensor: ms.Tensor, dim: int, mesh) -> ms.Tensor:
278277
@classmethod
279278
def unshard(cls, tensor: ms.Tensor, dim: int, mesh) -> ms.Tensor:
280279
tensor = tensor.contiguous()
281-
tensor = AllGatherFunction(dim, mesh)(tensor)
280+
tensor = AllGatherFunction.apply(tensor, dim, mesh)
282281
return tensor
283282

284283

0 commit comments

Comments
 (0)