Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion paddlenlp/transformers/deepseek_v2/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ def __init__(
attention_dropout=0.0,
speculate_model_type=False,
using_flex_token=False,
deepep_fine_grained=False,
deepep_tokens_per_subbatch=1024,
**kwargs,
):
self.vocab_size = vocab_size
Expand Down Expand Up @@ -227,7 +229,8 @@ def __init__(
self.speculate_model_type = speculate_model_type
self.use_fp8 = False
self.using_flex_token = using_flex_token

self.deepep_fine_grained = deepep_fine_grained
self.deepep_tokens_per_subbatch = deepep_tokens_per_subbatch
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
Expand Down
122 changes: 117 additions & 5 deletions paddlenlp/transformers/moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,13 @@
from paddle import Tensor, nn
from paddle.distributed.communication.group import Group

try:
from paddle import scatter_add_
except ImportError:
scatter_add_ = None

from .moe_gate import PretrainedMoEGate
from .moe_utils import FakeGather, topk_to_permuted_indices_single
from .token_dispatcher import MoEFlexTokenDispatcher


Expand Down Expand Up @@ -366,13 +372,119 @@ def expert_forward(self, dispatched_input, tokens_per_expert):

return paddle.concat(outputs, axis=0)

def maybe_split_subbatch_data(self, permuted_tokens, token_permuted_indices, prob_permuted_indices):
"""maybe_split_subbatch_data"""

def split_subbatch_data(data, tokens_per_subbatch):
total_token_num = data.shape[0]

full_batch_num, remainder = divmod(total_token_num, tokens_per_subbatch)
num_or_sections = [tokens_per_subbatch] * full_batch_num
if remainder:
num_or_sections.append(remainder)

assert (
sum(num_or_sections) == total_token_num
), f"get_subbatch_data fail, {sum(num_or_sections)}, {total_token_num}"
# when data is 0-size tensor, we need to compute it and construct the right backward graph.
if total_token_num == 0:
return [data]
return paddle.split(data, num_or_sections=num_or_sections, axis=0)

if self.config.deepep_tokens_per_subbatch > 0:
assert (
permuted_tokens.shape[0] == token_permuted_indices.shape[0]
), f"Shape mismatch between {permuted_tokens.shape[0]} and {token_permuted_indices.shape[0]}"
assert (
permuted_tokens.shape[0] == prob_permuted_indices.shape[0]
), f"Shape mismatch between {permuted_tokens.shape[0]} and {prob_permuted_indices.shape[0]}"
permuted_tokens_list = split_subbatch_data(permuted_tokens, self.config.deepep_tokens_per_subbatch)
token_permuted_indices_list = split_subbatch_data(
token_permuted_indices, self.config.deepep_tokens_per_subbatch
)
prob_permuted_indices_list = split_subbatch_data(
prob_permuted_indices, self.config.deepep_tokens_per_subbatch
)
else:
permuted_tokens_list = [permuted_tokens]
token_permuted_indices_list = [token_permuted_indices]
prob_permuted_indices_list = [prob_permuted_indices]
return permuted_tokens_list, token_permuted_indices_list, prob_permuted_indices_list

def fine_grained_forward_experts(self, dispatched_input, dispatched_probs, dispatched_indices, dispatch_topk):
"""fine_grained_forward_experts"""
print("moe layer input shape ", self.hidden_shape, " moe layer dispatch output shape ", dispatched_input.shape)
output_tokens = paddle.zeros(dispatched_input.shape, dispatched_input.dtype)

tokens_per_expert = self.token_dispatcher._comm_manager.tokens_per_expert

for expert_id, num_tokens in enumerate(tokens_per_expert):

token_permuted_indices, prob_permuted_indices = topk_to_permuted_indices_single(
dispatched_indices, num_tokens, expert_id, dispatch_topk
)
permuted_tokens = FakeGather.apply(dispatched_input, token_permuted_indices)
# If deepep_tokens_per_subbatch > 0, the data is split into multiple subbatches.
(
permuted_tokens_list,
token_permuted_indices_list,
prob_permuted_indices_list,
) = self.maybe_split_subbatch_data(permuted_tokens, token_permuted_indices, prob_permuted_indices)

for permuted_tokens_, token_permuted_indices_, prob_permuted_indices_ in zip(
permuted_tokens_list, token_permuted_indices_list, prob_permuted_indices_list
):
# ffn
permuted_tokens_ = self.experts[expert_id](permuted_tokens_)
# local unpermute
if dispatched_probs is not None:
permuted_probs = FakeGather.apply(dispatched_probs.flatten(), prob_permuted_indices_)
if permuted_tokens_.dtype != permuted_probs.dtype:
new_permuted_tokens = permuted_tokens_.astype(permuted_probs.dtype)
else:
new_permuted_tokens = permuted_tokens_
permuted_tokens_ = new_permuted_tokens * permuted_probs.unsqueeze(-1)
if scatter_add_ is not None:
scatter_add_(output_tokens, token_permuted_indices_, permuted_tokens_.astype(output_tokens.dtype))
else:
output_tokens.scatter_(
index=token_permuted_indices_,
updates=permuted_tokens_.astype(output_tokens.dtype),
overwrite=False,
)

dispatched_input._clear_to_zero_allocation()

return output_tokens

def forward(self, hidden_states: paddle.Tensor):
_, _, d_model = hidden_states.shape
# reshaped_input = hidden_states.reshape([-1, d_model])
probs, routing_map, l_aux, l_zloss = self.router(hidden_states)
(dispatched_input, tokens_per_expert) = self.token_dispatcher.token_permutation(
hidden_states, probs, routing_map
)
expert_output = self.expert_forward(dispatched_input, tokens_per_expert)
output, _ = self.token_dispatcher.token_unpermutation(expert_output, None)

if self.config.deepep_fine_grained:
# global dispatch
# (dispatched_input, tokens_per_expert) = self.token_dispatcher.token_permutation(
# hidden_states, probs, routing_map
# )
self.hidden_shape = hidden_states.shape
hidden_states = hidden_states.view([-1, self.hidden_shape[-1]])

self.token_dispatcher._comm_manager.setup_metadata(routing_map, probs)
dispatched_input = self.token_dispatcher._comm_manager.dispatch(hidden_states)

dispatched_indices = self.token_dispatcher._comm_manager.dispatched_indices
dispatched_probs = self.token_dispatcher._comm_manager.dispatched_probs
# local dispatch & forward_experts & local combine
output_tokens = self.fine_grained_forward_experts(
dispatched_input, dispatched_probs, dispatched_indices, self.moe_router_topk
)
# global combine
output = self.token_dispatcher._comm_manager.combine(output_tokens)
else:
(dispatched_input, tokens_per_expert) = self.token_dispatcher.token_permutation(
hidden_states, probs, routing_map
)
expert_output = self.expert_forward(dispatched_input, tokens_per_expert)
output, _ = self.token_dispatcher.token_unpermutation(expert_output, None)
return output, l_aux, l_zloss
82 changes: 82 additions & 0 deletions paddlenlp/transformers/moe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,85 @@ def unpermute(
else:
output_tokens.scatter_(index=sorted_indices, updates=permuted_tokens, overwrite=False)
return output_tokens


def topk_to_permuted_indices_single(x, num_tokens, expert_id, topk):
"""
Convert the topk indices to permuted indices.
"""
x = paddle.flatten(x)
prob_permuted_indices = paddle.tensor.search._restrict_nonzero(x == expert_id, num_tokens).flatten()
token_permuted_indices = prob_permuted_indices // topk
return token_permuted_indices, prob_permuted_indices


def topk_to_permuted_indices(x, num_tokens_per_expert_list, topk):
"""
Convert the topk indices to permuted indices.
"""
x = paddle.flatten(x)
prob_permuted_indices = paddle.concat(
[
paddle.tensor.search._restrict_nonzero(x == i, total_true_num)
for i, total_true_num in enumerate(num_tokens_per_expert_list)
]
).flatten()
token_permuted_indices = prob_permuted_indices // topk
return token_permuted_indices, prob_permuted_indices


class FakeClone(paddle.autograd.PyLayer):
"""
manual_backward中, 为了保留局部的计算图做临时反向计算
需要把manual_backward的output给clone出来, 这个clone
本质上不需要output的值, 而是需要拿到output身上的计算图

但调用paddle.clone会做一次额外的数据拷贝, 这是没必要的
FakeClone可以免去这个数据拷贝, 实现摘取计算图的目的
"""

@staticmethod
def forward(ctx, input):
"""forward"""
if input.is_contiguous():
fake_output = paddle.empty_like(input)
input._share_buffer_to(fake_output)
else:
fake_output = input.clone()
return fake_output

@staticmethod
def backward(ctx, grad_output):
"""backward"""
return grad_output


class FakeGather(paddle.autograd.PyLayer):
"""
临时绕开gather 0size索引的coredump问题
"""

@staticmethod
def forward(ctx, input, indices):
"""forward"""
assert len(indices.shape) == 1
ctx.save_for_backward(indices)
ctx.input_shape = input.shape
if indices.shape[0] == 0:
out_shape = input.shape
out_shape[0] = 0
return paddle.zeros(out_shape, dtype=input.dtype)
return paddle.index_select(input, axis=0, index=indices)

@staticmethod
def backward(ctx, grad_output):
"""backward"""
indices = ctx.saved_tensor()
input_shape = ctx.input_shape
grad_input = paddle.zeros(input_shape, dtype=grad_output.dtype)
if indices.shape[0] != 0:
if scatter_add_ is not None:
scatter_add_(grad_input, indices.unsqueeze(-1), grad_output)
else:
paddle.scatter_(grad_input, indices.unsqueeze(-1), grad_output, overwrite=False)
return grad_input, None
Loading