diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0068b536..1648067d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -71,7 +71,8 @@ jobs: source "${{ runner.tool_cache }}/${VENV_DIR}/bin/activate" export PYTHONPATH="$(pwd):$PYTHONPATH" echo "PYTHONPATH=$PYTHONPATH" - bash tests/ci_test.sh tileops_test_release.log + set -o pipefail + python -m pytest -q tests | tee tileops_test_release.log shell: bash - name: Cleanup venv @@ -145,7 +146,8 @@ jobs: source "${{ runner.tool_cache }}/${VENV_DIR}/bin/activate" export PYTHONPATH="$(pwd):$PYTHONPATH" echo "PYTHONPATH=$PYTHONPATH" - bash tests/ci_test.sh tileops_test_nightly.log + set -o pipefail + python -m pytest -q tests | tee tileops_test_nightly.log shell: bash - name: Cleanup venv @@ -165,30 +167,3 @@ jobs: name: tileops_test_nightly.log path: tileops_test_nightly.log retention-days: 7 # Equivalent to expire_in: 1 week - - tileops_profile_release: - # needs: [pre-commit, tileops_test_release] - needs: [tileops_test_release] - runs-on: [self-hosted, profile] - steps: - - name: Checkout code - uses: actions/checkout@v3 - with: - fetch-depth: 0 # Equivalent to GIT_STRATEGY: fetch - - - name: Setup & Run tests - run: | - source ~/miniconda3/etc/profile.d/conda.sh - conda activate tileops-release - export PYTHONPATH="$(pwd):$PYTHONPATH" - echo "PYTHONPATH=$PYTHONPATH" - bash benchmarks/profile_run.sh --log profile_out/tileops_profile_release.log - shell: bash - - - name: Upload profile_out artifacts - uses: actions/upload-artifact@v4 - if: always() - with: - name: profile_out - path: profile_out/ - retention-days: 7 diff --git a/benchmarks/flash_attn/mha.py b/benchmarks/flash_attn/mha.py index 0f9bb083..e063b880 100644 --- a/benchmarks/flash_attn/mha.py +++ b/benchmarks/flash_attn/mha.py @@ -148,7 +148,6 @@ def ref_program(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, o: torc q_bhsd, k_bhsd, v_bhsd, is_causal=self.is_causal) output = output_bhsd.transpose(1, 2).contiguous() - # from IPython import embed; embed() output.backward(grad_output) return q.grad, k.grad, v.grad diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/ci_test.sh b/tests/ci_test.sh deleted file mode 100755 index 3da02ec7..00000000 --- a/tests/ci_test.sh +++ /dev/null @@ -1,78 +0,0 @@ -#!/bin/bash - -# Accept log file name as input parameter, default to tileops_test.log -LOG_FILE="${1:-tileops_test.log}" - -# Run all Python test files in tests directory -echo -e "\033[0;34mRunning all Python test files...\033[0m" | tee -a "$LOG_FILE" - -# Store test results for summary -declare -a test_names -declare -a test_results - -# Initialize counters -passed_count=0 -failed_count=0 - -# Find all .py files in current directory where script is located -script_dir=$(dirname -- "${BASH_SOURCE[0]}") -test_files=$(find "$script_dir" -name "test*.py" -type f | sort) - -if [ -z "$test_files" ]; then - echo "No test files found in $(dirname "$script_dir")" | tee -a "$LOG_FILE" - exit 1 -fi - -# Table header alignment, assuming filename max length of 50 characters -printf "| %-50s | %-8s |\n" "Test File" "Status" | tee -a "$LOG_FILE" -printf "|%s|\n" "--------------------------------------------------|----------" | tee -a "$LOG_FILE" - -# Run each test file -for test_file in $test_files; do - file_name=$(basename "$test_file") - echo -e "\033[0;36mRunning test: $test_file\033[0m" | tee -a "$LOG_FILE" - echo "----------------------------------------" >> "$LOG_FILE" - - if python "$test_file" >> "$LOG_FILE" 2>&1; then - echo -e "\033[0;32m[PASS] $test_file\033[0m" | tee -a "$LOG_FILE" - printf "| %-50s | ✅ Pass |\n" "$file_name" | tee -a "$LOG_FILE" - test_names+=("$file_name") - test_results+=("✅ Pass") - passed_count=$((passed_count + 1)) - else - echo -e "\033[0;31m[FAIL] $test_file\033[0m" | tee -a "$LOG_FILE" - printf "| %-50s | ❌ Fail |\n" "$file_name" | tee -a "$LOG_FILE" - test_names+=("$file_name") - test_results+=("❌ Fail") - failed_count=$((failed_count + 1)) - fi - - echo "----------------------------------------" >> "$LOG_FILE" -done - -# Add statistics summary to log file -echo "" | tee -a "$LOG_FILE" -echo "Summary:" | tee -a "$LOG_FILE" -echo "- Passed: $passed_count" | tee -a "$LOG_FILE" -echo "- Failed: $failed_count" | tee -a "$LOG_FILE" -echo "- Total: $((passed_count + failed_count))" | tee -a "$LOG_FILE" - -# Print test results summary table -echo "" | tee -a "$LOG_FILE" -echo -e "\033[0;34mTest Results Summary:\033[0m" | tee -a "$LOG_FILE" -echo -e "\033[0;34m====================\033[0m" | tee -a "$LOG_FILE" -printf "| %-50s | %-8s |\n" "Test File" "Status" | tee -a "$LOG_FILE" -printf "|%s|\n" "--------------------------------------------------|----------" | tee -a "$LOG_FILE" - -# Print final summary table from stored results -for i in "${!test_names[@]}"; do - printf "| %-50s | %-8s |\n" "${test_names[$i]}" "${test_results[$i]}" | tee -a "$LOG_FILE" -done - -# If there are failed tests, CI fails -if [ $failed_count -gt 0 ]; then - echo -e "\033[0;31mError: $failed_count test(s) failed, stopping pipeline.\033[0m" | tee -a "$LOG_FILE" - exit 1 -else - echo -e "\033[0;32mAll tests passed!\033[0m" | tee -a "$LOG_FILE" -fi diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..08fbfc5c --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,9 @@ +import pytest +import torch + + +@pytest.fixture(autouse=True) +def setup() -> None: + torch.manual_seed(1235) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(1235) diff --git a/tests/functions/test_deepseek_dsa_decode_func.py b/tests/functions/test_deepseek_dsa_decode_func.py index 1c5834d7..2ee27ebf 100644 --- a/tests/functions/test_deepseek_dsa_decode_func.py +++ b/tests/functions/test_deepseek_dsa_decode_func.py @@ -1,24 +1,20 @@ -import argparse +import pytest +import torch from benchmarks import DeepSeekSparseAttentionDecodeBenchmark from top.functions import DeepSeekSparseAttentionDecodeWithKVCacheFunc from top.layers import DeepSeekSparseAttentionDecodeLayer -from top.utils import str2dtype -def test_sparse_mla_decode(batch, - heads, - seq_len_q, - seq_len_kv, - dim, - dim_tail, - topk, - stride_kv, - group_kv, - q_start_index_s, - sm_scale, - dtype, - tune=False): +@pytest.mark.parametrize( + "batch, heads, seq_len_q, seq_len_kv, dim, dim_tail, topk, stride_kv, group_kv, q_start_index_s, sm_scale, dtype, tune", + [ + (1, 128, 1024, 2048, 512, 64, 2048, 1, 1, 1024, None, torch.float16, False), + ], +) +def test_sparse_mla_decode(batch: int, heads: int, seq_len_q: int, seq_len_kv: int, dim: int, + dim_tail: int, topk: int, stride_kv: int, group_kv: int, + q_start_index_s: int, sm_scale: float, dtype: torch.dtype, tune: bool): fn = DeepSeekSparseAttentionDecodeWithKVCacheFunc( batch, heads, @@ -81,23 +77,4 @@ def test_sparse_mla_decode(batch, if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='batch size') - parser.add_argument('--seq_len', type=int, default=1024, help='sequence length') - parser.add_argument('--seq_len_kv', type=int, default=2048, help='key/value sequence length') - parser.add_argument('--heads', type=int, default=128, help='num heads') - parser.add_argument('--dim', type=int, default=512, help='head dim') - parser.add_argument('--dim_tail', type=int, default=64, help='tail dim') - parser.add_argument('--topk', type=int, default=2048, help='topk') - parser.add_argument('--stride_kv', type=int, default=1, help='stride_kv') - parser.add_argument('--group_kv', type=int, default=1, help='group_kv') - parser.add_argument('--sm_scale', type=float, default=None, help='softmax scaling factor') - parser.add_argument('--q_start_index_s', type=int, default=1024, help='query start index') - parser.add_argument( - '--dtype', type=str, default='float16', choices=['float16', 'bfloat16'], help='data type') - parser.add_argument('--tune', action='store_true', default=False, help='enable autotune') - args = parser.parse_args() - - test_sparse_mla_decode(args.batch, args.heads, args.seq_len, args.seq_len_kv, args.dim, - args.dim_tail, args.topk, args.stride_kv, args.group_kv, - args.q_start_index_s, args.sm_scale, str2dtype[args.dtype], args.tune) + pytest.main([__file__, "-vvs"]) diff --git a/tests/functions/test_deepseek_mla_decode_func.py b/tests/functions/test_deepseek_mla_decode_func.py index 775b990a..5538b492 100644 --- a/tests/functions/test_deepseek_mla_decode_func.py +++ b/tests/functions/test_deepseek_mla_decode_func.py @@ -1,12 +1,20 @@ -import argparse +import pytest + +import torch from benchmarks import MultiHeadLatentAttentionDecodeBenchmark from top.functions import MultiHeadLatentAttentionDecodeWithKVCacheFunc, mla_decode_with_kvcache from top.layers import MultiHeadLatentAttentionDecodeLayer -from top.utils import str2dtype -def test_mla_decode_fn(batch, kv_head_num, seq_len_kv, heads, dim, pe_dim, dtype): +@pytest.mark.parametrize( + "batch, kv_head_num, seq_len_kv, heads, dim, pe_dim, dtype", + [ + (32, 1, 8192, 128, 512, 64, torch.float16), + ], +) +def test_mla_decode_fn(batch: int, kv_head_num: int, seq_len_kv: int, heads: int, dim: int, + pe_dim: int, dtype: torch.dtype): mla_layer = MultiHeadLatentAttentionDecodeLayer(batch, heads, kv_head_num, seq_len_kv, dim, pe_dim, dtype) @@ -43,17 +51,4 @@ def test_mla_decode_fn(batch, kv_head_num, seq_len_kv, heads, dim, pe_dim, dtype if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=32, help='batch size') - parser.add_argument('--kv_head_num', type=int, default=1, help='number of key/value heads') - parser.add_argument('--seq_len_kv', type=int, default=8192, help='key/value sequence length') - parser.add_argument('--heads', type=int, default=128, help='num heads') - parser.add_argument('--dim', type=int, default=512, help='head dim') - parser.add_argument('--pe_dim', type=int, default=64, help='positional encoding dim') - parser.add_argument( - '--dtype', type=str, default='float16', choices=['float16', 'bfloat16'], help='data type') - parser.add_argument('--tune', action='store_true', default=False, help='enable autotune') - args = parser.parse_args() - - test_mla_decode_fn(args.batch, args.kv_head_num, args.seq_len_kv, args.heads, args.dim, - args.pe_dim, str2dtype[args.dtype]) + pytest.main([__file__, "-vvs"]) diff --git a/tests/functions/test_fp8_lighting_indexer_func.py b/tests/functions/test_fp8_lighting_indexer_func.py index 9cf82834..a98f2306 100644 --- a/tests/functions/test_fp8_lighting_indexer_func.py +++ b/tests/functions/test_fp8_lighting_indexer_func.py @@ -1,11 +1,18 @@ -import argparse +import pytest from benchmarks.deepseek_mla import Fp8LightingIndexerBenchmark from top.functions import Fp8LightingIndexerFunc from top.layers import Fp8LightingIndexerDecodeLayer -def test_fp8_lighting_indexer(seq_len, heads, index_dim, seq_len_kv, clean_logits, config): +@pytest.mark.parametrize( + "seq_len, heads, index_dim, seq_len_kv, clean_logits, config", + [ + (4096, 32, 64, 8192, True, None), + ], +) +def test_fp8_lighting_indexer(seq_len: int, heads: int, index_dim: int, seq_len_kv: int, + clean_logits: bool, config): fn = Fp8LightingIndexerFunc(seq_len, heads, index_dim, seq_len_kv, clean_logits, config) layer = Fp8LightingIndexerDecodeLayer(seq_len, heads, index_dim, seq_len_kv, clean_logits, config) @@ -32,64 +39,4 @@ def test_fp8_lighting_indexer(seq_len, heads, index_dim, seq_len_kv, clean_logit if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') - parser.add_argument('--heads', type=int, default=32, help='number of heads') - parser.add_argument('--index_dim', type=int, default=64, help='index dim') - parser.add_argument('--seq_len_kv', type=int, default=8192, help='key/value sequence length') - parser.add_argument( - '--clean_logits', - action=argparse.BooleanOptionalAction, - default=True, - help='whether to clean logits outside the valid range') - parser.add_argument('--config', type=str, default=None, help='positional encoding dim') - parser.add_argument('--tune', action='store_true', default=False, help='enable autotune') - args = parser.parse_args() - - test_fp8_lighting_indexer(args.seq_len, args.heads, args.index_dim, args.seq_len_kv, - args.clean_logits, args.config) - - -def test_fp8_lighting_indexer(seq_len, heads, index_dim, seq_len_kv, clean_logits, config): - fn = Fp8LightingIndexerFunc(seq_len, heads, index_dim, seq_len_kv, clean_logits, config) - layer = Fp8LightingIndexerDecodeLayer(seq_len, heads, index_dim, seq_len_kv, clean_logits, - config) - benchmark = Fp8LightingIndexerBenchmark(seq_len, heads, index_dim, seq_len_kv, clean_logits, - config) - - inputs = benchmark.gen_inputs() - - try: - print("Testing indexer_fn...") - benchmark.check_fn(fn, *inputs, grad=False) - print("✅ indexer_fn test passed") - except Exception as e: - print(f"❌ indexer_fn test failed: {e}") - raise - - try: - print("Testing indexer_layer...") - benchmark.check_fn(layer, *inputs, grad=False) - print("✅ indexer_layer test passed") - except Exception as e: - print(f"❌ indexer_layer test failed: {e}") - raise - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') - parser.add_argument('--heads', type=int, default=32, help='number of heads') - parser.add_argument('--index_dim', type=int, default=64, help='index dim') - parser.add_argument('--seq_len_kv', type=int, default=8192, help='key/value sequence length') - parser.add_argument( - '--clean_logits', - action=argparse.BooleanOptionalAction, - default=True, - help='whether to clean logits outside the valid range') - parser.add_argument('--config', type=str, default=None, help='positional encoding dim') - parser.add_argument('--tune', action='store_true', default=False, help='enable autotune') - args = parser.parse_args() - - test_fp8_lighting_indexer(args.seq_len, args.heads, args.index_dim, args.seq_len_kv, - args.clean_logits, args.config) + pytest.main([__file__, "-vvs"]) diff --git a/tests/functions/test_fp8_quant.py b/tests/functions/test_fp8_quant_func.py similarity index 81% rename from tests/functions/test_fp8_quant.py rename to tests/functions/test_fp8_quant_func.py index eccc5e22..829afbc6 100644 --- a/tests/functions/test_fp8_quant.py +++ b/tests/functions/test_fp8_quant_func.py @@ -15,7 +15,7 @@ (16384, 32, torch.float32, False), ], ) -def test_fp8_quant(seq_len_kv, index_dim, in_dtype, tune=False): +def test_fp8_quant(seq_len_kv, index_dim, in_dtype, tune): fn = Fp8QuantFunc(seq_len_kv, index_dim, in_dtype, tune=tune) layer = Fp8QuantLayer(seq_len_kv, index_dim, in_dtype, tune=tune) benchmark = Fp8QuantBenchmark(seq_len_kv, index_dim, in_dtype) @@ -39,7 +39,4 @@ def test_fp8_quant(seq_len_kv, index_dim, in_dtype, tune=False): if __name__ == "__main__": - test_fp8_quant(8192, 64, torch.float16, False) - test_fp8_quant(8192, 64, torch.bfloat16, False) - test_fp8_quant(4096, 128, torch.float32, False) - test_fp8_quant(16384, 32, torch.float32, False) + pytest.main([__file__, "-vvs"]) diff --git a/tests/functions/test_gqa_decode_func.py b/tests/functions/test_gqa_decode_func.py index a7d49412..f40dfaa5 100644 --- a/tests/functions/test_gqa_decode_func.py +++ b/tests/functions/test_gqa_decode_func.py @@ -1,11 +1,18 @@ -import argparse +import pytest +import torch from benchmarks import GroupQueryAttentionDecodeBenchmark from top.functions import GroupQueryAttentionDecodeWithKVCacheFunc, gqa_decode_with_kvcache -from top.utils import str2dtype -def test_gqa_decode_fn(batch, heads, seq_len_kv, dim, groups, dtype): +@pytest.mark.parametrize( + "batch, heads, seq_len_kv, dim, groups, dtype", + [ + (1, 32, 8192, 128, 1, torch.float16), + ], +) +def test_gqa_decode_fn(batch: int, heads: int, seq_len_kv: int, dim: int, groups: int, + dtype: torch.dtype): benchmark = GroupQueryAttentionDecodeBenchmark(batch, heads, groups, seq_len_kv, dim, dtype) inputs = benchmark.gen_inputs() @@ -19,16 +26,4 @@ def test_gqa_decode_fn(batch, heads, seq_len_kv, dim, groups, dtype): if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='batch size') - parser.add_argument('--groups', type=int, default=1, help='num groups') - parser.add_argument('--seq_len_kv', type=int, default=8192, help='key/value sequence length') - parser.add_argument('--heads', type=int, default=32, help='num heads') - parser.add_argument('--dim', type=int, default=128, help='head dim') - parser.add_argument( - '--dtype', type=str, default='float16', choices=['float16', 'bfloat16'], help='data type') - parser.add_argument('--tune', action='store_true', default=False, help='enable autotune') - args = parser.parse_args() - - test_gqa_decode_fn(args.batch, args.heads, args.seq_len_kv, args.dim, args.groups, - str2dtype[args.dtype]) + pytest.main([__file__, "-vvs"]) diff --git a/tests/functions/test_gqa_func.py b/tests/functions/test_gqa_func.py index a593a374..9b383216 100644 --- a/tests/functions/test_gqa_func.py +++ b/tests/functions/test_gqa_func.py @@ -1,12 +1,16 @@ -import argparse - +import pytest import torch from benchmarks import GroupQueryAttentionBenchmark from top.functions import GroupQueryAttentionFunc, gqa -from top.utils import str2dtype +@pytest.mark.parametrize( + "batch, seq_len, heads, heads_kv, dim, causal, dtype", + [ + (8, 1024, 32, 8, 128, False, torch.float16), + ], +) def test_gqa_fn(batch: int, seq_len: int, heads: int, heads_kv: int, dim: int, causal: bool, dtype: torch.dtype) -> None: benchmark = GroupQueryAttentionBenchmark(batch, heads, heads_kv, seq_len, dim, causal, dtype) @@ -22,16 +26,4 @@ def test_gqa_fn(batch: int, seq_len: int, heads: int, heads_kv: int, dim: int, c if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--seq_len', type=int, default=1024, help='sequence length') - parser.add_argument('--heads', type=int, default=32, help='num heads') - parser.add_argument('--heads_kv', type=int, default=8, help='num heads for key/value') - parser.add_argument('--dim', type=int, default=128, help='head dim') - parser.add_argument('--causal', action='store_true', default=False, help='causal attention') - parser.add_argument( - '--dtype', type=str, default='float16', choices=['float16', 'bfloat16'], help='data type') - args = parser.parse_args() - - test_gqa_fn(args.batch, args.seq_len, args.heads, args.heads_kv, args.dim, args.causal, - str2dtype[args.dtype]) + pytest.main([__file__, "-vvs"]) diff --git a/tests/functions/test_grouped_gemm_func.py b/tests/functions/test_grouped_gemm_func.py index 4f26574b..4d4ab791 100644 --- a/tests/functions/test_grouped_gemm_func.py +++ b/tests/functions/test_grouped_gemm_func.py @@ -1,14 +1,19 @@ -import argparse +import pytest import math - import torch from benchmarks import GroupedGemmBenchmark from top.functions import GroupedGemmFunc -from top.utils import str2dtype -def test_grouped_gemm_fn(batch_sizes_list, N, K, padding_M, dtype, tune=False): +@pytest.mark.parametrize( + "batch_sizes_list, N, K, padding_M, dtype, tune", + [ + ([4096, 4096, 4096, 4096], 4864, 8192, 128, torch.float16, False), + ], +) +def test_grouped_gemm_fn(batch_sizes_list: list, N: int, K: int, padding_M: int, dtype: torch.dtype, + tune: bool): batch_sum = sum(batch_sizes_list) batch_count = len(batch_sizes_list) batch_offsets_list = [0] @@ -44,23 +49,4 @@ def test_grouped_gemm_fn(batch_sizes_list, N, K, padding_M, dtype, tune=False): if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - '--batch_sizes_list', type=str, default="4096,4096,4096,4096", help='batch size list') - parser.add_argument('--N', type=int, default=4864, help='N') - parser.add_argument('--K', type=int, default=8192, help='K') - parser.add_argument('--padding_M', type=int, default=128, help='padding M') - parser.add_argument( - '--dtype', type=str, default='float16', choices=['float16', 'bfloat16'], help='data type') - parser.add_argument('--tune', action='store_true', default=False, help='enable autotune') - args = parser.parse_args() - - batch_sizes_list = [int(x) for x in args.batch_sizes_list.split(',')] - - test_grouped_gemm_fn( - batch_sizes_list=batch_sizes_list, - N=args.N, - K=args.K, - padding_M=args.padding_M, - dtype=str2dtype[args.dtype], - tune=args.tune) + pytest.main([__file__, "-vvs"]) diff --git a/tests/functions/test_matmul_func.py b/tests/functions/test_matmul_func.py index f3241048..a8c835dd 100644 --- a/tests/functions/test_matmul_func.py +++ b/tests/functions/test_matmul_func.py @@ -1,13 +1,17 @@ -import argparse - +import pytest import torch from benchmarks import MatMulBenchmark from top.functions import MatMulFunc, matmul -from top.utils import str2dtype -def test_matmul(m: int, n: int, k: int, dtype: torch.dtype, tune: bool = False) -> None: +@pytest.mark.parametrize( + "m, n, k, dtype, tune", + [ + (1024, 1024, 1024, torch.float16, False), + ], +) +def test_matmul(m: int, n: int, k: int, dtype: torch.dtype, tune: bool) -> None: benchmark = MatMulBenchmark(m, n, k, dtype) inputs = benchmark.gen_inputs() @@ -21,13 +25,4 @@ def test_matmul(m: int, n: int, k: int, dtype: torch.dtype, tune: bool = False) if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--M', type=int, default=1024, help='M') - parser.add_argument('--N', type=int, default=1024, help='N') - parser.add_argument('--K', type=int, default=1024, help='K') - parser.add_argument( - '--dtype', type=str, default='float16', choices=['float16', 'bfloat16'], help='data type') - parser.add_argument('--tune', action='store_true', default=False, help='enable autotune') - args = parser.parse_args() - - test_matmul(args.M, args.N, args.K, str2dtype[args.dtype], args.tune) + pytest.main([__file__, "-vvs"]) diff --git a/tests/functions/test_mha_decode_func.py b/tests/functions/test_mha_decode_func.py index ea211596..a97ff175 100644 --- a/tests/functions/test_mha_decode_func.py +++ b/tests/functions/test_mha_decode_func.py @@ -1,11 +1,18 @@ -import argparse +import pytest +import torch from benchmarks import MultiHeadAttentionDecodeBenchmark from top.functions import MultiHeadAttentionDecodeWithKVCacheFunc, mha_decode_with_kvcache -from top.utils import str2dtype -def test_mha_decode_fn(batch, seq_len_q, seq_len_kv, heads, dim, dtype): +@pytest.mark.parametrize( + "batch, seq_len_q, seq_len_kv, heads, dim, dtype", + [ + (1, 128, 8192, 32, 128, torch.float16), + ], +) +def test_mha_decode_fn(batch: int, seq_len_q: int, seq_len_kv: int, heads: int, dim: int, + dtype: torch.dtype): benchmark = MultiHeadAttentionDecodeBenchmark(batch, heads, seq_len_q, seq_len_kv, dim, dtype) inputs = benchmark.gen_inputs() @@ -19,16 +26,4 @@ def test_mha_decode_fn(batch, seq_len_q, seq_len_kv, heads, dim, dtype): if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='batch size') - parser.add_argument('--seq_len_q', type=int, default=128, help='query sequence length') - parser.add_argument('--seq_len_kv', type=int, default=8192, help='key/value sequence length') - parser.add_argument('--heads', type=int, default=32, help='num heads') - parser.add_argument('--dim', type=int, default=128, help='head dim') - parser.add_argument( - '--dtype', type=str, default='float16', choices=['float16', 'bfloat16'], help='data type') - parser.add_argument('--tune', action='store_true', default=False, help='enable autotune') - args = parser.parse_args() - - test_mha_decode_fn(args.batch, args.seq_len_q, args.seq_len_kv, args.heads, args.dim, - str2dtype[args.dtype]) + pytest.main([__file__, "-vvs"]) diff --git a/tests/functions/test_mha_func.py b/tests/functions/test_mha_func.py index 5bac6126..b1cc9cbc 100644 --- a/tests/functions/test_mha_func.py +++ b/tests/functions/test_mha_func.py @@ -1,12 +1,16 @@ -import argparse - +import pytest import torch from benchmarks import MultiHeadAttentionBenchmark from top.functions import MultiHeadAttentionFunc, mha -from top.utils import str2dtype +@pytest.mark.parametrize( + "batch, seq_len, heads, dim, causal, dtype", + [ + (8, 1024, 32, 128, False, torch.float16), + ], +) def test_mha_fn(batch: int, seq_len: int, heads: int, dim: int, causal: bool, dtype: torch.dtype) -> None: benchmark = MultiHeadAttentionBenchmark(batch, heads, seq_len, dim, causal, dtype) @@ -22,14 +26,4 @@ def test_mha_fn(batch: int, seq_len: int, heads: int, dim: int, causal: bool, if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--seq_len', type=int, default=1024, help='sequence length') - parser.add_argument('--heads', type=int, default=32, help='num heads') - parser.add_argument('--dim', type=int, default=128, help='head dim') - parser.add_argument('--causal', action='store_true', default=False, help='causal attention') - parser.add_argument( - '--dtype', type=str, default='float16', choices=['float16', 'bfloat16'], help='data type') - args = parser.parse_args() - - test_mha_fn(args.batch, args.seq_len, args.heads, args.dim, args.causal, str2dtype[args.dtype]) + pytest.main([__file__, "-vvs"]) diff --git a/tests/functions/test_topk_selector_func.py b/tests/functions/test_topk_selector_func.py index 97d5e418..b53d28a6 100644 --- a/tests/functions/test_topk_selector_func.py +++ b/tests/functions/test_topk_selector_func.py @@ -1,18 +1,19 @@ -import argparse +import pytest import torch from benchmarks import TopkSelectorBenchmark from top.functions import TopkSelectorFunc from top.layers import TopkSelectorLayer -from top.utils import str2dtype -def test_topk_selector(batch: int, - seq_len: int, - topk: int, - in_dtype: torch.dtype, - out_dtype: torch.dtype, - tune: bool = False) -> None: +@pytest.mark.parametrize( + "batch, seq_len, topk, in_dtype, out_dtype, tune", + [ + (64, 32 * 1024, 2048, torch.float32, torch.int32, False), + ], +) +def test_topk_selector(batch: int, seq_len: int, topk: int, in_dtype: torch.dtype, + out_dtype: torch.dtype, tune: bool) -> None: fn = TopkSelectorFunc(batch, seq_len, topk, in_dtype, out_dtype, tune=tune) layer = TopkSelectorLayer( batch, @@ -50,14 +51,4 @@ def test_topk_selector(batch: int, if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=64, help='batch size') - parser.add_argument('--seq_len', type=int, default=32 * 1024, help='sequence length') - parser.add_argument('--topk', type=int, default=2048, help='topk') - parser.add_argument('--in_dtype', type=str, default="float32", help='input type') - parser.add_argument('--out_dtype', type=str, default="int32", help='output type') - parser.add_argument('--tune', action='store_true', default=False, help='enable autotune') - args = parser.parse_args() - - test_topk_selector(args.batch, args.seq_len, args.topk, str2dtype[args.in_dtype], - str2dtype[args.out_dtype], args.tune) + pytest.main([__file__, "-vvs"]) diff --git a/tests/layers/test_gqa_decode_layer.py b/tests/layers/test_gqa_decode_layer.py index 68ec0082..22ff559b 100644 --- a/tests/layers/test_gqa_decode_layer.py +++ b/tests/layers/test_gqa_decode_layer.py @@ -1,11 +1,18 @@ -import argparse +import pytest +import torch from benchmarks import GroupQueryAttentionDecodeBenchmark from top.layers import GroupQueryAttentionDecodeLayer -from top.utils import str2dtype -def test_gqa_decode_layer(batch, heads, seq_len_kv, dim, groups, dtype): +@pytest.mark.parametrize( + "batch, heads, seq_len_kv, dim, groups, dtype", + [ + (1, 32, 8192, 128, 1, torch.float16), + ], +) +def test_gqa_decode_layer(batch: int, heads: int, seq_len_kv: int, dim: int, groups: int, + dtype: torch.dtype): fn = GroupQueryAttentionDecodeLayer(batch, heads, groups, seq_len_kv, dim, dtype) benchmark = GroupQueryAttentionDecodeBenchmark(batch, heads, groups, seq_len_kv, dim, dtype) @@ -14,16 +21,4 @@ def test_gqa_decode_layer(batch, heads, seq_len_kv, dim, groups, dtype): if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='batch size') - parser.add_argument('--groups', type=int, default=1, help='num groups') - parser.add_argument('--seq_len_kv', type=int, default=8192, help='key/value sequence length') - parser.add_argument('--heads', type=int, default=32, help='num heads') - parser.add_argument('--dim', type=int, default=128, help='head dim') - parser.add_argument( - '--dtype', type=str, default='float16', choices=['float16', 'bfloat16'], help='data type') - parser.add_argument('--tune', action='store_true', default=False, help='enable autotune') - args = parser.parse_args() - - test_gqa_decode_layer(args.batch, args.heads, args.seq_len_kv, args.dim, args.groups, - str2dtype[args.dtype]) + pytest.main([__file__, "-vvs"]) diff --git a/tests/layers/test_gqa_layer.py b/tests/layers/test_gqa_layer.py index 7ec3d468..d40f7101 100644 --- a/tests/layers/test_gqa_layer.py +++ b/tests/layers/test_gqa_layer.py @@ -1,12 +1,16 @@ -import argparse - +import pytest import torch from benchmarks import GroupQueryAttentionBenchmark from top.layers import GroupQueryAttentionLayer -from top.utils import str2dtype +@pytest.mark.parametrize( + "batch, seq_len, heads, heads_kv, dim, causal, dtype", + [ + (8, 1024, 32, 32, 128, False, torch.float16), + ], +) def test_gqa_layer(batch: int, seq_len: int, heads: int, heads_kv: int, dim: int, causal: bool, dtype: torch.dtype) -> None: @@ -18,16 +22,4 @@ def test_gqa_layer(batch: int, seq_len: int, heads: int, heads_kv: int, dim: int if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--seq_len', type=int, default=1024, help='sequence length') - parser.add_argument('--heads', type=int, default=32, help='num heads') - parser.add_argument('--heads_kv', type=int, default=32, help='num heads for key/value') - parser.add_argument('--dim', type=int, default=128, help='head dim') - parser.add_argument('--causal', action='store_true', default=False, help='causal attention') - parser.add_argument( - '--dtype', type=str, default='float16', choices=['float16', 'bfloat16'], help='data type') - args = parser.parse_args() - - test_gqa_layer(args.batch, args.seq_len, args.heads, args.heads_kv, args.dim, args.causal, - str2dtype[args.dtype]) + pytest.main([__file__, "-vvs"]) diff --git a/tests/layers/test_grouped_gemm_layer.py b/tests/layers/test_grouped_gemm_layer.py index a53a9944..ba69a557 100644 --- a/tests/layers/test_grouped_gemm_layer.py +++ b/tests/layers/test_grouped_gemm_layer.py @@ -1,11 +1,17 @@ -import argparse +import pytest +import torch from benchmarks import GroupedGemmBenchmark from top.layers import GroupedGemmLayer -from top.utils import str2dtype -def test_grouped_gemm_layer(batch_sum, batch_count, N, K, dtype): +@pytest.mark.parametrize( + "batch_sum, batch_count, N, K, dtype", + [ + (16384, 4, 4864, 8192, torch.float16), + ], +) +def test_grouped_gemm_layer(batch_sum: int, batch_count: int, N: int, K: int, dtype: torch.dtype): grouped_gemm = GroupedGemmLayer(batch_sum, batch_count, N, K, dtype) benchmark = GroupedGemmBenchmark(batch_sum, batch_count, N, K, dtype) inputs = benchmark.gen_inputs() @@ -18,13 +24,4 @@ def test_grouped_gemm_layer(batch_sum, batch_count, N, K, dtype): if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--batch_sum', type=int, default=16384, help='sum of batch_size_list') - parser.add_argument('--batch_count', type=int, default=4, help='length of batch_size_list') - parser.add_argument('--N', type=int, default=4864, help='head dim') - parser.add_argument('--K', type=int, default=8192, help='num heads') - parser.add_argument( - '--dtype', type=str, default='float16', choices=['float16', 'bfloat16'], help='data type') - args = parser.parse_args() - - test_grouped_gemm_layer(args.batch_sum, args.batch_count, args.N, args.K, str2dtype[args.dtype]) + pytest.main([__file__, "-vvs"]) diff --git a/tests/layers/test_linear.py b/tests/layers/test_linear.py index f54965cc..d0e6e2d8 100644 --- a/tests/layers/test_linear.py +++ b/tests/layers/test_linear.py @@ -1,12 +1,16 @@ -import argparse - +import pytest import torch from top.layers import LinearLayer -from top.utils import str2dtype -def test_linear(m: int, n: int, k: int, dtype: torch.dtype, tune: bool = False) -> None: +@pytest.mark.parametrize( + "m, n, k, dtype, tune", + [ + (1024, 1024, 1024, torch.float16, False), + ], +) +def test_linear(m: int, n: int, k: int, dtype: torch.dtype, tune: bool) -> None: linear_layer = LinearLayer(m, n, k, dtype=dtype, tune=tune) input_tensor = torch.randn(m, k, dtype=dtype, device='cuda', requires_grad=True) @@ -20,13 +24,4 @@ def test_linear(m: int, n: int, k: int, dtype: torch.dtype, tune: bool = False) if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--M', type=int, default=1024, help='M') - parser.add_argument('--N', type=int, default=1024, help='N') - parser.add_argument('--K', type=int, default=1024, help='K') - parser.add_argument( - '--dtype', type=str, default='float16', choices=['float16', 'bfloat16'], help='data type') - parser.add_argument('--tune', action='store_true', default=False, help='enable autotune') - args = parser.parse_args() - - test_linear(args.M, args.N, args.K, str2dtype[args.dtype], args.tune) + pytest.main([__file__, "-vvs"]) diff --git a/tests/layers/test_mha_decode_layer.py b/tests/layers/test_mha_decode_layer.py index 2af00280..d35d57d5 100644 --- a/tests/layers/test_mha_decode_layer.py +++ b/tests/layers/test_mha_decode_layer.py @@ -1,11 +1,18 @@ -import argparse +import pytest +import torch from benchmarks import MultiHeadAttentionDecodeBenchmark from top.layers import MultiHeadAttentionDecodeLayer -from top.utils import str2dtype -def test_mha_decode_layer(batch, seq_len_q, seq_len_kv, heads, dim, dtype): +@pytest.mark.parametrize( + "batch, seq_len_q, seq_len_kv, heads, dim, dtype", + [ + (1, 128, 8192, 32, 128, torch.float16), + ], +) +def test_mha_decode_layer(batch: int, seq_len_q: int, seq_len_kv: int, heads: int, dim: int, + dtype: torch.dtype): fn = MultiHeadAttentionDecodeLayer(batch, heads, seq_len_q, seq_len_kv, dim, dtype) benchmark = MultiHeadAttentionDecodeBenchmark(batch, heads, seq_len_q, seq_len_kv, dim, dtype) @@ -14,16 +21,4 @@ def test_mha_decode_layer(batch, seq_len_q, seq_len_kv, heads, dim, dtype): if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='batch size') - parser.add_argument('--seq_len_q', type=int, default=128, help='query sequence length') - parser.add_argument('--seq_len_kv', type=int, default=8192, help='key/value sequence length') - parser.add_argument('--heads', type=int, default=32, help='num heads') - parser.add_argument('--dim', type=int, default=128, help='head dim') - parser.add_argument( - '--dtype', type=str, default='float16', choices=['float16', 'bfloat16'], help='data type') - parser.add_argument('--tune', action='store_true', default=False, help='enable autotune') - args = parser.parse_args() - - test_mha_decode_layer(args.batch, args.seq_len_q, args.seq_len_kv, args.heads, args.dim, - str2dtype[args.dtype]) + pytest.main([__file__, "-vvs"]) diff --git a/tests/layers/test_mha_layer.py b/tests/layers/test_mha_layer.py index a33b7071..4a468cd1 100644 --- a/tests/layers/test_mha_layer.py +++ b/tests/layers/test_mha_layer.py @@ -1,12 +1,16 @@ -import argparse - +import pytest import torch from benchmarks import MultiHeadAttentionBenchmark from top.layers import MultiHeadAttentionLayer -from top.utils import str2dtype +@pytest.mark.parametrize( + "batch, seq_len, heads, dim, causal, dtype", + [ + (8, 1024, 32, 128, False, torch.float16), + ], +) def test_mha_layer(batch: int, seq_len: int, heads: int, dim: int, causal: bool, dtype: torch.dtype) -> None: @@ -18,15 +22,4 @@ def test_mha_layer(batch: int, seq_len: int, heads: int, dim: int, causal: bool, if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--seq_len', type=int, default=1024, help='sequence length') - parser.add_argument('--heads', type=int, default=32, help='num heads') - parser.add_argument('--dim', type=int, default=128, help='head dim') - parser.add_argument('--causal', action='store_true', default=False, help='causal attention') - parser.add_argument( - '--dtype', type=str, default='float16', choices=['float16', 'bfloat16'], help='data type') - args = parser.parse_args() - - test_mha_layer(args.batch, args.seq_len, args.heads, args.dim, args.causal, - str2dtype[args.dtype]) + pytest.main([__file__, "-vvs"]) diff --git a/tests/ops/test_deepseek_dsa_decode.py b/tests/ops/test_deepseek_dsa_decode.py index f42e9234..2b8e9d17 100644 --- a/tests/ops/test_deepseek_dsa_decode.py +++ b/tests/ops/test_deepseek_dsa_decode.py @@ -1,25 +1,20 @@ -import argparse - import torch +import pytest from benchmarks import DeepSeekSparseAttentionDecodeBenchmark from top.ops import DeepSeekSparseAttentionDecodeWithKVCacheOp -from top.utils import str2dtype -def test_sparse_mla_decode(batch: int, - heads: int, - seq_len_q: int, - seq_len_kv: int, - dim: int, - dim_tail: int, - topk: int, - stride_kv: int, - group_kv: int, - q_start_index_s: int, - sm_scale: float, - dtype: torch.dtype, - tune: bool = False) -> None: +@pytest.mark.parametrize( + "batch, heads, seq_len_q, seq_len_kv, dim, dim_tail, topk, stride_kv, group_kv, q_start_index_s, sm_scale, dtype, tune", + [ + (1, 128, 1024, 2048, 512, 64, 2048, 1, 1, 1024, None, torch.float16, False), + ], +) +def test_sparse_mla_decode(batch: int, heads: int, seq_len_q: int, seq_len_kv: int, dim: int, + dim_tail: int, topk: int, stride_kv: int, group_kv: int, + q_start_index_s: int, sm_scale: float, dtype: torch.dtype, + tune: bool) -> None: op = DeepSeekSparseAttentionDecodeWithKVCacheOp( batch, heads, @@ -54,23 +49,4 @@ def test_sparse_mla_decode(batch: int, if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='batch size') - parser.add_argument('--seq_len', type=int, default=1024, help='sequence length') - parser.add_argument('--seq_len_kv', type=int, default=2048, help='key/value sequence length') - parser.add_argument('--heads', type=int, default=128, help='num heads') - parser.add_argument('--dim', type=int, default=512, help='head dim') - parser.add_argument('--dim_tail', type=int, default=64, help='tail dim') - parser.add_argument('--topk', type=int, default=2048, help='topk') - parser.add_argument('--stride_kv', type=int, default=1, help='stride_kv') - parser.add_argument('--group_kv', type=int, default=1, help='group_kv') - parser.add_argument('--sm_scale', type=float, default=None, help='softmax scaling factor') - parser.add_argument('--q_start_index_s', type=int, default=1024, help='query start index') - parser.add_argument( - '--dtype', type=str, default='float16', choices=['float16', 'bfloat16'], help='data type') - parser.add_argument('--tune', action='store_true', default=False, help='enable autotune') - args = parser.parse_args() - - test_sparse_mla_decode(args.batch, args.heads, args.seq_len, args.seq_len_kv, args.dim, - args.dim_tail, args.topk, args.stride_kv, args.group_kv, - args.q_start_index_s, args.sm_scale, str2dtype[args.dtype], args.tune) + pytest.main([__file__, "-vvs"]) diff --git a/tests/ops/test_deepseek_mla_decode.py b/tests/ops/test_deepseek_mla_decode.py index f15c38d7..c3fb4fb7 100644 --- a/tests/ops/test_deepseek_mla_decode.py +++ b/tests/ops/test_deepseek_mla_decode.py @@ -1,11 +1,18 @@ -import argparse +import pytest +import torch from benchmarks import MultiHeadLatentAttentionDecodeBenchmark from top.ops import MultiHeadLatentAttentionDecodeWithKVCacheOp -from top.utils import str2dtype -def test_mla_decode(batch, heads, head_num_kv, seq_len_kv, dim, dim_pe, dtype, tune=False): +@pytest.mark.parametrize( + "batch, heads, head_num_kv, seq_len_kv, dim, dim_pe, dtype, tune", + [ + (32, 128, 1, 8192, 512, 64, torch.float16, False), + ], +) +def test_mla_decode(batch: int, heads: int, head_num_kv: int, seq_len_kv: int, dim: int, + dim_pe: int, dtype: torch.dtype, tune: bool): op = MultiHeadLatentAttentionDecodeWithKVCacheOp( batch, heads, head_num_kv, seq_len_kv, dim, dim_pe, dtype, tune=tune) benchmark = MultiHeadLatentAttentionDecodeBenchmark(batch, heads, head_num_kv, seq_len_kv, dim, @@ -17,17 +24,4 @@ def test_mla_decode(batch, heads, head_num_kv, seq_len_kv, dim, dim_pe, dtype, t if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=32, help='batch size') - parser.add_argument('--head_num_kv', type=int, default=1, help='number of key/value heads') - parser.add_argument('--seq_len_kv', type=int, default=8192, help='key/value sequence length') - parser.add_argument('--heads', type=int, default=128, help='num heads') - parser.add_argument('--dim', type=int, default=512, help='head dim') - parser.add_argument('--dim_pe', type=int, default=64, help='positional encoding dim') - parser.add_argument( - '--dtype', type=str, default='float16', choices=['float16', 'bfloat16'], help='data type') - parser.add_argument('--tune', action='store_true', default=False, help='enable autotune') - args = parser.parse_args() - - test_mla_decode(args.batch, args.heads, args.head_num_kv, args.seq_len_kv, args.dim, - args.dim_pe, str2dtype[args.dtype], args.tune) + pytest.main([__file__, "-vvs"]) diff --git a/tests/ops/test_deepseek_nsa_cmp_fwd.py b/tests/ops/test_deepseek_nsa_cmp_fwd.py index a24d7071..8a92d587 100644 --- a/tests/ops/test_deepseek_nsa_cmp_fwd.py +++ b/tests/ops/test_deepseek_nsa_cmp_fwd.py @@ -5,11 +5,6 @@ from top.ops import NSACmpFwdVarlenOp -@pytest.fixture(autouse=True) -def setup() -> None: - torch.manual_seed(1234) - - @pytest.mark.parametrize( ("seq_num, c_seq_len, heads, dim_k, dim_v, group, scale, bc, bs, bk, bv, " "dtype, accum_dtype, tune"), @@ -48,18 +43,4 @@ def test_nsa_cmp_fwd_varlen_op( if __name__ == "__main__": - test_nsa_cmp_fwd_varlen_op( - seq_num=12, - c_seq_len=8192, - heads=32, - dim_k=128, - dim_v=128, - group=16, - scale=128**-0.5, - bc=32, - bs=32, - bk=128, - bv=128, - dtype=torch.float16, - accum_dtype=torch.float32, - tune=False) + pytest.main([__file__, "-vvs"]) diff --git a/tests/ops/test_deepseek_nsa_fwd.py b/tests/ops/test_deepseek_nsa_fwd.py index 0911f6b8..ec92358a 100644 --- a/tests/ops/test_deepseek_nsa_fwd.py +++ b/tests/ops/test_deepseek_nsa_fwd.py @@ -7,12 +7,6 @@ from top.ops import NSAFwdVarlenOp -@pytest.fixture(autouse=True) -def setup() -> None: - """Set up the test environment.""" - torch.manual_seed(1234) - - @pytest.mark.parametrize( ("batch, heads, c_seq_len, dim, is_causal, scale, block_size, " "groups, selected_blocks, dtype, accum_dtype, tune"), @@ -61,7 +55,4 @@ def test_nsa_varlen_op( if __name__ == "__main__": - - test_nsa_varlen_op(1, 16, 1024, 64, True, 0.1, 32, 16, 1, torch.float16, torch.float32, False) - test_nsa_varlen_op(4, 16, 8192, 64, True, 0.1, 32, 16, 1, torch.float16, torch.float32, False) - test_nsa_varlen_op(2, 16, 8192, 64, True, 0.1, 32, 16, 4, torch.float16, torch.float32, False) + pytest.main([__file__, "-vvs"]) diff --git a/tests/ops/test_deepseek_nsa_gqa_window_sliding.py b/tests/ops/test_deepseek_nsa_gqa_window_sliding.py index fd5828ae..30eccf52 100644 --- a/tests/ops/test_deepseek_nsa_gqa_window_sliding.py +++ b/tests/ops/test_deepseek_nsa_gqa_window_sliding.py @@ -7,12 +7,6 @@ from top.ops import GQAWindowSlidingOp -@pytest.fixture(autouse=True) -def setup() -> None: - """Set up the test environment.""" - torch.manual_seed(1234) - - @pytest.mark.parametrize( ("batch_size", "groups", "uq", "ukv", "heads", "dim", "is_causal", "window_size_left", "window_size_right", "dtype", "accum_dtype", "tune"), @@ -57,47 +51,8 @@ def test_nsa_gqa_window_sliding_op( op = GQAWindowSlidingOp(**params) inputs = benchmark.gen_inputs() - benchmark.check(op, *inputs) + benchmark.check(op, *inputs, atol=3e-3, rtol=1e-5) if __name__ == "__main__": - - test_nsa_gqa_window_sliding_op( - batch_size=1, - groups=16, - uq=1024, - ukv=1024, - heads=64, - dim=128, - is_causal=True, - window_size_left=32, - window_size_right=-1, - dtype=torch.float16, - accum_dtype=torch.float32, - tune=False) - test_nsa_gqa_window_sliding_op( - batch_size=3, - groups=16, - uq=8192, - ukv=8192, - heads=64, - dim=128, - is_causal=True, - window_size_left=2048, - window_size_right=0, - dtype=torch.float16, - accum_dtype=torch.float32, - tune=False) - test_nsa_gqa_window_sliding_op( - batch_size=3, - groups=16, - uq=8192, - ukv=8192, - heads=64, - dim=128, - is_causal=False, - window_size_left=-1, - window_size_right=-1, - dtype=torch.float16, - accum_dtype=torch.float32, - tune=False) + pytest.main([__file__, "-vvs"]) diff --git a/tests/ops/test_deepseek_nsa_topk.py b/tests/ops/test_deepseek_nsa_topk.py index 3967dcea..9ddc363b 100644 --- a/tests/ops/test_deepseek_nsa_topk.py +++ b/tests/ops/test_deepseek_nsa_topk.py @@ -5,11 +5,6 @@ from top.ops import NSATopkVarlenOp -@pytest.fixture(autouse=True) -def setup() -> None: - torch.manual_seed(1234) - - @pytest.mark.parametrize( ("seq_num, c_seq_len, heads, dim, group, scale, selected_block_num, bc, bs, bk, " "dtype, accum_dtype, tune"), @@ -58,7 +53,4 @@ def test_nsa_topk_varlen_op( if __name__ == "__main__": - test_nsa_topk_varlen_op(5, 1024, 32, 128, 16, 1, 16, 32, 32, 128, torch.float16, torch.float32, - False) - test_nsa_topk_varlen_op(3, 512, 32, 128, 16, 1, 16, 32, 32, 128, torch.float16, torch.float32, - False) + pytest.main([__file__, "-vvs"]) diff --git a/tests/ops/test_fp8_lighting_indexer.py b/tests/ops/test_fp8_lighting_indexer.py index e7c056e8..3ae2dec9 100644 --- a/tests/ops/test_fp8_lighting_indexer.py +++ b/tests/ops/test_fp8_lighting_indexer.py @@ -1,17 +1,19 @@ -import argparse from typing import Optional +import pytest + from benchmarks import Fp8LightingIndexerBenchmark from top.ops import Fp8LightingIndexerOp -def test_indexer(seq_len: int, - heads: int, - index_dim: int, - seq_len_kv: int, - clean_logits: bool, - config: Optional[dict], - tune: bool = False) -> None: +@pytest.mark.parametrize( + "seq_len, heads, index_dim, seq_len_kv, clean_logits, config, tune", + [ + (4096, 32, 64, 8192, True, None, False), + ], +) +def test_indexer(seq_len: int, heads: int, index_dim: int, seq_len_kv: int, clean_logits: bool, + config: Optional[dict], tune: bool) -> None: op = Fp8LightingIndexerOp( seq_len, heads, index_dim, seq_len_kv, clean_logits, config, tune=tune) benchmark = Fp8LightingIndexerBenchmark(seq_len, heads, index_dim, seq_len_kv, clean_logits, @@ -24,19 +26,4 @@ def test_indexer(seq_len: int, if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') - parser.add_argument('--heads', type=int, default=32, help='number of heads') - parser.add_argument('--index_dim', type=int, default=64, help='index dim') - parser.add_argument('--seq_len_kv', type=int, default=8192, help='key/value sequence length') - parser.add_argument( - '--clean_logits', - action=argparse.BooleanOptionalAction, - default=True, - help='whether to clean logits outside the valid range') - parser.add_argument('--config', type=str, default=None, help='positional encoding dim') - parser.add_argument('--tune', action='store_true', default=False, help='enable autotune') - args = parser.parse_args() - - test_indexer(args.seq_len, args.heads, args.index_dim, args.seq_len_kv, args.clean_logits, - args.config, args.tune) + pytest.main([__file__, "-vvs"]) diff --git a/tests/ops/test_fp8_quant.py b/tests/ops/test_fp8_quant.py index e623e2fc..8c0d72d3 100644 --- a/tests/ops/test_fp8_quant.py +++ b/tests/ops/test_fp8_quant.py @@ -30,7 +30,4 @@ def test_fp8_quant_op(seq_len_kv: int, index_dim: int, in_dtype: torch.dtype, tu if __name__ == "__main__": - test_fp8_quant_op(8192, 64, torch.float16, False) - test_fp8_quant_op(8192, 64, torch.bfloat16, False) - test_fp8_quant_op(4096, 128, torch.float32, False) - test_fp8_quant_op(16384, 32, torch.float32, False) + pytest.main([__file__, "-vvs"]) diff --git a/tests/ops/test_gemm.py b/tests/ops/test_gemm.py index afea2b2d..b2ee7068 100644 --- a/tests/ops/test_gemm.py +++ b/tests/ops/test_gemm.py @@ -1,19 +1,18 @@ -import argparse - import torch +import pytest from benchmarks import GemmBenchmark from top.ops import GemmOp -from top.utils import str2dtype -def test_gemm(m: int, - n: int, - k: int, - dtype: torch.dtype, - trans_a: bool = False, - trans_b: bool = False, - tune: bool = False) -> None: +@pytest.mark.parametrize( + "m, n, k, dtype, trans_a, trans_b, tune", + [ + (1024, 1024, 1024, torch.float16, False, False, False), + ], +) +def test_gemm(m: int, n: int, k: int, dtype: torch.dtype, trans_a: bool, trans_b: bool, + tune: bool) -> None: op = GemmOp(m, n, k, trans_a=trans_a, trans_b=trans_b, dtype=dtype, tune=tune) benchmark = GemmBenchmark(m, n, k, dtype, trans_a=trans_a, trans_b=trans_b) @@ -23,15 +22,5 @@ def test_gemm(m: int, if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--m', type=int, default=1024, help='M') - parser.add_argument('--n', type=int, default=1024, help='N') - parser.add_argument('--k', type=int, default=1024, help='K') - parser.add_argument( - '--dtype', type=str, default='float16', choices=['float16', 'bfloat16'], help='data type') - parser.add_argument('--trans_a', action='store_true', default=False, help='transpose input A') - parser.add_argument('--trans_b', action='store_true', default=False, help='transpose input B') - parser.add_argument('--tune', action='store_true', default=False, help='enable autotune') - args = parser.parse_args() - - test_gemm(args.m, args.n, args.k, str2dtype[args.dtype], args.trans_a, args.trans_b, args.tune) + # Run tests with pytest + pytest.main([__file__, "-vvs"]) diff --git a/tests/ops/test_gqa.py b/tests/ops/test_gqa.py index ae9f4eaf..0ecd1d55 100644 --- a/tests/ops/test_gqa.py +++ b/tests/ops/test_gqa.py @@ -1,64 +1,41 @@ -import argparse - +import pytest import torch from benchmarks import GroupQueryAttentionBwdBenchmark, GroupQueryAttentionFwdBenchmark from top.ops import GroupQueryAttentionBwdOp, GroupQueryAttentionFwdOp -from top.utils import str2dtype -def test_gqa_fwd(batch: int, - seq_len: int, - heads: int, - heads_kv: int, - dim: int, - causal: bool, - dtype: torch.dtype, - tune: bool = False) -> None: +@pytest.mark.parametrize("batch, seq_len, heads, heads_kv, dim, causal, dtype, tune", [ + (1, 1024, 8, 4, 64, False, torch.float16, False), + (4, 2048, 64, 4, 128, False, torch.float16, False), + (4, 2048, 64, 4, 128, False, torch.bfloat16, False), +]) +def test_gqa_fwd(batch: int, seq_len: int, heads: int, heads_kv: int, dim: int, causal: bool, + dtype: torch.dtype, tune: bool) -> None: op = GroupQueryAttentionFwdOp(batch, heads, heads_kv, seq_len, dim, causal, dtype, tune=tune) benchmark = GroupQueryAttentionFwdBenchmark(batch, heads, heads_kv, seq_len, dim, causal, dtype) inputs = benchmark.gen_inputs() print("Forward Results:") - benchmark.check(op, *inputs, atol=5e-4, rtol=1e-5) + benchmark.check(op, *inputs, atol=5e-3, rtol=1e-5) benchmark.profile(op, *inputs) -def test_gqa_bwd(batch: int, - seq_len: int, - heads: int, - heads_kv: int, - dim: int, - causal: bool, - dtype: torch.dtype, - tune: bool = False) -> None: +@pytest.mark.parametrize("batch, seq_len, heads, heads_kv, dim, causal, dtype, tune", [ + (1, 1024, 8, 4, 64, False, torch.float16, False), + (4, 2048, 64, 4, 128, False, torch.float16, False), + (4, 2048, 64, 4, 128, False, torch.bfloat16, False), +]) +def test_gqa_bwd(batch: int, seq_len: int, heads: int, heads_kv: int, dim: int, causal: bool, + dtype: torch.dtype, tune: bool) -> None: op = GroupQueryAttentionBwdOp(batch, heads, heads_kv, seq_len, dim, causal, dtype, tune=tune) benchmark = GroupQueryAttentionBwdBenchmark(batch, heads, heads_kv, seq_len, dim, causal, dtype) inputs = benchmark.gen_inputs() print("Backward Results:") - benchmark.check(op, *inputs, atol=5e-4, rtol=1e-5) + benchmark.check(op, *inputs, atol=5e-3, rtol=1e-5) benchmark.profile(op, *inputs) if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--seq_len', type=int, default=1024, help='sequence length') - parser.add_argument('--heads', type=int, default=32, help='num heads') - parser.add_argument('--heads_kv', type=int, default=8, help='num heads kv') - parser.add_argument('--dim', type=int, default=128, help='head dim') - parser.add_argument('--causal', action='store_true', default=False, help='causal attention') - parser.add_argument( - '--dtype', type=str, default='float16', choices=['float16', 'bfloat16'], help='data type') - parser.add_argument('--tune', action='store_true', default=False, help='enable autotune') - parser.add_argument( - '--disable_bwd', action='store_false', default=True, help='when test fwd profile') - args = parser.parse_args() - - test_gqa_fwd(args.batch, args.seq_len, args.heads, args.heads_kv, args.dim, args.causal, - str2dtype[args.dtype], args.tune) - - if args.disable_bwd: - test_gqa_bwd(args.batch, args.seq_len, args.heads, args.heads_kv, args.dim, args.causal, - str2dtype[args.dtype], args.tune) + pytest.main([__file__, "-vvs"]) diff --git a/tests/ops/test_gqa_decode.py b/tests/ops/test_gqa_decode.py index 98175771..83a33fb3 100644 --- a/tests/ops/test_gqa_decode.py +++ b/tests/ops/test_gqa_decode.py @@ -1,19 +1,18 @@ -import argparse - import torch +import pytest from benchmarks import GroupQueryAttentionDecodeBenchmark from top.ops import GroupQueryAttentionDecodeWithKVCacheOp -from top.utils import str2dtype -def test_gqa_decode(b: int, - h: int, - g: int, - s_kv: int, - d: int, - dtype: torch.dtype, - tune: bool = False) -> None: +@pytest.mark.parametrize( + "b, h, g, s_kv, d, dtype, tune", + [ + (1, 32, 8, 8192, 128, torch.float16, False), + ], +) +def test_gqa_decode(b: int, h: int, g: int, s_kv: int, d: int, dtype: torch.dtype, + tune: bool) -> None: op = GroupQueryAttentionDecodeWithKVCacheOp(b, h, g, s_kv, d, dtype, tune=tune) benchmark = GroupQueryAttentionDecodeBenchmark(b, h, g, s_kv, d, dtype) @@ -23,20 +22,4 @@ def test_gqa_decode(b: int, if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='batch size') - parser.add_argument('--groups', type=int, default=8, help='number of groups') - parser.add_argument('--seq_len_kv', type=int, default=8192, help='key/value sequence length') - parser.add_argument('--heads', type=int, default=32, help='num heads') - parser.add_argument('--dim', type=int, default=128, help='head dim') - parser.add_argument( - '--dtype', type=str, default='float16', choices=['float16', 'bfloat16'], help='data type') - parser.add_argument('--tune', action='store_true', default=False, help='enable autotune') - args = parser.parse_args() - - test_gqa_decode(args.batch, args.heads, args.groups, args.seq_len_kv, args.dim, - str2dtype['float16'], args.tune) - test_gqa_decode(args.batch, args.heads, args.groups, args.seq_len_kv, args.dim, - str2dtype['bfloat16'], args.tune) - test_gqa_decode(args.batch, args.heads, args.groups, 10, args.dim, str2dtype[args.dtype], - args.tune) + pytest.main([__file__, "-vvs"]) diff --git a/tests/ops/test_gqa_decode_paged.py b/tests/ops/test_gqa_decode_paged.py index 76a6a2ef..10c6f74c 100644 --- a/tests/ops/test_gqa_decode_paged.py +++ b/tests/ops/test_gqa_decode_paged.py @@ -10,11 +10,6 @@ from top.ops import GroupQueryAttentionDecodePagedWithKVCacheOp -@pytest.fixture(autouse=True) -def setup() -> None: - torch.manual_seed(12345) - - def _torch_ref_gqa_decode_paged( q: torch.Tensor, k: torch.Tensor, @@ -116,3 +111,7 @@ def test_gqa_decode_paged_op( cos_sim = F.cosine_similarity( output.reshape(batch, -1), output_ref.reshape(batch, -1), dim=-1, eps=1e-8) assert cos_sim.min() > 0.99, f"cosine similarity {cos_sim.min().item()} too low" + + +if __name__ == "__main__": + pytest.main([__file__, "-vvs"]) diff --git a/tests/ops/test_gqa_decode_paged_legacy.py b/tests/ops/test_gqa_decode_paged_legacy.py index a9357200..aa69e8d9 100644 --- a/tests/ops/test_gqa_decode_paged_legacy.py +++ b/tests/ops/test_gqa_decode_paged_legacy.py @@ -1,14 +1,15 @@ """Legacy-style test for GroupQueryAttentionDecodePagedWithKVCacheOp (argparse + check + profile).""" -import argparse - +import pytest import torch from benchmarks.flash_decode import GroupQueryAttentionDecodePagedBenchmark from top.ops import GroupQueryAttentionDecodePagedWithKVCacheOp -from top.utils import str2dtype +@pytest.mark.parametrize("batch,heads,groups,seqlen_kv,dim,page_size,dtype", [ + (1, 16, 8, 512, 128, 128, torch.float16), +]) def test_gqa_decode_paged( batch: int, heads: int, @@ -19,42 +20,16 @@ def test_gqa_decode_paged( dtype: torch.dtype, tune: bool = False, ) -> None: + torch.manual_seed(123) # 替代 fixture 中的随机种子设置 op = GroupQueryAttentionDecodePagedWithKVCacheOp( batch, heads, groups, seqlen_kv, dim, page_size, dtype, tune=tune) benchmark = GroupQueryAttentionDecodePagedBenchmark(batch, heads, groups, seqlen_kv, dim, page_size, dtype) inputs = benchmark.gen_inputs() - benchmark.check(op, *inputs) + benchmark.check(op, *inputs, atol=1e-2, rtol=1e-2) benchmark.profile(op, *inputs) if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--batch", type=int, default=1, help="batch size") - parser.add_argument("--heads", type=int, default=16, help="num heads") - parser.add_argument("--groups", type=int, default=8, help="num kv groups") - parser.add_argument("--seqlen_kv", type=int, default=512, help="key/value sequence length") - parser.add_argument("--dim", type=int, default=128, help="head dim") - parser.add_argument("--page_size", type=int, default=128, help="page size") - parser.add_argument( - "--dtype", - type=str, - default="float16", - choices=["float16", "bfloat16"], - help="data type", - ) - parser.add_argument("--tune", action="store_true", default=False, help="enable autotune") - args = parser.parse_args() - - dtype = str2dtype[args.dtype] - test_gqa_decode_paged( - args.batch, - args.heads, - args.groups, - args.seqlen_kv, - args.dim, - args.page_size, - dtype, - args.tune, - ) + pytest.main([__file__, "-vvs"]) diff --git a/tests/ops/test_grouped_gemm.py b/tests/ops/test_grouped_gemm.py index 139dee5a..01220ee2 100644 --- a/tests/ops/test_grouped_gemm.py +++ b/tests/ops/test_grouped_gemm.py @@ -1,7 +1,7 @@ -import argparse import time import torch +import pytest from benchmarks import ( GroupedGemmBenchmark, @@ -11,10 +11,16 @@ GroupedGemmTTBenchmark, ) from top.ops.grouped_gemm import GroupedGemmNNOp, GroupedGemmNTOp, GroupedGemmTNOp, GroupedGemmTTOp -from top.utils import str2dtype -def test_grouped_gemm_nt(batch_sum, batch_count, N, K, dtype, tune=False): +@pytest.mark.parametrize( + "batch_sum, batch_count, N, K, dtype, tune", + [ + (16384, 4, 4864, 4096, torch.float16, False), + ], +) +def test_grouped_gemm_nt(batch_sum: int, batch_count: int, N: int, K: int, dtype: torch.dtype, + tune: bool): op = GroupedGemmNTOp(batch_sum, batch_count, N, K, dtype, tune=tune) benchmark = GroupedGemmNTBenchmark(batch_sum, batch_count, N, K, dtype) @@ -23,7 +29,14 @@ def test_grouped_gemm_nt(batch_sum, batch_count, N, K, dtype, tune=False): benchmark.profile(op, *inputs) -def test_grouped_gemm_nn(batch_sum, batch_count, N, K, dtype, tune=False): +@pytest.mark.parametrize( + "batch_sum, batch_count, N, K, dtype, tune", + [ + (16384, 4, 4864, 4096, torch.float16, False), + ], +) +def test_grouped_gemm_nn(batch_sum: int, batch_count: int, N: int, K: int, dtype: torch.dtype, + tune: bool): op = GroupedGemmNNOp(batch_sum, batch_count, N, K, dtype, tune=tune) benchmark = GroupedGemmNNBenchmark(batch_sum, batch_count, N, K, dtype) @@ -32,7 +45,14 @@ def test_grouped_gemm_nn(batch_sum, batch_count, N, K, dtype, tune=False): benchmark.profile(op, *inputs) -def test_grouped_gemm_tn(batch_sum, batch_count, N, K, dtype, tune=False): +@pytest.mark.parametrize( + "batch_sum, batch_count, N, K, dtype, tune", + [ + (16384, 4, 4864, 4096, torch.float16, False), + ], +) +def test_grouped_gemm_tn(batch_sum: int, batch_count: int, N: int, K: int, dtype: torch.dtype, + tune: bool): op = GroupedGemmTNOp(batch_sum, batch_count, N, K, dtype, tune=tune) benchmark = GroupedGemmTNBenchmark(batch_sum, batch_count, N, K, dtype) @@ -41,7 +61,14 @@ def test_grouped_gemm_tn(batch_sum, batch_count, N, K, dtype, tune=False): benchmark.profile(op, *inputs) -def test_grouped_gemm_tt(batch_sum, batch_count, N, K, dtype, tune=False): +@pytest.mark.parametrize( + "batch_sum, batch_count, N, K, dtype, tune", + [ + (16384, 4, 4864, 4096, torch.float16, False), + ], +) +def test_grouped_gemm_tt(batch_sum: int, batch_count: int, N: int, K: int, dtype: torch.dtype, + tune: bool): op = GroupedGemmTTOp(batch_sum, batch_count, N, K, dtype, tune=tune) benchmark = GroupedGemmTTBenchmark(batch_sum, batch_count, N, K, dtype) @@ -50,7 +77,14 @@ def test_grouped_gemm_tt(batch_sum, batch_count, N, K, dtype, tune=False): benchmark.profile(op, *inputs) -def test_grouped_gemm_complete(batch_sum, batch_count, N, K, dtype, tune=False): +@pytest.mark.parametrize( + "batch_sum, batch_count, N, K, dtype, tune", + [ + (16384, 4, 4864, 4096, torch.float16, False), + ], +) +def test_grouped_gemm_complete(batch_sum: int, batch_count: int, N: int, K: int, dtype: torch.dtype, + tune: bool): from top.functions.grouped_gemm import GroupedGemmFunc op = GroupedGemmFunc(batch_sum, batch_count, N, K, dtype, tune=tune) @@ -84,29 +118,4 @@ def test_grouped_gemm_complete(batch_sum, batch_count, N, K, dtype, tune=False): if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--batch_sum', type=int, default=16384, help='sum of batch_size_list') - parser.add_argument('--batch_count', type=int, default=4, help='length of batch_size_list') - parser.add_argument('--N', type=int, default=4864, help='head dim') - parser.add_argument('--K', type=int, default=4096, help='num heads') - parser.add_argument( - '--dtype', type=str, default='float16', choices=['float16', 'bfloat16'], help='data type') - parser.add_argument('--tune', action='store_true', help='enable autotune') - - args = parser.parse_args() - - print("Testing grouped_gemm_nt (forward)...") - test_grouped_gemm_nt(args.batch_sum, args.batch_count, args.N, args.K, str2dtype[args.dtype], - args.tune) - print("Testing grouped_gemm_nn (backward dA)...") - test_grouped_gemm_nn(args.batch_sum, args.batch_count, args.N, args.K, str2dtype[args.dtype], - args.tune) - print("Testing grouped_gemm_tn (backward dB)...") - test_grouped_gemm_tn(args.batch_sum, args.batch_count, args.N, args.K, str2dtype[args.dtype], - args.tune) - print("Testing grouped_gemm_tt (backward dB)...") - test_grouped_gemm_tt(args.batch_sum, args.batch_count, args.N, args.K, str2dtype[args.dtype], - args.tune) - print("Testing complete grouped_gemm function...") - test_grouped_gemm_complete(args.batch_sum, args.batch_count, args.N, args.K, - str2dtype[args.dtype], args.tune) + pytest.main([__file__, "-vvs"]) diff --git a/tests/ops/test_mean_pooling_ops.py b/tests/ops/test_mean_pooling_ops.py index 8c3b132e..2d3f6d59 100644 --- a/tests/ops/test_mean_pooling_ops.py +++ b/tests/ops/test_mean_pooling_ops.py @@ -71,13 +71,4 @@ def test_mean_pooling_op(batch_size: int, seq_len: int, heads: int, dim: int, ch if __name__ == "__main__": - test_mean_pooling_op(1, 8192, 64, 128, 64, torch.float16, torch.float32, False, None) - test_mean_pooling_op(1, 8192, 64, 128, 64, torch.float16, torch.float32, True, None) - test_mean_pooling_op(2, 2049, 64, 128, 64, torch.float16, torch.float32, False, None) - test_mean_pooling_op(1, 1024, 64, 128, 64, torch.float16, torch.float32, False, - torch.tensor([0, 256, 768, 1024], dtype=torch.int32, device='cuda')) - test_mean_pooling_op( - 1, 8192, 64, 128, 64, torch.float16, torch.float32, True, - torch.tensor([0, 2048, 4096, 6144, 8192], dtype=torch.int32, device='cuda')) - test_mean_pooling_op(1, 1000, 64, 128, 32, torch.float16, torch.float32, True, - torch.tensor([0, 100, 300, 600, 1000], dtype=torch.int32, device='cuda')) + pytest.main([__file__, "-vvs"]) diff --git a/tests/ops/test_mha.py b/tests/ops/test_mha.py index b0b87a3a..8c040890 100644 --- a/tests/ops/test_mha.py +++ b/tests/ops/test_mha.py @@ -1,46 +1,41 @@ -import argparse +import pytest +import torch from benchmarks import MultiHeadAttentionBwdBenchmark, MultiHeadAttentionFwdBenchmark from top.ops import MultiHeadAttentionBwdOp, MultiHeadAttentionFwdOp -from top.utils import str2dtype -def test_mha_fwd(batch, seq_len, heads, dim, causal, dtype, tune=False): +@pytest.mark.parametrize("batch, seq_len, heads, dim, causal, dtype, tune", [ + (1, 1024, 8, 64, False, torch.float16, False), + (16, 2048, 16, 128, False, torch.float16, False), + (4, 4096, 16, 128, False, torch.bfloat16, True), +]) +def test_mha_fwd(batch: int, seq_len: int, heads: int, dim: int, causal: bool, dtype: torch.dtype, + tune: bool) -> None: op = MultiHeadAttentionFwdOp(batch, heads, seq_len, dim, causal, dtype, tune=tune) benchmark = MultiHeadAttentionFwdBenchmark(batch, heads, seq_len, dim, causal, dtype) inputs = benchmark.gen_inputs() print("Forward Results:") - benchmark.check(op, *inputs, atol=5e-4, rtol=1e-5) + benchmark.check(op, *inputs, atol=5e-3, rtol=1e-5) benchmark.profile(op, *inputs) -def test_mha_bwd(batch, seq_len, heads, dim, causal, dtype, tune=False): +@pytest.mark.parametrize("batch, seq_len, heads, dim, causal, dtype, tune", [ + (1, 1024, 8, 64, False, torch.float16, False), + (16, 2048, 16, 128, False, torch.float16, False), + (4, 4096, 16, 128, False, torch.bfloat16, True), +]) +def test_mha_bwd(batch: int, seq_len: int, heads: int, dim: int, causal: bool, dtype: torch.dtype, + tune: bool) -> None: op = MultiHeadAttentionBwdOp(batch, heads, seq_len, dim, causal, dtype, tune=tune) benchmark = MultiHeadAttentionBwdBenchmark(batch, heads, seq_len, dim, causal, dtype) inputs = benchmark.gen_inputs() print("Backward Results:") - benchmark.check(op, *inputs, atol=5e-4, rtol=1e-5) + benchmark.check(op, *inputs, atol=5e-3, rtol=1e-5) benchmark.profile(op, *inputs) if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--seq_len', type=int, default=1024, help='sequence length') - parser.add_argument('--heads', type=int, default=32, help='num heads') - parser.add_argument('--dim', type=int, default=128, help='head dim') - parser.add_argument('--causal', action='store_true', default=False, help='causal attention') - parser.add_argument( - '--dtype', type=str, default='float16', choices=['float16', 'bfloat16'], help='data type') - parser.add_argument('--tune', action='store_true', default=False, help='enable autotune') - parser.add_argument( - '--disable_bwd', action='store_false', default=True, help='when test fwd profile') - args = parser.parse_args() - - test_mha_fwd(args.batch, args.seq_len, args.heads, args.dim, args.causal, str2dtype[args.dtype], - args.tune) - if args.disable_bwd: - test_mha_bwd(args.batch, args.seq_len, args.heads, args.dim, args.causal, - str2dtype[args.dtype], args.tune) + pytest.main([__file__, "-vvs"]) diff --git a/tests/ops/test_mha_decode.py b/tests/ops/test_mha_decode.py index aa433481..3e9a7103 100644 --- a/tests/ops/test_mha_decode.py +++ b/tests/ops/test_mha_decode.py @@ -1,24 +1,20 @@ -import argparse - import torch +import pytest from benchmarks import MultiHeadAttentionDecodeBenchmark from top.ops import MultiHeadAttentionDecodeWithKVCacheOp -from top.utils import str2dtype - -# Set fixed seed for reproducibility -torch.manual_seed(42) -if torch.cuda.is_available(): - torch.cuda.manual_seed_all(42) -def test_mha_decode(b: int, - h: int, - s_q: int, - s_kv: int, - d: int, - dtype: torch.dtype, - tune: bool = False) -> None: +@pytest.mark.parametrize( + ("b", "h", "s_q", "s_kv", "d", "dtype", "tune"), + [ + (1, 32, 128, 8192, 128, torch.float16, False), + (1, 32, 128, 8192, 128, torch.bfloat16, False), + (1, 32, 128, 5, 128, torch.float16, False), + ], +) +def test_mha_decode(b: int, h: int, s_q: int, s_kv: int, d: int, dtype: torch.dtype, + tune: bool) -> None: op = MultiHeadAttentionDecodeWithKVCacheOp(b, h, s_q, s_kv, d, dtype, tune=tune) benchmark = MultiHeadAttentionDecodeBenchmark(b, h, s_q, s_kv, d, dtype) @@ -28,20 +24,4 @@ def test_mha_decode(b: int, if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=1, help='batch size') - parser.add_argument('--seq_len_q', type=int, default=128, help='query sequence length') - parser.add_argument('--seq_len_kv', type=int, default=8192, help='key/value sequence length') - parser.add_argument('--heads', type=int, default=32, help='num heads') - parser.add_argument('--dim', type=int, default=128, help='head dim') - parser.add_argument( - '--dtype', type=str, default='bfloat16', choices=['float16', 'bfloat16'], help='data type') - parser.add_argument('--tune', action='store_true', default=False, help='enable autotune') - args = parser.parse_args() - - test_mha_decode(args.batch, args.heads, args.seq_len_q, args.seq_len_kv, args.dim, - str2dtype["bfloat16"], args.tune) - test_mha_decode(args.batch, args.heads, args.seq_len_q, args.seq_len_kv, args.dim, - str2dtype["float16"], args.tune) - test_mha_decode(args.batch, args.heads, args.seq_len_q, 5, args.dim, str2dtype["float16"], - args.tune) + pytest.main([__file__, "-vvs"]) diff --git a/tests/ops/test_mha_decode_paged.py b/tests/ops/test_mha_decode_paged.py index 9939d395..7e1d068c 100644 --- a/tests/ops/test_mha_decode_paged.py +++ b/tests/ops/test_mha_decode_paged.py @@ -9,12 +9,6 @@ from top.ops import MultiHeadAttentionDecodePagedWithKVCacheOp -@pytest.fixture(autouse=True) -def setup() -> None: - """Set up the test environment.""" - torch.manual_seed(12345) - - def _torch_ref_mha_decode_paged( q: torch.Tensor, k: torch.Tensor, @@ -112,3 +106,7 @@ def test_mha_decode_paged_op( output.reshape(batch, -1), output_ref.reshape(batch, -1), dim=-1, eps=1e-8) assert cos_sim.min() > 0.99, f"cosine similarity {cos_sim.min().item()} too low" torch.cuda.empty_cache() + + +if __name__ == "__main__": + pytest.main([__file__, "-vvs"]) diff --git a/tests/ops/test_mha_decode_paged_legacy.py b/tests/ops/test_mha_decode_paged_legacy.py index 6efd8804..25015c83 100644 --- a/tests/ops/test_mha_decode_paged_legacy.py +++ b/tests/ops/test_mha_decode_paged_legacy.py @@ -1,14 +1,15 @@ """Legacy-style test for MultiHeadAttentionDecodePagedWithKVCacheOp (argparse + check + profile).""" -import argparse - +import pytest import torch from benchmarks.flash_decode import MultiHeadAttentionDecodePagedBenchmark from top.ops import MultiHeadAttentionDecodePagedWithKVCacheOp -from top.utils import str2dtype +@pytest.mark.parametrize("batch,heads,seqlen_q,seqlen_kv,dim,page_size,is_causal,dtype", [ + (1, 16, 1, 512, 128, 128, False, torch.float16), +]) def test_mha_decode_paged( batch: int, heads: int, @@ -26,38 +27,9 @@ def test_mha_decode_paged( page_size, is_causal, dtype) inputs = benchmark.gen_inputs() - benchmark.check(op, *inputs) + benchmark.check(op, *inputs, atol=2e-3, rtol=1e-5) benchmark.profile(op, *inputs) if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--batch", type=int, default=1, help="batch size") - parser.add_argument("--heads", type=int, default=16, help="num heads") - parser.add_argument("--seqlen_q", type=int, default=1, help="query sequence length") - parser.add_argument("--seqlen_kv", type=int, default=512, help="key/value sequence length") - parser.add_argument("--dim", type=int, default=128, help="head dim") - parser.add_argument("--page_size", type=int, default=128, help="page size") - parser.add_argument("--is_causal", action="store_true", default=False, help="causal mask") - parser.add_argument( - "--dtype", - type=str, - default="float16", - choices=["float16", "bfloat16"], - help="data type", - ) - parser.add_argument("--tune", action="store_true", default=False, help="enable autotune") - args = parser.parse_args() - - dtype = str2dtype[args.dtype] - test_mha_decode_paged( - args.batch, - args.heads, - args.seqlen_q, - args.seqlen_kv, - args.dim, - args.page_size, - args.is_causal, - dtype, - args.tune, - ) + pytest.main([__file__, "-vvs"]) diff --git a/tests/ops/test_mha_decode_pytest.py b/tests/ops/test_mha_decode_pytest.py deleted file mode 100644 index f93d029e..00000000 --- a/tests/ops/test_mha_decode_pytest.py +++ /dev/null @@ -1,36 +0,0 @@ -"""Pytest version of test_mha_decode: parametrized correctness check for MultiHeadAttentionDecodeWithKVCacheOp.""" - -import pytest -import torch - -from benchmarks import MultiHeadAttentionDecodeBenchmark -from top.ops import MultiHeadAttentionDecodeWithKVCacheOp - - -@pytest.fixture(autouse=True) -def setup() -> None: - torch.manual_seed(12345) - - -@pytest.mark.parametrize( - ("batch", "heads", "seq_len_q", "seq_len_kv", "dim", "dtype", "tune"), - [ - (1, 32, 128, 8192, 128, torch.float16, False), - (1, 32, 128, 8192, 128, torch.bfloat16, False), - (1, 32, 128, 5, 128, torch.float16, False), - ], -) -def test_mha_decode_op( - batch: int, - heads: int, - seq_len_q: int, - seq_len_kv: int, - dim: int, - dtype: torch.dtype, - tune: bool, -) -> None: - op = MultiHeadAttentionDecodeWithKVCacheOp( - batch, heads, seq_len_q, seq_len_kv, dim, dtype, tune=tune) - benchmark = MultiHeadAttentionDecodeBenchmark(batch, heads, seq_len_q, seq_len_kv, dim, dtype) - inputs = benchmark.gen_inputs() - benchmark.check(op, *inputs) diff --git a/tests/ops/test_mhc_post.py b/tests/ops/test_mhc_post.py index 97f661d8..8a5d91b1 100644 --- a/tests/ops/test_mhc_post.py +++ b/tests/ops/test_mhc_post.py @@ -6,12 +6,6 @@ from top.ops import ManifoldConstrainedHyperConnectionPostOp -@pytest.fixture(autouse=True) -def setup() -> None: - """Set up the test environment.""" - torch.manual_seed(12345) - - @pytest.mark.parametrize( ("batch, n_expand, c_x, dtype, tune"), [ @@ -43,3 +37,7 @@ def test_mhc_post_op( cos_sim_x_out = torch.nn.functional.cosine_similarity(x_out_ref, x_out, dim=-1, eps=1e-8) assert cos_sim_x_out.min() > 0.99 + + +if __name__ == "__main__": + pytest.main([__file__, "-vvs"]) diff --git a/tests/ops/test_mhc_pre.py b/tests/ops/test_mhc_pre.py index 15eab89b..7ce023dd 100644 --- a/tests/ops/test_mhc_pre.py +++ b/tests/ops/test_mhc_pre.py @@ -8,12 +8,6 @@ from top.ops import ManifoldConstrainedHyperConnectionPreOp -@pytest.fixture(autouse=True) -def setup() -> None: - """Set up the test environment.""" - torch.manual_seed(1235) - - @pytest.mark.parametrize( ("batch, n_expand, c_x, dtype, tune"), [ @@ -94,3 +88,7 @@ def test_mhc_pre_op( cos_sim_x_layer = torch.nn.functional.cosine_similarity(x_layer_ref, x_layer, dim=-1, eps=1e-8) assert cos_sim_x_layer.min() > 0.99 + + +if __name__ == "__main__": + pytest.main([__file__, "-vvs"]) diff --git a/tests/ops/test_topk_selector.py b/tests/ops/test_topk_selector.py index 95d7dff4..7350575f 100644 --- a/tests/ops/test_topk_selector.py +++ b/tests/ops/test_topk_selector.py @@ -1,4 +1,5 @@ import pytest + from benchmarks import TopkSelectorBenchmark from top.ops import TopkSelectorOp from top.utils import str2dtype @@ -25,7 +26,4 @@ def test_topk_selector_op(batch: int, seq_len: int, topk: int, in_dtype: str, ou if __name__ == "__main__": - test_topk_selector_op(64, 32 * 1024, 1024, "float32", "int32", False) - test_topk_selector_op(64, 32 * 1024, 2048, "float32", "int32", False) - test_topk_selector_op(128, 64 * 1024, 1024, "float32", "int32", False) - test_topk_selector_op(128, 64 * 1024, 2048, "float32", "int32", False) + pytest.main([__file__, "-vvs"]) diff --git a/tests/test_autotune.py b/tests/test_autotune.py index ff417322..6f619d8f 100644 --- a/tests/test_autotune.py +++ b/tests/test_autotune.py @@ -1,10 +1,16 @@ -import argparse +import pytest +import torch from top.ops import MultiHeadAttentionFwdOp -from top.utils import str2dtype -def test_mha_kernel_autotune(B, S, H, D, causal, dtype): +@pytest.mark.parametrize( + "B, S, H, D, causal, dtype", + [ + (8, 1024, 32, 128, False, torch.float16), + ], +) +def test_mha_kernel_autotune(B: int, S: int, H: int, D: int, causal: bool, dtype: torch.dtype): # 1. test autotune at initialization op = MultiHeadAttentionFwdOp(B, H, S, D, causal, dtype, tune=True) @@ -13,15 +19,4 @@ def test_mha_kernel_autotune(B, S, H, D, causal, dtype): if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--seq_len', type=int, default=1024, help='sequence length') - parser.add_argument('--heads', type=int, default=32, help='num heads') - parser.add_argument('--dim', type=int, default=128, help='head dim') - parser.add_argument('--causal', action='store_true', default=False, help='causal attention') - parser.add_argument( - '--dtype', type=str, default='float16', choices=['float16', 'bfloat16'], help='data type') - args = parser.parse_args() - - test_mha_kernel_autotune(args.batch, args.seq_len, args.heads, args.dim, args.causal, - str2dtype[args.dtype]) + pytest.main([__file__, "-vvs"]) diff --git a/tests/test_compile.py b/tests/test_compile.py index 5c0e5d12..4c5a3eca 100644 --- a/tests/test_compile.py +++ b/tests/test_compile.py @@ -1,38 +1,33 @@ # This test validates the compatibility of TileOps operators with torch.compile(). # Check: https://docs.pytorch.org/tutorials/advanced/python_custom_ops.html -import argparse - +import pytest import torch from benchmarks import MultiHeadAttentionFwdBenchmark from top.ops import MultiHeadAttentionFwdOp -from top.utils import str2dtype -def test_mha_kernel_compile(B, S, H, D, causal, dtype): +@pytest.mark.parametrize( + "B, S, H, D, causal, dtype", + [ + (8, 1024, 32, 128, False, torch.float16), + (4, 512, 16, 64, True, torch.bfloat16), + (2, 2048, 64, 128, False, torch.float16), + ], +) +def test_mha_kernel_compile(B: int, S: int, H: int, D: int, causal: bool, dtype: torch.dtype): op = MultiHeadAttentionFwdOp(B, H, S, D, causal, dtype) benchmark = MultiHeadAttentionFwdBenchmark(B, H, S, D, causal, dtype) compiled_op = torch.compile(op, fullgraph=True) inputs = benchmark.gen_inputs() benchmark.check( - compiled_op, *inputs, atol=3e-4, rtol=1e-5) # will throw an error if not compatible + compiled_op, *inputs, atol=5e-3, rtol=1e-5) # will throw an error if not compatible benchmark.profile(compiled_op, *inputs) print('Successfully validate the compatibility with torch.compile().✅') if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--seq_len', type=int, default=1024, help='sequence length') - parser.add_argument('--heads', type=int, default=32, help='num heads') - parser.add_argument('--dim', type=int, default=128, help='head dim') - parser.add_argument('--causal', action='store_true', default=False, help='causal attention') - parser.add_argument( - '--dtype', type=str, default='float16', choices=['float16', 'bfloat16'], help='data type') - args = parser.parse_args() - - test_mha_kernel_compile(args.batch, args.seq_len, args.heads, args.dim, args.causal, - str2dtype[args.dtype]) + pytest.main([__file__, "-vvs"]) diff --git a/tests/test_gemm_torch.py b/tests/test_gemm_torch.py index 7a5792bd..1e4edd64 100644 --- a/tests/test_gemm_torch.py +++ b/tests/test_gemm_torch.py @@ -1,6 +1,6 @@ -import argparse import time +import pytest import torch import torch.nn as nn @@ -9,7 +9,13 @@ def calculate_gemm_flops(M, N, K): return 2.0 * M * N * K -def benchmark_pytorch_gemm(M, N, K, dtype, num_iter=100): +@pytest.mark.parametrize( + "M, N, K, dtype, num_iter", + [ + (16384, 8192, 13824, torch.float16, 100), + ], +) +def test_pytorch_gemm(M: int, N: int, K: int, dtype, num_iter: int): device = 'cuda' A = torch.randn(M, K, device=device, dtype=dtype) B = torch.randn(K, N, device=device, dtype=dtype) @@ -29,7 +35,13 @@ def benchmark_pytorch_gemm(M, N, K, dtype, num_iter=100): return elapsed_time, tflops, flops -def benchmark_cublas_gemm(M, N, K, dtype, num_iter=100): +@pytest.mark.parametrize( + "M, N, K, dtype, num_iter", + [ + (16384, 8192, 13824, torch.float16, 100), + ], +) +def test_cublas_gemm(M: int, N: int, K: int, dtype, num_iter: int): device = 'cuda' linear = nn.Linear(K, N, bias=False).to(device).to(dtype) input_tensor = torch.randn(M, K, device=device, dtype=dtype) @@ -50,30 +62,4 @@ def benchmark_cublas_gemm(M, N, K, dtype, num_iter=100): if __name__ == "__main__": - parser = argparse.ArgumentParser(description='GEMM Performance Benchmark') - parser.add_argument('--M', type=int, default=16384, help='Matrix A rows') - parser.add_argument('--N', type=int, default=8192, help='Matrix B columns') - parser.add_argument('--K', type=int, default=13824, help='Matrix A columns / Matrix B rows') - parser.add_argument( - '--dtype', - type=str, - default='float16', - choices=['float16', 'float32', 'bfloat16'], - help='Data type') - args = parser.parse_args() - dtype_map = {'float16': torch.float16, 'float32': torch.float32, 'bfloat16': torch.bfloat16} - M = args.M - N = args.N - K = args.K - dtype = dtype_map[args.dtype] - print("=" * 60) - print("GEMM Performance Benchmark") - print("=" * 60) - print("Configuration:") - print(f" M: {M}, N: {N}, K: {K}") - print(f" Data type: {dtype}") - base_time, base_tflops, flops = benchmark_pytorch_gemm(M, N, K, dtype) - print("\nPyTorch torch.matmul:") - print(f" Time: {base_time * 1000:.4f} ms") - print(f" Performance: {base_tflops:.2f} TFLOPS") - print(f" Total FLOPs: {flops / 1e12:.2f} TFLOPs") + pytest.main([__file__, "-vvs"]) diff --git a/tests/test_gemm_triton.py b/tests/test_gemm_triton.py index 0aa55f25..a26231b5 100644 --- a/tests/test_gemm_triton.py +++ b/tests/test_gemm_triton.py @@ -1,6 +1,6 @@ -import argparse import time +import pytest import torch import triton import triton.language as tl @@ -79,7 +79,13 @@ def calculate_gemm_flops(M, N, K): return 2.0 * M * N * K -def benchmark_triton_gemm_fp16(M, N, K, dtype, num_iter=100): +@pytest.mark.parametrize( + "M, N, K, dtype, num_iter", + [ + (4096, 4864, 8192, torch.float16, 100), + ], +) +def test_benchmark_triton_gemm_fp16(M: int, N: int, K: int, dtype, num_iter: int): device = 'cuda' A = torch.randn(M, K, device=device, dtype=dtype) B = torch.randn(K, N, device=device, dtype=dtype) @@ -138,7 +144,13 @@ def benchmark_triton_gemm_fp16(M, N, K, dtype, num_iter=100): return results -def verify_triton_gemm_fp16(M, N, K, dtype): +@pytest.mark.parametrize( + "M, N, K, dtype", + [ + (512, 512, 512, torch.float16), + ], +) +def test_verify_triton_gemm_fp16(M: int, N: int, K: int, dtype): print("Verifying Triton GEMM correctness (fp16 accumulation)...") device = 'cuda' A = torch.randn(M, K, device=device, dtype=dtype) @@ -158,30 +170,4 @@ def verify_triton_gemm_fp16(M, N, K, dtype): if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Triton GEMM performance test - fp16 accumulation') - parser.add_argument('--M', type=int, default=4096, help='Number of rows in matrix A') - parser.add_argument('--N', type=int, default=4864, help='Number of columns in matrix B') - parser.add_argument( - '--K', type=int, default=8192, help='Number of columns in matrix A / rows in matrix B') - parser.add_argument( - '--dtype', - type=str, - default='float16', - choices=['float16'], - help='Data type (only float16 supported)') - parser.add_argument('--verify', action='store_true', help='Verify correctness') - args = parser.parse_args() - dtype = torch.float16 - M = args.M - N = args.N - K = args.K - print("Triton GEMM standalone performance test (fp16 computation and accumulation)") - print("=" * 60) - print(f"Matrix dimensions: A[{M}, {K}] × B[{K}, {N}] = C[{M}, {N}]") - print(f"Data type: {dtype} (fp16 computation and accumulation)") - print(f"Total computation: {calculate_gemm_flops(M, N, K) / 1e12:.2f} TFLOPs") - print() - if args.verify: - verify_triton_gemm_fp16(M, N, K, dtype) - print() - benchmark_triton_gemm_fp16(M, N, K, dtype) + pytest.main([__file__, "-vvs"]) diff --git a/tests/test_grouped_gemm_torch.py b/tests/test_grouped_gemm_torch.py index f7089d79..91beda8a 100644 --- a/tests/test_grouped_gemm_torch.py +++ b/tests/test_grouped_gemm_torch.py @@ -1,5 +1,6 @@ import time +import pytest import torch @@ -82,7 +83,13 @@ def benchmark_single(gemm, a, b, batch_sizes, num_iter=100): return (time.time() - start_time) / num_iter -def test_all_grouped_gemm(batch_sum=4096, batch_count=4, k=8192, n=4864, dtype=torch.float16): +@pytest.mark.parametrize( + "batch_sum, batch_count, k, n, dtype", + [ + (4096, 4, 8192, 4864, torch.float16), + ], +) +def test_all_grouped_gemm(batch_sum, batch_count, k, n, dtype): print("=" * 70) print("PyTorch Grouped GEMM Performance Test") print("=" * 70) @@ -191,4 +198,4 @@ def test_all_grouped_gemm(batch_sum=4096, batch_count=4, k=8192, n=4864, dtype=t if __name__ == "__main__": - test_all_grouped_gemm() + pytest.main([__file__, "-vvs"]) diff --git a/tests/test_grouped_gemm_triton.py b/tests/test_grouped_gemm_triton.py index d8032ae0..1ecc2373 100644 --- a/tests/test_grouped_gemm_triton.py +++ b/tests/test_grouped_gemm_triton.py @@ -1,7 +1,9 @@ import argparse + import math import time +import pytest import torch import triton import triton.language as tl @@ -829,7 +831,13 @@ def calculate_flops_tt(batch_sizes, K, N): return 2.0 * sum(size * N * K for size in batch_sizes) -def test_grouped_gemm_nt(batch_sum: int, batch_count: int, K: int, N: int, dtype=torch.float16): +@pytest.mark.parametrize( + "batch_sum, batch_count, K, N, dtype", + [ + (16384, 4, 8192, 13824, torch.float16), + ], +) +def test_grouped_gemm_nt(batch_sum: int, batch_count: int, K: int, N: int, dtype: torch.dtype): print("Testing grouped_gemm_nt (forward)...") inputs = prepare_nt_inputs(batch_sum, batch_count, K, N, dtype) @@ -844,7 +852,13 @@ def test_grouped_gemm_nt(batch_sum: int, batch_count: int, K: int, N: int, dtype return success -def test_grouped_gemm_nn(batch_sum: int, batch_count: int, K: int, N: int, dtype=torch.float16): +@pytest.mark.parametrize( + "batch_sum, batch_count, K, N, dtype", + [ + (16384, 4, 8192, 13824, torch.float16), + ], +) +def test_grouped_gemm_nn(batch_sum: int, batch_count: int, K: int, N: int, dtype: torch.dtype): print("\nTesting grouped_gemm_nn (backward dA)...") inputs = prepare_nn_inputs(batch_sum, batch_count, K, N, dtype) @@ -863,7 +877,13 @@ def test_grouped_gemm_nn(batch_sum: int, batch_count: int, K: int, N: int, dtype return success -def test_grouped_gemm_tn(batch_sum: int, batch_count: int, K: int, N: int, dtype=torch.float16): +@pytest.mark.parametrize( + "batch_sum, batch_count, K, N, dtype", + [ + (16384, 4, 8192, 13824, torch.float16), + ], +) +def test_grouped_gemm_tn(batch_sum: int, batch_count: int, K: int, N: int, dtype: torch.dtype): print("\nTesting grouped_gemm_tn (backward dB)...") inputs = prepare_tn_inputs(batch_sum, batch_count, K, N, dtype) @@ -878,8 +898,14 @@ def test_grouped_gemm_tn(batch_sum: int, batch_count: int, K: int, N: int, dtype return success -def test_grouped_gemm_tt(batch_sum: int, batch_count: int, K: int, N: int, dtype=torch.float16): - print("\nTesting grouped_gemm_tn (backward dB)...") +@pytest.mark.parametrize( + "batch_sum, batch_count, K, N, dtype", + [ + (16384, 4, 8192, 13824, torch.float16), + ], +) +def test_grouped_gemm_tt(batch_sum: int, batch_count: int, K: int, N: int, dtype: torch.dtype): + print("\nTesting grouped_gemm_tt (backward dB)...") inputs = prepare_tt_inputs(batch_sum, batch_count, K, N, dtype) batch_sizes = inputs[2] @@ -921,4 +947,4 @@ def main(): if __name__ == "__main__": - main() + pytest.main([__file__, "-vvs"])