Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 4 additions & 29 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ jobs:
source "${{ runner.tool_cache }}/${VENV_DIR}/bin/activate"
export PYTHONPATH="$(pwd):$PYTHONPATH"
echo "PYTHONPATH=$PYTHONPATH"
bash tests/ci_test.sh tileops_test_release.log
set -o pipefail
python -m pytest -q tests | tee tileops_test_release.log
shell: bash

- name: Cleanup venv
Expand Down Expand Up @@ -145,7 +146,8 @@ jobs:
source "${{ runner.tool_cache }}/${VENV_DIR}/bin/activate"
export PYTHONPATH="$(pwd):$PYTHONPATH"
echo "PYTHONPATH=$PYTHONPATH"
bash tests/ci_test.sh tileops_test_nightly.log
set -o pipefail
python -m pytest -q tests | tee tileops_test_nightly.log
shell: bash

- name: Cleanup venv
Expand All @@ -165,30 +167,3 @@ jobs:
name: tileops_test_nightly.log
path: tileops_test_nightly.log
retention-days: 7 # Equivalent to expire_in: 1 week

tileops_profile_release:
# needs: [pre-commit, tileops_test_release]
needs: [tileops_test_release]
runs-on: [self-hosted, profile]
steps:
- name: Checkout code
uses: actions/checkout@v3
with:
fetch-depth: 0 # Equivalent to GIT_STRATEGY: fetch

- name: Setup & Run tests
run: |
source ~/miniconda3/etc/profile.d/conda.sh
conda activate tileops-release
export PYTHONPATH="$(pwd):$PYTHONPATH"
echo "PYTHONPATH=$PYTHONPATH"
bash benchmarks/profile_run.sh --log profile_out/tileops_profile_release.log
shell: bash

- name: Upload profile_out artifacts
uses: actions/upload-artifact@v4
if: always()
with:
name: profile_out
path: profile_out/
retention-days: 7
1 change: 0 additions & 1 deletion benchmarks/flash_attn/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ def ref_program(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, o: torc
q_bhsd, k_bhsd, v_bhsd, is_causal=self.is_causal)
output = output_bhsd.transpose(1, 2).contiguous()

# from IPython import embed; embed()
output.backward(grad_output)
return q.grad, k.grad, v.grad

Expand Down
Empty file removed tests/__init__.py
Empty file.
78 changes: 0 additions & 78 deletions tests/ci_test.sh

This file was deleted.

9 changes: 9 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import pytest
import torch


@pytest.fixture(autouse=True)
def setup() -> None:
torch.manual_seed(1235)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(1235)
47 changes: 12 additions & 35 deletions tests/functions/test_deepseek_dsa_decode_func.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,20 @@
import argparse
import pytest
import torch

from benchmarks import DeepSeekSparseAttentionDecodeBenchmark
from top.functions import DeepSeekSparseAttentionDecodeWithKVCacheFunc
from top.layers import DeepSeekSparseAttentionDecodeLayer
from top.utils import str2dtype


