Skip to content

Commit 6ab42aa

Browse files
committed
update
Signed-off-by: inkcherry <mingzhi.liu@intel.com>
1 parent 8ec1af5 commit 6ab42aa

File tree

3 files changed

+53
-5
lines changed

3 files changed

+53
-5
lines changed

deepspeed/module_inject/auto_tp.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,12 @@ def _replace(self, child, name, conv_linear_layer):
382382
if self.conv_linear_layer:
383383
return Conv_LinearALlreduce(child, self.mp_group, name=name)
384384
elif name == "lm_head" or name == 'embed_out':
385-
return LmHeadLinearAllreduce(child, self.mp_group)
385+
if is_autotp_training_mode():
386+
# pass
387+
# return child
388+
return LinearLayer(child, self.mp_group, name=name, gather_output=True)
389+
else:
390+
return LmHeadLinearAllreduce(child, self.mp_group)
386391

387392
return LinearAllreduce(child, self.mp_group, name=name)
388393
else:

deepspeed/module_inject/layers.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,39 @@ def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[None, torch.Tensor]:
109109
dist.all_reduce(grad_output.contiguous(), group=ctx.group)
110110
return None, grad_output
111111

112+
class GatherTensor(torch.autograd.Function):
113+
"""Gather the input from model parallel region and concatinate."""
114+
115+
# @staticmethod
116+
# def symbolic(graph, input_):
117+
# """Symbolic function for tracing."""
118+
# return _gather_along_last_dim(input_)
119+
120+
@staticmethod
121+
def forward(ctx, group, input_):
122+
"""Forward function."""
123+
# gather along last dim
124+
world_size=dist.get_world_size(group)
125+
if world_size==1:
126+
return
127+
ctx.group=group
128+
ctx.world_size=world_size
129+
130+
gather_shape = (world_size,) + input_.shape
131+
output =torch.empty(gather_shape, dtype=input_.dtype, device=get_accelerator().current_device_name() )
132+
dist.all_gather_into_tensor(output, input_.contiguous(), group)
133+
tensor_list = output.chunk(world_size, dim=0)
134+
output = torch.cat(tensor_list, dim=-1).squeeze(0).contiguous()
135+
return output
136+
137+
@staticmethod
138+
def backward(ctx, grad_output):
139+
#split along last_dim
140+
"""Backward function."""
141+
rank = dist.get_rank(ctx.group)
142+
input_list = torch.chunk(grad_output, ctx.world_size, -1)
143+
grad_output = input_list[rank].contiguous()
144+
return None, grad_output
112145

113146
class TensorParallel_Layer(nn.Module, ABC):
114147
"""
@@ -394,23 +427,31 @@ def uneven_partition(self, params_list):
394427
#remove kwargs from partition.
395428
class LinearLayer(TensorParallel_Layer):
396429

397-
def __init__(self, module, mp_group=None, skip_partition=False, **kwargs):
430+
def __init__(self, module, mp_group=None, skip_partition=False, gather_output=False, **kwargs):
398431
super(LinearLayer, self).__init__(mp_group, **kwargs)
399432
self.weight = module.weight
400433
self.bias = module.bias
434+
if gather_output:
435+
b=0
401436
if not skip_partition:
402437
self._tp_partition([self.weight, self.bias])
403438
self.support_training = True
404439
self.config_tp_params(self.weight)
405440
if self.bias is not None:
406441
self.config_tp_params(self.bias)
442+
self.gather_output=gather_output
443+
407444

408445
def forward(self, input):
409446
if getattr(self, 'mp_group', None) is not None:
410447
input = ColumnParallel.apply(self.mp_group, input)
411448
output = torch.matmul(input, self.weight.transpose(-1, -2))
412449
if self.bias is not None:
413450
output += self.bias
451+
452+
if self.gather_output:
453+
output = GatherTensor.apply(self.mp_group,output)
454+
414455
return output
415456

416457
@torch.no_grad()
@@ -598,6 +639,8 @@ def __init__(self, module, mp_group, **kwargs):
598639
def forward(self, input):
599640
input_shard_size = get_shard_size(input.shape[-1], self.tp_world_size, "lm_head")
600641
input_shard_offset = sum(get_shard_size_list(input.shape[-1], self.tp_world_size, "lm_head")[0:self.tp_index])
642+
input= input[:, :, input_shard_offset:input_shard_offset + input_shard_size]
643+
601644
output = torch.matmul(input[:, :, input_shard_offset:input_shard_offset + input_shard_size],
602645
self.weight.transpose(-1, -2))
603646
if self.mp_group is not None:

deepspeed/module_inject/replace_module.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -335,9 +335,9 @@ def replace_fn(child, _policy, layer_id=0, prefix="", state_dict=None):
335335
return new_module
336336

337337
def set_lm_head(module):
338-
if is_autotp_training_mode():
339-
# we need to handle autoTP training mode separately.
340-
return
338+
# if is_autotp_training_mode():
339+
# # we need to handle autoTP training mode separately.
340+
# return
341341

342342
embedding_weight = None
343343
for n, p in module.named_parameters():

0 commit comments

Comments
 (0)