Skip to content

Conversation

@arai713
Copy link
Contributor

@arai713 arai713 commented Oct 17, 2025

Proposed changes

This PR is expanding the supported datatypes for Stream-K Gemm by adding examples for fp8 and bf8. Previously examples for fp16 and bf 16 were added through this PR. Currently these examples only support atomic reduction. Unit tests for fp8 and bf8 have been added through a test suite.

Checklist

Please put an x into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.

  • I have added tests relevant to the introduced functionality, and the unit tests are passing locally
  • I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, IF the test takes more than 30 seconds to run.
  • I have added inline documentation which enables the maintainers with understanding the motivation
  • I have removed the stale documentation which is no longer relevant after this pull request
  • (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request
  • I have run clang-format on all changed files
  • Any dependent changes have been merged

Discussion

If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered

Copy link
Collaborator

@ecamartins ecamartins Oct 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to overwrite the existing types, or just add a new type for BF8 and FP8? (Right now they are overwritten)

Comment on lines +23 to +26
std::tuple< Row, Col, Row, F16, F16, F32, F16>,
std::tuple< Row, Col, Row, F8, F8, F32, F16>,
std::tuple< Row, Col, Row, BF8, BF8, F32, F16>,
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming we don't want to overwrite the existing smoke tests, I think we can remove FP16 and BF16 from these types. Also, can we please add RRR, CRR, and CCR for FP8 and BF8?

using Col = ck_tile::tensor_layout::gemm::ColumnMajor;

// clang-format off
using KernelTypesStreamK = ::testing::Types<
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is a duplicate of KernelTypesStreamK in test/ck_tile/gemm_streamk/test_gemm_streamk_types.hpp.

const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, /*kbatch*/ 1, max_accumulated_value);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may need to use estimate_num_wgs_per_tile here because for some test cases, we have >1 workgroup atomically adding to the same C macro tile. This could lead to round-off error and result in failing tests in a sporadic nature.

@@ -0,0 +1,282 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a test harness already in test_gemm_streamk.hpp, I don't think we need to duplicate it here

using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;

using Mem = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Mem>;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are types needed for the rest of the smoke tests


include_directories(BEFORE ${CMAKE_CURRENT_SOURCE_DIR})

add_gtest_executable(test_ck_tile_streamk test_gemm_streamk_fp8_bf8.cpp)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should look at adding the fp8 tests to the smoke test

Copy link
Collaborator

@cgmillette cgmillette left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to consolidate new fp8 tests into smoke tests with existing test harness.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants