@@ -109,6 +109,39 @@ def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[None, torch.Tensor]:
109109 dist .all_reduce (grad_output .contiguous (), group = ctx .group )
110110 return None , grad_output
111111
112+ class GatherTensor (torch .autograd .Function ):
113+ """Gather the input from model parallel region and concatinate."""
114+
115+ # @staticmethod
116+ # def symbolic(graph, input_):
117+ # """Symbolic function for tracing."""
118+ # return _gather_along_last_dim(input_)
119+
120+ @staticmethod
121+ def forward (ctx , group , input_ ):
122+ """Forward function."""
123+ # gather along last dim
124+ world_size = dist .get_world_size (group )
125+ if world_size == 1 :
126+ return
127+ ctx .group = group
128+ ctx .world_size = world_size
129+
130+ gather_shape = (world_size ,) + input_ .shape
131+ output = torch .empty (gather_shape , dtype = input_ .dtype , device = get_accelerator ().current_device_name () )
132+ dist .all_gather_into_tensor (output , input_ .contiguous (), group )
133+ tensor_list = output .chunk (world_size , dim = 0 )
134+ output = torch .cat (tensor_list , dim = - 1 ).squeeze (0 ).contiguous ()
135+ return output
136+
137+ @staticmethod
138+ def backward (ctx , grad_output ):
139+ #split along last_dim
140+ """Backward function."""
141+ rank = dist .get_rank (ctx .group )
142+ input_list = torch .chunk (grad_output , ctx .world_size , - 1 )
143+ grad_output = input_list [rank ].contiguous ()
144+ return None , grad_output
112145
113146class TensorParallel_Layer (nn .Module , ABC ):
114147 """
@@ -394,23 +427,31 @@ def uneven_partition(self, params_list):
394427#remove kwargs from partition.
395428class LinearLayer (TensorParallel_Layer ):
396429
397- def __init__ (self , module , mp_group = None , skip_partition = False , ** kwargs ):
430+ def __init__ (self , module , mp_group = None , skip_partition = False , gather_output = False , ** kwargs ):
398431 super (LinearLayer , self ).__init__ (mp_group , ** kwargs )
399432 self .weight = module .weight
400433 self .bias = module .bias
434+ if gather_output :
435+ b = 0
401436 if not skip_partition :
402437 self ._tp_partition ([self .weight , self .bias ])
403438 self .support_training = True
404439 self .config_tp_params (self .weight )
405440 if self .bias is not None :
406441 self .config_tp_params (self .bias )
442+ self .gather_output = gather_output
443+
407444
408445 def forward (self , input ):
409446 if getattr (self , 'mp_group' , None ) is not None :
410447 input = ColumnParallel .apply (self .mp_group , input )
411448 output = torch .matmul (input , self .weight .transpose (- 1 , - 2 ))
412449 if self .bias is not None :
413450 output += self .bias
451+
452+ if self .gather_output :
453+ output = GatherTensor .apply (self .mp_group ,output )
454+
414455 return output
415456
416457 @torch .no_grad ()
@@ -598,6 +639,8 @@ def __init__(self, module, mp_group, **kwargs):
598639 def forward (self , input ):
599640 input_shard_size = get_shard_size (input .shape [- 1 ], self .tp_world_size , "lm_head" )
600641 input_shard_offset = sum (get_shard_size_list (input .shape [- 1 ], self .tp_world_size , "lm_head" )[0 :self .tp_index ])
642+ input = input [:, :, input_shard_offset :input_shard_offset + input_shard_size ]
643+
601644 output = torch .matmul (input [:, :, input_shard_offset :input_shard_offset + input_shard_size ],
602645 self .weight .transpose (- 1 , - 2 ))
603646 if self .mp_group is not None :
0 commit comments