Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,35 @@
np.random.seed(0)
paddle.seed(0)

import os
from contextlib import contextmanager


@contextmanager
def temp_env(key, value):
old_val = os.environ.get(key)
os.environ[key] = value
try:
yield
finally:
if old_val is None:
os.environ.pop(key, None)
else:
os.environ[key] = old_val


class TestTreeMask(unittest.TestCase):
def setUp(self):
# TODO(liuzichang): If set q_head=32 or bsz=128, some case will fail.
paddle.seed(0)
self.max_seq_len = 32768
self.encoder_max_partition_size = self.max_seq_len
self.max_partition_size = self.max_seq_len

self.max_dec_len = 1024
self.bsz = 64
self.run_time = 10
self.warm_up = 2
self.run_time = 3
self.warm_up = 1
self.block_size = 64
self.head_dim = 128
self.num_q_head = 20
Expand Down Expand Up @@ -157,7 +174,7 @@ def ref_attention(self, q, k, v, mask, use_qknorm=False):
)

def run_append_c16_attention(
self, q_len, kv_len, prefill=False, attn_mask=None, use_qknorm=False, mask_offset=None
self, q_len, kv_len, prefill=False, attn_mask=None, use_qknorm=False, mask_offset=None, qkv=None
):
if prefill:
seq_lens_enc = [
Expand All @@ -184,19 +201,18 @@ def run_append_c16_attention(
batch_id_per_token, cu_seqlens_q, cu_seqlens_k = self.get_padding_offset(
self.bsz, seq_lens_this_time, seq_lens_decoder
)
if qkv is None:
qkv_varlen_shape = [token_num, (self.num_q_head + 2 * self.num_kv_head) * self.head_dim]
qkv = paddle.randn(shape=qkv_varlen_shape).astype(self.dtype)
self.split_qkv(qkv, self.bsz, q_len)

qkv_varlen_shape = [token_num, (self.num_q_head + 2 * self.num_kv_head) * self.head_dim]
rotary_embs_shape = [
2,
1,
self.max_seq_len,
1,
self.head_dim if self.use_neox_rotary_style else self.head_dim // 2,
]

qkv = paddle.randn(shape=qkv_varlen_shape).astype(self.dtype)
self.split_qkv(qkv, self.bsz, q_len)

rotary_embs = paddle.randn(shape=rotary_embs_shape).astype("float32")
rotary_embs[0, :, :, :, :] = 1
rotary_embs[1, :, :, :, :] = 0
Expand Down Expand Up @@ -320,6 +336,9 @@ def run_append_c16_attention(
return out.reshape([token_num, self.num_q_head, self.head_dim])

def test_naive_speculative_decoding(self):
"""
在 speculative mode 下,测试 Attention 在 causal_mask 下的功能
"""
prefill_len = 8192
dec_len_q = 5
total_len = prefill_len + dec_len_q
Expand All @@ -334,6 +353,9 @@ def test_naive_speculative_decoding(self):
)

def test_mask(self):
"""
在 speculative mode 下,测试 Attention 在传入 mask 下的功能
"""
prefill_len = 8192
dec_len_q = 5
total_len = prefill_len + dec_len_q
Expand All @@ -357,6 +379,9 @@ def test_mask(self):
)

def test_tree_mask(self):
"""
在 speculative mode 下,测试 Attention 在传入 tree mask 下的功能
"""
prefill_len = 8192
dec_len_q = 5
total_len = prefill_len + dec_len_q
Expand All @@ -383,6 +408,9 @@ def test_tree_mask(self):
)

def test_mask_offset(self):
"""
在 speculative mode 下,测试 Attention 在传入 mask_offset 下的功能
"""
prefill_len = 8192
dec_len_q = 5
total_len = prefill_len + dec_len_q
Expand All @@ -406,6 +434,107 @@ def test_mask_offset(self):
ref_out.astype("float32").numpy(), dec_out.astype("float32").numpy(), rtol=1e-03, atol=5e-03
)

def test_consistency_with_multi_tokens(self):
"""
严格测试投机解码多token功能,包含 qkv_norm/mask_offset,通过对比:
(A) 完整 5 步推理中取出第 1~3 步的结果
(B) 单独跑第 1~3 步推理的结果
若一致则投机解码 Attention 计算正确。
"""
# -----------------------------
# 基础参数
# -----------------------------
import os

os.environ["FLAGS_max_partition_size"] = "131072"

prefill_len = 8192
dec_len_q = 5
dec_start, dec_end = 1, 3
dec_len = dec_end - dec_start

head_num = self.num_q_head + 2 * self.num_kv_head
qkv_dim = head_num * self.head_dim

# -----------------------------
# Prefill 阶段(构建 KV Cache)
# -----------------------------
with temp_env("FLAGS_max_partition_size", "131072"):
self.run_append_c16_attention(
prefill_len, # q_len
0, # q_start
True, # is_prefill
use_qknorm=self.use_qknorm,
)

# -----------------------------
# 构造 mask_offset:每步两项 [0, prefill+i]
# -----------------------------
pattern = []
for i in range(dec_len_q):
pattern.extend([0, prefill_len + i + 1])
mask_offset = paddle.tile(paddle.to_tensor(pattern, dtype="int32"), [self.bsz])

# 截取 dec_start ~ dec_end 的偏移字段
mask_offset_slice = (
mask_offset.reshape([self.bsz, -1])[:, 2 * dec_start : 2 * dec_end].reshape([-1]).astype("int32")
)

# -----------------------------
# 构造完整 qkv(5 步)
# -----------------------------
qkv_full = paddle.randn([self.bsz * dec_len_q, qkv_dim], dtype=self.dtype)

# -----------------------------
# 运行完整 5 步 attention
# (前三个参数依然必须位置传参)
# -----------------------------
with temp_env("FLAGS_max_partition_size", "131072"):
dec_out_full = self.run_append_c16_attention(
dec_len_q, # q_len
prefill_len, # q_start
False, # is_prefill
use_qknorm=self.use_qknorm,
mask_offset=mask_offset,
qkv=qkv_full,
)

# -----------------------------
# 构造截取后的 qkv:只保留第 1~3 步
# -----------------------------
qkv_slice = qkv_full.reshape([self.bsz, dec_len_q, qkv_dim])[:, dec_start:dec_end, :].reshape([-1, qkv_dim])

# -----------------------------
# 单独跑 1~3 步 attention
# -----------------------------
with temp_env("FLAGS_max_partition_size", "131072"):
dec_out_slice = self.run_append_c16_attention(
dec_len, # q_len
prefill_len + dec_start, # q_start
False, # is_prefill
use_qknorm=self.use_qknorm,
mask_offset=mask_offset_slice,
qkv=qkv_slice,
)

# -----------------------------
# 从 full 结果中提取第 1~3 步的输出并对齐
# -----------------------------
dec_out_full_range = (
dec_out_full.reshape([self.bsz, dec_len_q, -1, self.head_dim])[:, dec_start:dec_end, :, :]
.reshape([self.bsz * dec_len, -1, self.head_dim])
.astype("float32")
.numpy()
)

# -----------------------------
# 两种模式输出必须完全一致
# -----------------------------
np.testing.assert_array_equal(
dec_out_full_range,
dec_out_slice.astype("float32").numpy(),
)


if __name__ == "__main__":
unittest.main()
Loading