@@ -232,33 +232,32 @@ def post_construct(self, module, output):
232232 return output [0 ] if is_tensor else tuple (output )
233233
234234
235- class AllGatherFunction (ms .nn . Cell ):
236- def __init__ ( self , dim , group ):
237- super (). __init__ ()
238- self .dim = dim
239- self .group = group
240- self .world_size = mint .distributed .get_world_size (group )
241- self .rank = mint .distributed .get_rank (group )
235+ class AllGatherFunction (ms .common . _Function ):
236+ @ staticmethod
237+ def forward ( ctx , tensor , dim , group ):
238+ ctx .dim = dim
239+ ctx .group = group
240+ ctx .world_size = mint .distributed .get_world_size (group )
241+ ctx .rank = mint .distributed .get_rank (group )
242242
243- def construct (self , tensor ):
244- # return funcol.all_gather_tensor(tensor, dim, group=group)
245243 # mint.distributed.all_gather_into_tensor only support dim=0
246- tensor_t = tensor .transpose (self . dim , 0 ) if self . dim != 0 else tensor
244+ tensor_t = tensor .transpose (dim , 0 ) if dim != 0 else tensor
247245
248246 out_shape = list (tensor_t .shape )
249- out_shape [0 ] *= self .world_size
247+ out_shape [0 ] *= ctx .world_size
250248 output = mint .zeros (out_shape , dtype = tensor_t .dtype )
251249
252- mint .distributed .all_gather_into_tensor (output , tensor_t .contiguous (), group = self . group )
250+ mint .distributed .all_gather_into_tensor (output , tensor_t .contiguous (), group = group )
253251
254- if self . dim != 0 :
255- output = output .transpose (0 , self . dim )
252+ if dim != 0 :
253+ output = output .transpose (0 , dim )
256254
257255 return output
258256
259- def bprop (self , tensor , out , dout ):
260- grad_chunks = mint .chunk (dout , self .world_size , dim = self .dim )
261- return (grad_chunks [self .rank ],)
257+ @staticmethod
258+ def backward (ctx , grad_output ):
259+ grad_chunks = mint .chunk (grad_output , ctx .world_size , dim = ctx .dim )
260+ return grad_chunks [ctx .rank ], None , None
262261
263262
264263class EquipartitionSharder :
@@ -278,7 +277,7 @@ def shard(cls, tensor: ms.Tensor, dim: int, mesh) -> ms.Tensor:
278277 @classmethod
279278 def unshard (cls , tensor : ms .Tensor , dim : int , mesh ) -> ms .Tensor :
280279 tensor = tensor .contiguous ()
281- tensor = AllGatherFunction ( dim , mesh )( tensor )
280+ tensor = AllGatherFunction . apply ( tensor , dim , mesh )
282281 return tensor
283282
284283
0 commit comments