From 4a4ebc6f62b12447fe6a18dfa5c098b0674decb0 Mon Sep 17 00:00:00 2001 From: superAngGao Date: Wed, 11 Feb 2026 10:45:10 +0800 Subject: [PATCH 1/2] manually updating the branch --- tests/ops/test_gqa_decode_paged_legacy.py | 35 ----------------------- tests/ops/test_mha_decode_paged_legacy.py | 35 ----------------------- tests/ops/test_mhc_pre.py | 4 +-- top/kernels/mhc/mhc_pre.py | 23 +++++++++++---- top/ops/mhc_pre.py | 16 ++++++++--- 5 files changed, 31 insertions(+), 82 deletions(-) delete mode 100644 tests/ops/test_gqa_decode_paged_legacy.py delete mode 100644 tests/ops/test_mha_decode_paged_legacy.py diff --git a/tests/ops/test_gqa_decode_paged_legacy.py b/tests/ops/test_gqa_decode_paged_legacy.py deleted file mode 100644 index aa69e8d..0000000 --- a/tests/ops/test_gqa_decode_paged_legacy.py +++ /dev/null @@ -1,35 +0,0 @@ -"""Legacy-style test for GroupQueryAttentionDecodePagedWithKVCacheOp (argparse + check + profile).""" - -import pytest -import torch - -from benchmarks.flash_decode import GroupQueryAttentionDecodePagedBenchmark -from top.ops import GroupQueryAttentionDecodePagedWithKVCacheOp - - -@pytest.mark.parametrize("batch,heads,groups,seqlen_kv,dim,page_size,dtype", [ - (1, 16, 8, 512, 128, 128, torch.float16), -]) -def test_gqa_decode_paged( - batch: int, - heads: int, - groups: int, - seqlen_kv: int, - dim: int, - page_size: int, - dtype: torch.dtype, - tune: bool = False, -) -> None: - torch.manual_seed(123) # 替代 fixture 中的随机种子设置 - op = GroupQueryAttentionDecodePagedWithKVCacheOp( - batch, heads, groups, seqlen_kv, dim, page_size, dtype, tune=tune) - benchmark = GroupQueryAttentionDecodePagedBenchmark(batch, heads, groups, seqlen_kv, dim, - page_size, dtype) - - inputs = benchmark.gen_inputs() - benchmark.check(op, *inputs, atol=1e-2, rtol=1e-2) - benchmark.profile(op, *inputs) - - -if __name__ == "__main__": - pytest.main([__file__, "-vvs"]) diff --git a/tests/ops/test_mha_decode_paged_legacy.py b/tests/ops/test_mha_decode_paged_legacy.py deleted file mode 100644 index 25015c8..0000000 --- a/tests/ops/test_mha_decode_paged_legacy.py +++ /dev/null @@ -1,35 +0,0 @@ -"""Legacy-style test for MultiHeadAttentionDecodePagedWithKVCacheOp (argparse + check + profile).""" - -import pytest -import torch - -from benchmarks.flash_decode import MultiHeadAttentionDecodePagedBenchmark -from top.ops import MultiHeadAttentionDecodePagedWithKVCacheOp - - -@pytest.mark.parametrize("batch,heads,seqlen_q,seqlen_kv,dim,page_size,is_causal,dtype", [ - (1, 16, 1, 512, 128, 128, False, torch.float16), -]) -def test_mha_decode_paged( - batch: int, - heads: int, - seqlen_q: int, - seqlen_kv: int, - dim: int, - page_size: int, - is_causal: bool, - dtype: torch.dtype, - tune: bool = False, -) -> None: - op = MultiHeadAttentionDecodePagedWithKVCacheOp( - batch, heads, seqlen_q, seqlen_kv, dim, page_size, is_causal, dtype, tune=tune) - benchmark = MultiHeadAttentionDecodePagedBenchmark(batch, heads, seqlen_q, seqlen_kv, dim, - page_size, is_causal, dtype) - - inputs = benchmark.gen_inputs() - benchmark.check(op, *inputs, atol=2e-3, rtol=1e-5) - benchmark.profile(op, *inputs) - - -if __name__ == "__main__": - pytest.main([__file__, "-vvs"]) diff --git a/tests/ops/test_mhc_pre.py b/tests/ops/test_mhc_pre.py index 7ce023d..b6df898 100644 --- a/tests/ops/test_mhc_pre.py +++ b/tests/ops/test_mhc_pre.py @@ -33,11 +33,12 @@ def test_mhc_pre_op( alpha_post = torch.randn(()) alpha_res = torch.randn(()) sinkhorn_repeat = 20 + eps = 0.02 #test_mhc_kernel = mhc_pre_kernel(batch, n_expand, c_x, dtype=torch.bfloat16) test_mhc_pre_op = ManifoldConstrainedHyperConnectionPreOp( batch, n_expand, c_x, dtype=torch.bfloat16) x_res, x_layer = test_mhc_pre_op.forward(phi, x, b, alpha_pre, alpha_post, alpha_res, - sinkhorn_repeat) + sinkhorn_repeat, eps) # check the correctness with torch... xsqr = x * x # the square of x @@ -57,7 +58,6 @@ def test_mhc_pre_op( H_pre_ref = torch.sigmoid(alpha_pre * H_pre_ref / r_ref.unsqueeze(-1) + b_pre_ref) H_res_ref = alpha_res * H_res_ref / r_ref.unsqueeze(-1).unsqueeze(-1) + b_res_ref - eps = 0.0001 H_res_ref_tmp = H_res_ref.max(dim=-1, keepdim=True).values H_res_ref = torch.exp(H_res_ref - H_res_ref_tmp) diff --git a/top/kernels/mhc/mhc_pre.py b/top/kernels/mhc/mhc_pre.py index 53d863c..db7688f 100644 --- a/top/kernels/mhc/mhc_pre.py +++ b/top/kernels/mhc/mhc_pre.py @@ -147,6 +147,7 @@ def _get_H_1( def _get_H_res( H_res_0: T.Tensor([batch, n_expand, n_expand], dtype), sinkhorn_repeat: T.int, + sinkhorn_eps: T.float, H_res: T.Tensor([batch, n_expand, n_expand], dtype), ): with T.Kernel(batch, threads=threads) as (bx): @@ -156,7 +157,7 @@ def _get_H_res( tmp2 = T.alloc_fragment([n_expand], dtype) h_out_shared = T.alloc_shared([n_expand, n_expand], dtype) - eps = 0.0001 + eps = sinkhorn_eps # exponential function... # get the max value first... for i, j in T.Parallel(n_expand, n_expand): @@ -239,13 +240,14 @@ def mhc_pre( H_post: T.Tensor([batch, n_expand], dtype), H_res_0: T.Tensor([batch, n_expand, n_expand], dtype), sinkhorn_repeat: T.int, + sinkhorn_eps: T.float, H_res: T.Tensor([batch, n_expand, n_expand], dtype), x_res: T.Tensor([batch, n_expand * c_x], x_dtype), x_layer: T.Tensor([batch, c_x], x_dtype), ): _get_H_0_no_split(phi, x, r, H) _get_H_1(H, r, b, alpha_pre, alpha_post, alpha_res, H_pre, H_post, H_res_0) - _get_H_res(H_res_0, sinkhorn_repeat, H_res) + _get_H_res(H_res_0, sinkhorn_repeat, sinkhorn_eps, H_res) _get_x(x, H_pre, H_res, x_res, x_layer) return mhc_pre @@ -259,12 +261,13 @@ def _mhc_pre_wrapped_kernel(batch: int, n_expand: int, c_x: int, dtype: str, blo x: torch.Tensor, H: torch.Tensor, r: torch.Tensor, b: torch.Tensor, alpha_pre: float, alpha_post: float, alpha_res: float, H_pre: torch.Tensor, H_post: torch.Tensor, H_res_0: torch.Tensor, - sinkhorn_repeat: int, H_res: torch.Tensor, + sinkhorn_repeat: int, sinkhorn_eps: float, H_res: torch.Tensor, x_res: torch.Tensor) -> torch.Tensor: return _mhc_pre_kernel(batch, n_expand, c_x, dtype)(block_x_b, block_C, num_stages, threads)(phi, x, H, r, b, alpha_pre, alpha_post, alpha_res, H_pre, - H_post, H_res_0, sinkhorn_repeat, H_res, x_res) + H_post, H_res_0, sinkhorn_repeat, sinkhorn_eps, H_res, + x_res) @_mhc_pre_wrapped_kernel.register_fake @@ -320,7 +323,15 @@ def autotune_configs(self) -> list[dict]: } for c in _configs] return configs - def forward(self, phi, x, b, alpha_pre, alpha_post, alpha_res, sinkhorn_repeat): + def forward(self, + phi, + x, + b, + alpha_pre, + alpha_post, + alpha_res, + sinkhorn_repeat, + sinkhorn_eps=0.02): # H_pre, H_post, H_res_0, H_res are tensors need to be allocated.... r = torch.empty([self.batch], device=x.device, dtype=self.weights_dtype) H = torch.empty([self.batch, self.n_expand * self.n_expand + 2 * self.n_expand], @@ -340,5 +351,5 @@ def forward(self, phi, x, b, alpha_pre, alpha_post, alpha_res, sinkhorn_repeat): self.config["block_x_b"], self.config["block_C"], self.config["num_stages"], self.config["threads"], phi, x, H, r, b, alpha_pre, alpha_post, alpha_res, H_pre, H_post, - H_res_0, sinkhorn_repeat, H_res, x_res) + H_res_0, sinkhorn_repeat, sinkhorn_eps, H_res, x_res) return x_res, result diff --git a/top/ops/mhc_pre.py b/top/ops/mhc_pre.py index 3aa91c9..9947023 100644 --- a/top/ops/mhc_pre.py +++ b/top/ops/mhc_pre.py @@ -33,7 +33,15 @@ def __init__(self, def default_kernel_map(self) -> Dict[str, Kernel]: return {"mhc_pre_kernel": mhc_pre_kernel} - def forward(self, phi: torch.Tensor, x: torch.Tensor, b: torch.Tensor, alpha_pre: float, - alpha_post: float, alpha_res: float, sinkhorn_repeat: int) -> torch.Tensor: - - return self.kernel(phi, x, b, alpha_pre, alpha_post, alpha_res, sinkhorn_repeat) + def forward(self, + phi: torch.Tensor, + x: torch.Tensor, + b: torch.Tensor, + alpha_pre: float, + alpha_post: float, + alpha_res: float, + sinkhorn_repeat: int, + sinkhorn_eps: float = 0.02) -> torch.Tensor: + + return self.kernel(phi, x, b, alpha_pre, alpha_post, alpha_res, sinkhorn_repeat, + sinkhorn_eps) From f536f7284ea9e84327ef31d330d1ffb942f4226a Mon Sep 17 00:00:00 2001 From: superAngGao Date: Wed, 11 Feb 2026 11:04:52 +0800 Subject: [PATCH 2/2] add pytest parameters for gqa_decode --- tests/ops/test_gqa_decode.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/ops/test_gqa_decode.py b/tests/ops/test_gqa_decode.py index 83a33fb..d2a38b3 100644 --- a/tests/ops/test_gqa_decode.py +++ b/tests/ops/test_gqa_decode.py @@ -9,6 +9,8 @@ "b, h, g, s_kv, d, dtype, tune", [ (1, 32, 8, 8192, 128, torch.float16, False), + (4, 32, 4, 4096, 128, torch.bfloat16, False), + (8, 64, 16, 8192, 128, torch.float16, False), ], ) def test_gqa_decode(b: int, h: int, g: int, s_kv: int, d: int, dtype: torch.dtype,