Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 69 additions & 43 deletions paddlenlp/transformers/deepseek_v2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import paddle.nn.functional as F
from paddle import Tensor, nn
from paddle.distributed import fleet
from paddle.distributed.communication.reduce import ReduceOp
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
from paddle.distributed.fleet.recompute.recompute import recompute
from paddle.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
Expand Down Expand Up @@ -881,8 +880,8 @@ def __init__(self, config: DeepseekV2Config):
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = DeepseekV2MLP(config=config, intermediate_size=intermediate_size, is_moe=False)

def forward(self, hidden_states):
final_hidden_states, l_aux, l_zloss = super().forward(hidden_states)
def forward(self, hidden_states, masked_tokens=None):
final_hidden_states, l_aux, l_zloss = super().forward(hidden_states, masked_tokens=masked_tokens)
if self.training and self.alpha > 0.0:
l_aux = l_aux * self.alpha
final_hidden_states = AddAuxiliaryLoss.apply(final_hidden_states, l_aux)
Expand Down Expand Up @@ -1003,6 +1002,7 @@ def linear_dtype_gaurd():
# fmt: on

if self.config.tensor_parallel_degree > 1 and self.config.sequence_parallel:

def grad_allreduce_hook(param, accumulation_steps):
hcg = fleet.get_hybrid_communicate_group()
pg = hcg.get_model_parallel_group().process_group
Expand All @@ -1018,10 +1018,22 @@ def __impl__():
pg.allreduce(param.grad).wait()

return __impl__

# kv_a_proj_with_mqa and q_a_proj grad need to be reduce between mp
self.kv_a_proj_with_mqa.weight._register_backward_hook(grad_allreduce_hook(self.kv_a_proj_with_mqa.weight, accumulation_steps=config.gradient_accumulation_steps))
self.q_a_proj.weight._register_backward_hook(grad_allreduce_hook(self.q_a_proj.weight, accumulation_steps=config.gradient_accumulation_steps))

# self.kv_a_proj_with_mqa.weight._register_backward_hook(
# grad_allreduce_hook(
# self.kv_a_proj_with_mqa.weight, accumulation_steps=config.gradient_accumulation_steps
# )
# )
# self.q_a_proj.weight._register_backward_hook(
# grad_allreduce_hook(self.q_a_proj.weight, accumulation_steps=config.gradient_accumulation_steps)
# )
mark_as_sequence_parallel_parameter(self.kv_a_proj_with_mqa.weight)
mark_as_sequence_parallel_parameter(self.q_a_proj.weight)
if config.attention_bias:
mark_as_sequence_parallel_parameter(self.kv_a_proj_with_mqa.bias)
mark_as_sequence_parallel_parameter(self.q_a_proj.bias)

self._init_rope()

