diff --git a/requirements.txt b/requirements.txt index 4b91af08827..df12ac4266a 100755 --- a/requirements.txt +++ b/requirements.txt @@ -27,5 +27,5 @@ ROCm/half@rocm-5.6.0 pybind/pybind11@3e9dfa2866941655c56877882565e7577de6fc7b --build msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On -ROCm/composable_kernel@b7775add2d28251674d81e220cd4a857b90b997a -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On +/home/mhalilce/composable_kernel --cmake subdir -DCMAKE_DIR=codegen -DCMAKE_POSITION_INDEPENDENT_CODE=On ROCm/rocMLIR@ac8d49a80e8cb7c4f0c11417a458a0de7c7d02b1 -DBUILD_FAT_LIBROCKCOMPILER=On diff --git a/src/targets/gpu/CMakeLists.txt b/src/targets/gpu/CMakeLists.txt index 82cc1fb0a3c..b5a8e37b793 100644 --- a/src/targets/gpu/CMakeLists.txt +++ b/src/targets/gpu/CMakeLists.txt @@ -57,7 +57,11 @@ else() endif() if(MIGRAPHX_USE_COMPOSABLEKERNEL) - find_package(composable_kernel 1.0.0 REQUIRED COMPONENTS jit_library) + find_package(composable_kernel_host 1.0.0 REQUIRED) + if(NOT TARGET composable_kernel::ck_host) + # Manually including targets + include(${composable_kernel_host_TARGET_FILE}) + endif() endif() if(BUILD_DEV) @@ -122,7 +126,7 @@ target_compile_definitions(kernel_file_check PRIVATE -DMIGRAPHX_WAVEFRONTSIZE=64 target_include_directories(kernel_file_check PRIVATE $) target_link_libraries(kernel_file_check compile_for_gpu) if(MIGRAPHX_USE_COMPOSABLEKERNEL) - target_link_libraries(kernel_file_check composable_kernel::jit_library) + target_link_libraries(kernel_file_check composable_kernel::ck_host) endif() rocm_clang_tidy_check(kernel_file_check) @@ -392,7 +396,7 @@ else() endif() target_link_libraries(migraphx_gpu PRIVATE migraphx_kernels) if(MIGRAPHX_USE_COMPOSABLEKERNEL) - target_link_libraries(migraphx_gpu PRIVATE composable_kernel::jit_library) + target_link_libraries(migraphx_gpu PRIVATE composable_kernel::ck_host) target_compile_definitions(migraphx_gpu PRIVATE MIGRAPHX_USE_COMPOSABLEKERNEL=1) endif() diff --git a/src/targets/gpu/fuse_ck.cpp b/src/targets/gpu/fuse_ck.cpp index bf9a269f3e1..c60ca163a45 100644 --- a/src/targets/gpu/fuse_ck.cpp +++ b/src/targets/gpu/fuse_ck.cpp @@ -90,11 +90,11 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins) return false; if(not ck_gemm::is_ck_supported_type(ins->get_shape().type())) return false; - auto a = ins->inputs().front()->get_shape(); - auto b = ins->inputs().back()->get_shape(); - auto m = a.lens()[a.lens().size() - 2]; - auto n = b.lens().back(); - auto k = a.lens().back(); + auto a = ins->inputs().front()->get_shape(); + auto b = ins->inputs().back()->get_shape(); + auto m = a.lens()[a.lens().size() - 2]; + auto n = b.lens().back(); + auto k = a.lens().back(); auto batch_size = std::accumulate( a.lens().rbegin() + 2, a.lens().rend(), std::size_t{1}, std::multiplies()); // Integer gemms must be divisible by 4 in ck @@ -118,7 +118,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins) } return true; } - return k <= 2048; + return k <= 1024; } struct find_ck_gemm_pointwise @@ -207,7 +207,8 @@ struct find_ck_gemm_softmax_gemm void fuse_ck::apply(module_pass_manager& mpm) const { - match::find_matches(mpm, find_ck_gemm_softmax_gemm{}, find_ck_gemm_pointwise{}); + match::find_matches(mpm, find_ck_gemm_softmax_gemm{}); + match::find_matches(mpm, find_ck_gemm_pointwise{}); match::find_matches(mpm, find_ck_gemm{}); } diff --git a/src/targets/gpu/include/migraphx/gpu/ck.hpp b/src/targets/gpu/include/migraphx/gpu/ck.hpp index 18d4dce25a2..ea41b252547 100644 --- a/src/targets/gpu/include/migraphx/gpu/ck.hpp +++ b/src/targets/gpu/include/migraphx/gpu/ck.hpp @@ -30,8 +30,9 @@ #include #include -#include "ck/host/device_gemm_multiple_d.hpp" -#include "ck/host/device_batched_gemm_softmax_gemm.hpp" +#include "ck/host/headers.hpp" +#include "ck/host/device_gemm_multiple_d/problem.hpp" +#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp" namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { diff --git a/src/targets/gpu/jit/ck_gemm.cpp b/src/targets/gpu/jit/ck_gemm.cpp index 392eaa0c67b..23376d6c666 100644 --- a/src/targets/gpu/jit/ck_gemm.cpp +++ b/src/targets/gpu/jit/ck_gemm.cpp @@ -37,6 +37,7 @@ #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -136,15 +137,19 @@ struct ck_gemm_compiler : compiler { const auto& c_shape = inputs.back(); auto tuning_value = v.get("tuning_value", 34); - auto batch_count = get_batch_count(c_shape); - auto problem = create_problem(inputs, v); + auto batch_count = get_batch_count(c_shape); + auto problem = create_problem(inputs, v); - const auto include_header = problem.GetIncludeHeader(); - const auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name()); + const auto include_header = problem.GetIncludeHeader(); + const auto solutions = + problem.GetSolutions(ctx.get_current_device().get_gfx_name(), "", ""); const auto& solution = solutions.at(tuning_value); - const auto template_str = solution.template_str; - const auto blocks_per_batch = solution.grid_size; - const auto block_size = solution.block_size; + const auto template_str = solution.ToTemplateString(); + const auto block_size = solution.GetTemplateParameter("BlockSize"); + const auto m_per_block = solution.GetTemplateParameter("MPerBlock"); + const auto n_per_block = solution.GetTemplateParameter("NPerBlock"); + const auto blocks_per_batch = ck::host::integer_divide_ceil(problem.M, m_per_block) * + ck::host::integer_divide_ceil(problem.N, n_per_block); hip_compile_options options; options.additional_src_files = ck_headers(); @@ -221,7 +226,7 @@ struct ck_gemm_compiler : compiler tuning_config tc; auto shapes = to_shapes(ins->inputs()); auto problem = create_problem(shapes, create_settings(ins, op)); - auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name()); + auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name(), "", ""); tc.solutions.resize(solutions.size()); std::iota(tc.solutions.begin(), tc.solutions.end(), 0); std::vector gemm_shapes{shapes[0], shapes[1], shapes.back()}; diff --git a/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp b/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp index 693153d0982..ad40d84161d 100644 --- a/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp +++ b/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp @@ -37,6 +37,7 @@ #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -134,12 +135,16 @@ struct ck_gemm_softmax_gemm_compiler : compiler auto batch_count = get_batch_count(c_shape); auto problem = create_problem(inputs, v); - const auto include_header = problem.GetIncludeHeader(); - const auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name()); + const auto include_header = problem.GetIncludeHeader(); + const auto solutions = + problem.GetSolutions(ctx.get_current_device().get_gfx_name(), "", ""); const auto& solution = solutions.at(tuning_value); - const auto template_str = solution.template_str; - const auto blocks_per_batch = solution.grid_size; - const auto block_size = solution.block_size; + const auto template_str = solution.ToTemplateString(); + const auto block_size = solution.GetTemplateParameter("BlockSize"); + const auto m_per_block = solution.GetTemplateParameter("Gemm01MPerBlock"); + const auto n1_per_block = solution.GetTemplateParameter("Gemm1NPerBlock"); + const auto blocks_per_batch = ck::host::integer_divide_ceil(problem.M, m_per_block) * + ck::host::integer_divide_ceil(problem.O, n1_per_block); hip_compile_options options; options.additional_src_files = ck_headers(); @@ -222,7 +227,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler tuning_config tc; auto shapes = to_shapes(ins->inputs()); auto problem = create_problem(shapes, create_settings(ins, op)); - auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name()); + auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name(), "", ""); tc.solutions.resize(solutions.size()); std::iota(tc.solutions.begin(), tc.solutions.end(), 0); std::vector gemm_shapes{shapes[0], shapes[1], shapes.back()};