def test_sparse_mla_decode(batch,
heads,
seq_len_q,
seq_len_kv,
dim,
dim_tail,
topk,
stride_kv,
group_kv,
q_start_index_s,
sm_scale,
dtype,
tune=False):
@pytest.mark.parametrize(
"batch, heads, seq_len_q, seq_len_kv, dim, dim_tail, topk, stride_kv, group_kv, q_start_index_s, sm_scale, dtype, tune",
[
(1, 128, 1024, 2048, 512, 64, 2048, 1, 1, 1024, None, torch.float16, False),
],
)
def test_sparse_mla_decode(batch: int, heads: int, seq_len_q: int, seq_len_kv: int, dim: int,
dim_tail: int, topk: int, stride_kv: int, group_kv: int,
q_start_index_s: int, sm_scale: float, dtype: torch.dtype, tune: bool):
fn = DeepSeekSparseAttentionDecodeWithKVCacheFunc(
batch,
heads,
Expand Down Expand Up @@ -81,23 +77,4 @@ def test_sparse_mla_decode(batch,


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=1, help='batch size')
parser.add_argument('--seq_len', type=int, default=1024, help='sequence length')
parser.add_argument('--seq_len_kv', type=int, default=2048, help='key/value sequence length')
parser.add_argument('--heads', type=int, default=128, help='num heads')
parser.add_argument('--dim', type=int, default=512, help='head dim')
parser.add_argument('--dim_tail', type=int, default=64, help='tail dim')
parser.add_argument('--topk', type=int, default=2048, help='topk')
parser.add_argument('--stride_kv', type=int, default=1, help='stride_kv')
parser.add_argument('--group_kv', type=int, default=1, help='group_kv')
parser.add_argument('--sm_scale', type=float, default=None, help='softmax scaling factor')
parser.add_argument('--q_start_index_s', type=int, default=1024, help='query start index')
parser.add_argument(
'--dtype', type=str, default='float16', choices=['float16', 'bfloat16'], help='data type')
parser.add_argument('--tune', action='store_true', default=False, help='enable autotune')
args = parser.parse_args()

test_sparse_mla_decode(args.batch, args.heads, args.seq_len, args.seq_len_kv, args.dim,
args.dim_tail, args.topk, args.stride_kv, args.group_kv,
args.q_start_index_s, args.sm_scale, str2dtype[args.dtype], args.tune)
pytest.main([__file__, "-vvs"])
29 changes: 12 additions & 17 deletions tests/functions/test_deepseek_mla_decode_func.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
import argparse
import pytest

import torch

from benchmarks import MultiHeadLatentAttentionDecodeBenchmark
from top.functions import MultiHeadLatentAttentionDecodeWithKVCacheFunc, mla_decode_with_kvcache
from top.layers import MultiHeadLatentAttentionDecodeLayer
from top.utils import str2dtype


def test_mla_decode_fn(batch, kv_head_num, seq_len_kv, heads, dim, pe_dim, dtype):
@pytest.mark.parametrize(
"batch, kv_head_num, seq_len_kv, heads, dim, pe_dim, dtype",
[
(32, 1, 8192, 128, 512, 64, torch.float16),
],
)
def test_mla_decode_fn(batch: int, kv_head_num: int, seq_len_kv: int, heads: int, dim: int,
pe_dim: int, dtype: torch.dtype):

mla_layer = MultiHeadLatentAttentionDecodeLayer(batch, heads, kv_head_num, seq_len_kv, dim,
pe_dim, dtype)
Expand Down Expand Up @@ -43,17 +51,4 @@ def test_mla_decode_fn(batch, kv_head_num, seq_len_kv, heads, dim, pe_dim, dtype


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=32, help='batch size')
parser.add_argument('--kv_head_num', type=int, default=1, help='number of key/value heads')
parser.add_argument('--seq_len_kv', type=int, default=8192, help='key/value sequence length')
parser.add_argument('--heads', type=int, default=128, help='num heads')
parser.add_argument('--dim', type=int, default=512, help='head dim')
parser.add_argument('--pe_dim', type=int, default=64, help='positional encoding dim')
parser.add_argument(
'--dtype', type=str, default='float16', choices=['float16', 'bfloat16'], help='data type')
parser.add_argument('--tune', action='store_true', default=False, help='enable autotune')
args = parser.parse_args()

test_mla_decode_fn(args.batch, args.kv_head_num, args.seq_len_kv, args.heads, args.dim,
args.pe_dim, str2dtype[args.dtype])
pytest.main([__file__, "-vvs"])
73 changes: 10 additions & 63 deletions tests/functions/test_fp8_lighting_indexer_func.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
import argparse
import pytest

from benchmarks.deepseek_mla import Fp8LightingIndexerBenchmark
from top.functions import Fp8LightingIndexerFunc
from top.layers import Fp8LightingIndexerDecodeLayer


def test_fp8_lighting_indexer(seq_len, heads, index_dim, seq_len_kv, clean_logits, config):
@pytest.mark.parametrize(
"seq_len, heads, index_dim, seq_len_kv, clean_logits, config",
[
(4096, 32, 64, 8192, True, None),
],
)
def test_fp8_lighting_indexer(seq_len: int, heads: int, index_dim: int, seq_len_kv: int,
clean_logits: bool, config):
fn = Fp8LightingIndexerFunc(seq_len, heads, index_dim, seq_len_kv, clean_logits, config)
layer = Fp8LightingIndexerDecodeLayer(seq_len, heads, index_dim, seq_len_kv, clean_logits,
config)
Expand All @@ -32,64 +39,4 @@ def test_fp8_lighting_indexer(seq_len, heads, index_dim, seq_len_kv, clean_logit


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--seq_len', type=int, default=4096, help='sequence length')
parser.add_argument('--heads', type=int, default=32, help='number of heads')
parser.add_argument('--index_dim', type=int, default=64, help='index dim')
parser.add_argument('--seq_len_kv', type=int, default=8192, help='key/value sequence length')
parser.add_argument(
'--clean_logits',
action=argparse.BooleanOptionalAction,
default=True,
help='whether to clean logits outside the valid range')
parser.add_argument('--config', type=str, default=None, help='positional encoding dim')
parser.add_argument('--tune', action='store_true', default=False, help='enable autotune')
args = parser.parse_args()

test_fp8_lighting_indexer(args.seq_len, args.heads, args.index_dim, args.seq_len_kv,
args.clean_logits, args.config)


def test_fp8_lighting_indexer(seq_len, heads, index_dim, seq_len_kv, clean_logits, config):
fn = Fp8LightingIndexerFunc(seq_len, heads, index_dim, seq_len_kv, clean_logits, config)
layer = Fp8LightingIndexerDecodeLayer(seq_len, heads, index_dim, seq_len_kv, clean_logits,
config)
benchmark = Fp8LightingIndexerBenchmark(seq_len, heads, index_dim, seq_len_kv, clean_logits,
config)

inputs = benchmark.gen_inputs()

try:
print("Testing indexer_fn...")
benchmark.check_fn(fn, *inputs, grad=False)
print("✅ indexer_fn test passed")
except Exception as e:
print(f"❌ indexer_fn test failed: {e}")
raise

try:
print("Testing indexer_layer...")
benchmark.check_fn(layer, *inputs, grad=False)
print("✅ indexer_layer test passed")
except Exception as e:
print(f"❌ indexer_layer test failed: {e}")
raise


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--seq_len', type=int, default=4096, help='sequence length')
parser.add_argument('--heads', type=int, default=32, help='number of heads')
parser.add_argument('--index_dim', type=int, default=64, help='index dim')
parser.add_argument('--seq_len_kv', type=int, default=8192, help='key/value sequence length')
parser.add_argument(
'--clean_logits',
action=argparse.BooleanOptionalAction,
default=True,
help='whether to clean logits outside the valid range')
parser.add_argument('--config', type=str, default=None, help='positional encoding dim')
parser.add_argument('--tune', action='store_true', default=False, help='enable autotune')
args = parser.parse_args()

test_fp8_lighting_indexer(args.seq_len, args.heads, args.index_dim, args.seq_len_kv,
args.clean_logits, args.config)
pytest.main([__file__, "-vvs"])
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
(16384, 32, torch.float32, False),
],
)
def test_fp8_quant(seq_len_kv, index_dim, in_dtype, tune=False):
def test_fp8_quant(seq_len_kv, index_dim, in_dtype, tune):
fn = Fp8QuantFunc(seq_len_kv, index_dim, in_dtype, tune=tune)
layer = Fp8QuantLayer(seq_len_kv, index_dim, in_dtype, tune=tune)
benchmark = Fp8QuantBenchmark(seq_len_kv, index_dim, in_dtype)
Expand All @@ -39,7 +39,4 @@ def test_fp8_quant(seq_len_kv, index_dim, in_dtype, tune=False):


