Skip to content

Commit

Permalink
[GPU] Enable tile and fuse matmul by default
Browse files Browse the repository at this point in the history
Signed-off-by: Nirvedh <nirvedh@gmail.com>
Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com>
  • Loading branch information
nirvedhmeshram committed Jan 10, 2025
1 parent 106371d commit b5ed37c
Show file tree
Hide file tree
Showing 11 changed files with 88 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -165,13 +165,9 @@ static std::optional<GPUMMASchedule> getMmaScheduleFromProblemAndTarget(
problem, intrinsics, seeds, maxSharedMemoryBytes, targetSubgroupSize,
transposedLhs, transposedRhs, /*canUpcastAcc=*/false,
/*mustBeAligned*/ mustBeAligned, doCPromotion);
if (!schedule) {
// Then try again by allowing upcasting accumulator.
schedule = deduceMMASchedule(
problem, intrinsics, seeds, maxSharedMemoryBytes, targetSubgroupSize,
transposedLhs, transposedRhs, /*canUpcastAcc=*/true,
/*mustBeAligned*/ mustBeAligned, doCPromotion);
}
// TODO (nirvedhmeshram) : Add support for upcasting accumulator schedule.
// Currently we dont have this for TileAndFuse path, see
// https://github.com/iree-org/iree/issues/19532
return schedule;
}

