From df47da54b37d180b8f88d8610980183144f5aec2 Mon Sep 17 00:00:00 2001 From: freeliuzc Date: Wed, 3 Dec 2025 13:39:45 +0800 Subject: [PATCH 1/2] update unit test of append_attention in spec_mode --- ..._tree_mask.py => test_append_attention.py} | 144 +++++++++++++++++- 1 file changed, 136 insertions(+), 8 deletions(-) rename tests/operators/{test_tree_mask.py => test_append_attention.py} (76%) diff --git a/tests/operators/test_tree_mask.py b/tests/operators/test_append_attention.py similarity index 76% rename from tests/operators/test_tree_mask.py rename to tests/operators/test_append_attention.py index 57a62044814..3e833b01f47 100644 --- a/tests/operators/test_tree_mask.py +++ b/tests/operators/test_append_attention.py @@ -31,6 +31,22 @@ np.random.seed(0) paddle.seed(0) +import os +from contextlib import contextmanager + + +@contextmanager +def temp_env(key, value): + old_val = os.environ.get(key) + os.environ[key] = value + try: + yield + finally: + if old_val is None: + os.environ.pop(key, None) + else: + os.environ[key] = old_val + class TestTreeMask(unittest.TestCase): def setUp(self): @@ -40,12 +56,12 @@ def setUp(self): self.max_partition_size = self.max_seq_len self.max_dec_len = 1024 - self.bsz = 64 + self.bsz = 2 self.run_time = 10 self.warm_up = 2 self.block_size = 64 self.head_dim = 128 - self.num_q_head = 20 + self.num_q_head = 32 self.num_kv_head = 4 self.use_qknorm = True self.dtype = "bfloat16" @@ -157,7 +173,7 @@ def ref_attention(self, q, k, v, mask, use_qknorm=False): ) def run_append_c16_attention( - self, q_len, kv_len, prefill=False, attn_mask=None, use_qknorm=False, mask_offset=None + self, q_len, kv_len, prefill=False, attn_mask=None, use_qknorm=False, mask_offset=None, qkv=None ): if prefill: seq_lens_enc = [ @@ -184,8 +200,11 @@ def run_append_c16_attention( batch_id_per_token, cu_seqlens_q, cu_seqlens_k = self.get_padding_offset( self.bsz, seq_lens_this_time, seq_lens_decoder ) + if qkv is None: + qkv_varlen_shape = [token_num, (self.num_q_head + 2 * self.num_kv_head) * self.head_dim] + qkv = paddle.randn(shape=qkv_varlen_shape).astype(self.dtype) + self.split_qkv(qkv, self.bsz, q_len) - qkv_varlen_shape = [token_num, (self.num_q_head + 2 * self.num_kv_head) * self.head_dim] rotary_embs_shape = [ 2, 1, @@ -193,10 +212,6 @@ def run_append_c16_attention( 1, self.head_dim if self.use_neox_rotary_style else self.head_dim // 2, ] - - qkv = paddle.randn(shape=qkv_varlen_shape).astype(self.dtype) - self.split_qkv(qkv, self.bsz, q_len) - rotary_embs = paddle.randn(shape=rotary_embs_shape).astype("float32") rotary_embs[0, :, :, :, :] = 1 rotary_embs[1, :, :, :, :] = 0 @@ -320,6 +335,9 @@ def run_append_c16_attention( return out.reshape([token_num, self.num_q_head, self.head_dim]) def test_naive_speculative_decoding(self): + """ + 在 speculative mode 下,测试 Attention 在 causal_mask 下的功能 + """ prefill_len = 8192 dec_len_q = 5 total_len = prefill_len + dec_len_q @@ -334,6 +352,9 @@ def test_naive_speculative_decoding(self): ) def test_mask(self): + """ + 在 speculative mode 下,测试 Attention 在传入 mask 下的功能 + """ prefill_len = 8192 dec_len_q = 5 total_len = prefill_len + dec_len_q @@ -357,6 +378,9 @@ def test_mask(self): ) def test_tree_mask(self): + """ + 在 speculative mode 下,测试 Attention 在传入 tree mask 下的功能 + """ prefill_len = 8192 dec_len_q = 5 total_len = prefill_len + dec_len_q @@ -383,6 +407,9 @@ def test_tree_mask(self): ) def test_mask_offset(self): + """ + 在 speculative mode 下,测试 Attention 在传入 mask_offset 下的功能 + """ prefill_len = 8192 dec_len_q = 5 total_len = prefill_len + dec_len_q @@ -406,6 +433,107 @@ def test_mask_offset(self): ref_out.astype("float32").numpy(), dec_out.astype("float32").numpy(), rtol=1e-03, atol=5e-03 ) + def test_consistency_with_multi_tokens(self): + """ + 严格测试投机解码多token功能,包含 qkv_norm/mask_offset,通过对比: + (A) 完整 5 步推理中取出第 1~3 步的结果 + (B) 单独跑第 1~3 步推理的结果 + 若一致则投机解码 Attention 计算正确。 + """ + # ----------------------------- + # 基础参数 + # ----------------------------- + import os + + os.environ["FLAGS_max_partition_size"] = "131072" + + prefill_len = 8192 + dec_len_q = 5 + dec_start, dec_end = 1, 3 + dec_len = dec_end - dec_start + + head_num = self.num_q_head + 2 * self.num_kv_head + qkv_dim = head_num * self.head_dim + + # ----------------------------- + # Prefill 阶段(构建 KV Cache) + # ----------------------------- + with temp_env("FLAGS_max_partition_size", "131072"): + self.run_append_c16_attention( + prefill_len, # q_len + 0, # q_start + True, # is_prefill + use_qknorm=self.use_qknorm, + ) + + # ----------------------------- + # 构造 mask_offset:每步两项 [0, prefill+i] + # ----------------------------- + pattern = [] + for i in range(dec_len_q): + pattern.extend([0, prefill_len + i + 1]) + mask_offset = paddle.tile(paddle.to_tensor(pattern, dtype="int32"), [self.bsz]) + + # 截取 dec_start ~ dec_end 的偏移字段 + mask_offset_slice = ( + mask_offset.reshape([self.bsz, -1])[:, 2 * dec_start : 2 * dec_end].reshape([-1]).astype("int32") + ) + + # ----------------------------- + # 构造完整 qkv(5 步) + # ----------------------------- + qkv_full = paddle.randn([self.bsz * dec_len_q, qkv_dim], dtype=self.dtype) + + # ----------------------------- + # 运行完整 5 步 attention + # (前三个参数依然必须位置传参) + # ----------------------------- + with temp_env("FLAGS_max_partition_size", "131072"): + dec_out_full = self.run_append_c16_attention( + dec_len_q, # q_len + prefill_len, # q_start + False, # is_prefill + use_qknorm=self.use_qknorm, + mask_offset=mask_offset, + qkv=qkv_full, + ) + + # ----------------------------- + # 构造截取后的 qkv:只保留第 1~3 步 + # ----------------------------- + qkv_slice = qkv_full.reshape([self.bsz, dec_len_q, qkv_dim])[:, dec_start:dec_end, :].reshape([-1, qkv_dim]) + + # ----------------------------- + # 单独跑 1~3 步 attention + # ----------------------------- + with temp_env("FLAGS_max_partition_size", "131072"): + dec_out_slice = self.run_append_c16_attention( + dec_len, # q_len + prefill_len + dec_start, # q_start + False, # is_prefill + use_qknorm=self.use_qknorm, + mask_offset=mask_offset_slice, + qkv=qkv_slice, + ) + + # ----------------------------- + # 从 full 结果中提取第 1~3 步的输出并对齐 + # ----------------------------- + dec_out_full_range = ( + dec_out_full.reshape([self.bsz, dec_len_q, -1, self.head_dim])[:, dec_start:dec_end, :, :] + .reshape([self.bsz * dec_len, -1, self.head_dim]) + .astype("float32") + .numpy() + ) + + # ----------------------------- + # 两种模式输出必须完全一致 + # ----------------------------- + np.testing.assert_array_equal( + dec_out_full_range, + dec_out_slice.astype("float32").numpy(), + ) + if __name__ == "__main__": unittest.main() From 0aad8ae17fcd32640be3474085eaa91359618af3 Mon Sep 17 00:00:00 2001 From: freeliuzc Date: Wed, 3 Dec 2025 14:31:25 +0800 Subject: [PATCH 2/2] update unit test --- tests/operators/test_append_attention.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/operators/test_append_attention.py b/tests/operators/test_append_attention.py index 3e833b01f47..b05f173fce2 100644 --- a/tests/operators/test_append_attention.py +++ b/tests/operators/test_append_attention.py @@ -50,18 +50,19 @@ def temp_env(key, value): class TestTreeMask(unittest.TestCase): def setUp(self): + # TODO(liuzichang): If set q_head=32 or bsz=128, some case will fail. paddle.seed(0) self.max_seq_len = 32768 self.encoder_max_partition_size = self.max_seq_len self.max_partition_size = self.max_seq_len self.max_dec_len = 1024 - self.bsz = 2 - self.run_time = 10 - self.warm_up = 2 + self.bsz = 64 + self.run_time = 3 + self.warm_up = 1 self.block_size = 64 self.head_dim = 128 - self.num_q_head = 32 + self.num_q_head = 20 self.num_kv_head = 4 self.use_qknorm = True self.dtype = "bfloat16"