Skip to content

Commit 6418393

Browse files
committed
feat: add quant grouped gemm tests cases for aquant (regular and transpose C) and non-persistent kernel
1 parent 472fcc0 commit 6418393

File tree

7 files changed

+410
-68
lines changed

7 files changed

+410
-68
lines changed

test/ck_tile/grouped_gemm_quant/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
1414
add_gtest_executable(test_ck_tile_grouped_gemm_quant_tensor test_grouped_gemm_quant_tensor.cpp)
1515
target_compile_options(test_ck_tile_grouped_gemm_quant_tensor PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
1616

17+
add_gtest_executable(test_ck_tile_grouped_gemm_quant_aquant test_grouped_gemm_quant_aquant.cpp)
18+
target_compile_options(test_ck_tile_grouped_gemm_quant_aquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
19+
1720
add_gtest_executable(test_ck_tile_grouped_gemm_quant_bquant test_grouped_gemm_quant_bquant.cpp)
1821
target_compile_options(test_ck_tile_grouped_gemm_quant_bquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
1922
endif()

test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant.cpp

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,32 +18,41 @@ using True = ck_tile::bool_constant<true>;
1818
using False = ck_tile::bool_constant<false>;
1919
using RowColQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::RowColQuant>;
2020
using TensorQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::TensorQuant>;
21+
using AQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::AQuantGrouped>;
2122
using BQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
2223

2324
// clang-format off
2425
using KernelTypes = ::testing::Types<
25-
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB
26-
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>,
27-
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>,
28-
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>,
29-
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>,
30-
31-
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>,
32-
std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>,
33-
std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>,
34-
std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>,
35-
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>,
36-
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>,
37-
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>,
38-
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>,
39-
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>,
40-
std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>,
41-
std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>,
42-
std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>,
43-
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False>,
44-
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, False>,
45-
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, True>,
46-
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True>
26+
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC
27+
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>,
28+
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>,
29+
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>,
30+
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>,
31+
32+
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False, True, False>,
33+
std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False, True, False>,
34+
std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False, True, False>,
35+
std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False, True, False>,
36+
37+
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>,
38+
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>,
39+
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>,
40+
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>,
41+
42+
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False, True, False>,
43+
std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False, True, False>,
44+
std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False, True, False>,
45+
std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False, True, False>,
46+
47+
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, True>,
48+
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, False>,
49+
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuant, False, True, True>,
50+
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuant, False, True, False>,
51+
52+
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False, True, False>,
53+
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True, True, False>,
54+
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, False, True, False>,
55+
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, True, True, False>
4756
>;
4857
// clang-format on
4958

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2+
// SPDX-License-Identifier: MIT
3+
4+
#include <tuple>
5+
6+
#include "gtest/gtest.h"
7+
8+
#include "ck_tile/host.hpp"
9+
#include "test_grouped_gemm_util_quant.hpp"
10+
11+
using F16 = ck_tile::half_t;
12+
using F32 = float;
13+
using FP8 = ck_tile::fp8_t;
14+
using BF8 = ck_tile::bf8_t;
15+
using Row = ck_tile::tensor_layout::gemm::RowMajor;
16+
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
17+
using True = ck_tile::bool_constant<true>;
18+
using False = ck_tile::bool_constant<false>;
19+
using AQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::AQuantGrouped>;
20+
21+
// clang-format off
22+
using KernelTypes_AQuant = ::testing::Types<
23+
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC
24+
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, True>,
25+
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, False>,
26+
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuant, False, True, True>,
27+
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuant, False, True, False>,
28+
29+
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, False, True>,
30+
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, False, False>
31+
>;
32+
// clang-format on
33+
34+
TYPED_TEST_SUITE(TestCkTileGroupedGemmQuant_AQuant, KernelTypes_AQuant);
35+
36+
#define TEST_CLASS_NAME TestCkTileGroupedGemmQuant_AQuant
37+
#include "test_grouped_gemm_quant_ut_cases.inc"
38+
#undef TEST_CLASS_NAME

test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_bquant.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,14 @@ using BQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQ
2020

2121
// clang-format off
2222
using KernelTypes_BQuant = ::testing::Types<
23-
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB
24-
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False>,
25-
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True>
23+
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC
24+
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False, True, False>,
25+
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True, True, False>,
26+
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, False, True, False>,
27+
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, True, True, False>,
28+
29+
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False, False, False>,
30+
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True, False, False>
2631
>;
2732
// clang-format on
2833

test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_rowcol.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,14 @@ using RowColQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantTyp
2020

2121
// clang-format off
2222
using KernelTypes_RowCol = ::testing::Types<
23-
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB
24-
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>,
25-
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>,
26-
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>,
27-
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>
23+
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC
24+
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>,
25+
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>,
26+
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>,
27+
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>,
28+
29+
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, False, False>,
30+
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, False, False>
2831
>;
2932
// clang-format on
3033

test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_tensor.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,14 @@ using TensorQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantTyp
2020

2121
// clang-format off
2222
using KernelTypes_Tensor = ::testing::Types<
23-
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB
24-
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>,
25-
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>,
26-
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>,
27-
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>
23+
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC
24+
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>,
25+
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>,
26+
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>,
27+
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>,
28+
29+
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, False, False>,
30+
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, False, False>
2831
>;
2932
// clang-format on
3033

0 commit comments

Comments
 (0)