Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Addressing (Post Merge) code review comments for PR 1845 #1883

Open
wants to merge 7 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -351,12 +351,12 @@ def cmake_build(Map conf=[:]){
}
if (params.RUN_CK_TILE_GEMM_TESTS){
try{
archiveArtifacts "perf_tile_gemm_*.log"
archiveArtifacts "perf_tile_gemm_**.log"
if (arch_type == 1){
stash includes: "perf_tile_gemm_**_fp16_gfx90a.log", name: "perf_tile_gemm_log_gfx90a"
stash includes: "perf_tile_gemm_**_gfx90a.log", name: "perf_tile_gemm_log_gfx90a"
}
else if (arch_type == 2){
stash includes: "perf_tile_gemm_**_fp16_gfx942.log", name: "perf_tile_gemm_log_gfx942"
stash includes: "perf_tile_gemm_**_gfx942.log", name: "perf_tile_gemm_log_gfx942"
}
}
catch(Exception err){
Expand Down Expand Up @@ -795,8 +795,8 @@ pipeline {
description: "Run the ck_tile FMHA tests (default: OFF)")
booleanParam(
name: "RUN_CK_TILE_GEMM_TESTS",
defaultValue: true,
description: "Run the ck_tile GEMM tests (default: ON)")
defaultValue: false,
description: "Run the ck_tile GEMM tests (default: OFF)")
booleanParam(
name: "BUILD_INSTANCES_ONLY",
defaultValue: false,
Expand Down
56 changes: 31 additions & 25 deletions example/ck_tile/03_gemm/gemm_basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,45 +99,51 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&

#include "run_gemm_example.inc"

template <typename PrecType>
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;

if(a_layout == "R" && b_layout == "C")
{
return run_gemm_example_with_layouts<PrecType>(argc, argv, Row{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
}
}

int run_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;

using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;

std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");

if(a_layout == "R" && b_layout == "C")
if(data_type == "fp16")
{
return run_gemm_example_prec_type<ck_tile::half_t>(a_layout, b_layout, argc, argv);
}
else if(data_type == "bf16")
{
return run_gemm_example_prec_type<ck_tile::bf16_t>(a_layout, b_layout, argc, argv);
}
else if(data_type == "fp8")
{
if(data_type == "fp16")
{
return run_gemm_example_with_layouts<ck_tile::half_t>(argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "bf16")
{
return run_gemm_example_with_layouts<ck_tile::bf16_t>(argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "fp8")
{
return run_gemm_example_with_layouts<ck_tile::fp8_t>(argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "bf8")
{
return run_gemm_example_with_layouts<ck_tile::bf8_t>(argc, argv, Row{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data_type!");
}
return run_gemm_example_prec_type<ck_tile::fp8_t>(a_layout, b_layout, argc, argv);
}
else if(data_type == "bf8")
{
return run_gemm_example_prec_type<ck_tile::bf8_t>(a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
throw std::runtime_error("Unsupported Precision Type for the input arguments !");
}
}

Expand Down
10 changes: 5 additions & 5 deletions example/ck_tile/03_gemm/gemm_basic.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@
#endif

template <typename DataType>
struct GemmBasicTypeConfig;
struct GemmTypeConfig;

template <>
struct GemmBasicTypeConfig<ck_tile::half_t>
struct GemmTypeConfig<ck_tile::half_t>
{
using ADataType = ck_tile::half_t;
using BDataType = ck_tile::half_t;
Expand All @@ -49,7 +49,7 @@ struct GemmBasicTypeConfig<ck_tile::half_t>
};

template <>
struct GemmBasicTypeConfig<ck_tile::bf16_t>
struct GemmTypeConfig<ck_tile::bf16_t>
{
using ADataType = ck_tile::bf16_t;
using BDataType = ck_tile::bf16_t;
Expand All @@ -58,7 +58,7 @@ struct GemmBasicTypeConfig<ck_tile::bf16_t>
};

template <>
struct GemmBasicTypeConfig<ck_tile::fp8_t>
struct GemmTypeConfig<ck_tile::fp8_t>
{
using ADataType = ck_tile::fp8_t;
using BDataType = ck_tile::fp8_t;
Expand All @@ -67,7 +67,7 @@ struct GemmBasicTypeConfig<ck_tile::fp8_t>
};

template <>
struct GemmBasicTypeConfig<ck_tile::bf8_t>
struct GemmTypeConfig<ck_tile::bf8_t>
{
using ADataType = ck_tile::bf8_t;
using BDataType = ck_tile::bf8_t;
Expand Down
8 changes: 4 additions & 4 deletions example/ck_tile/03_gemm/run_gemm_example.inc
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,10 @@ int run_gemm_example_with_layouts(int argc,
if(!result)
return -1;

using ADataType = typename GemmBasicTypeConfig<PrecType>::ADataType;
using BDataType = typename GemmBasicTypeConfig<PrecType>::BDataType;
using CDataType = typename GemmBasicTypeConfig<PrecType>::CDataType;
using AccDataType = typename GemmBasicTypeConfig<PrecType>::AccDataType;
using ADataType = typename GemmTypeConfig<PrecType>::ADataType;
using BDataType = typename GemmTypeConfig<PrecType>::BDataType;
using CDataType = typename GemmTypeConfig<PrecType>::CDataType;
using AccDataType = typename GemmTypeConfig<PrecType>::AccDataType;

ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
Expand Down
Empty file modified example/ck_tile/03_gemm/script/benchmark_basic_bf16.sh
100644 → 100755
Empty file.
Empty file modified example/ck_tile/03_gemm/script/benchmark_basic_bf8.sh
100644 → 100755
Empty file.
Empty file modified example/ck_tile/03_gemm/script/benchmark_basic_fp8.sh
100644 → 100755
Empty file.
Empty file.
Empty file modified example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf8.sh
100644 → 100755
Empty file.
Empty file modified example/ck_tile/03_gemm/script/benchmark_mem_pipeline_fp8.sh
100644 → 100755
Empty file.
13 changes: 5 additions & 8 deletions example/ck_tile/03_gemm/script/run_full_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,11 @@ function print_log_header(){
}

# run verification tests
example/ck_tile/03_gemm/script/smoke_test_basic.sh
example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh

# run performance benchmarks
export gemm_basic_log="perf_tile_gemm_basic_fp16_$GPU_arch.log"
print_log_header $gemm_basic_log $env_type $branch $host_name
example/ck_tile/03_gemm/script/benchmark_basic.sh 2>&1 | tee -a $gemm_basic_log

export gemm_mem_pipeline_log="perf_tile_gemm_mem_pipeline_fp16_$GPU_arch.log"
print_log_header $gemm_mem_pipeline_log $env_type $branch $host_name
example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh 2>&1 | tee -a $gemm_mem_pipeline_log
for dtype in fp16 bf16 fp8 bf8; do
export gemm_log="perf_tile_gemm_mem_pipeline_${dtype}_${GPU_arch}.log"
print_log_header $gemm_log $env_type $branch $host_name
example/ck_tile/03_gemm/script/benchmark_mem_pipeline_$dtype.sh 2>&1 | tee -a $gemm_log
done
125 changes: 37 additions & 88 deletions example/ck_tile/03_gemm/universal_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -267,114 +267,63 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&

#include "run_gemm_example.inc"

template <typename PrecType>
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;

if(a_layout == "R" && b_layout == "R")
{
return run_gemm_example_with_layouts<PrecType>(argc, argv, Row{}, Row{}, Row{});
}
else if(a_layout == "R" && b_layout == "C")
{
return run_gemm_example_with_layouts<PrecType>(argc, argv, Row{}, Col{}, Row{});
}
else if(a_layout == "C" && b_layout == "R")
{
return run_gemm_example_with_layouts<PrecType>(argc, argv, Col{}, Row{}, Row{});
}
else if(a_layout == "C" && b_layout == "C")
{
return run_gemm_example_with_layouts<PrecType>(argc, argv, Col{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported memory layout for the input matrices !!!");
}
}

int run_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;

using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;

std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");

if(a_layout == "R" && b_layout == "R")
if(data_type == "fp16")
{
if(data_type == "fp16")
{
return run_gemm_example_with_layouts<ck_tile::half_t>(argc, argv, Row{}, Row{}, Row{});
}
else if(data_type == "bf16")
{
return run_gemm_example_with_layouts<ck_tile::bf16_t>(argc, argv, Row{}, Row{}, Row{});
}
else if(data_type == "fp8")
{
return run_gemm_example_with_layouts<ck_tile::fp8_t>(argc, argv, Row{}, Row{}, Row{});
}
else if(data_type == "bf8")
{
return run_gemm_example_with_layouts<ck_tile::bf8_t>(argc, argv, Row{}, Row{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data_type!");
}
return run_gemm_example_prec_type<ck_tile::half_t>(a_layout, b_layout, argc, argv);
}
else if(a_layout == "R" && b_layout == "C")
else if(data_type == "bf16")
{
if(data_type == "fp16")
{
return run_gemm_example_with_layouts<ck_tile::half_t>(argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "bf16")
{
return run_gemm_example_with_layouts<ck_tile::bf16_t>(argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "fp8")
{
return run_gemm_example_with_layouts<ck_tile::fp8_t>(argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "bf8")
{
return run_gemm_example_with_layouts<ck_tile::bf8_t>(argc, argv, Row{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data_type!");
}
return run_gemm_example_prec_type<ck_tile::bf16_t>(a_layout, b_layout, argc, argv);
}
else if(a_layout == "C" && b_layout == "C")
else if(data_type == "fp8")
{
if(data_type == "fp16")
{
return run_gemm_example_with_layouts<ck_tile::half_t>(argc, argv, Col{}, Col{}, Row{});
}
else if(data_type == "bf16")
{
return run_gemm_example_with_layouts<ck_tile::bf16_t>(argc, argv, Col{}, Col{}, Row{});
}
else if(data_type == "fp8")
{
return run_gemm_example_with_layouts<ck_tile::fp8_t>(argc, argv, Col{}, Col{}, Row{});
}
else if(data_type == "bf8")
{
return run_gemm_example_with_layouts<ck_tile::bf8_t>(argc, argv, Col{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data_type!");
}
return run_gemm_example_prec_type<ck_tile::fp8_t>(a_layout, b_layout, argc, argv);
}
else if(a_layout == "C" && b_layout == "R")
else if(data_type == "bf8")
{
if(data_type == "fp16")
{
return run_gemm_example_with_layouts<ck_tile::half_t>(argc, argv, Col{}, Row{}, Row{});
}
else if(data_type == "bf16")
{
return run_gemm_example_with_layouts<ck_tile::bf16_t>(argc, argv, Col{}, Row{}, Row{});
}
else if(data_type == "fp8")
{
return run_gemm_example_with_layouts<ck_tile::fp8_t>(argc, argv, Col{}, Row{}, Row{});
}
else if(data_type == "bf8")
{
return run_gemm_example_with_layouts<ck_tile::bf8_t>(argc, argv, Col{}, Row{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data_type!");
}
return run_gemm_example_prec_type<ck_tile::bf8_t>(a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
throw std::runtime_error("Unsupported data type for this operation !!!");
}
}

Expand Down
3 changes: 1 addition & 2 deletions include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ struct CShuffleEpilogue
*
* @return The vector store size for C tensor.
*/
template <typename ODataType>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
{
constexpr index_t MaxVectorStoreSize = 16;
Expand Down Expand Up @@ -143,7 +142,7 @@ struct CShuffleEpilogue
TileDistributionEncodingPattern2D<kBlockSize,
kMPerIteration,
kNPerIteration,
GetVectorSizeC<ODataType>(),
GetVectorSizeC(),
tile_distribution_pattern::thread_raked>;
constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution();

Expand Down
10 changes: 5 additions & 5 deletions include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ struct GemmKernel

CK_TILE_HOST static bool IsSupportedArgument(const GemmKernelArgs& kargs)
{
if constexpr(EpiloguePipeline::template GetVectorSizeC<CDataType>() % 2 != 0 &&
if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<CDataType, fp16_t, bf16_t>::value)
{
if(kargs.k_batch != 1)
Expand Down Expand Up @@ -248,7 +248,7 @@ struct GemmKernel
<< std::endl;
return false;
}
if(kargs.N % EpiloguePipeline::template GetVectorSizeC<CDataType>() != 0)
if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0)
{
std::cerr << "N is not a multiple of vector load size for C tensor!" << std::endl;
return false;
Expand All @@ -263,7 +263,7 @@ struct GemmKernel
<< std::endl;
return false;
}
if(kargs.M % EpiloguePipeline::template GetVectorSizeC<CDataType>() != 0)
if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0)
{
std::cerr << "M is not a multiple of vector load size for C tensor!" << std::endl;
return false;
Expand Down Expand Up @@ -329,7 +329,7 @@ struct GemmKernel
c_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_C, 1),
number<EpiloguePipeline::template GetVectorSizeC<CDataType>()>{},
number<EpiloguePipeline::GetVectorSizeC()>{},
number<1>{});
}
else
Expand Down Expand Up @@ -597,7 +597,7 @@ struct GemmKernel
{
// Do not compile in case where we have unsupported
// VectorSizeC & data type configuration.
if constexpr(!(EpiloguePipeline::template GetVectorSizeC<CDataType>() % 2 != 0 &&
if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<CDataType, fp16_t, bf16_t>::value))
{
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
Expand Down
Loading