diff --git a/requirements.txt b/requirements.txt
index 4b91af08827..df12ac4266a 100755
--- a/requirements.txt
+++ b/requirements.txt
@@ -27,5 +27,5 @@ ROCm/half@rocm-5.6.0
pybind/pybind11@3e9dfa2866941655c56877882565e7577de6fc7b --build
msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off
sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On
-ROCm/composable_kernel@b7775add2d28251674d81e220cd4a857b90b997a -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On
+/home/mhalilce/composable_kernel --cmake subdir -DCMAKE_DIR=codegen -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCm/rocMLIR@ac8d49a80e8cb7c4f0c11417a458a0de7c7d02b1 -DBUILD_FAT_LIBROCKCOMPILER=On
diff --git a/src/targets/gpu/CMakeLists.txt b/src/targets/gpu/CMakeLists.txt
index 82cc1fb0a3c..b5a8e37b793 100644
--- a/src/targets/gpu/CMakeLists.txt
+++ b/src/targets/gpu/CMakeLists.txt
@@ -57,7 +57,11 @@ else()
endif()
if(MIGRAPHX_USE_COMPOSABLEKERNEL)
- find_package(composable_kernel 1.0.0 REQUIRED COMPONENTS jit_library)
+ find_package(composable_kernel_host 1.0.0 REQUIRED)
+ if(NOT TARGET composable_kernel::ck_host)
+ # Manually including targets
+ include(${composable_kernel_host_TARGET_FILE})
+ endif()
endif()
if(BUILD_DEV)
@@ -122,7 +126,7 @@ target_compile_definitions(kernel_file_check PRIVATE -DMIGRAPHX_WAVEFRONTSIZE=64
target_include_directories(kernel_file_check PRIVATE $)
target_link_libraries(kernel_file_check compile_for_gpu)
if(MIGRAPHX_USE_COMPOSABLEKERNEL)
- target_link_libraries(kernel_file_check composable_kernel::jit_library)
+ target_link_libraries(kernel_file_check composable_kernel::ck_host)
endif()
rocm_clang_tidy_check(kernel_file_check)
@@ -392,7 +396,7 @@ else()
endif()
target_link_libraries(migraphx_gpu PRIVATE migraphx_kernels)
if(MIGRAPHX_USE_COMPOSABLEKERNEL)
- target_link_libraries(migraphx_gpu PRIVATE composable_kernel::jit_library)
+ target_link_libraries(migraphx_gpu PRIVATE composable_kernel::ck_host)
target_compile_definitions(migraphx_gpu PRIVATE MIGRAPHX_USE_COMPOSABLEKERNEL=1)
endif()
diff --git a/src/targets/gpu/fuse_ck.cpp b/src/targets/gpu/fuse_ck.cpp
index bf9a269f3e1..c60ca163a45 100644
--- a/src/targets/gpu/fuse_ck.cpp
+++ b/src/targets/gpu/fuse_ck.cpp
@@ -90,11 +90,11 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
return false;
if(not ck_gemm::is_ck_supported_type(ins->get_shape().type()))
return false;
- auto a = ins->inputs().front()->get_shape();
- auto b = ins->inputs().back()->get_shape();
- auto m = a.lens()[a.lens().size() - 2];
- auto n = b.lens().back();
- auto k = a.lens().back();
+ auto a = ins->inputs().front()->get_shape();
+ auto b = ins->inputs().back()->get_shape();
+ auto m = a.lens()[a.lens().size() - 2];
+ auto n = b.lens().back();
+ auto k = a.lens().back();
auto batch_size = std::accumulate(
a.lens().rbegin() + 2, a.lens().rend(), std::size_t{1}, std::multiplies());
// Integer gemms must be divisible by 4 in ck
@@ -118,7 +118,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
}
return true;
}
- return k <= 2048;
+ return k <= 1024;
}
struct find_ck_gemm_pointwise
@@ -207,7 +207,8 @@ struct find_ck_gemm_softmax_gemm
void fuse_ck::apply(module_pass_manager& mpm) const
{
- match::find_matches(mpm, find_ck_gemm_softmax_gemm{}, find_ck_gemm_pointwise{});
+ match::find_matches(mpm, find_ck_gemm_softmax_gemm{});
+ match::find_matches(mpm, find_ck_gemm_pointwise{});
match::find_matches(mpm, find_ck_gemm{});
}
diff --git a/src/targets/gpu/include/migraphx/gpu/ck.hpp b/src/targets/gpu/include/migraphx/gpu/ck.hpp
index 18d4dce25a2..ea41b252547 100644
--- a/src/targets/gpu/include/migraphx/gpu/ck.hpp
+++ b/src/targets/gpu/include/migraphx/gpu/ck.hpp
@@ -30,8 +30,9 @@
#include
#include
-#include "ck/host/device_gemm_multiple_d.hpp"
-#include "ck/host/device_batched_gemm_softmax_gemm.hpp"
+#include "ck/host/headers.hpp"
+#include "ck/host/device_gemm_multiple_d/problem.hpp"
+#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp"
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
diff --git a/src/targets/gpu/jit/ck_gemm.cpp b/src/targets/gpu/jit/ck_gemm.cpp
index 392eaa0c67b..23376d6c666 100644
--- a/src/targets/gpu/jit/ck_gemm.cpp
+++ b/src/targets/gpu/jit/ck_gemm.cpp
@@ -37,6 +37,7 @@
#include
#include
#include
+#include
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
@@ -136,15 +137,19 @@ struct ck_gemm_compiler : compiler
{
const auto& c_shape = inputs.back();
auto tuning_value = v.get("tuning_value", 34);
- auto batch_count = get_batch_count(c_shape);
- auto problem = create_problem(inputs, v);
+ auto batch_count = get_batch_count(c_shape);
+ auto problem = create_problem(inputs, v);
- const auto include_header = problem.GetIncludeHeader();
- const auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name());
+ const auto include_header = problem.GetIncludeHeader();
+ const auto solutions =
+ problem.GetSolutions(ctx.get_current_device().get_gfx_name(), "", "");
const auto& solution = solutions.at(tuning_value);
- const auto template_str = solution.template_str;
- const auto blocks_per_batch = solution.grid_size;
- const auto block_size = solution.block_size;
+ const auto template_str = solution.ToTemplateString();
+ const auto block_size = solution.GetTemplateParameter("BlockSize");
+ const auto m_per_block = solution.GetTemplateParameter("MPerBlock");
+ const auto n_per_block = solution.GetTemplateParameter("NPerBlock");
+ const auto blocks_per_batch = ck::host::integer_divide_ceil(problem.M, m_per_block) *
+ ck::host::integer_divide_ceil(problem.N, n_per_block);
hip_compile_options options;
options.additional_src_files = ck_headers();
@@ -221,7 +226,7 @@ struct ck_gemm_compiler : compiler
tuning_config tc;
auto shapes = to_shapes(ins->inputs());
auto problem = create_problem(shapes, create_settings(ins, op));
- auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name());
+ auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name(), "", "");
tc.solutions.resize(solutions.size());
std::iota(tc.solutions.begin(), tc.solutions.end(), 0);
std::vector gemm_shapes{shapes[0], shapes[1], shapes.back()};
diff --git a/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp b/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp
index 693153d0982..ad40d84161d 100644
--- a/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp
+++ b/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp
@@ -37,6 +37,7 @@
#include
#include
#include
+#include
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
@@ -134,12 +135,16 @@ struct ck_gemm_softmax_gemm_compiler : compiler
auto batch_count = get_batch_count(c_shape);
auto problem = create_problem(inputs, v);
- const auto include_header = problem.GetIncludeHeader();
- const auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name());
+ const auto include_header = problem.GetIncludeHeader();
+ const auto solutions =
+ problem.GetSolutions(ctx.get_current_device().get_gfx_name(), "", "");
const auto& solution = solutions.at(tuning_value);
- const auto template_str = solution.template_str;
- const auto blocks_per_batch = solution.grid_size;
- const auto block_size = solution.block_size;
+ const auto template_str = solution.ToTemplateString();
+ const auto block_size = solution.GetTemplateParameter("BlockSize");
+ const auto m_per_block = solution.GetTemplateParameter("Gemm01MPerBlock");
+ const auto n1_per_block = solution.GetTemplateParameter("Gemm1NPerBlock");
+ const auto blocks_per_batch = ck::host::integer_divide_ceil(problem.M, m_per_block) *
+ ck::host::integer_divide_ceil(problem.O, n1_per_block);
hip_compile_options options;
options.additional_src_files = ck_headers();
@@ -222,7 +227,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler
tuning_config tc;
auto shapes = to_shapes(ins->inputs());
auto problem = create_problem(shapes, create_settings(ins, op));
- auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name());
+ auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name(), "", "");
tc.solutions.resize(solutions.size());
std::iota(tc.solutions.begin(), tc.solutions.end(), 0);
std::vector gemm_shapes{shapes[0], shapes[1], shapes.back()};