Expand Down Expand Up @@ -393,9 +389,16 @@ setIGEMMConvolutionLoweringConfig(IREE::GPU::TargetAttr target,
std::array<int64_t, 3> workgroupSize = {configAndWgSize->second, 1, 1};
LoweringConfigAttr loweringConfig = configAndWgSize->first;

bool usePrefetchSharedMemory = true;
// Prefetching has issues when doing c promotion, see
// https://github.com/iree-org/iree/issues/19612.
if (llvm::any_of(getPromotedOperandList(loweringConfig).value(),
[](int64_t promote) { return promote == 2; })) {
usePrefetchSharedMemory = false;
}
SmallVector<NamedAttribute, 1> pipelineAttrs;
auto pipelineOptions = IREE::GPU::GPUPipelineOptionsAttr::get(
linalgOp->getContext(), /*prefetchSharedMemory=*/true,
linalgOp->getContext(), /*prefetchSharedMemory=*/usePrefetchSharedMemory,
/*no_reduce_shared_memory_bank_conflicts=*/false,
/*use_igemm_convolution=*/true,
/*reorder_workgroups_strategy=*/std::nullopt);
Expand Down Expand Up @@ -436,9 +439,16 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
std::array<int64_t, 3> workgroupSize = {configAndWgSize->second, 1, 1};
LoweringConfigAttr loweringConfig = configAndWgSize->first;

bool usePrefetchSharedMemory = true;
// Prefetching has issues when doing c promotion, see
// https://github.com/iree-org/iree/issues/19612.
if (llvm::any_of(getPromotedOperandList(loweringConfig).value(),
[](int64_t promote) { return promote == 2; })) {
usePrefetchSharedMemory = false;
}
SmallVector<NamedAttribute, 1> pipelineAttrs;
auto pipelineOptions = IREE::GPU::GPUPipelineOptionsAttr::get(
linalgOp->getContext(), /*prefetchSharedMemory=*/true,
linalgOp->getContext(), /*prefetchSharedMemory=*/usePrefetchSharedMemory,
/*no_reduce_shared_memory_bank_conflicts=*/false,
/*use_igemm_convolution=*/false,
/*reorder_workgroups_strategy=*/std::nullopt);
Expand Down
8 changes: 4 additions & 4 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
namespace mlir::iree_compiler {

llvm::cl::opt<bool> clGPUTestTileAndFuseMatmul(
"iree-codegen-llvmgpu-test-tile-and-fuse-matmul",
llvm::cl::opt<bool> clGPUEnableTileAndFuseMatmul(
"iree-codegen-llvmgpu-enable-tile-and-fuse-matmul",
llvm::cl::desc("test the the tile and fuse pipeline for matmul"),
llvm::cl::init(false));
llvm::cl::init(true));

llvm::cl::opt<bool> clGPUTestTileAndFuseVectorize(
"iree-codegen-llvmgpu-test-tile-and-fuse-vectorize",
Expand Down Expand Up @@ -2352,7 +2352,7 @@ static LogicalResult setRootConfig(IREE::GPU::TargetAttr target,
LDBG("Tile and fuse data tiled multi_mma config");
return success();
}
if (clGPUTestTileAndFuseMatmul) {
if (clGPUEnableTileAndFuseMatmul) {
if (succeeded(IREE::GPU::setMatmulLoweringConfig(target, entryPointFn,
computeOp))) {
LDBG("Tile and fuse matmul config");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func.func @nhwc_conv_unaligned_mfma() {

// CHECK-LABEL: func.func @nhwc_conv_unaligned_mfma
// CHECK-SAME: #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [256, 1, 1] subgroup_size = 64
// CHECK-SAME: #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false
// CHECK-SAME: #iree_gpu.pipeline_options<prefetch_shared_memory = false, no_reduce_shared_memory_bank_conflicts = false
// CHECK-SAME: use_igemm_convolution = true

// CHECK: linalg.conv_2d_nhwc_hwcf {{.*}}lowering_config = #iree_gpu.lowering_config
Expand Down Expand Up @@ -106,7 +106,7 @@ func.func @nchw_conv_unaligned_mfma() {

// CHECK-LABEL: func.func @nchw_conv_unaligned_mfma
// CHECK-SAME: #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [256, 1, 1] subgroup_size = 64
// CHECK-SAME: #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false
// CHECK-SAME: #iree_gpu.pipeline_options<prefetch_shared_memory = false, no_reduce_shared_memory_bank_conflicts = false
// CHECK-SAME: use_igemm_convolution = true

// CHECK: linalg.conv_2d_nchw_fchw {{.*}}lowering_config = #iree_gpu.lowering_config
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: iree-opt --mlir-print-local-scope --split-input-file --iree-gpu-test-target=gfx942 \
// RUN: --iree-codegen-llvmgpu-test-tile-and-fuse-matmul=true --iree-codegen-llvmgpu-test-tile-and-fuse-vectorize=true \
// RUN: --iree-codegen-llvmgpu-test-tile-and-fuse-vectorize=true \
// RUN: --iree-codegen-llvmgpu-use-igemm=false \
// RUN: --pass-pipeline="builtin.module(iree-llvmgpu-select-lowering-strategy)" %s | FileCheck %s

Expand All @@ -10,21 +10,23 @@
#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>
#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
func.func @expanded_matmul_transpose_b(%lhs: tensor<2x64x2048xf16>, %rhs: tensor<10x64x2048xf16>) -> tensor<2x10x64x64xf16> {
func.func @expanded_matmul_transpose_b(%lhs: tensor<2x64x2048xf16>, %rhs: tensor<10x64x2048xf16>) -> tensor<2x10x64x64xf32> {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f16
%5 = tensor.empty() : tensor<2x10x64x64xf16>
%6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<2x10x64x64xf16>) -> tensor<2x10x64x64xf16>
%cst = arith.constant 0.000000e+00 : f32
%5 = tensor.empty() : tensor<2x10x64x64xf32>
%6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<2x10x64x64xf32>) -> tensor<2x10x64x64xf32>
%7 = linalg.generic {
indexing_maps = [#map, #map1, #map2],
iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]}
ins(%lhs, %rhs : tensor<2x64x2048xf16>, tensor<10x64x2048xf16>) outs(%6 : tensor<2x10x64x64xf16>) {
^bb0(%in: f16, %in_0: f16, %out: f16):
%8 = arith.mulf %in, %in_0 : f16
%9 = arith.addf %8, %out : f16
linalg.yield %9 : f16
} -> tensor<2x10x64x64xf16>
return %7 : tensor<2x10x64x64xf16>
ins(%lhs, %rhs : tensor<2x64x2048xf16>, tensor<10x64x2048xf16>) outs(%6 : tensor<2x10x64x64xf32>) {
^bb0(%in: f16, %in_0: f16, %out: f32):
%8 = arith.extf %in : f16 to f32
%9 = arith.extf %in_0 : f16 to f32
%10 = arith.mulf %8, %9 : f32
%11 = arith.addf %10, %out : f32
linalg.yield %11 : f32
} -> tensor<2x10x64x64xf32>
return %7 : tensor<2x10x64x64xf32>
}

// CHECK-LABEL: func.func @expanded_matmul_transpose_b
Expand All @@ -46,21 +48,23 @@ func.func @expanded_matmul_transpose_b(%lhs: tensor<2x64x2048xf16>, %rhs: tensor
#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d4, d5)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d3, d4, d5)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
func.func @multi_dim_mma_schedule(%lhs: tensor<10x32x128x16xf16>, %rhs: tensor<4x32x128x16xf16>) -> tensor<10x4x32x32xf16> {
func.func @multi_dim_mma_schedule(%lhs: tensor<10x32x128x16xf16>, %rhs: tensor<4x32x128x16xf16>) -> tensor<10x4x32x32xf32> {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f16
%5 = tensor.empty() : tensor<10x4x32x32xf16>
%6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<10x4x32x32xf16>) -> tensor<10x4x32x32xf16>
%cst = arith.constant 0.000000e+00 : f32
%5 = tensor.empty() : tensor<10x4x32x32xf32>
%6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<10x4x32x32xf32>) -> tensor<10x4x32x32xf32>
%7 = linalg.generic {
indexing_maps = [#map, #map1, #map2],
iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]}
ins(%lhs, %rhs : tensor<10x32x128x16xf16>, tensor<4x32x128x16xf16>) outs(%6 : tensor<10x4x32x32xf16>) {
^bb0(%in: f16, %in_0: f16, %out: f16):
%8 = arith.mulf %in, %in_0 : f16
%9 = arith.addf %8, %out : f16
linalg.yield %9 : f16
} -> tensor<10x4x32x32xf16>
return %7 : tensor<10x4x32x32xf16>
ins(%lhs, %rhs : tensor<10x32x128x16xf16>, tensor<4x32x128x16xf16>) outs(%6 : tensor<10x4x32x32xf32>) {
^bb0(%in: f16, %in_0: f16, %out: f32):
%8 = arith.extf %in : f16 to f32
%9 = arith.extf %in_0 : f16 to f32
%10 = arith.mulf %8, %9 : f32
%11 = arith.addf %10, %out : f32
linalg.yield %11 : f32
} -> tensor<10x4x32x32xf32>
return %7 : tensor<10x4x32x32xf32>
}

// CHECK-LABEL: func.func @multi_dim_mma_schedule
Expand All @@ -79,23 +83,25 @@ func.func @multi_dim_mma_schedule(%lhs: tensor<10x32x128x16xf16>, %rhs: tensor<4
#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d3, d5, d6)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2, d4, d5, d6)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
func.func @dynamic_multi_dim_mma_schedule(%lhs: tensor<?x6x16x?x16xf16>, %rhs: tensor<?x32x?x16xf16>) -> tensor<?x6x?x16x32xf16> {
func.func @dynamic_multi_dim_mma_schedule(%lhs: tensor<?x6x16x?x16xf16>, %rhs: tensor<?x32x?x16xf16>) -> tensor<?x6x?x16x32xf32> {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f16
%cst = arith.constant 0.000000e+00 : f32
%d0 = tensor.dim %lhs, %c0 : tensor<?x6x16x?x16xf16>
%d2 = tensor.dim %rhs, %c0 : tensor<?x32x?x16xf16>
%5 = tensor.empty(%d0, %d2) : tensor<?x6x?x16x32xf16>
%6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<?x6x?x16x32xf16>) -> tensor<?x6x?x16x32xf16>
%5 = tensor.empty(%d0, %d2) : tensor<?x6x?x16x32xf32>
%6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<?x6x?x16x32xf32>) -> tensor<?x6x?x16x32xf32>
%7 = linalg.generic {
indexing_maps = [#map, #map1, #map2],
iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]}
ins(%lhs, %rhs : tensor<?x6x16x?x16xf16>, tensor<?x32x?x16xf16>) outs(%6 : tensor<?x6x?x16x32xf16>) {
^bb0(%in: f16, %in_0: f16, %out: f16):
%8 = arith.mulf %in, %in_0 : f16
%9 = arith.addf %8, %out : f16
linalg.yield %9 : f16
} -> tensor<?x6x?x16x32xf16>
return %7 : tensor<?x6x?x16x32xf16>
ins(%lhs, %rhs : tensor<?x6x16x?x16xf16>, tensor<?x32x?x16xf16>) outs(%6 : tensor<?x6x?x16x32xf32>) {
^bb0(%in: f16, %in_0: f16, %out: f32):
%8 = arith.extf %in : f16 to f32
%9 = arith.extf %in_0 : f16 to f32
%10 = arith.mulf %8, %9 : f32
%11 = arith.addf %10, %out : f32
linalg.yield %11 : f32
} -> tensor<?x6x?x16x32xf32>
return %7 : tensor<?x6x?x16x32xf32>
}

// CHECK-LABEL: func.func @dynamic_multi_dim_mma_schedule
Expand Down Expand Up @@ -271,7 +277,7 @@ func.func @unaligned_to_intrinsic_batched_matmul(%lhs : tensor<12x577x577xf32>,

// CHECK-LABEL: func.func @unaligned_to_intrinsic_batched_matmul
// CHECK-SAME: #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [64, 1, 1] subgroup_size = 64
// CHECK-SAME: {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false, use_igemm_convolution = false>}
// CHECK-SAME: {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = false, no_reduce_shared_memory_bank_conflicts = false, use_igemm_convolution = false>}
// CHECK: linalg.batch_matmul {{.*}}lowering_config = #iree_gpu.lowering_config
// CHECK-SAME: padding = [1, 16, 16, 4]
// CHECK-SAME: promote_operands = [0, 1, 2]
Expand Down Expand Up @@ -300,7 +306,7 @@ func.func @unaligned_to_intrinsic_batched_matmul_tiling_check(%lhs : tensor<12x5

// CHECK-LABEL: func.func @unaligned_to_intrinsic_batched_matmul_tiling_check
// CHECK-SAME: #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [256, 1, 1] subgroup_size = 64
// CHECK-SAME: {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false, use_igemm_convolution = false>}
// CHECK-SAME: {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = false, no_reduce_shared_memory_bank_conflicts = false, use_igemm_convolution = false>}
// CHECK: linalg.batch_matmul {{.*}}lowering_config = #iree_gpu.lowering_config
// CHECK-SAME: padding = [1, 16, 512, 4]
// CHECK-SAME: promote_operands = [0, 1, 2]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx1100 --iree-codegen-llvmgpu-use-vector-distribution \
// RUN: --iree-codegen-llvmgpu-enable-tile-and-fuse-matmul=false \
// RUN: --pass-pipeline="builtin.module(iree-llvmgpu-select-lowering-strategy)" %s | FileCheck %s --check-prefix=WMMA

// TODO: This test is still using the legacy LLVMGPU kernel config. This needs
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx942 --iree-codegen-llvmgpu-use-vector-distribution \
// RUN: --iree-codegen-llvmgpu-use-unaligned-gemm-vector-distribution --iree-codegen-llvmgpu-use-igemm=false \
// RUN: --iree-codegen-llvmgpu-enable-tile-and-fuse-matmul=false \
// RUN: --pass-pipeline="builtin.module(iree-llvmgpu-select-lowering-strategy)" %s | FileCheck %s

// TODO: This test is still using the legacy LLVMGPU kernel config. This needs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@ func.func @custom_op(%arg0 : tensor<384x512xf32>, %arg1 : tensor<512x128xf32>,
return %1 : tensor<384x128xf32>
}
// CHECK: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[64, 64, 0]]>
// CHECK: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64,
// CHECK: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [256, 1, 1] subgroup_size = 64,
// CHECK: func @custom_op
// CHECK-SAME: translation_info = #[[TRANSLATION]]
// CHECK: iree_linalg_ext.custom_op
// CHECK-SAME: lowering_config = #[[CONFIG]]
// CHECK: ^bb
// CHECK: linalg.matmul
// CHECK-SAME: lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x4_F32>, promote_operands = [0, 1], reduction = [0, 0, 32], subgroup_m_count = 2 : i64, subgroup_n_count = 2 : i64, workgroup = [64, 64, 0]}>
// CHECK-SAME: lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x4_F32>, promote_operands = [0, 1], reduction = [0, 0, 8], subgroup = [2, 2, 0], workgroup = [64, 64, 0]}>
// CHECK: iree_linalg_ext.yield

// -----
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(iree-codegen-llvmgpu-configuration-pipeline)" \
// RUN: --iree-gpu-test-target=sm_60 %s | FileCheck %s
// RUN: --iree-gpu-test-target=sm_60 --iree-codegen-llvmgpu-enable-tile-and-fuse-matmul=false %s | FileCheck %s
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(iree-codegen-llvmgpu-configuration-pipeline)" \
// RUN: --iree-gpu-test-target=sm_80 %s | FileCheck %s --check-prefix=SM80
// RUN: --iree-gpu-test-target=sm_80 --iree-codegen-llvmgpu-enable-tile-and-fuse-matmul=false %s | FileCheck %s --check-prefix=SM80

// Transform dialect attributes are tested separately.

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-nvvm-pipeline)))' --iree-gpu-test-target=sm_80 -split-input-file %s -o - | FileCheck %s
// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant( \
// RUN: builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-nvvm-pipeline)))' \
// RUN: --iree-codegen-llvmgpu-enable-tile-and-fuse-matmul=false --iree-gpu-test-target=sm_80 -split-input-file %s -o - | FileCheck %s

// This test checks that the lowering of nvvm includes the extraction
// and optimization of address computations.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
// RUN: iree-opt --split-input-file --iree-gpu-test-target=sm_80 --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-nvvm-pipeline)))" -iree-codegen-llvmgpu-use-mma-sync %s | FileCheck %s
// RUN: iree-opt --split-input-file --iree-gpu-test-target=sm_80 \
// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant( \
// RUN: builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-nvvm-pipeline)))" \
// RUN: --iree-codegen-llvmgpu-enable-tile-and-fuse-matmul=false -iree-codegen-llvmgpu-use-mma-sync %s | FileCheck %s

// Verify that a simple element wise op gets lowered succefully all the way to
// nvvm/llvm dialect via mma.sync path.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
// RUN: iree-opt --split-input-file --iree-gpu-test-target=sm_60 --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-nvvm-pipeline)))" -iree-codegen-llvmgpu-use-wmma %s | FileCheck %s
// RUN: iree-opt --split-input-file --iree-gpu-test-target=sm_80 --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-nvvm-pipeline)))" -iree-codegen-llvmgpu-use-wmma %s | FileCheck %s --check-prefix=SM80
// RUN: iree-opt --split-input-file --iree-gpu-test-target=sm_60 \
// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant( \
// RUN: builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-nvvm-pipeline)))" \
// RUN: --iree-codegen-llvmgpu-enable-tile-and-fuse-matmul=false -iree-codegen-llvmgpu-use-wmma %s | FileCheck %s
// RUN: iree-opt --split-input-file --iree-gpu-test-target=sm_80 \
// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant( \
// RUN: builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-nvvm-pipeline)))" \
// RUN: --iree-codegen-llvmgpu-enable-tile-and-fuse-matmul=false -iree-codegen-llvmgpu-use-wmma %s | FileCheck %s --check-prefix=SM80

// Verify that a simple element wise op gets lowered succefully all the way to
// nvvm/llvm dialect.
Expand Down

0 comments on commit b5ed37c

Please sign in to comment.