@@ -109,23 +109,23 @@ def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[None, torch.Tensor]:
109109 dist .all_reduce (grad_output .contiguous (), group = ctx .group )
110110 return None , grad_output
111111
112+
112113class GatherTensor (torch .autograd .Function ):
113114 """Gather the input from model parallel region and concatinate."""
114115
115-
116116 @staticmethod
117117 def forward (ctx , group , input_ ):
118118 """Forward function."""
119119 # gather along last dim
120- world_size = dist .get_world_size (group )
121- if world_size == 1 :
122- return
123- ctx .group = group
124- ctx .world_size = world_size
125-
126- gather_shape = (world_size ,) + input_ .shape
127- output = torch .empty (gather_shape , dtype = input_ .dtype , device = get_accelerator ().current_device_name () )
128- dist .all_gather_into_tensor (output , input_ .contiguous (), group )
120+ world_size = dist .get_world_size (group )
121+ if world_size == 1 :
122+ return
123+ ctx .group = group
124+ ctx .world_size = world_size
125+
126+ gather_shape = (world_size , ) + input_ .shape
127+ output = torch .empty (gather_shape , dtype = input_ .dtype , device = get_accelerator ().current_device_name ())
128+ dist .all_gather_into_tensor (output , input_ .contiguous (), group )
129129 tensor_list = output .chunk (world_size , dim = 0 )
130130 output = torch .cat (tensor_list , dim = - 1 ).squeeze (0 ).contiguous ()
131131 return output
@@ -139,6 +139,7 @@ def backward(ctx, grad_output):
139139 grad_output = input_list [rank ].contiguous ()
140140 return None , grad_output
141141
142+
142143class TensorParallel_Layer (nn .Module , ABC ):
143144 """
144145 A base class for model layers with tensor parallelism support.
@@ -434,19 +435,18 @@ def __init__(self, module, mp_group=None, skip_partition=False, gather_output=Fa
434435 self .config_tp_params (self .weight )
435436 if self .bias is not None :
436437 self .config_tp_params (self .bias )
437- self .gather_output = gather_output
438-
438+ self .gather_output = gather_output
439439
440440 def forward (self , input ):
441441 if getattr (self , 'mp_group' , None ) is not None :
442442 input = ColumnParallel .apply (self .mp_group , input )
443443 output = torch .matmul (input , self .weight .transpose (- 1 , - 2 ))
444444 if self .bias is not None :
445445 output += self .bias
446-
446+
447447 if self .gather_output :
448- output = GatherTensor .apply (self .mp_group ,output )
449-
448+ output = GatherTensor .apply (self .mp_group , output )
449+
450450 return output
451451
452452 @torch .no_grad ()
@@ -634,7 +634,7 @@ def __init__(self, module, mp_group, **kwargs):
634634 def forward (self , input ):
635635 input_shard_size = get_shard_size (input .shape [- 1 ], self .tp_world_size , "lm_head" )
636636 input_shard_offset = sum (get_shard_size_list (input .shape [- 1 ], self .tp_world_size , "lm_head" )[0 :self .tp_index ])
637-
637+
638638 output = torch .matmul (input [:, :, input_shard_offset :input_shard_offset + input_shard_size ],
639639 self .weight .transpose (- 1 , - 2 ))
640640 if self .mp_group is not None :
0 commit comments