From a29d659507a1102ec8189496c34014417a043197 Mon Sep 17 00:00:00 2001 From: Yichen Zhang Date: Thu, 25 Sep 2025 19:24:33 +0800 Subject: [PATCH 1/4] drop padded tokens in MoE dispatching --- .../transformers/deepseek_v2/modeling.py | 27 ++++++++++++++----- paddlenlp/transformers/moe_layer.py | 4 +-- paddlenlp/transformers/token_dispatcher.py | 9 ++++++- 3 files changed, 31 insertions(+), 9 deletions(-) diff --git a/paddlenlp/transformers/deepseek_v2/modeling.py b/paddlenlp/transformers/deepseek_v2/modeling.py index 375e4a3c8885..2bf07008818d 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling.py +++ b/paddlenlp/transformers/deepseek_v2/modeling.py @@ -881,8 +881,8 @@ def __init__(self, config: DeepseekV2Config): intermediate_size = config.moe_intermediate_size * config.n_shared_experts self.shared_experts = DeepseekV2MLP(config=config, intermediate_size=intermediate_size, is_moe=False) - def forward(self, hidden_states): - final_hidden_states, l_aux, l_zloss = super().forward(hidden_states) + def forward(self, hidden_states, masked_token_indices=None): + final_hidden_states, l_aux, l_zloss = super().forward(hidden_states, masked_token_indices=masked_token_indices) if self.training and self.alpha > 0.0: l_aux = l_aux * self.alpha final_hidden_states = AddAuxiliaryLoss.apply(final_hidden_states, l_aux) @@ -1003,6 +1003,7 @@ def linear_dtype_gaurd(): # fmt: on if self.config.tensor_parallel_degree > 1 and self.config.sequence_parallel: + def grad_allreduce_hook(param, accumulation_steps): hcg = fleet.get_hybrid_communicate_group() pg = hcg.get_model_parallel_group().process_group @@ -1018,10 +1019,17 @@ def __impl__(): pg.allreduce(param.grad).wait() return __impl__ + # kv_a_proj_with_mqa and q_a_proj grad need to be reduce between mp - self.kv_a_proj_with_mqa.weight._register_backward_hook(grad_allreduce_hook(self.kv_a_proj_with_mqa.weight, accumulation_steps=config.gradient_accumulation_steps)) - self.q_a_proj.weight._register_backward_hook(grad_allreduce_hook(self.q_a_proj.weight, accumulation_steps=config.gradient_accumulation_steps)) - + self.kv_a_proj_with_mqa.weight._register_backward_hook( + grad_allreduce_hook( + self.kv_a_proj_with_mqa.weight, accumulation_steps=config.gradient_accumulation_steps + ) + ) + self.q_a_proj.weight._register_backward_hook( + grad_allreduce_hook(self.q_a_proj.weight, accumulation_steps=config.gradient_accumulation_steps) + ) + self._init_rope() self.softmax_scale = self.q_head_dim ** (-0.5) @@ -1431,7 +1439,14 @@ def forward( self_attn_weights = attn_outputs[2] if output_attentions else None present_key_value = attn_outputs[3] if use_cache else None - hidden_states = self.mlp(hidden_states) + masked_token_indices = None + if attn_mask_startend_row_indices is not None and isinstance(self.mlp, DeepseekV2MoEFlexToken): + flat_mask = paddle.flatten(attn_mask_startend_row_indices) + masked_token_indices = flat_mask == 0 + hidden_states = self.mlp(hidden_states, masked_token_indices=masked_token_indices) + else: + hidden_states = self.mlp(hidden_states) + outputs = self.post_process( hidden_states, residual, output_attentions, use_cache, self_attn_weights, present_key_value ) diff --git a/paddlenlp/transformers/moe_layer.py b/paddlenlp/transformers/moe_layer.py index 723bf525d6df..6ed7e9e1daa6 100644 --- a/paddlenlp/transformers/moe_layer.py +++ b/paddlenlp/transformers/moe_layer.py @@ -378,12 +378,12 @@ def expert_forward(self, dispatched_input, tokens_per_expert): return paddle.concat(outputs, axis=0) - def forward(self, hidden_states: paddle.Tensor): + def forward(self, hidden_states: paddle.Tensor, masked_token_indices=None): _, _, d_model = hidden_states.shape # reshaped_input = hidden_states.reshape([-1, d_model]) probs, routing_map, l_aux, l_zloss = self.router(hidden_states) (dispatched_input, tokens_per_expert) = self.token_dispatcher.token_permutation( - hidden_states, probs, routing_map + hidden_states, probs, routing_map, masked_token_indices ) expert_output = self.expert_forward(dispatched_input, tokens_per_expert) output, _ = self.token_dispatcher.token_unpermutation(expert_output, None) diff --git a/paddlenlp/transformers/token_dispatcher.py b/paddlenlp/transformers/token_dispatcher.py index 128f6e52f4d6..7eecbb699545 100644 --- a/paddlenlp/transformers/token_dispatcher.py +++ b/paddlenlp/transformers/token_dispatcher.py @@ -261,12 +261,19 @@ def __init__(self, num_local_experts: int, moe_router_topk: int, num_moe_experts ) def token_permutation( - self, hidden_states: paddle.Tensor, probs: paddle.Tensor, routing_map: paddle.Tensor + self, hidden_states: paddle.Tensor, probs: paddle.Tensor, routing_map: paddle.Tensor, masked_token_indices=None ) -> Tuple[paddle.Tensor, paddle.Tensor]: self.hidden_shape = hidden_states.shape hidden_states = hidden_states.view([-1, self.hidden_shape[-1]]) self._comm_manager.setup_metadata(routing_map, probs) + if masked_token_indices is not None: + self._comm_manager.token_indices.stop_gradient = True + masked_token_indices = masked_token_indices.unsqueeze(axis=-1) + self._comm_manager.token_indices = paddle.masked_fill( + self._comm_manager.token_indices, masked_token_indices, -1 + ) + hidden_states = self._comm_manager.dispatch(hidden_states) global_input_tokens = self._comm_manager.get_permuted_hidden_states_by_experts(hidden_states) tokens_per_expert = self._comm_manager.get_number_of_tokens_per_expert() From 09ec3591c1f5c3c815450d6bc70d2ae3d2e463f5 Mon Sep 17 00:00:00 2001 From: Yichen Zhang Date: Thu, 25 Sep 2025 20:01:54 +0800 Subject: [PATCH 2/4] fix sequence parallel bugs in deepseek-v3 --- .../transformers/deepseek_v2/modeling.py | 57 +++++++------------ 1 file changed, 21 insertions(+), 36 deletions(-) diff --git a/paddlenlp/transformers/deepseek_v2/modeling.py b/paddlenlp/transformers/deepseek_v2/modeling.py index 2bf07008818d..ed74b0be08a5 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling.py +++ b/paddlenlp/transformers/deepseek_v2/modeling.py @@ -33,7 +33,6 @@ import paddle.nn.functional as F from paddle import Tensor, nn from paddle.distributed import fleet -from paddle.distributed.communication.reduce import ReduceOp from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker from paddle.distributed.fleet.recompute.recompute import recompute from paddle.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -1021,14 +1020,19 @@ def __impl__(): return __impl__ # kv_a_proj_with_mqa and q_a_proj grad need to be reduce between mp - self.kv_a_proj_with_mqa.weight._register_backward_hook( - grad_allreduce_hook( - self.kv_a_proj_with_mqa.weight, accumulation_steps=config.gradient_accumulation_steps - ) - ) - self.q_a_proj.weight._register_backward_hook( - grad_allreduce_hook(self.q_a_proj.weight, accumulation_steps=config.gradient_accumulation_steps) - ) + # self.kv_a_proj_with_mqa.weight._register_backward_hook( + # grad_allreduce_hook( + # self.kv_a_proj_with_mqa.weight, accumulation_steps=config.gradient_accumulation_steps + # ) + # ) + # self.q_a_proj.weight._register_backward_hook( + # grad_allreduce_hook(self.q_a_proj.weight, accumulation_steps=config.gradient_accumulation_steps) + # ) + mark_as_sequence_parallel_parameter(self.kv_a_proj_with_mqa.weight) + mark_as_sequence_parallel_parameter(self.q_a_proj.weight) + if config.attention_bias: + mark_as_sequence_parallel_parameter(self.kv_a_proj_with_mqa.bias) + mark_as_sequence_parallel_parameter(self.q_a_proj.bias) self._init_rope() @@ -1562,6 +1566,10 @@ def __init__( self.hnorm = DeepseekV2RMSNorm(config) self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size) + if config.sequence_parallel and config.tensor_parallel_degree > 1: + mark_as_sequence_parallel_parameter(self.eh_proj.weight) + mark_as_sequence_parallel_parameter(self.eh_proj.bias) + def subbatch_recompute_forward( self, hidden_states: paddle.Tensor, @@ -2241,10 +2249,6 @@ def __init__(self, config: DeepseekV2Config): else: self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none", ignore_index=self.ignore_index) - if self.config.sequence_parallel: - self.seq_para_scale = 1.0 / self.config.tensor_parallel_degree - self.mp_group = fleet.get_hybrid_communicate_group().get_model_parallel_group() - def forward(self, prediction_scores, masked_lm_labels, router_loss=None, mtp_logits=None): if self.enable_parallel_cross_entropy: @@ -2262,17 +2266,11 @@ def compute_loss(preds, labels): ) count = paddle.sum(binary_sequence) - if self.config.sequence_parallel: - dist.all_reduce(count, op=ReduceOp.SUM, group=self.mp_group) - if count == 0: loss = paddle.sum(masked_lm_loss * binary_sequence) else: loss = paddle.sum(masked_lm_loss * binary_sequence) / count - if self.config.sequence_parallel: - dist.all_reduce(loss, op=ReduceOp.SUM, group=self.mp_group) - return loss def add_loss(main_loss, loss): @@ -2284,28 +2282,15 @@ def add_loss(main_loss, loss): masked_lm_labels = masked_lm_labels[:, : -self.config.num_nextn_predict_layers] seq_length = masked_lm_labels.shape[1] - if self.config.sequence_parallel: - masked_lm_labels = masked_lm_labels.transpose([1, 0]) # [B, S] --> [S, B] - masked_lm_labels = ScatterOp.apply(masked_lm_labels) - loss = compute_loss(prediction_scores, masked_lm_labels) mtp_loss_res = [] for depth in range(self.config.num_nextn_predict_layers): prediction_scores_cur_depth = mtp_logits[depth] masked_lm_labels_cur_depth = masked_lm_labels_ori[:, (depth + 1) : (depth + 1 + seq_length)] - - if self.config.sequence_parallel: - masked_lm_labels_cur_depth = masked_lm_labels_cur_depth.transpose([1, 0]) # [B, S] --> [S, B] - masked_lm_labels_cur_depth = ScatterOp.apply(masked_lm_labels_cur_depth) - res_cur_depth = compute_loss(prediction_scores_cur_depth, masked_lm_labels_cur_depth) - - if self.config.sequence_parallel: - res_cur_depth = res_cur_depth * self.seq_para_scale - dist.all_reduce(res_cur_depth, op=ReduceOp.SUM, group=self.mp_group) - mtp_loss_res.append(res_cur_depth) + loss = add_loss(loss, self.config.num_nextn_predict_lambda * sum([x for x in mtp_loss_res]) / len(mtp_loss_res)) # fmt: skip else: @@ -2351,9 +2336,9 @@ def __init__(self, config: DeepseekV2Config): def forward(self, hidden_states, tensor_parallel_output=None): - # if self.config.sequence_parallel: - # hidden_states = GatherOp.apply(hidden_states) - # hidden_states = paddle.transpose(hidden_states, [1, 0, 2]) + if self.config.sequence_parallel: + hidden_states = GatherOp.apply(hidden_states) + hidden_states = paddle.transpose(hidden_states, [1, 0, 2]) # [S, B, H] --> [B, S, H] # hidden_states = paddle.reshape_(hidden_states, [-1, self.seq_length, self.config.hidden_size]) if tensor_parallel_output is None: From f0d794a0d7fc09de83165863d1491f08373e74fc Mon Sep 17 00:00:00 2001 From: Yichen Zhang Date: Thu, 25 Sep 2025 21:15:40 +0800 Subject: [PATCH 3/4] drop padded token subbatch --- .../transformers/deepseek_v2/modeling.py | 46 ++++++++++++++----- 1 file changed, 35 insertions(+), 11 deletions(-) diff --git a/paddlenlp/transformers/deepseek_v2/modeling.py b/paddlenlp/transformers/deepseek_v2/modeling.py index ed74b0be08a5..5e4dfed223e2 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling.py +++ b/paddlenlp/transformers/deepseek_v2/modeling.py @@ -1296,17 +1296,38 @@ def subbatch_recompute_forward( seq_len = hidden_states.shape[seq_axis] assert seq_len % sub_seq_len == 0 num_chunks = seq_len // sub_seq_len - split_list = [sub_seq_len] * num_chunks - input_list = paddle.split(hidden_states, split_list, axis=seq_axis) + input_list = paddle.split(hidden_states, num_chunks, axis=seq_axis) output_list = [] - for chunk in input_list: - out = recompute( - self.mlp.forward, - chunk, - **offload_kwargs, - ) - output_list.append(out) + if isinstance(self.mlp, DeepseekV2MoEFlexToken): + if attn_mask_startend_row_indices is not None: + if self.config.sequence_parallel and self.config.tensor_parallel_degree > 1: + flat_mask = paddle.transpose(attn_mask_startend_row_indices, [2, 0, 1]) + flat_mask = ScatterOp.apply(flat_mask) + flat_mask = paddle.flatten(flat_mask) + mask_list = paddle.split(flat_mask, num_chunks) + else: + mask_list = [None] * num_chunks + + for chunk, mask_chunk in zip(input_list, mask_list): + masked_token_indices = None + if mask_chunk is not None: + masked_token_indices = mask_chunk == 0 + out = recompute( + self.mlp.forward, + chunk, + masked_token_indices=masked_token_indices, + **offload_kwargs, + ) + output_list.append(out) + else: + for chunk in input_list: + out = recompute( + self.mlp.forward, + chunk, + **offload_kwargs, + ) + output_list.append(out) hidden_states = paddle.concat(output_list, axis=seq_axis) outputs = recompute( self.post_process, @@ -1443,9 +1464,12 @@ def forward( self_attn_weights = attn_outputs[2] if output_attentions else None present_key_value = attn_outputs[3] if use_cache else None - masked_token_indices = None if attn_mask_startend_row_indices is not None and isinstance(self.mlp, DeepseekV2MoEFlexToken): - flat_mask = paddle.flatten(attn_mask_startend_row_indices) + masked_token_indices = None + if self.config.sequence_parallel and self.config.tensor_parallel_degree > 1: + flat_mask = paddle.transpose(attn_mask_startend_row_indices, [2, 0, 1]) + flat_mask = ScatterOp.apply(flat_mask) + flat_mask = paddle.flatten(flat_mask) masked_token_indices = flat_mask == 0 hidden_states = self.mlp(hidden_states, masked_token_indices=masked_token_indices) else: From c70ed4fdf9132c56b51a31a4f5a6c803498cd277 Mon Sep 17 00:00:00 2001 From: Yichen Zhang Date: Fri, 26 Sep 2025 14:27:52 +0800 Subject: [PATCH 4/4] fix offload setting in subbatch --- paddlenlp/transformers/deepseek_v2/modeling.py | 18 ++++++++++-------- .../transformers/deepseek_v2/modeling_pp.py | 3 ++- paddlenlp/transformers/moe_layer.py | 4 ++-- paddlenlp/transformers/token_dispatcher.py | 10 ++++------ 4 files changed, 18 insertions(+), 17 deletions(-) diff --git a/paddlenlp/transformers/deepseek_v2/modeling.py b/paddlenlp/transformers/deepseek_v2/modeling.py index 5e4dfed223e2..26d664f16c54 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling.py +++ b/paddlenlp/transformers/deepseek_v2/modeling.py @@ -880,8 +880,8 @@ def __init__(self, config: DeepseekV2Config): intermediate_size = config.moe_intermediate_size * config.n_shared_experts self.shared_experts = DeepseekV2MLP(config=config, intermediate_size=intermediate_size, is_moe=False) - def forward(self, hidden_states, masked_token_indices=None): - final_hidden_states, l_aux, l_zloss = super().forward(hidden_states, masked_token_indices=masked_token_indices) + def forward(self, hidden_states, masked_tokens=None): + final_hidden_states, l_aux, l_zloss = super().forward(hidden_states, masked_tokens=masked_tokens) if self.training and self.alpha > 0.0: l_aux = l_aux * self.alpha final_hidden_states = AddAuxiliaryLoss.apply(final_hidden_states, l_aux) @@ -1310,13 +1310,14 @@ def subbatch_recompute_forward( mask_list = [None] * num_chunks for chunk, mask_chunk in zip(input_list, mask_list): - masked_token_indices = None + masked_tokens = None if mask_chunk is not None: - masked_token_indices = mask_chunk == 0 + masked_tokens = mask_chunk == 0 + offload_kwargs["offload_indices"] = [0] out = recompute( self.mlp.forward, chunk, - masked_token_indices=masked_token_indices, + masked_tokens=masked_tokens, **offload_kwargs, ) output_list.append(out) @@ -1329,6 +1330,7 @@ def subbatch_recompute_forward( ) output_list.append(out) hidden_states = paddle.concat(output_list, axis=seq_axis) + offload_kwargs["offload_indices"] = [0] outputs = recompute( self.post_process, hidden_states, @@ -1465,13 +1467,13 @@ def forward( present_key_value = attn_outputs[3] if use_cache else None if attn_mask_startend_row_indices is not None and isinstance(self.mlp, DeepseekV2MoEFlexToken): - masked_token_indices = None + masked_tokens = None if self.config.sequence_parallel and self.config.tensor_parallel_degree > 1: flat_mask = paddle.transpose(attn_mask_startend_row_indices, [2, 0, 1]) flat_mask = ScatterOp.apply(flat_mask) flat_mask = paddle.flatten(flat_mask) - masked_token_indices = flat_mask == 0 - hidden_states = self.mlp(hidden_states, masked_token_indices=masked_token_indices) + masked_tokens = flat_mask == 0 + hidden_states = self.mlp(hidden_states, masked_tokens=masked_tokens) else: hidden_states = self.mlp(hidden_states) diff --git a/paddlenlp/transformers/deepseek_v2/modeling_pp.py b/paddlenlp/transformers/deepseek_v2/modeling_pp.py index 0d7e7882fe59..762e910d42cf 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling_pp.py +++ b/paddlenlp/transformers/deepseek_v2/modeling_pp.py @@ -294,11 +294,12 @@ def forward(self, args): hidden_states = hidden_states_main_model for depth in range(self.config.num_nextn_predict_layers): inputs_embeds_cur_depth = inputs_embeds_cur_depth_list[depth] - + moelayer_use_subbatch_recompute = self.config.moe_subbatch_token_num > 0 if moelayer_use_subbatch_recompute: hidden_states = super().subbatch_recompute_forward( hidden_states, + inputs_embeds_cur_depth, position_ids=position_ids, attention_mask=attention_mask, attn_mask_startend_row_indices=attn_mask_startend_row_indices, diff --git a/paddlenlp/transformers/moe_layer.py b/paddlenlp/transformers/moe_layer.py index 6ed7e9e1daa6..7b423813e2f6 100644 --- a/paddlenlp/transformers/moe_layer.py +++ b/paddlenlp/transformers/moe_layer.py @@ -378,12 +378,12 @@ def expert_forward(self, dispatched_input, tokens_per_expert): return paddle.concat(outputs, axis=0) - def forward(self, hidden_states: paddle.Tensor, masked_token_indices=None): + def forward(self, hidden_states: paddle.Tensor, masked_tokens=None): _, _, d_model = hidden_states.shape # reshaped_input = hidden_states.reshape([-1, d_model]) probs, routing_map, l_aux, l_zloss = self.router(hidden_states) (dispatched_input, tokens_per_expert) = self.token_dispatcher.token_permutation( - hidden_states, probs, routing_map, masked_token_indices + hidden_states, probs, routing_map, masked_tokens ) expert_output = self.expert_forward(dispatched_input, tokens_per_expert) output, _ = self.token_dispatcher.token_unpermutation(expert_output, None) diff --git a/paddlenlp/transformers/token_dispatcher.py b/paddlenlp/transformers/token_dispatcher.py index 7eecbb699545..589a2acce9a5 100644 --- a/paddlenlp/transformers/token_dispatcher.py +++ b/paddlenlp/transformers/token_dispatcher.py @@ -261,18 +261,16 @@ def __init__(self, num_local_experts: int, moe_router_topk: int, num_moe_experts ) def token_permutation( - self, hidden_states: paddle.Tensor, probs: paddle.Tensor, routing_map: paddle.Tensor, masked_token_indices=None + self, hidden_states: paddle.Tensor, probs: paddle.Tensor, routing_map: paddle.Tensor, masked_tokens=None ) -> Tuple[paddle.Tensor, paddle.Tensor]: self.hidden_shape = hidden_states.shape hidden_states = hidden_states.view([-1, self.hidden_shape[-1]]) self._comm_manager.setup_metadata(routing_map, probs) - if masked_token_indices is not None: + if masked_tokens is not None: self._comm_manager.token_indices.stop_gradient = True - masked_token_indices = masked_token_indices.unsqueeze(axis=-1) - self._comm_manager.token_indices = paddle.masked_fill( - self._comm_manager.token_indices, masked_token_indices, -1 - ) + masked_tokens = masked_tokens.unsqueeze(axis=-1) + self._comm_manager.token_indices = paddle.masked_fill(self._comm_manager.token_indices, masked_tokens, -1) hidden_states = self._comm_manager.dispatch(hidden_states) global_input_tokens = self._comm_manager.get_permuted_hidden_states_by_experts(hidden_states)