diff --git a/paddlenlp/transformers/deepseek_v2/modeling.py b/paddlenlp/transformers/deepseek_v2/modeling.py index 375e4a3c8885..26d664f16c54 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling.py +++ b/paddlenlp/transformers/deepseek_v2/modeling.py @@ -33,7 +33,6 @@ import paddle.nn.functional as F from paddle import Tensor, nn from paddle.distributed import fleet -from paddle.distributed.communication.reduce import ReduceOp from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker from paddle.distributed.fleet.recompute.recompute import recompute from paddle.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -881,8 +880,8 @@ def __init__(self, config: DeepseekV2Config): intermediate_size = config.moe_intermediate_size * config.n_shared_experts self.shared_experts = DeepseekV2MLP(config=config, intermediate_size=intermediate_size, is_moe=False) - def forward(self, hidden_states): - final_hidden_states, l_aux, l_zloss = super().forward(hidden_states) + def forward(self, hidden_states, masked_tokens=None): + final_hidden_states, l_aux, l_zloss = super().forward(hidden_states, masked_tokens=masked_tokens) if self.training and self.alpha > 0.0: l_aux = l_aux * self.alpha final_hidden_states = AddAuxiliaryLoss.apply(final_hidden_states, l_aux) @@ -1003,6 +1002,7 @@ def linear_dtype_gaurd(): # fmt: on if self.config.tensor_parallel_degree > 1 and self.config.sequence_parallel: + def grad_allreduce_hook(param, accumulation_steps): hcg = fleet.get_hybrid_communicate_group() pg = hcg.get_model_parallel_group().process_group @@ -1018,10 +1018,22 @@ def __impl__(): pg.allreduce(param.grad).wait() return __impl__ + # kv_a_proj_with_mqa and q_a_proj grad need to be reduce between mp - self.kv_a_proj_with_mqa.weight._register_backward_hook(grad_allreduce_hook(self.kv_a_proj_with_mqa.weight, accumulation_steps=config.gradient_accumulation_steps)) - self.q_a_proj.weight._register_backward_hook(grad_allreduce_hook(self.q_a_proj.weight, accumulation_steps=config.gradient_accumulation_steps)) - + # self.kv_a_proj_with_mqa.weight._register_backward_hook( + # grad_allreduce_hook( + # self.kv_a_proj_with_mqa.weight, accumulation_steps=config.gradient_accumulation_steps + # ) + # ) + # self.q_a_proj.weight._register_backward_hook( + # grad_allreduce_hook(self.q_a_proj.weight, accumulation_steps=config.gradient_accumulation_steps) + # ) + mark_as_sequence_parallel_parameter(self.kv_a_proj_with_mqa.weight) + mark_as_sequence_parallel_parameter(self.q_a_proj.weight) + if config.attention_bias: + mark_as_sequence_parallel_parameter(self.kv_a_proj_with_mqa.bias) + mark_as_sequence_parallel_parameter(self.q_a_proj.bias) + self._init_rope() self.softmax_scale = self.q_head_dim ** (-0.5) @@ -1284,18 +1296,41 @@ def subbatch_recompute_forward( seq_len = hidden_states.shape[seq_axis] assert seq_len % sub_seq_len == 0 num_chunks = seq_len // sub_seq_len - split_list = [sub_seq_len] * num_chunks - input_list = paddle.split(hidden_states, split_list, axis=seq_axis) + input_list = paddle.split(hidden_states, num_chunks, axis=seq_axis) output_list = [] - for chunk in input_list: - out = recompute( - self.mlp.forward, - chunk, - **offload_kwargs, - ) - output_list.append(out) + if isinstance(self.mlp, DeepseekV2MoEFlexToken): + if attn_mask_startend_row_indices is not None: + if self.config.sequence_parallel and self.config.tensor_parallel_degree > 1: + flat_mask = paddle.transpose(attn_mask_startend_row_indices, [2, 0, 1]) + flat_mask = ScatterOp.apply(flat_mask) + flat_mask = paddle.flatten(flat_mask) + mask_list = paddle.split(flat_mask, num_chunks) + else: + mask_list = [None] * num_chunks + + for chunk, mask_chunk in zip(input_list, mask_list): + masked_tokens = None + if mask_chunk is not None: + masked_tokens = mask_chunk == 0 + offload_kwargs["offload_indices"] = [0] + out = recompute( + self.mlp.forward, + chunk, + masked_tokens=masked_tokens, + **offload_kwargs, + ) + output_list.append(out) + else: + for chunk in input_list: + out = recompute( + self.mlp.forward, + chunk, + **offload_kwargs, + ) + output_list.append(out) hidden_states = paddle.concat(output_list, axis=seq_axis) + offload_kwargs["offload_indices"] = [0] outputs = recompute( self.post_process, hidden_states, @@ -1431,7 +1466,17 @@ def forward( self_attn_weights = attn_outputs[2] if output_attentions else None present_key_value = attn_outputs[3] if use_cache else None - hidden_states = self.mlp(hidden_states) + if attn_mask_startend_row_indices is not None and isinstance(self.mlp, DeepseekV2MoEFlexToken): + masked_tokens = None + if self.config.sequence_parallel and self.config.tensor_parallel_degree > 1: + flat_mask = paddle.transpose(attn_mask_startend_row_indices, [2, 0, 1]) + flat_mask = ScatterOp.apply(flat_mask) + flat_mask = paddle.flatten(flat_mask) + masked_tokens = flat_mask == 0 + hidden_states = self.mlp(hidden_states, masked_tokens=masked_tokens) + else: + hidden_states = self.mlp(hidden_states) + outputs = self.post_process( hidden_states, residual, output_attentions, use_cache, self_attn_weights, present_key_value ) @@ -1547,6 +1592,10 @@ def __init__( self.hnorm = DeepseekV2RMSNorm(config) self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size) + if config.sequence_parallel and config.tensor_parallel_degree > 1: + mark_as_sequence_parallel_parameter(self.eh_proj.weight) + mark_as_sequence_parallel_parameter(self.eh_proj.bias) + def subbatch_recompute_forward( self, hidden_states: paddle.Tensor, @@ -2226,10 +2275,6 @@ def __init__(self, config: DeepseekV2Config): else: self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none", ignore_index=self.ignore_index) - if self.config.sequence_parallel: - self.seq_para_scale = 1.0 / self.config.tensor_parallel_degree - self.mp_group = fleet.get_hybrid_communicate_group().get_model_parallel_group() - def forward(self, prediction_scores, masked_lm_labels, router_loss=None, mtp_logits=None): if self.enable_parallel_cross_entropy: @@ -2247,17 +2292,11 @@ def compute_loss(preds, labels): ) count = paddle.sum(binary_sequence) - if self.config.sequence_parallel: - dist.all_reduce(count, op=ReduceOp.SUM, group=self.mp_group) - if count == 0: loss = paddle.sum(masked_lm_loss * binary_sequence) else: loss = paddle.sum(masked_lm_loss * binary_sequence) / count - if self.config.sequence_parallel: - dist.all_reduce(loss, op=ReduceOp.SUM, group=self.mp_group) - return loss def add_loss(main_loss, loss): @@ -2269,28 +2308,15 @@ def add_loss(main_loss, loss): masked_lm_labels = masked_lm_labels[:, : -self.config.num_nextn_predict_layers] seq_length = masked_lm_labels.shape[1] - if self.config.sequence_parallel: - masked_lm_labels = masked_lm_labels.transpose([1, 0]) # [B, S] --> [S, B] - masked_lm_labels = ScatterOp.apply(masked_lm_labels) - loss = compute_loss(prediction_scores, masked_lm_labels) mtp_loss_res = [] for depth in range(self.config.num_nextn_predict_layers): prediction_scores_cur_depth = mtp_logits[depth] masked_lm_labels_cur_depth = masked_lm_labels_ori[:, (depth + 1) : (depth + 1 + seq_length)] - - if self.config.sequence_parallel: - masked_lm_labels_cur_depth = masked_lm_labels_cur_depth.transpose([1, 0]) # [B, S] --> [S, B] - masked_lm_labels_cur_depth = ScatterOp.apply(masked_lm_labels_cur_depth) - res_cur_depth = compute_loss(prediction_scores_cur_depth, masked_lm_labels_cur_depth) - - if self.config.sequence_parallel: - res_cur_depth = res_cur_depth * self.seq_para_scale - dist.all_reduce(res_cur_depth, op=ReduceOp.SUM, group=self.mp_group) - mtp_loss_res.append(res_cur_depth) + loss = add_loss(loss, self.config.num_nextn_predict_lambda * sum([x for x in mtp_loss_res]) / len(mtp_loss_res)) # fmt: skip else: @@ -2336,9 +2362,9 @@ def __init__(self, config: DeepseekV2Config): def forward(self, hidden_states, tensor_parallel_output=None): - # if self.config.sequence_parallel: - # hidden_states = GatherOp.apply(hidden_states) - # hidden_states = paddle.transpose(hidden_states, [1, 0, 2]) + if self.config.sequence_parallel: + hidden_states = GatherOp.apply(hidden_states) + hidden_states = paddle.transpose(hidden_states, [1, 0, 2]) # [S, B, H] --> [B, S, H] # hidden_states = paddle.reshape_(hidden_states, [-1, self.seq_length, self.config.hidden_size]) if tensor_parallel_output is None: diff --git a/paddlenlp/transformers/deepseek_v2/modeling_pp.py b/paddlenlp/transformers/deepseek_v2/modeling_pp.py index 0d7e7882fe59..762e910d42cf 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling_pp.py +++ b/paddlenlp/transformers/deepseek_v2/modeling_pp.py @@ -294,11 +294,12 @@ def forward(self, args): hidden_states = hidden_states_main_model for depth in range(self.config.num_nextn_predict_layers): inputs_embeds_cur_depth = inputs_embeds_cur_depth_list[depth] - + moelayer_use_subbatch_recompute = self.config.moe_subbatch_token_num > 0 if moelayer_use_subbatch_recompute: hidden_states = super().subbatch_recompute_forward( hidden_states, + inputs_embeds_cur_depth, position_ids=position_ids, attention_mask=attention_mask, attn_mask_startend_row_indices=attn_mask_startend_row_indices, diff --git a/paddlenlp/transformers/moe_layer.py b/paddlenlp/transformers/moe_layer.py index 723bf525d6df..7b423813e2f6 100644 --- a/paddlenlp/transformers/moe_layer.py +++ b/paddlenlp/transformers/moe_layer.py @@ -378,12 +378,12 @@ def expert_forward(self, dispatched_input, tokens_per_expert): return paddle.concat(outputs, axis=0) - def forward(self, hidden_states: paddle.Tensor): + def forward(self, hidden_states: paddle.Tensor, masked_tokens=None): _, _, d_model = hidden_states.shape # reshaped_input = hidden_states.reshape([-1, d_model]) probs, routing_map, l_aux, l_zloss = self.router(hidden_states) (dispatched_input, tokens_per_expert) = self.token_dispatcher.token_permutation( - hidden_states, probs, routing_map + hidden_states, probs, routing_map, masked_tokens ) expert_output = self.expert_forward(dispatched_input, tokens_per_expert) output, _ = self.token_dispatcher.token_unpermutation(expert_output, None) diff --git a/paddlenlp/transformers/token_dispatcher.py b/paddlenlp/transformers/token_dispatcher.py index 128f6e52f4d6..589a2acce9a5 100644 --- a/paddlenlp/transformers/token_dispatcher.py +++ b/paddlenlp/transformers/token_dispatcher.py @@ -261,12 +261,17 @@ def __init__(self, num_local_experts: int, moe_router_topk: int, num_moe_experts ) def token_permutation( - self, hidden_states: paddle.Tensor, probs: paddle.Tensor, routing_map: paddle.Tensor + self, hidden_states: paddle.Tensor, probs: paddle.Tensor, routing_map: paddle.Tensor, masked_tokens=None ) -> Tuple[paddle.Tensor, paddle.Tensor]: self.hidden_shape = hidden_states.shape hidden_states = hidden_states.view([-1, self.hidden_shape[-1]]) self._comm_manager.setup_metadata(routing_map, probs) + if masked_tokens is not None: + self._comm_manager.token_indices.stop_gradient = True + masked_tokens = masked_tokens.unsqueeze(axis=-1) + self._comm_manager.token_indices = paddle.masked_fill(self._comm_manager.token_indices, masked_tokens, -1) + hidden_states = self._comm_manager.dispatch(hidden_states) global_input_tokens = self._comm_manager.get_permuted_hidden_states_by_experts(hidden_states) tokens_per_expert = self._comm_manager.get_number_of_tokens_per_expert()