self.softmax_scale = self.q_head_dim ** (-0.5)
Expand Down Expand Up @@ -1284,18 +1296,41 @@ def subbatch_recompute_forward(
seq_len = hidden_states.shape[seq_axis]
assert seq_len % sub_seq_len == 0
num_chunks = seq_len // sub_seq_len
split_list = [sub_seq_len] * num_chunks
input_list = paddle.split(hidden_states, split_list, axis=seq_axis)
input_list = paddle.split(hidden_states, num_chunks, axis=seq_axis)
output_list = []

for chunk in input_list:
out = recompute(
self.mlp.forward,
chunk,
**offload_kwargs,
)
output_list.append(out)
if isinstance(self.mlp, DeepseekV2MoEFlexToken):
if attn_mask_startend_row_indices is not None:
if self.config.sequence_parallel and self.config.tensor_parallel_degree > 1:
flat_mask = paddle.transpose(attn_mask_startend_row_indices, [2, 0, 1])
flat_mask = ScatterOp.apply(flat_mask)
flat_mask = paddle.flatten(flat_mask)
mask_list = paddle.split(flat_mask, num_chunks)
else:
mask_list = [None] * num_chunks

for chunk, mask_chunk in zip(input_list, mask_list):
masked_tokens = None
if mask_chunk is not None:
masked_tokens = mask_chunk == 0
offload_kwargs["offload_indices"] = [0]
out = recompute(
self.mlp.forward,
chunk,
masked_tokens=masked_tokens,
**offload_kwargs,
)
output_list.append(out)
else:
for chunk in input_list:
out = recompute(
self.mlp.forward,
chunk,
**offload_kwargs,
)
output_list.append(out)
hidden_states = paddle.concat(output_list, axis=seq_axis)
offload_kwargs["offload_indices"] = [0]
outputs = recompute(
self.post_process,
hidden_states,
Expand Down Expand Up @@ -1431,7 +1466,17 @@ def forward(
self_attn_weights = attn_outputs[2] if output_attentions else None
present_key_value = attn_outputs[3] if use_cache else None

hidden_states = self.mlp(hidden_states)
if attn_mask_startend_row_indices is not None and isinstance(self.mlp, DeepseekV2MoEFlexToken):
masked_tokens = None
if self.config.sequence_parallel and self.config.tensor_parallel_degree > 1:
flat_mask = paddle.transpose(attn_mask_startend_row_indices, [2, 0, 1])
flat_mask = ScatterOp.apply(flat_mask)
flat_mask = paddle.flatten(flat_mask)
masked_tokens = flat_mask == 0
hidden_states = self.mlp(hidden_states, masked_tokens=masked_tokens)
else:
hidden_states = self.mlp(hidden_states)

outputs = self.post_process(
hidden_states, residual, output_attentions, use_cache, self_attn_weights, present_key_value
)
Expand Down Expand Up @@ -1547,6 +1592,10 @@ def __init__(
self.hnorm = DeepseekV2RMSNorm(config)
self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size)

if config.sequence_parallel and config.tensor_parallel_degree > 1:
mark_as_sequence_parallel_parameter(self.eh_proj.weight)
mark_as_sequence_parallel_parameter(self.eh_proj.bias)

def subbatch_recompute_forward(
self,
hidden_states: paddle.Tensor,
Expand Down Expand Up @@ -2226,10 +2275,6 @@ def __init__(self, config: DeepseekV2Config):
else:
self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none", ignore_index=self.ignore_index)

if self.config.sequence_parallel:
self.seq_para_scale = 1.0 / self.config.tensor_parallel_degree
self.mp_group = fleet.get_hybrid_communicate_group().get_model_parallel_group()

def forward(self, prediction_scores, masked_lm_labels, router_loss=None, mtp_logits=None):

if self.enable_parallel_cross_entropy:
Expand All @@ -2247,17 +2292,11 @@ def compute_loss(preds, labels):
)
count = paddle.sum(binary_sequence)

if self.config.sequence_parallel:
dist.all_reduce(count, op=ReduceOp.SUM, group=self.mp_group)

if count == 0:
loss = paddle.sum(masked_lm_loss * binary_sequence)
else:
loss = paddle.sum(masked_lm_loss * binary_sequence) / count

if self.config.sequence_parallel:
dist.all_reduce(loss, op=ReduceOp.SUM, group=self.mp_group)

return loss

def add_loss(main_loss, loss):
Expand All @@ -2269,28 +2308,15 @@ def add_loss(main_loss, loss):
masked_lm_labels = masked_lm_labels[:, : -self.config.num_nextn_predict_layers]
seq_length = masked_lm_labels.shape[1]

if self.config.sequence_parallel:
masked_lm_labels = masked_lm_labels.transpose([1, 0]) # [B, S] --> [S, B]
masked_lm_labels = ScatterOp.apply(masked_lm_labels)

loss = compute_loss(prediction_scores, masked_lm_labels)

mtp_loss_res = []
for depth in range(self.config.num_nextn_predict_layers):
prediction_scores_cur_depth = mtp_logits[depth]
masked_lm_labels_cur_depth = masked_lm_labels_ori[:, (depth + 1) : (depth + 1 + seq_length)]

if self.config.sequence_parallel:
masked_lm_labels_cur_depth = masked_lm_labels_cur_depth.transpose([1, 0]) # [B, S] --> [S, B]
masked_lm_labels_cur_depth = ScatterOp.apply(masked_lm_labels_cur_depth)

res_cur_depth = compute_loss(prediction_scores_cur_depth, masked_lm_labels_cur_depth)

if self.config.sequence_parallel:
res_cur_depth = res_cur_depth * self.seq_para_scale
dist.all_reduce(res_cur_depth, op=ReduceOp.SUM, group=self.mp_group)

mtp_loss_res.append(res_cur_depth)

loss = add_loss(loss, self.config.num_nextn_predict_lambda * sum([x for x in mtp_loss_res]) / len(mtp_loss_res)) # fmt: skip

else:
Expand Down Expand Up @@ -2336,9 +2362,9 @@ def __init__(self, config: DeepseekV2Config):

def forward(self, hidden_states, tensor_parallel_output=None):

# if self.config.sequence_parallel:
# hidden_states = GatherOp.apply(hidden_states)
# hidden_states = paddle.transpose(hidden_states, [1, 0, 2])
if self.config.sequence_parallel:
hidden_states = GatherOp.apply(hidden_states)
hidden_states = paddle.transpose(hidden_states, [1, 0, 2]) # [S, B, H] --> [B, S, H]
# hidden_states = paddle.reshape_(hidden_states, [-1, self.seq_length, self.config.hidden_size])

if tensor_parallel_output is None:
Expand Down
3 changes: 2 additions & 1 deletion paddlenlp/transformers/deepseek_v2/modeling_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,11 +294,12 @@ def forward(self, args):
hidden_states = hidden_states_main_model
for depth in range(self.config.num_nextn_predict_layers):
inputs_embeds_cur_depth = inputs_embeds_cur_depth_list[depth]

moelayer_use_subbatch_recompute = self.config.moe_subbatch_token_num > 0
if moelayer_use_subbatch_recompute:
hidden_states = super().subbatch_recompute_forward(
hidden_states,
inputs_embeds_cur_depth,
position_ids=position_ids,
attention_mask=attention_mask,
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
Expand Down
4 changes: 2 additions & 2 deletions paddlenlp/transformers/moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,12 +378,12 @@ def expert_forward(self, dispatched_input, tokens_per_expert):

return paddle.concat(outputs, axis=0)

def forward(self, hidden_states: paddle.Tensor):
def forward(self, hidden_states: paddle.Tensor, masked_tokens=None):
_, _, d_model = hidden_states.shape
# reshaped_input = hidden_states.reshape([-1, d_model])
probs, routing_map, l_aux, l_zloss = self.router(hidden_states)
(dispatched_input, tokens_per_expert) = self.token_dispatcher.token_permutation(
hidden_states, probs, routing_map
hidden_states, probs, routing_map, masked_tokens
)
expert_output = self.expert_forward(dispatched_input, tokens_per_expert)
output, _ = self.token_dispatcher.token_unpermutation(expert_output, None)
Expand Down
7 changes: 6 additions & 1 deletion paddlenlp/transformers/token_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,12 +261,17 @@ def __init__(self, num_local_experts: int, moe_router_topk: int, num_moe_experts
)

def token_permutation(
self, hidden_states: paddle.Tensor, probs: paddle.Tensor, routing_map: paddle.Tensor
self, hidden_states: paddle.Tensor, probs: paddle.Tensor, routing_map: paddle.Tensor, masked_tokens=None
) -> Tuple[paddle.Tensor, paddle.Tensor]:
self.hidden_shape = hidden_states.shape
hidden_states = hidden_states.view([-1, self.hidden_shape[-1]])

self._comm_manager.setup_metadata(routing_map, probs)
if masked_tokens is not None:
self._comm_manager.token_indices.stop_gradient = True
masked_tokens = masked_tokens.unsqueeze(axis=-1)
self._comm_manager.token_indices = paddle.masked_fill(self._comm_manager.token_indices, masked_tokens, -1)

hidden_states = self._comm_manager.dispatch(hidden_states)
global_input_tokens = self._comm_manager.get_permuted_hidden_states_by_experts(hidden_states)
tokens_per_expert = self._comm_manager.get_number_of_tokens_per_expert()
Expand Down
Loading