diff --git a/.github/workflows/ci-gpu.yaml b/.github/workflows/ci-gpu.yaml index f61779396..8189a4959 100644 --- a/.github/workflows/ci-gpu.yaml +++ b/.github/workflows/ci-gpu.yaml @@ -287,6 +287,10 @@ jobs: cmake --build . - name: Test water + # As we are mixing static LLVM build with shared Water we will get a + # usual double option registration LLVM issue. Only compile Water in + # this case to look for missing libraries. + if: ${{ matrix.shared_libs == 'OFF' }} run: | cd cmake_build cmake --build . --target check-water diff --git a/water/include/water/Transforms/Passes.td b/water/include/water/Transforms/Passes.td index 8ef72c96a..c64609557 100644 --- a/water/include/water/Transforms/Passes.td +++ b/water/include/water/Transforms/Passes.td @@ -104,4 +104,13 @@ def WaterGreedySLPVectorizer : Pass<"water-greedy-slp-vectorizer"> { ]; } +def WaterGPUToGPURuntime : Pass<"water-gpu-to-gpu-runtime", "::mlir::ModuleOp"> { + let summary = "Lower GPU dialect ops to runtime calls"; + let description = [{ + This pass lowers operations from the GPU dialect to calls into GPU runtime + functions. + }]; + let dependentDialects = ["::mlir::LLVM::LLVMDialect"]; +} + #endif // WATER_PASSES diff --git a/water/lib/Transforms/CMakeLists.txt b/water/lib/Transforms/CMakeLists.txt index 08aa31a50..c441b192b 100644 --- a/water/lib/Transforms/CMakeLists.txt +++ b/water/lib/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRWaterTransforms AccessCheckers.cpp CheckStaticAssertions.cpp + GPUToGPURuntime.cpp SLPVectorizer.cpp ADDITIONAL_HEADER_DIRS @@ -14,7 +15,9 @@ add_mlir_dialect_library(MLIRWaterTransforms MLIRArithDialect MLIRControlFlowDialect MLIRFuncDialect + MLIRGPUDialect MLIRIR + MLIRLLVMDialect MLIRMemRefDialect MLIRPass MLIRRewrite diff --git a/water/lib/Transforms/GPUToGPURuntime.cpp b/water/lib/Transforms/GPUToGPURuntime.cpp new file mode 100644 index 000000000..9513c9a8f --- /dev/null +++ b/water/lib/Transforms/GPUToGPURuntime.cpp @@ -0,0 +1,264 @@ +// Copyright 2025 The Wave Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "water/Transforms/Passes.h" + +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Pass/Pass.h" + +namespace mlir::water { +#define GEN_PASS_DEF_WATERGPUTOGPURUNTIME +#include "water/Transforms/Passes.h.inc" +} // namespace mlir::water + +using namespace mlir; +using namespace mlir::water; + +namespace { +/// Generate a unique LLVM global name for a given source name. +static SmallString<128> getUniqueLLVMGlobalName(ModuleOp mod, + SymbolTable &table, + const llvm::Twine &srcName) { + unsigned counter = 0; + return SymbolTable::generateSymbolName<128>( + srcName.str(), + [&](StringRef candidate) { return table.lookupSymbolIn(mod, candidate); }, + counter); +} + +/// Helper to build a function call to a given function name with the given +/// return type and argument types. +struct FunctionCallBuilder { + // TODO: cannot use TypeRange as `LLVM::LLVMFunctionType::get` refuses to + // accept it. + FunctionCallBuilder(StringRef functionName, Type returnType, + ArrayRef argumentTypes) + : functionName(functionName), + functionType(LLVM::LLVMFunctionType::get(returnType, argumentTypes)) {} + LLVM::CallOp create(Location loc, OpBuilder &builder, + ValueRange arguments) const { + Operation *module = builder.getBlock() + ->getParentOp() + ->getParentWithTrait(); + assert(module && "module not found"); + SymbolTable symbolTable(module); + auto function = [&] { + if (auto function = symbolTable.lookup(functionName)) + return function; + + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPointToEnd(&module->getRegion(0).front()); + return LLVM::LLVMFuncOp::create(builder, loc, functionName, functionType); + }(); + return LLVM::CallOp::create(builder, loc, function, arguments); + } + + StringRef functionName; + LLVM::LLVMFunctionType functionType; +}; + +/// Create a unique LLVM global for a kernel handle. +static Value createKernelHandle(OpBuilder &builder, SymbolTable &symbolTable, + Type globalType, ModuleOp mod, + const llvm::Twine &name) { + Type ptrType = LLVM::LLVMPointerType::get(builder.getContext()); + Location loc = builder.getUnknownLoc(); + LLVM::GlobalOp handle; + { + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPointToStart(mod.getBody()); + SmallString<128> handleName = + getUniqueLLVMGlobalName(mod, symbolTable, name); + handle = LLVM::GlobalOp::create( + builder, loc, globalType, /*isConstant*/ false, LLVM::Linkage::Internal, + handleName, LLVM::ZeroAttr::get(builder.getContext())); + } + return LLVM::AddressOfOp::create(builder, loc, ptrType, handle.getSymName()); +} + +/// Get the object from a gpu binary op. +static gpu::ObjectAttr getSelectedObject(gpu::BinaryOp op) { + ArrayRef objects = op.getObjectsAttr().getValue(); + + // Obtain the index of the object to select. + int64_t index = -1; + if (Attribute target = + cast(op.getOffloadingHandlerAttr()) + .getTarget()) { + // If the target attribute is a number it is the index. Otherwise compare + // the attribute to every target inside the object array to find the index. + if (auto indexAttr = dyn_cast(target)) { + index = indexAttr.getInt(); + } else { + for (auto &&[i, attr] : llvm::enumerate(objects)) { + auto obj = dyn_cast(attr); + if (obj && obj.getTarget() == target) { + index = i; + break; + } + } + } + } else { + // If the target attribute is null then it's selecting the first object in + // the object array. + index = 0; + } + + if (index < 0 || index >= static_cast(objects.size())) { + op->emitError("the requested target object couldn't be found"); + return nullptr; + } + auto result = dyn_cast(objects[index]); + if (!result) + op->emitError("invalid object type"); + + return result; +} + +/// Lookup the binary holding the kernel from enclosing module. +static gpu::ObjectAttr getBinary(gpu::LaunchFuncOp op) { + auto kernelBinary = SymbolTable::lookupNearestSymbolFrom( + op, op.getKernelModuleName()); + if (!kernelBinary) { + op.emitError("couldn't find the binary holding the kernel: " + + op.getKernelModuleName().getValue()); + return nullptr; + } + + return getSelectedObject(kernelBinary); +} + +struct WaterGPUToGPURuntimePass final + : public water::impl::WaterGPUToGPURuntimeBase { + using WaterGPUToGPURuntimeBase::WaterGPUToGPURuntimeBase; + + void runOnOperation() override { + ModuleOp mod = getOperation(); + MLIRContext *context = &getContext(); + IRRewriter builder(context); + + Type i32Type = builder.getI32Type(); + Type i64Type = builder.getI64Type(); + Type ptrType = LLVM::LLVMPointerType::get(context); + Type voidType = LLVM::LLVMVoidType::get(context); + FunctionCallBuilder loadFuncBuilder("wave_load_kernel", ptrType, + { + ptrType, // stream + ptrType, // cached kernel handle + ptrType, // binary pointer + i64Type, // binary size + ptrType // function name + }); + FunctionCallBuilder launchFuncBuilder("wave_launch_kernel", voidType, + { + ptrType, // stream + ptrType, // function + i32Type, // shared memory bytes + i64Type, // gridX + i64Type, // gridY + i64Type, // gridZ + i64Type, // blockX + i64Type, // blockY + i64Type, // blockZ + ptrType, // kernel operands + i32Type // kernel operands count + }); + + SymbolTable symbolTable(mod); + + auto visitor = [&](gpu::LaunchFuncOp op) -> WalkResult { + auto func = op->getParentOfType(); + if (!func) { + op.emitError("launch func op must have a func op parent"); + return WalkResult::interrupt(); + } + ValueRange blockArgs = func.getFunctionBody().front().getArguments(); + if (blockArgs.empty()) { + op.emitError("func op must have at least one argument"); + return WalkResult::interrupt(); + } + // First argument is stream pointer + Value stream = blockArgs.front(); + if (!isa(stream.getType())) { + op.emitError("stream argument must be a pointer"); + return WalkResult::interrupt(); + } + + gpu::ObjectAttr object = getBinary(op); + if (!object) + return WalkResult::interrupt(); + + StringRef objData = object.getObject(); + + Location loc = op.getLoc(); + auto getStr = [&](StringRef varName, StringRef str) -> Value { + Twine strVal = str + StringRef("\0", 1); + return LLVM::createGlobalString( + loc, builder, getUniqueLLVMGlobalName(mod, symbolTable, varName), + strVal.str(), LLVM::Linkage::Internal); + }; + + auto createConst = [&](Type type, int64_t val) -> Value { + return LLVM::ConstantOp::create(builder, loc, type, + builder.getIntegerAttr(type, val)); + }; + + auto createAlloca = [&](Type elemType, int64_t size) -> Value { + Value sizeVal = createConst(i64Type, size); + return LLVM::AllocaOp::create(builder, loc, ptrType, elemType, sizeVal, + 0); + }; + + builder.setInsertionPoint(op); + StringRef kernelName = op.getKernelName(); + Value kernelHandle = createKernelHandle(builder, symbolTable, ptrType, + mod, kernelName + "_handle"); + Value kernelNameStr = getStr(kernelName, kernelName); + + Value dataPtr = LLVM::createGlobalString( + loc, builder, + getUniqueLLVMGlobalName(mod, symbolTable, kernelName + "_data"), + objData, LLVM::Linkage::Internal); + Value dataSize = createConst(i64Type, objData.size()); + + Value funcObject = + loadFuncBuilder + .create(loc, builder, + {stream, kernelHandle, dataPtr, dataSize, kernelNameStr}) + ->getResult(0); + + Value sharedMemoryBytes = createConst(i32Type, 0); + ValueRange args = op.getKernelOperands(); + auto argsPtrArrayType = LLVM::LLVMArrayType::get(ptrType, args.size()); + Value argsArray = LLVM::PoisonOp::create(builder, loc, argsPtrArrayType); + + for (auto &&[i, arg] : llvm::enumerate(args)) { + Value argData = createAlloca(arg.getType(), 1); + LLVM::StoreOp::create(builder, loc, arg, argData); + argsArray = + LLVM::InsertValueOp::create(builder, loc, argsArray, argData, i); + } + Value argsArrayPtr = createAlloca(argsPtrArrayType, 1); + LLVM::StoreOp::create(builder, loc, argsArray, argsArrayPtr); + Value argsCount = createConst(i32Type, args.size()); + + launchFuncBuilder.create( + loc, builder, + {stream, funcObject, sharedMemoryBytes, op.getGridSizeX(), + op.getGridSizeY(), op.getGridSizeZ(), op.getBlockSizeX(), + op.getBlockSizeY(), op.getBlockSizeZ(), argsArrayPtr, argsCount}); + builder.eraseOp(op); + return WalkResult::advance(); + }; + if (mod.walk(visitor).wasInterrupted()) + return signalPassFailure(); + + mod->walk([&](gpu::BinaryOp op) { builder.eraseOp(op); }); + } +}; +} // namespace diff --git a/water/test/Transforms/gpu-to-gpu-runtime.mlir b/water/test/Transforms/gpu-to-gpu-runtime.mlir new file mode 100644 index 000000000..3699dad6a --- /dev/null +++ b/water/test/Transforms/gpu-to-gpu-runtime.mlir @@ -0,0 +1,207 @@ +// RUN: water-opt %s --water-gpu-to-gpu-runtime --split-input-file --verify-diagnostics | FileCheck %s + +module attributes {gpu.container_module} { + // CHECK: llvm.mlir.global internal constant @[[KERNEL_DATA:my_kernel_data[_0-9]*]] + // CHECK: llvm.mlir.global internal constant @[[KERNEL_NAME:my_kernel[_0-9]*]] + // CHECK: llvm.mlir.global internal @[[KERNEL_HANDLE:my_kernel_handle[_0-9]*]] + + gpu.binary @kernel_binary [ + #gpu.object<#rocdl.target, "\00\01\02\03"> + ] + + // CHECK-LABEL: llvm.func @test_launch + // CHECK-SAME: (%[[STREAM:.*]]: !llvm.ptr, %[[ARG0:.*]]: f32, %[[ARG1:.*]]: i64) + llvm.func @test_launch(%stream: !llvm.ptr, %arg0: f32, %arg1: i64) { + %c128 = arith.constant 128 : i64 + %c256 = arith.constant 256 : i64 + %c1 = arith.constant 1 : i64 + + // CHECK-DAG: %[[HANDLE_ADDR:.*]] = llvm.mlir.addressof @[[KERNEL_HANDLE]] + // CHECK-DAG: %[[DATA_ADDR:.*]] = llvm.mlir.addressof @[[KERNEL_DATA]] + // CHECK-DAG: %[[NAME_ADDR:.*]] = llvm.mlir.addressof @[[KERNEL_NAME]] + + // CHECK-DAG: %[[DATA_ADDR_GEP:.*]] = llvm.getelementptr %[[DATA_ADDR]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<4 x i8> + // CHECK-DAG: %[[NAME_ADDR_GEP:.*]] = llvm.getelementptr %[[NAME_ADDR]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<10 x i8> + + // CHECK: %[[DATA_SIZE:.*]] = llvm.mlir.constant(4 : i64) : i64 + + // CHECK: %[[FUNC:.*]] = llvm.call @wave_load_kernel(%[[STREAM]], %[[HANDLE_ADDR]], %[[DATA_ADDR_GEP]], %[[DATA_SIZE]], %[[NAME_ADDR_GEP]]) + // CHECK-SAME: : (!llvm.ptr, !llvm.ptr, !llvm.ptr, i64, !llvm.ptr) -> !llvm.ptr + + // CHECK-DAG: %[[SHARED_MEM:.*]] = llvm.mlir.constant(0 : i32) : i32 + + // CHECK-DAG: %[[ARG0_ALLOCA:.*]] = llvm.alloca %{{.*}} x f32 + // CHECK-DAG: llvm.store %[[ARG0]], %[[ARG0_ALLOCA]] + + // CHECK-DAG: %[[ARG1_ALLOCA:.*]] = llvm.alloca %{{.*}} x i64 + // CHECK-DAG: llvm.store %[[ARG1]], %[[ARG1_ALLOCA]] + + // CHECK-DAG: %[[ARGS_ARRAY:.*]] = llvm.mlir.poison : !llvm.array<2 x ptr> + // CHECK-DAG: %[[ARGS_ARRAY_1:.*]] = llvm.insertvalue %[[ARG0_ALLOCA]], %[[ARGS_ARRAY]][0] + // CHECK-DAG: %[[ARGS_ARRAY_2:.*]] = llvm.insertvalue %[[ARG1_ALLOCA]], %[[ARGS_ARRAY_1]][1] + + // CHECK: %[[ARGS_PTR:.*]] = llvm.alloca %{{.*}} x !llvm.array<2 x ptr> + // CHECK: llvm.store %[[ARGS_ARRAY_2]], %[[ARGS_PTR]] + + // CHECK: %[[ARGS_COUNT:.*]] = llvm.mlir.constant(2 : i32) : i32 + + // CHECK: llvm.call @wave_launch_kernel(%[[STREAM]], %[[FUNC]], %[[SHARED_MEM]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[ARGS_PTR]], %[[ARGS_COUNT]]) + // CHECK-SAME: : (!llvm.ptr, !llvm.ptr, i32, i64, i64, i64, i64, i64, i64, !llvm.ptr, i32) -> () + + // CHECK-NOT: gpu.launch_func + gpu.launch_func @kernel_binary::@my_kernel + blocks in (%c128, %c1, %c1) + threads in (%c256, %c1, %c1) : i64 + args(%arg0: f32, %arg1: i64) + + llvm.return + } + + // CHECK-NOT: gpu.binary +} + +// ----- + +module attributes {gpu.container_module} { + // CHECK: llvm.mlir.global internal constant @[[KERNEL3_DATA:kernel_a_data[_0-9]*]] + // CHECK: llvm.mlir.global internal constant @[[KERNEL3_NAME:kernel_a[_0-9]*]] + // CHECK: llvm.mlir.global internal @[[KERNEL3_HANDLE:kernel_a_handle[_0-9]*]] + + // CHECK: llvm.mlir.global internal constant @[[KERNEL2_DATA:kernel_b_data[_0-9]*]] + // CHECK: llvm.mlir.global internal constant @[[KERNEL2_NAME:kernel_b[_0-9]*]] + // CHECK: llvm.mlir.global internal @[[KERNEL2_HANDLE:kernel_b_handle[_0-9]*]] + + // CHECK: llvm.mlir.global internal constant @[[KERNEL1_DATA:kernel_a_data[_0-9]*]] + // CHECK: llvm.mlir.global internal constant @[[KERNEL1_NAME:kernel_a[_0-9]*]] + // CHECK: llvm.mlir.global internal @[[KERNEL1_HANDLE:kernel_a_handle[_0-9]*]] + + gpu.binary @kernel_binary_a [ + #gpu.object<#rocdl.target, "\00\01\02\03\04\05"> + ] + + gpu.binary @kernel_binary_b [ + #gpu.object<#rocdl.target, "\10\11\12\13"> + ] + + // CHECK-LABEL: llvm.func @test_multiple_launches + // CHECK-SAME: (%[[STREAM:.*]]: !llvm.ptr, %[[ARG0:.*]]: f32, %[[ARG1:.*]]: i64, %[[ARG2:.*]]: i32) + llvm.func @test_multiple_launches(%stream: !llvm.ptr, %arg0: f32, %arg1: i64, %arg2: i32) { + %c64 = arith.constant 64 : i64 + %c128 = arith.constant 128 : i64 + %c256 = arith.constant 256 : i64 + %c512 = arith.constant 512 : i64 + %c1 = arith.constant 1 : i64 + %c2 = arith.constant 2 : i64 + + // First launch + // CHECK-DAG: %[[HANDLE1_ADDR:.*]] = llvm.mlir.addressof @[[KERNEL1_HANDLE]] + // CHECK-DAG: %[[NAME1_ADDR:.*]] = llvm.mlir.addressof @[[KERNEL1_NAME]] + // CHECK-DAG: %[[DATA1_ADDR:.*]] = llvm.mlir.addressof @[[KERNEL1_DATA]] + + // CHECK-DAG: %[[DATA1_ADDR_GEP:.*]] = llvm.getelementptr %[[DATA1_ADDR]][0, 0] + // CHECK-DAG: %[[NAME1_ADDR_GEP:.*]] = llvm.getelementptr %[[NAME1_ADDR]][0, 0] + + // CHECK: %[[DATA1_SIZE:.*]] = llvm.mlir.constant(6 : i64) : i64 + + // CHECK: %[[FUNC1:.*]] = llvm.call @wave_load_kernel(%[[STREAM]], %[[HANDLE1_ADDR]], %[[DATA1_ADDR_GEP]], %[[DATA1_SIZE]], %[[NAME1_ADDR_GEP]]) + // CHECK-SAME: : (!llvm.ptr, !llvm.ptr, !llvm.ptr, i64, !llvm.ptr) -> !llvm.ptr + + // CHECK: %[[SHARED_MEM1:.*]] = llvm.mlir.constant(0 : i32) : i32 + + // CHECK: llvm.call @wave_launch_kernel(%[[STREAM]], %[[FUNC1]], %[[SHARED_MEM1]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) + + // CHECK-NOT: gpu.launch_func @kernel_binary_a + gpu.launch_func @kernel_binary_a::@kernel_a + blocks in (%c128, %c1, %c1) + threads in (%c256, %c1, %c1) : i64 + args(%arg0: f32, %arg1: i64) + + // Second launch + // CHECK-DAG: %[[HANDLE2_ADDR:.*]] = llvm.mlir.addressof @[[KERNEL2_HANDLE]] + // CHECK-DAG: %[[DATA2_ADDR:.*]] = llvm.mlir.addressof @[[KERNEL2_DATA]] + // CHECK-DAG: %[[NAME2_ADDR:.*]] = llvm.mlir.addressof @[[KERNEL2_NAME]] + + // CHECK-DAG: %[[DATA2_ADDR_GEP:.*]] = llvm.getelementptr %[[DATA2_ADDR]][0, 0] + // CHECK-DAG: %[[NAME2_ADDR_GEP:.*]] = llvm.getelementptr %[[NAME2_ADDR]][0, 0] + + // CHECK: %[[DATA2_SIZE:.*]] = llvm.mlir.constant(4 : i64) : i64 + + // CHECK: %[[FUNC2:.*]] = llvm.call @wave_load_kernel(%[[STREAM]], %[[HANDLE2_ADDR]], %[[DATA2_ADDR_GEP]], %[[DATA2_SIZE]], %[[NAME2_ADDR_GEP]]) + // CHECK-SAME: : (!llvm.ptr, !llvm.ptr, !llvm.ptr, i64, !llvm.ptr) -> !llvm.ptr + + // CHECK: %[[SHARED_MEM2:.*]] = llvm.mlir.constant(0 : i32) : i32 + + // CHECK: llvm.call @wave_launch_kernel(%[[STREAM]], %[[FUNC2]], %[[SHARED_MEM2]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) + + // CHECK-NOT: gpu.launch_func @kernel_binary_b + gpu.launch_func @kernel_binary_b::@kernel_b + blocks in (%c64, %c2, %c1) + threads in (%c512, %c1, %c1) : i64 + args(%arg2: i32) + + // Third launch + // CHECK-DAG: %[[HANDLE3_ADDR:.*]] = llvm.mlir.addressof @[[KERNEL3_HANDLE]] + // CHECK-DAG: %[[DATA3_ADDR:.*]] = llvm.mlir.addressof @[[KERNEL3_DATA]] + // CHECK-DAG: %[[NAME3_ADDR:.*]] = llvm.mlir.addressof @[[KERNEL3_NAME]] + + // CHECK-DAG: %[[DATA3_ADDR_GEP:.*]] = llvm.getelementptr %[[DATA3_ADDR]][0, 0] + // CHECK-DAG: %[[NAME3_ADDR_GEP:.*]] = llvm.getelementptr %[[NAME3_ADDR]][0, 0] + + // CHECK: %[[DATA3_SIZE:.*]] = llvm.mlir.constant(6 : i64) : i64 + + // CHECK: %[[FUNC3:.*]] = llvm.call @wave_load_kernel(%[[STREAM]], %[[HANDLE3_ADDR]], %[[DATA3_ADDR_GEP]], %[[DATA3_SIZE]], %[[NAME3_ADDR_GEP]]) + // CHECK-SAME: : (!llvm.ptr, !llvm.ptr, !llvm.ptr, i64, !llvm.ptr) -> !llvm.ptr + + // CHECK: %[[SHARED_MEM3:.*]] = llvm.mlir.constant(0 : i32) : i32 + + // CHECK: llvm.call @wave_launch_kernel(%[[STREAM]], %[[FUNC3]], %[[SHARED_MEM3]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) + + // CHECK-NOT: gpu.launch_func @kernel_binary_a + gpu.launch_func @kernel_binary_a::@kernel_a + blocks in (%c256, %c1, %c1) + threads in (%c128, %c1, %c1) : i64 + args(%arg1: i64, %arg2: i32) + + llvm.return + } + + // CHECK-NOT: gpu.binary +} + +// ----- + +module attributes {gpu.container_module} { + gpu.binary @kernel_binary [ + #gpu.object<#rocdl.target, "\00\01\02\03"> + ] + + llvm.func @test_no_arguments() { + %c1 = arith.constant 1 : i64 + + // expected-error @+1 {{func op must have at least one argument}} + gpu.launch_func @kernel_binary::@my_kernel + blocks in (%c1, %c1, %c1) + threads in (%c1, %c1, %c1) : i64 + + llvm.return + } +} + +// ----- + +module attributes {gpu.container_module} { + gpu.binary @kernel_binary [ + #gpu.object<#rocdl.target, "\00\01\02\03"> + ] + + llvm.func @test_non_pointer_stream(%stream: i64) { + %c1 = arith.constant 1 : i64 + + // expected-error @+1 {{stream argument must be a pointer}} + gpu.launch_func @kernel_binary::@my_kernel + blocks in (%c1, %c1, %c1) + threads in (%c1, %c1, %c1) : i64 + + llvm.return + } +} diff --git a/water/tools/water-opt/CMakeLists.txt b/water/tools/water-opt/CMakeLists.txt index 9f4e1a79f..1af3103c0 100644 --- a/water/tools/water-opt/CMakeLists.txt +++ b/water/tools/water-opt/CMakeLists.txt @@ -6,6 +6,7 @@ set(LIBS MLIRArithDialect MLIRGPUDialect MLIROptLib + MLIRROCDLDialect MLIRWaterTransforms MLIRWaveTransforms diff --git a/water/tools/water-opt/water-opt.cpp b/water/tools/water-opt/water-opt.cpp index 50f45c7f0..edb7ae681 100644 --- a/water/tools/water-opt/water-opt.cpp +++ b/water/tools/water-opt/water-opt.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" @@ -49,9 +50,9 @@ int main(int argc, char **argv) { registry.insert(); + mlir::LLVM::LLVMDialect, mlir::ROCDL::ROCDLDialect, + mlir::memref::MemRefDialect, mlir::scf::SCFDialect, + mlir::vector::VectorDialect, wave::WaveDialect>(); mlir::water::test::registerWaterTestDialect(registry);