diff --git a/deepspeed/sequence/layer.py b/deepspeed/sequence/layer.py index f17cfa883cc6..e809fe1118b5 100644 --- a/deepspeed/sequence/layer.py +++ b/deepspeed/sequence/layer.py @@ -12,48 +12,76 @@ from deepspeed.accelerator import get_accelerator -def post_all2all(transpose, res_shape): +def post_all2all(scatter_idx, batch_dim_idx, seq_world_size, bs, seq_len, num_head, head_dim): def post_func(input): - if transpose: - input = input.transpose(0, 2).contiguous() - input = input.reshape(res_shape) - return input + if batch_dim_idx == 0: + # b, s, n, h + if scatter_idx < 2: + output = input.permute(1, 2, 0, 3, 4).contiguous() + output = output.reshape(bs, seq_len // seq_world_size, seq_world_size * num_head, + head_dim).contiguous() + else: + output = input.permute(1, 0, 2, 3, 4).contiguous() + output = output.reshape(bs, seq_world_size * seq_len, num_head // seq_world_size, + head_dim).contiguous() + else: + # s, b, n, h + if scatter_idx < 2: + output = input.permute(1, 2, 0, 3, 4).contiguous() + output = output.reshape(seq_len // seq_world_size, bs, seq_world_size * num_head, + head_dim).contiguous() + else: + output = input.reshape(seq_len * seq_world_size, bs, num_head // seq_world_size, head_dim).contiguous() + return output return post_func -def single_all_to_all(input, scatter_idx, gather_idx, group, async_op=False, handle=None, type=None): +def single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, async_op=False, handle=None, type=None): seq_world_size = dist.get_world_size(group) - inp_shape = list(input.shape) - inp_shape[scatter_idx] = inp_shape[scatter_idx] // seq_world_size + if batch_dim_idx == 0: + # b, s, n, h + if scatter_idx < 2: + bs, global_seq_len, num_local_head, head_dim = input.shape + input_t = input.reshape([bs, seq_world_size, global_seq_len // seq_world_size, num_local_head, + head_dim]).contiguous() + input_t = input_t.permute(1, 0, 2, 3, 4).contiguous() + else: + bs, local_seq_len, num_total_head, head_dim = input.shape + assert num_total_head % seq_world_size == 0, f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!" + input_t = input.reshape([bs, local_seq_len, seq_world_size, num_total_head // seq_world_size, + head_dim]).contiguous() + input_t = input_t.permute(2, 0, 1, 3, 4).contiguous() + else: + # s, b, n, h + if scatter_idx < 2: + global_seq_len, bs, num_local_head, head_dim = input.shape + input_t = input.reshape([seq_world_size, global_seq_len // seq_world_size, bs, num_local_head, + head_dim]).contiguous() + else: + local_seq_len, bs, num_total_head, head_dim = input.shape + assert num_total_head % seq_world_size == 0, f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!" + input_t = input.reshape([local_seq_len, bs, seq_world_size, num_total_head // seq_world_size, + head_dim]).contiguous() + input_t = input_t.permute(2, 0, 1, 3, 4).contiguous() + if scatter_idx < 2: - input_t = input.reshape( - [seq_world_size, inp_shape[scatter_idx]] + \ - inp_shape[scatter_idx + 1:] - ).contiguous() + post_all2all_fun = post_all2all(scatter_idx, batch_dim_idx, seq_world_size, bs, global_seq_len, num_local_head, + head_dim) else: - # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! - input_t = input.reshape( - [-1, seq_world_size, inp_shape[scatter_idx]] + \ - inp_shape[scatter_idx + 1:] - ).transpose(0, 1).contiguous() + post_all2all_fun = post_all2all(scatter_idx, batch_dim_idx, seq_world_size, bs, local_seq_len, num_total_head, + head_dim) output = torch.empty_like(input_t) work = dist.all_to_all_single(output, input_t, group=group, async_op=async_op) - res_shape=( inp_shape[: gather_idx] + \ - [inp_shape[gather_idx] * seq_world_size,] + \ - inp_shape[gather_idx + 1:]) - transpose = True if scatter_idx < 2 else False - post_all2all_fun = post_all2all(transpose, res_shape) - if async_op: if type in ('dq', 'dk'): handle[type + '_work'] = work handle[type + '_grad'] = output handle[type + '_post_all2all_func'] = post_all2all_fun - return output.view(res_shape) + return output res = post_all2all_fun(output) return res @@ -67,6 +95,7 @@ def forward(ctx: Any, input: Tensor, scatter_idx: int, gather_idx: int, + batch_dim_idx: int, stream=None, handle=None, type=None, @@ -77,14 +106,15 @@ def forward(ctx: Any, ctx.stream = stream ctx.handle = handle ctx.type = type + ctx.batch_dim_idx = batch_dim_idx if ctx.handle is None: - res = single_all_to_all(input, scatter_idx, gather_idx, group, False) + res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, False) else: # overlap communication path if not is_fwd and type == 'o': assert ctx.stream != None - res = single_all_to_all(input, scatter_idx, gather_idx, group, False) + res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, False) get_accelerator().current_stream().wait_stream(ctx.stream) del ctx.stream.activation_buffer_list # The computation of d o_weight can overlap with the communication of d o_input @@ -92,15 +122,15 @@ def forward(ctx: Any, elif not is_fwd and type in ('q', 'k'): # Achieve communication overlap by pipelining the matrix computation and communication of dq, dk, and dv type = 'd' + type - res = single_all_to_all(input, scatter_idx, gather_idx, group, True, handle, type) + res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, True, handle, type) elif is_fwd and type in ('q', 'k'): # Achieve communication overlap by pipelining the matrix computation and communication of q, k, and v type = 'fwd_' + type - res = single_all_to_all(input, scatter_idx, gather_idx, group, False, handle, type) + res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, False, handle, type) else: - res = single_all_to_all(input, scatter_idx, gather_idx, group, False) + res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, False) return res @@ -108,8 +138,8 @@ def forward(ctx: Any, def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: return (None, - _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx, ctx.stream, ctx.handle, - ctx.type, False), None, None, None, None, None, None) + _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx, ctx.batch_dim_idx, + ctx.stream, ctx.handle, ctx.type, False), None, None, None, None, None, None, None) class DistributedAttention(torch.nn.Module): @@ -148,13 +178,14 @@ def layer_sync(self, layer): if self.sp_overlap_comm and hasattr(layer, 'done_event'): self.dafult_stream.wait_event(layer.done_event) - def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any, **kwargs) -> Tensor: + def forward(self, query: Tensor, key: Tensor, value: Tensor, batch_dim_idx: int, *args: Any, **kwargs) -> Tensor: """ forward Arguments: query (Tensor): query input to the layer key (Tensor): key input to the layer value (Tensor): value input to the layer + batch_dim_idx (int): indicating which dim is batch args: other args Returns: @@ -179,15 +210,15 @@ def pre_hook_fun(grad): return pre_hook_fun self.layer_sync(query) - query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx, None, + query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx, batch_dim_idx, None, self.overlap_handles, 'q') self.layer_sync(key) - key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx, None, self.overlap_handles, - 'k') + key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx, batch_dim_idx, None, + self.overlap_handles, 'k') if self.sp_overlap_comm: self.dafult_stream.wait_stream(self.sp_stream) - value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx, None, + value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx, batch_dim_idx, None, self.overlap_handles, 'v') if self.sp_overlap_comm: @@ -205,8 +236,8 @@ def pre_hook_fun(grad): context_layer = self.local_attn(query_layer, key_layer, value_layer, *args, **kwargs) - output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx, self.sp_stream, - self.overlap_handles, 'o') + output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx, batch_dim_idx, + self.sp_stream, self.overlap_handles, 'o') #out e.g., [s/p::h] return output