Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tests/ops/test_gqa_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
"b, h, g, s_kv, d, dtype, tune",
[
(1, 32, 8, 8192, 128, torch.float16, False),
(4, 32, 4, 4096, 128, torch.bfloat16, False),
(8, 64, 16, 8192, 128, torch.float16, False),
],
)
def test_gqa_decode(b: int, h: int, g: int, s_kv: int, d: int, dtype: torch.dtype,
Expand Down
35 changes: 0 additions & 35 deletions tests/ops/test_gqa_decode_paged_legacy.py

This file was deleted.

35 changes: 0 additions & 35 deletions tests/ops/test_mha_decode_paged_legacy.py

This file was deleted.

4 changes: 2 additions & 2 deletions tests/ops/test_mhc_pre.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,12 @@ def test_mhc_pre_op(
alpha_post = torch.randn(())
alpha_res = torch.randn(())
sinkhorn_repeat = 20
eps = 0.02
#test_mhc_kernel = mhc_pre_kernel(batch, n_expand, c_x, dtype=torch.bfloat16)
test_mhc_pre_op = ManifoldConstrainedHyperConnectionPreOp(
batch, n_expand, c_x, dtype=torch.bfloat16)
x_res, x_layer = test_mhc_pre_op.forward(phi, x, b, alpha_pre, alpha_post, alpha_res,
sinkhorn_repeat)
sinkhorn_repeat, eps)

# check the correctness with torch...
xsqr = x * x # the square of x
Expand All @@ -57,7 +58,6 @@ def test_mhc_pre_op(
H_pre_ref = torch.sigmoid(alpha_pre * H_pre_ref / r_ref.unsqueeze(-1) + b_pre_ref)
H_res_ref = alpha_res * H_res_ref / r_ref.unsqueeze(-1).unsqueeze(-1) + b_res_ref

eps = 0.0001
H_res_ref_tmp = H_res_ref.max(dim=-1, keepdim=True).values

H_res_ref = torch.exp(H_res_ref - H_res_ref_tmp)
Expand Down
23 changes: 17 additions & 6 deletions top/kernels/mhc/mhc_pre.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def _get_H_1(
def _get_H_res(
H_res_0: T.Tensor([batch, n_expand, n_expand], dtype),
sinkhorn_repeat: T.int,
sinkhorn_eps: T.float,
H_res: T.Tensor([batch, n_expand, n_expand], dtype),
):
with T.Kernel(batch, threads=threads) as (bx):
Expand All @@ -156,7 +157,7 @@ def _get_H_res(
tmp2 = T.alloc_fragment([n_expand], dtype)
h_out_shared = T.alloc_shared([n_expand, n_expand], dtype)

eps = 0.0001
eps = sinkhorn_eps
# exponential function...
# get the max value first...
for i, j in T.Parallel(n_expand, n_expand):
Expand Down Expand Up @@ -239,13 +240,14 @@ def mhc_pre(
H_post: T.Tensor([batch, n_expand], dtype),
H_res_0: T.Tensor([batch, n_expand, n_expand], dtype),
sinkhorn_repeat: T.int,
sinkhorn_eps: T.float,
H_res: T.Tensor([batch, n_expand, n_expand], dtype),
x_res: T.Tensor([batch, n_expand * c_x], x_dtype),
x_layer: T.Tensor([batch, c_x], x_dtype),
):
_get_H_0_no_split(phi, x, r, H)
_get_H_1(H, r, b, alpha_pre, alpha_post, alpha_res, H_pre, H_post, H_res_0)
_get_H_res(H_res_0, sinkhorn_repeat, H_res)
_get_H_res(H_res_0, sinkhorn_repeat, sinkhorn_eps, H_res)
_get_x(x, H_pre, H_res, x_res, x_layer)

return mhc_pre
Expand All @@ -259,12 +261,13 @@ def _mhc_pre_wrapped_kernel(batch: int, n_expand: int, c_x: int, dtype: str, blo
x: torch.Tensor, H: torch.Tensor, r: torch.Tensor, b: torch.Tensor,
alpha_pre: float, alpha_post: float, alpha_res: float,
H_pre: torch.Tensor, H_post: torch.Tensor, H_res_0: torch.Tensor,
sinkhorn_repeat: int, H_res: torch.Tensor,
sinkhorn_repeat: int, sinkhorn_eps: float, H_res: torch.Tensor,
x_res: torch.Tensor) -> torch.Tensor:
return _mhc_pre_kernel(batch, n_expand, c_x,
dtype)(block_x_b, block_C, num_stages,
threads)(phi, x, H, r, b, alpha_pre, alpha_post, alpha_res, H_pre,
H_post, H_res_0, sinkhorn_repeat, H_res, x_res)
H_post, H_res_0, sinkhorn_repeat, sinkhorn_eps, H_res,
x_res)


@_mhc_pre_wrapped_kernel.register_fake
Expand Down Expand Up @@ -320,7 +323,15 @@ def autotune_configs(self) -> list[dict]:
} for c in _configs]
return configs

def forward(self, phi, x, b, alpha_pre, alpha_post, alpha_res, sinkhorn_repeat):
def forward(self,
phi,
x,
b,
alpha_pre,
alpha_post,
alpha_res,
sinkhorn_repeat,
sinkhorn_eps=0.02):
# H_pre, H_post, H_res_0, H_res are tensors need to be allocated....
r = torch.empty([self.batch], device=x.device, dtype=self.weights_dtype)
H = torch.empty([self.batch, self.n_expand * self.n_expand + 2 * self.n_expand],
Expand All @@ -340,5 +351,5 @@ def forward(self, phi, x, b, alpha_pre, alpha_post, alpha_res, sinkhorn_repeat):
self.config["block_x_b"], self.config["block_C"],
self.config["num_stages"], self.config["threads"], phi, x,
H, r, b, alpha_pre, alpha_post, alpha_res, H_pre, H_post,
H_res_0, sinkhorn_repeat, H_res, x_res)
H_res_0, sinkhorn_repeat, sinkhorn_eps, H_res, x_res)
return x_res, result
16 changes: 12 additions & 4 deletions top/ops/mhc_pre.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,15 @@ def __init__(self,
def default_kernel_map(self) -> Dict[str, Kernel]:
return {"mhc_pre_kernel": mhc_pre_kernel}

def forward(self, phi: torch.Tensor, x: torch.Tensor, b: torch.Tensor, alpha_pre: float,
alpha_post: float, alpha_res: float, sinkhorn_repeat: int) -> torch.Tensor:

return self.kernel(phi, x, b, alpha_pre, alpha_post, alpha_res, sinkhorn_repeat)
def forward(self,
phi: torch.Tensor,
x: torch.Tensor,
b: torch.Tensor,
alpha_pre: float,
alpha_post: float,
alpha_res: float,
sinkhorn_repeat: int,
sinkhorn_eps: float = 0.02) -> torch.Tensor:

return self.kernel(phi, x, b, alpha_pre, alpha_post, alpha_res, sinkhorn_repeat,
sinkhorn_eps)
Loading