if __name__ == "__main__":
test_fp8_quant(8192, 64, torch.float16, False)
test_fp8_quant(8192, 64, torch.bfloat16, False)
test_fp8_quant(4096, 128, torch.float32, False)
test_fp8_quant(16384, 32, torch.float32, False)
pytest.main([__file__, "-vvs"])
27 changes: 11 additions & 16 deletions tests/functions/test_gqa_decode_func.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
import argparse
import pytest
import torch

from benchmarks import GroupQueryAttentionDecodeBenchmark
from top.functions import GroupQueryAttentionDecodeWithKVCacheFunc, gqa_decode_with_kvcache
from top.utils import str2dtype


def test_gqa_decode_fn(batch, heads, seq_len_kv, dim, groups, dtype):
@pytest.mark.parametrize(
"batch, heads, seq_len_kv, dim, groups, dtype",
[
(1, 32, 8192, 128, 1, torch.float16),
],
)
def test_gqa_decode_fn(batch: int, heads: int, seq_len_kv: int, dim: int, groups: int,
dtype: torch.dtype):
benchmark = GroupQueryAttentionDecodeBenchmark(batch, heads, groups, seq_len_kv, dim, dtype)

inputs = benchmark.gen_inputs()
Expand All @@ -19,16 +26,4 @@ def test_gqa_decode_fn(batch, heads, seq_len_kv, dim, groups, dtype):


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=1, help='batch size')
parser.add_argument('--groups', type=int, default=1, help='num groups')
parser.add_argument('--seq_len_kv', type=int, default=8192, help='key/value sequence length')
parser.add_argument('--heads', type=int, default=32, help='num heads')
parser.add_argument('--dim', type=int, default=128, help='head dim')
parser.add_argument(
'--dtype', type=str, default='float16', choices=['float16', 'bfloat16'], help='data type')
parser.add_argument('--tune', action='store_true', default=False, help='enable autotune')
args = parser.parse_args()

test_gqa_decode_fn(args.batch, args.heads, args.seq_len_kv, args.dim, args.groups,
str2dtype[args.dtype])
pytest.main([__file__, "-vvs"])
Loading
Loading