diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp index 3c514c57ad4b..084ba36e6721 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp @@ -165,13 +165,9 @@ static std::optional getMmaScheduleFromProblemAndTarget( problem, intrinsics, seeds, maxSharedMemoryBytes, targetSubgroupSize, transposedLhs, transposedRhs, /*canUpcastAcc=*/false, /*mustBeAligned*/ mustBeAligned, doCPromotion); - if (!schedule) { - // Then try again by allowing upcasting accumulator. - schedule = deduceMMASchedule( - problem, intrinsics, seeds, maxSharedMemoryBytes, targetSubgroupSize, - transposedLhs, transposedRhs, /*canUpcastAcc=*/true, - /*mustBeAligned*/ mustBeAligned, doCPromotion); - } + // TODO (nirvedhmeshram) : Add support for upcasting accumulator schedule. + // Currently we dont have this for TileAndFuse path, see + // https://github.com/iree-org/iree/issues/19532 return schedule; } @@ -393,9 +389,16 @@ setIGEMMConvolutionLoweringConfig(IREE::GPU::TargetAttr target, std::array workgroupSize = {configAndWgSize->second, 1, 1}; LoweringConfigAttr loweringConfig = configAndWgSize->first; + bool usePrefetchSharedMemory = true; + // Prefetching has issues when doing c promotion, see + // https://github.com/iree-org/iree/issues/19612. + if (llvm::any_of(getPromotedOperandList(loweringConfig).value(), + [](int64_t promote) { return promote == 2; })) { + usePrefetchSharedMemory = false; + } SmallVector pipelineAttrs; auto pipelineOptions = IREE::GPU::GPUPipelineOptionsAttr::get( - linalgOp->getContext(), /*prefetchSharedMemory=*/true, + linalgOp->getContext(), /*prefetchSharedMemory=*/usePrefetchSharedMemory, /*no_reduce_shared_memory_bank_conflicts=*/false, /*use_igemm_convolution=*/true, /*reorder_workgroups_strategy=*/std::nullopt); @@ -436,9 +439,16 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target, std::array workgroupSize = {configAndWgSize->second, 1, 1}; LoweringConfigAttr loweringConfig = configAndWgSize->first; + bool usePrefetchSharedMemory = true; + // Prefetching has issues when doing c promotion, see + // https://github.com/iree-org/iree/issues/19612. + if (llvm::any_of(getPromotedOperandList(loweringConfig).value(), + [](int64_t promote) { return promote == 2; })) { + usePrefetchSharedMemory = false; + } SmallVector pipelineAttrs; auto pipelineOptions = IREE::GPU::GPUPipelineOptionsAttr::get( - linalgOp->getContext(), /*prefetchSharedMemory=*/true, + linalgOp->getContext(), /*prefetchSharedMemory=*/usePrefetchSharedMemory, /*no_reduce_shared_memory_bank_conflicts=*/false, /*use_igemm_convolution=*/false, /*reorder_workgroups_strategy=*/std::nullopt); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp index e09a72b52df5..2ba637f9889c 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp @@ -48,10 +48,10 @@ #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") namespace mlir::iree_compiler { -llvm::cl::opt clGPUTestTileAndFuseMatmul( - "iree-codegen-llvmgpu-test-tile-and-fuse-matmul", +llvm::cl::opt clGPUEnableTileAndFuseMatmul( + "iree-codegen-llvmgpu-enable-tile-and-fuse-matmul", llvm::cl::desc("test the the tile and fuse pipeline for matmul"), - llvm::cl::init(false)); + llvm::cl::init(true)); llvm::cl::opt clGPUTestTileAndFuseVectorize( "iree-codegen-llvmgpu-test-tile-and-fuse-vectorize", @@ -2352,7 +2352,7 @@ static LogicalResult setRootConfig(IREE::GPU::TargetAttr target, LDBG("Tile and fuse data tiled multi_mma config"); return success(); } - if (clGPUTestTileAndFuseMatmul) { + if (clGPUEnableTileAndFuseMatmul) { if (succeeded(IREE::GPU::setMatmulLoweringConfig(target, entryPointFn, computeOp))) { LDBG("Tile and fuse matmul config"); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_igemm_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_igemm_tile_and_fuse.mlir index cf170ef7d930..d8af22e58664 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_igemm_tile_and_fuse.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_igemm_tile_and_fuse.mlir @@ -76,7 +76,7 @@ func.func @nhwc_conv_unaligned_mfma() { // CHECK-LABEL: func.func @nhwc_conv_unaligned_mfma // CHECK-SAME: #iree_codegen.translation_info (d0, d2, d4)> #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)> #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)> -func.func @expanded_matmul_transpose_b(%lhs: tensor<2x64x2048xf16>, %rhs: tensor<10x64x2048xf16>) -> tensor<2x10x64x64xf16> { +func.func @expanded_matmul_transpose_b(%lhs: tensor<2x64x2048xf16>, %rhs: tensor<10x64x2048xf16>) -> tensor<2x10x64x64xf32> { %c0 = arith.constant 0 : index - %cst = arith.constant 0.000000e+00 : f16 - %5 = tensor.empty() : tensor<2x10x64x64xf16> - %6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<2x10x64x64xf16>) -> tensor<2x10x64x64xf16> + %cst = arith.constant 0.000000e+00 : f32 + %5 = tensor.empty() : tensor<2x10x64x64xf32> + %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<2x10x64x64xf32>) -> tensor<2x10x64x64xf32> %7 = linalg.generic { indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} - ins(%lhs, %rhs : tensor<2x64x2048xf16>, tensor<10x64x2048xf16>) outs(%6 : tensor<2x10x64x64xf16>) { - ^bb0(%in: f16, %in_0: f16, %out: f16): - %8 = arith.mulf %in, %in_0 : f16 - %9 = arith.addf %8, %out : f16 - linalg.yield %9 : f16 - } -> tensor<2x10x64x64xf16> - return %7 : tensor<2x10x64x64xf16> + ins(%lhs, %rhs : tensor<2x64x2048xf16>, tensor<10x64x2048xf16>) outs(%6 : tensor<2x10x64x64xf32>) { + ^bb0(%in: f16, %in_0: f16, %out: f32): + %8 = arith.extf %in : f16 to f32 + %9 = arith.extf %in_0 : f16 to f32 + %10 = arith.mulf %8, %9 : f32 + %11 = arith.addf %10, %out : f32 + linalg.yield %11 : f32 + } -> tensor<2x10x64x64xf32> + return %7 : tensor<2x10x64x64xf32> } // CHECK-LABEL: func.func @expanded_matmul_transpose_b @@ -46,21 +48,23 @@ func.func @expanded_matmul_transpose_b(%lhs: tensor<2x64x2048xf16>, %rhs: tensor #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d4, d5)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d3, d4, d5)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> -func.func @multi_dim_mma_schedule(%lhs: tensor<10x32x128x16xf16>, %rhs: tensor<4x32x128x16xf16>) -> tensor<10x4x32x32xf16> { +func.func @multi_dim_mma_schedule(%lhs: tensor<10x32x128x16xf16>, %rhs: tensor<4x32x128x16xf16>) -> tensor<10x4x32x32xf32> { %c0 = arith.constant 0 : index - %cst = arith.constant 0.000000e+00 : f16 - %5 = tensor.empty() : tensor<10x4x32x32xf16> - %6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<10x4x32x32xf16>) -> tensor<10x4x32x32xf16> + %cst = arith.constant 0.000000e+00 : f32 + %5 = tensor.empty() : tensor<10x4x32x32xf32> + %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<10x4x32x32xf32>) -> tensor<10x4x32x32xf32> %7 = linalg.generic { indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} - ins(%lhs, %rhs : tensor<10x32x128x16xf16>, tensor<4x32x128x16xf16>) outs(%6 : tensor<10x4x32x32xf16>) { - ^bb0(%in: f16, %in_0: f16, %out: f16): - %8 = arith.mulf %in, %in_0 : f16 - %9 = arith.addf %8, %out : f16 - linalg.yield %9 : f16 - } -> tensor<10x4x32x32xf16> - return %7 : tensor<10x4x32x32xf16> + ins(%lhs, %rhs : tensor<10x32x128x16xf16>, tensor<4x32x128x16xf16>) outs(%6 : tensor<10x4x32x32xf32>) { + ^bb0(%in: f16, %in_0: f16, %out: f32): + %8 = arith.extf %in : f16 to f32 + %9 = arith.extf %in_0 : f16 to f32 + %10 = arith.mulf %8, %9 : f32 + %11 = arith.addf %10, %out : f32 + linalg.yield %11 : f32 + } -> tensor<10x4x32x32xf32> + return %7 : tensor<10x4x32x32xf32> } // CHECK-LABEL: func.func @multi_dim_mma_schedule @@ -79,23 +83,25 @@ func.func @multi_dim_mma_schedule(%lhs: tensor<10x32x128x16xf16>, %rhs: tensor<4 #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d3, d5, d6)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2, d4, d5, d6)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)> -func.func @dynamic_multi_dim_mma_schedule(%lhs: tensor, %rhs: tensor) -> tensor { +func.func @dynamic_multi_dim_mma_schedule(%lhs: tensor, %rhs: tensor) -> tensor { %c0 = arith.constant 0 : index - %cst = arith.constant 0.000000e+00 : f16 + %cst = arith.constant 0.000000e+00 : f32 %d0 = tensor.dim %lhs, %c0 : tensor %d2 = tensor.dim %rhs, %c0 : tensor - %5 = tensor.empty(%d0, %d2) : tensor - %6 = linalg.fill ins(%cst : f16) outs(%5 : tensor) -> tensor + %5 = tensor.empty(%d0, %d2) : tensor + %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor) -> tensor %7 = linalg.generic { indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} - ins(%lhs, %rhs : tensor, tensor) outs(%6 : tensor) { - ^bb0(%in: f16, %in_0: f16, %out: f16): - %8 = arith.mulf %in, %in_0 : f16 - %9 = arith.addf %8, %out : f16 - linalg.yield %9 : f16 - } -> tensor - return %7 : tensor + ins(%lhs, %rhs : tensor, tensor) outs(%6 : tensor) { + ^bb0(%in: f16, %in_0: f16, %out: f32): + %8 = arith.extf %in : f16 to f32 + %9 = arith.extf %in_0 : f16 to f32 + %10 = arith.mulf %8, %9 : f32 + %11 = arith.addf %10, %out : f32 + linalg.yield %11 : f32 + } -> tensor + return %7 : tensor } // CHECK-LABEL: func.func @dynamic_multi_dim_mma_schedule @@ -271,7 +277,7 @@ func.func @unaligned_to_intrinsic_batched_matmul(%lhs : tensor<12x577x577xf32>, // CHECK-LABEL: func.func @unaligned_to_intrinsic_batched_matmul // CHECK-SAME: #iree_codegen.translation_info} +// CHECK-SAME: {gpu_pipeline_options = #iree_gpu.pipeline_options} // CHECK: linalg.batch_matmul {{.*}}lowering_config = #iree_gpu.lowering_config // CHECK-SAME: padding = [1, 16, 16, 4] // CHECK-SAME: promote_operands = [0, 1, 2] @@ -300,7 +306,7 @@ func.func @unaligned_to_intrinsic_batched_matmul_tiling_check(%lhs : tensor<12x5 // CHECK-LABEL: func.func @unaligned_to_intrinsic_batched_matmul_tiling_check // CHECK-SAME: #iree_codegen.translation_info} +// CHECK-SAME: {gpu_pipeline_options = #iree_gpu.pipeline_options} // CHECK: linalg.batch_matmul {{.*}}lowering_config = #iree_gpu.lowering_config // CHECK-SAME: padding = [1, 16, 512, 4] // CHECK-SAME: promote_operands = [0, 1, 2] diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_gfx1100.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_gfx1100.mlir index 3198f1592bdd..e42bdc266742 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_gfx1100.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_gfx1100.mlir @@ -1,4 +1,5 @@ // RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx1100 --iree-codegen-llvmgpu-use-vector-distribution \ +// RUN: --iree-codegen-llvmgpu-enable-tile-and-fuse-matmul=false \ // RUN: --pass-pipeline="builtin.module(iree-llvmgpu-select-lowering-strategy)" %s | FileCheck %s --check-prefix=WMMA // TODO: This test is still using the legacy LLVMGPU kernel config. This needs diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_gfx942.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_gfx942.mlir index 373b67b04e8f..f43881200de0 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_gfx942.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_gfx942.mlir @@ -1,5 +1,6 @@ // RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx942 --iree-codegen-llvmgpu-use-vector-distribution \ // RUN: --iree-codegen-llvmgpu-use-unaligned-gemm-vector-distribution --iree-codegen-llvmgpu-use-igemm=false \ +// RUN: --iree-codegen-llvmgpu-enable-tile-and-fuse-matmul=false \ // RUN: --pass-pipeline="builtin.module(iree-llvmgpu-select-lowering-strategy)" %s | FileCheck %s // TODO: This test is still using the legacy LLVMGPU kernel config. This needs diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_custom_op.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_custom_op.mlir index 62ccec73c67a..bea2f2abe738 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_custom_op.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_custom_op.mlir @@ -33,14 +33,14 @@ func.func @custom_op(%arg0 : tensor<384x512xf32>, %arg1 : tensor<512x128xf32>, return %1 : tensor<384x128xf32> } // CHECK: #[[CONFIG:.+]] = #iree_codegen.lowering_config -// CHECK: #[[TRANSLATION:.+]] = #iree_codegen.translation_info, promote_operands = [0, 1], reduction = [0, 0, 32], subgroup_m_count = 2 : i64, subgroup_n_count = 2 : i64, workgroup = [64, 64, 0]}> +// CHECK-SAME: lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout, promote_operands = [0, 1], reduction = [0, 0, 8], subgroup = [2, 2, 0], workgroup = [64, 64, 0]}> // CHECK: iree_linalg_ext.yield // ----- diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/gpu_set_num_workgroups.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/gpu_set_num_workgroups.mlir index 642c6ed1a179..cd93a7b0268b 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/gpu_set_num_workgroups.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/gpu_set_num_workgroups.mlir @@ -1,7 +1,7 @@ // RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(iree-codegen-llvmgpu-configuration-pipeline)" \ -// RUN: --iree-gpu-test-target=sm_60 %s | FileCheck %s +// RUN: --iree-gpu-test-target=sm_60 --iree-codegen-llvmgpu-enable-tile-and-fuse-matmul=false %s | FileCheck %s // RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(iree-codegen-llvmgpu-configuration-pipeline)" \ -// RUN: --iree-gpu-test-target=sm_80 %s | FileCheck %s --check-prefix=SM80 +// RUN: --iree-gpu-test-target=sm_80 --iree-codegen-llvmgpu-enable-tile-and-fuse-matmul=false %s | FileCheck %s --check-prefix=SM80 // Transform dialect attributes are tested separately. diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_extract_address_computation.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_extract_address_computation.mlir index ef9e587da95b..ae250684dc07 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_extract_address_computation.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_extract_address_computation.mlir @@ -1,4 +1,6 @@ -// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-nvvm-pipeline)))' --iree-gpu-test-target=sm_80 -split-input-file %s -o - | FileCheck %s +// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant( \ +// RUN: builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-nvvm-pipeline)))' \ +// RUN: --iree-codegen-llvmgpu-enable-tile-and-fuse-matmul=false --iree-gpu-test-target=sm_80 -split-input-file %s -o - | FileCheck %s // This test checks that the lowering of nvvm includes the extraction // and optimization of address computations. diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_mma_sync_pipeline_test.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_mma_sync_pipeline_test.mlir index c0cd53377863..2065390cd199 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_mma_sync_pipeline_test.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_mma_sync_pipeline_test.mlir @@ -1,4 +1,7 @@ -// RUN: iree-opt --split-input-file --iree-gpu-test-target=sm_80 --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-nvvm-pipeline)))" -iree-codegen-llvmgpu-use-mma-sync %s | FileCheck %s +// RUN: iree-opt --split-input-file --iree-gpu-test-target=sm_80 \ +// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant( \ +// RUN: builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-nvvm-pipeline)))" \ +// RUN: --iree-codegen-llvmgpu-enable-tile-and-fuse-matmul=false -iree-codegen-llvmgpu-use-mma-sync %s | FileCheck %s // Verify that a simple element wise op gets lowered succefully all the way to // nvvm/llvm dialect via mma.sync path. diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir index ad6aad32420c..b210a806ae3a 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir @@ -1,5 +1,11 @@ -// RUN: iree-opt --split-input-file --iree-gpu-test-target=sm_60 --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-nvvm-pipeline)))" -iree-codegen-llvmgpu-use-wmma %s | FileCheck %s -// RUN: iree-opt --split-input-file --iree-gpu-test-target=sm_80 --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-nvvm-pipeline)))" -iree-codegen-llvmgpu-use-wmma %s | FileCheck %s --check-prefix=SM80 +// RUN: iree-opt --split-input-file --iree-gpu-test-target=sm_60 \ +// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant( \ +// RUN: builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-nvvm-pipeline)))" \ +// RUN: --iree-codegen-llvmgpu-enable-tile-and-fuse-matmul=false -iree-codegen-llvmgpu-use-wmma %s | FileCheck %s +// RUN: iree-opt --split-input-file --iree-gpu-test-target=sm_80 \ +// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant( \ +// RUN: builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-nvvm-pipeline)))" \ +// RUN: --iree-codegen-llvmgpu-enable-tile-and-fuse-matmul=false -iree-codegen-llvmgpu-use-wmma %s | FileCheck %s --check-prefix=SM80 // Verify that a simple element wise op gets lowered succefully all the way to // nvvm/llvm dialect.