diff --git a/src/enzyme_ad/jax/BUILD b/src/enzyme_ad/jax/BUILD index fd02c7a6b..d15407139 100644 --- a/src/enzyme_ad/jax/BUILD +++ b/src/enzyme_ad/jax/BUILD @@ -839,13 +839,15 @@ cc_library( cc_library( name = "XLADerivatives", - srcs = glob([ - "Implementations/*.cpp", - "Passes/*.cpp", - "Dialect/*.cpp", - "Dialect/Distributed/*.cpp", - "Dialect/Tessera/*.cpp", - ]) + [ + srcs = glob( + [ + "Implementations/*.cpp", + "Passes/*.cpp", + "Dialect/*.cpp", + "Dialect/Distributed/*.cpp", + "Dialect/Tessera/*.cpp", + ], + ) + [ "Utils.cpp", ], hdrs = glob([ @@ -905,6 +907,9 @@ cc_library( "@llvm-project//llvm:Passes", "@llvm-project//llvm:Scalar", "@llvm-project//llvm:Support", + "@llvm-project//mlir:AMDGPUDialect", + "@llvm-project//mlir:AMDGPUToROCDL", + "@llvm-project//mlir:AMDGPUUtils", "@llvm-project//mlir:AffineAnalysis", "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:AffineToStandard", @@ -935,6 +940,7 @@ cc_library( "@llvm-project//mlir:GPUPipelines", "@llvm-project//mlir:GPUToGPURuntimeTransforms", "@llvm-project//mlir:GPUToNVVMTransforms", + "@llvm-project//mlir:GPUToROCDLTransforms", "@llvm-project//mlir:GPUTransforms", "@llvm-project//mlir:IR", "@llvm-project//mlir:IndexToLLVM", @@ -946,6 +952,7 @@ cc_library( "@llvm-project//mlir:MathDialect", "@llvm-project//mlir:MathToLLVM", "@llvm-project//mlir:MathToLibm", + "@llvm-project//mlir:MathToROCDL", "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:MemRefToLLVM", "@llvm-project//mlir:MemRefTransforms", diff --git a/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp b/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp index 79aab32f4..36aa03fed 100644 --- a/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp +++ b/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp @@ -20,15 +20,19 @@ #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" +#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" #include "mlir/Conversion/MathToLibm/MathToLibm.h" +#include "mlir/Conversion/MathToROCDL/MathToROCDL.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Conversion/UBToLLVM/UBToLLVM.h" +#include "mlir/Dialect/AMDGPU/Utils/Chipset.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -932,20 +936,33 @@ struct CMemcpyOpLowering : public CLoadStoreOpLowering { auto ptrty = LLVM::LLVMPointerType::get(op.getContext()); + auto i32 = rewriter.getIntegerType(32); + + std::string memcpyFuncName; + + bool xla = backend.starts_with("xla"); + + if (xla) { + memcpyFuncName = "reactantXLAMemcpy"; + } else if (backend == "cuda") { + memcpyFuncName = "cudaMemcpy"; + } else if (backend == "rocm") { + memcpyFuncName = "hipMemcpy"; + } + SmallVector tys = {ptrty, ptrty, size.getType(), rewriter.getIntegerType(32)}; if (backend.starts_with("xla")) { tys.insert(tys.begin(), ptrty); } - auto i32 = rewriter.getIntegerType(32); - bool xla = backend.starts_with("xla"); - auto cudaMemcpyFn = LLVM::lookupOrCreateFn( - rewriter, moduleOp, xla ? "reactantXLAMemcpy" : "cudaMemcpy", tys, + auto memcpyFn = LLVM::lookupOrCreateFn( + rewriter, moduleOp, memcpyFuncName, tys, xla ? (mlir::Type)LLVM::LLVMVoidType::get(rewriter.getContext()) : (mlir::Type)i32); - if (failed(cudaMemcpyFn)) + if (failed(memcpyFn)) { return failure(); + } SmallVector args = {dst, src, size, LLVM::ConstantOp::create(rewriter, op.getLoc(), @@ -961,7 +978,7 @@ struct CMemcpyOpLowering : public CLoadStoreOpLowering { args.insert(args.begin(), xdata); } - LLVM::CallOp::create(rewriter, op.getLoc(), cudaMemcpyFn.value(), args); + LLVM::CallOp::create(rewriter, op.getLoc(), memcpyFn.value(), args); rewriter.eraseOp(op); return success(); } @@ -1553,7 +1570,7 @@ struct LowerGPUAlternativesOp auto kernelId = LLVM::createGlobalString( loc, rewriter, std::string("kernelId.") + std::to_string(num++), nullTermLocStr, LLVM::Linkage::Internal, /*opaquePointers*/ true); - auto totalAlternatives = LLVM::ConstantOp::create(rewriter, + auto totalAlternatives = LLVM::ConstantOp::create(rewriter, loc, llvmInt32Type, gao->getNumRegions()); auto alternative = rtPGOGetAlternativeCallBuilder @@ -1562,7 +1579,7 @@ struct LowerGPUAlternativesOp int i = 0; for (auto ®ion : gao->getRegions()) { - auto cmpOp = arith::CmpIOp::create(rewriter, + auto cmpOp = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, alternative, arith::ConstantIntOp::create(rewriter, loc, i, 32)); auto ifOp = scf::IfOp::create(rewriter, loc, cmpOp, /* hasElse */ true); @@ -1752,6 +1769,24 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, auto loc = kernelModule.getLoc(); auto ctorloc = rewriter.getUnknownLoc(); + + std::string registerFatBinaryFuncName; + std::string registerFunctionFuncName; + std::string registerVarFuncName; + std::string unregisterFatBinaryFuncName; + + if (gpuTarget == "cuda") { + registerFatBinaryFuncName = "__cudaRegisterFatBinary"; + registerFunctionFuncName = "__cudaRegisterFunction"; + registerVarFuncName = "__cudaRegisterVar"; + unregisterFatBinaryFuncName = "__cudaUnregisterFatBinary"; + } else { + registerFatBinaryFuncName = "__hipRegisterFatBinary"; + registerFunctionFuncName = "__hipRegisterFunction"; + registerVarFuncName = "__hipRegisterVar"; + unregisterFatBinaryFuncName = "__hipUnregisterFatBinary"; + } + rewriter.modifyOpInPlace(kernelModule, [&]() { kernelModule->setAttr("polygeist_stubs", rewriter.getUnitAttr()); }); @@ -1818,6 +1853,7 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, moduleIDPrefix = "__hip_"; fatMagic = HIPFatMagic; } + (void)fatbinConstantName; (void)moduleIDSectionName; @@ -1882,10 +1918,11 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, ctorBuilder, ctorloc, llvmPointerType, addressOfWrapper); auto cudaRegisterFatbinFn = - LLVM::lookupOrCreateFn(rewriter, moduleOp, "__cudaRegisterFatBinary", + LLVM::lookupOrCreateFn(rewriter, moduleOp, registerFatBinaryFuncName, llvmPointerType, llvmPointerType); + if (failed(cudaRegisterFatbinFn)) { - llvm::errs() << " cudamalloc already exists with different types\n"; + llvm::errs() << "cudamalloc already exists with different types\n"; return failure(); } @@ -1946,12 +1983,15 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, llvmPointerType, llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType}; + auto cudaRegisterFn = LLVM::lookupOrCreateFn( - rewriter, moduleOp, "__cudaRegisterFunction", tys, llvmInt32Type); + rewriter, moduleOp, registerFunctionFuncName, tys, llvmInt32Type); + if (failed(cudaRegisterFn)) { llvm::errs() << " cudamalloc already exists with different types\n"; return failure(); } + Value args[] = { module.getResult(), bitcast, @@ -2012,7 +2052,8 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, 0)}); } } - // TODO this has to happen only for some CUDA versions + // TODO this has to happen only for some CUDA versions, hip does not need + // finialize cuda 11.X if (gpuTarget == "cuda") { auto cudaRegisterFatbinFn = LLVM::lookupOrCreateFn( rewriter, moduleOp, "__cudaRegisterFatBinaryEnd", llvmPointerType, @@ -2045,15 +2086,15 @@ ConvertGPUModuleOp::matchAndRewrite(gpu::GPUModuleOp kernelModule, dtorBuilder, ctorloc, llvmPointerPointerType, aoo->getResult(0)); auto cudaUnRegisterFatbinFn = LLVM::lookupOrCreateFn( - rewriter, moduleOp, "__cudaUnregisterFatBinary", llvmPointerType, + rewriter, moduleOp, unregisterFatBinaryFuncName, llvmPointerType, llvmVoidType); if (failed(cudaUnRegisterFatbinFn)) { llvm::errs() << " cudamalloc already exists with different types\n"; return failure(); } - LLVM::CallOp::create(rewriter, ctorloc, cudaUnRegisterFatbinFn.value(), ValueRange(module)); + LLVM::ReturnOp::create(dtorBuilder, ctorloc, ValueRange()); auto dtorSymbol = FlatSymbolRefAttr::get(dtor); { @@ -2173,10 +2214,13 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite( LLVM::ZExtOp::create(rewriter, loc, i64, dynamicSharedMemorySize)); args.push_back(stream); - auto launchCall = LLVM::CallOp::create( - rewriter, loc, TypeRange(i32), "cudaLaunchKernel", - args); // FlatSymbolRefAttr::get(rewriter.getStringAttr("cudaLaunchKernel")), - // args); + // Create LLVM call to launch kernel + std::string launchFuncName = + (gpuTarget == "rocm") ? "hipLaunchKernel" : "cudaLaunchKernel"; + + auto launchCall = + LLVM::CallOp::create(rewriter, loc, TypeRange(i32), launchFuncName, args); + if (launchOp.getAsyncToken()) { // Async launch: make dependent ops use the same stream. rewriter.replaceOp(launchOp, {stream}); @@ -2449,7 +2493,6 @@ class ConvertAllocOpToGpuRuntimeCallPattern if (backend == "cuda") { auto one = LLVM::ConstantOp::create(rewriter, loc, i64, rewriter.getI64IntegerAttr(1)); - auto ptr = LLVM::AllocaOp::create(rewriter, loc, ptrty, ptr1ty, one); Type tys[] = {ptrty, i64}; auto cudaMallocFn = @@ -2463,8 +2506,29 @@ class ConvertAllocOpToGpuRuntimeCallPattern ptr, sizeBytes, }; + LLVM::CallOp::create(rewriter, loc, cudaMallocFn.value(), args); allocatedPtr = LLVM::LoadOp::create(rewriter, loc, ptr1ty, ptr); + } else if (backend == "rocm") { + auto one = LLVM::ConstantOp::create(rewriter, loc, i64, + rewriter.getI64IntegerAttr(1)); + auto ptr = LLVM::AllocaOp::create(rewriter, loc, ptrty, ptr1ty, one); + Type tys[] = {ptrty, i64}; + auto hipMallocFn = + LLVM::lookupOrCreateFn(rewriter, moduleOp, "hipMalloc", tys, i32); + + if (failed(hipMallocFn)) { + llvm::errs() << " hipMalloc already exists with different types\n"; + return failure(); + } + + Value args[] = { + ptr, + sizeBytes, + }; + + LLVM::CallOp::create(rewriter, loc, hipMallocFn.value(), args); + allocatedPtr = LLVM::LoadOp::create(rewriter, loc, ptr1ty, ptr); } else if (backend.starts_with("cpu")) { Type convertedIndex = typeConverter->convertType(rewriter.getIndexType()); @@ -2601,9 +2665,9 @@ class ConvertOccupancyOp if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter))) return failure(); - if (backend != "cuda") + if (backend != "cuda" && backend != "rocm") return rewriter.notifyMatchFailure( - op, "Occupancy op lowering only supported for CUDA"); + op, "Occupancy op lowering only supported for CUDA and ROCM"); auto moduleOp = op->getParentOfType(); auto i64 = rewriter.getIntegerType(64); @@ -2616,12 +2680,20 @@ class ConvertOccupancyOp Type tys[] = {ptrty, ptrty, intty, adaptor.getDynamicSMemSize().getType(), adaptor.getFlags().getType()}; - auto cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlagsFn = - LLVM::lookupOrCreateFn( - rewriter, moduleOp, - "cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags", tys, i32); - if (failed(cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlagsFn)) { - llvm::errs() << " cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags " + std::string occupancyFuncName; + if (backend == "cuda") { + occupancyFuncName = + "cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags"; + } else if (backend == "rocm") { + occupancyFuncName = + "hipOccupancyMaxActiveBlocksPerMultiprocessorWithFlags"; + } + + auto occupancyFn = + LLVM::lookupOrCreateFn(rewriter, moduleOp, occupancyFuncName, tys, i32); + + if (failed(occupancyFn)) { + llvm::errs() << " occupancyMaxActiveBlocksPerMultiprocessorWithFlags " "already exists with different types\n"; return failure(); } @@ -2637,9 +2709,7 @@ class ConvertOccupancyOp auto addr = LLVM::AddressOfOp::create(rewriter, loc, ptrty, funcStubName); Value args[] = {ptr, addr, adaptor.getBlockSize(), adaptor.getDynamicSMemSize(), adaptor.getFlags()}; - LLVM::CallOp::create( - rewriter, loc, - cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlagsFn.value(), args); + LLVM::CallOp::create(rewriter, loc, occupancyFn.value(), args); rewriter.replaceOpWithNewOp(op, intty, ptr); return success(); @@ -2662,9 +2732,9 @@ class ConvertGPUKernelAddressOp matchAndRewrite(enzymexla::GPUKernelAddressOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (backend != "cuda") + if (backend != "cuda" && backend != "rocm") return rewriter.notifyMatchFailure( - op, "KernelAddress lowering only supported for CUDA"); + op, "KernelAddress lowering only supported for CUDA and ROCM"); std::string funcStubName = getFuncStubName(op.getFn().getRootReference().getValue(), @@ -2678,7 +2748,7 @@ class ConvertGPUKernelAddressOp }; /// A rewrite pattern to convert gpu.alloc operations into a GPU runtime -/// call. Currently it supports CUDA, CPU, and XLA. +/// call. Currently it supports CUDA, ROCM, CPU, and XLA. template class ConvertDeallocOpToGpuRuntimeCallPattern : public ConvertOpToGpuRuntimeCallPattern { @@ -2731,6 +2801,20 @@ class ConvertDeallocOpToGpuRuntimeCallPattern ptr, }; LLVM::CallOp::create(rewriter, loc, cudaFreeFn.value(), args); + } else if (backend == "rocm") { + Type tys[] = {ptr1ty}; + auto hipFreeFn = + LLVM::lookupOrCreateFn(rewriter, moduleOp, "hipFree", tys, i32); + + if (failed(hipFreeFn)) { + llvm::errs() << " hipfree already exists with different types\n"; + return failure(); + } + Value args[] = { + ptr, + }; + LLVM::CallOp::create(rewriter, loc, hipFreeFn.value(), args); + } else if (backend.starts_with("cpu")) { FailureOr freeFunc = @@ -2923,8 +3007,9 @@ struct GPUFuncOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; GPUFuncOpLowering(LLVMTypeConverter &converter, unsigned allocaAddrSpace, - StringAttr kernelAttributeName) - : ConvertOpToLLVMPattern(converter), + StringAttr kernelAttributeName, + PatternBenefit benefit = PatternBenefit(1)) + : ConvertOpToLLVMPattern(converter, benefit), allocaAddrSpace(allocaAddrSpace), kernelAttributeName(kernelAttributeName) {} @@ -3631,6 +3716,11 @@ populateCStyleGPUFuncLoweringPatterns(RewritePatternSet &patterns, populateLibDeviceConversionPatterns(typeConverter, patterns, benefit); patterns.add(typeConverter, benefit); + } else if (gpuTarget == "rocm") { + typeConverter.getContext().loadDialect(); + mlir::populateGpuToROCDLConversionPatterns(typeConverter, patterns, + mlir::gpu::amd::Runtime::HIP, + amdgpu::Chipset()); } } } @@ -3663,7 +3753,6 @@ static LLVM::LLVMFuncOp addMocCUDAFunction(ModuleOp module, Type streamTy) { moduleBuilder, fname, LLVM::LLVMFunctionType::get(voidTy, {ptrTy, ptrTy, streamTy})); resumeOp.setPrivate(); - return resumeOp; } @@ -3992,13 +4081,20 @@ struct ConvertPolygeistToLLVMPass bool hasLaunch = m->walk([](gpu::LaunchFuncOp) { return WalkResult::interrupt(); }).wasInterrupted(); + + std::string launchFuncName; + if (backend == "rocm") { + launchFuncName = "hipLaunchKernel"; + } else { + launchFuncName = "cudaLaunchKernel"; + } if (hasLaunch) { OpBuilder rewriter(m); auto i32 = rewriter.getIntegerType(32); auto i64 = rewriter.getIntegerType(64); auto ptrty = LLVM::LLVMPointerType::get(rewriter.getContext()); Type tys[] = {ptrty, i64, i32, i64, i32, ptrty, i64, ptrty}; - (void)LLVM::lookupOrCreateFn(rewriter, m, "cudaLaunchKernel", tys, i32); + (void)LLVM::lookupOrCreateFn(rewriter, m, launchFuncName, tys, i32); } for (auto mod : gmods) { @@ -4014,6 +4110,11 @@ struct ConvertPolygeistToLLVMPass target.addLegalOp(); target.addLegalDialect(); target.addLegalOp(); + } else if (backend == "rocm") { + target.addIllegalDialect(); + target.addLegalOp(); + target.addLegalDialect(); + target.addLegalOp(); } } @@ -4027,7 +4128,7 @@ struct ConvertPolygeistToLLVMPass mod->emitError() << "failed to apply folding"; return signalPassFailure(); } - }; + } LLVMConversionTarget target(getContext()); RewritePatternSet patterns(&getContext()); @@ -4201,6 +4302,14 @@ struct ConvertPolygeistToLLVMPass } } }); + m->walk([](LLVM::LLVMFuncOp call) { + if (call.getName() == "cudaDeviceSynchronize") { + call->erase(); + } else if (call.getName() == "hipDeviceSynchronize") { + call->erase(); + return; + } + }); } if (StringRef(gpuTarget).starts_with("xla") || gpuTarget == "cpu") { const char *toErase[] = {"cudaGetLastError"}; @@ -4252,7 +4361,6 @@ struct ConvertPolygeistToLLVMPass signalPassFailure(); return; } - { const char *GetDeviceFromHostFuncName = "__reactant$get_device_from_host"; SmallVector toHandle; diff --git a/test/lit_tests/lowering/rocm.mlir b/test/lit_tests/lowering/rocm.mlir new file mode 100644 index 000000000..48bea8ccb --- /dev/null +++ b/test/lit_tests/lowering/rocm.mlir @@ -0,0 +1,45 @@ +// RUN: enzymexlamlir-opt %s --pass-pipeline="builtin.module(convert-polygeist-to-llvm{backend=rocm})" | FileCheck %s + +module attributes {gpu.container_module} { + llvm.func @test_rocm_launch(%arg0: !llvm.ptr) { + %c32 = arith.constant 32 : index + %c1 = arith.constant 1 : index + %c1_i64 = arith.constant 1 : i64 + %stream = llvm.inttoptr %c1_i64 : i64 to !llvm.ptr + %token = "enzymexla.stream2token"(%stream) : (!llvm.ptr) -> !gpu.async.token + gpu.launch_func [%token] @test_module::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c32, %c1, %c1) args(%arg0 : !llvm.ptr) + llvm.return + } + + func.func @test_rocm_alloc() { + %alloc = gpu.alloc() : memref<256xf32, 1> + gpu.dealloc %alloc : memref<256xf32, 1> + return + } + + func.func @test_rocm_memcpy(%src: memref<256xf32>, %dst: memref<256xf32, 1>) { + %c1024 = arith.constant 1024 : index + "enzymexla.memcpy"(%dst, %src, %c1024) : (memref<256xf32, 1>, memref<256xf32>, index) -> () + return + } + + gpu.module @test_module { + gpu.func @test_kernel(%arg0: !llvm.ptr) kernel { + gpu.return + } + } +} + +// CHECK-LABEL: llvm.func @test_rocm_launch +// CHECK: llvm.call @hipLaunchKernel +// CHECK-NOT: cudaLaunchKernel + +// CHECK-LABEL: llvm.func @test_rocm_alloc +// CHECK: llvm.call @hipMalloc +// CHECK: llvm.call @hipFree +// CHECK-NOT: cudaMalloc +// CHECK-NOT: cudaFree + +// CHECK-LABEL: llvm.func @test_rocm_memcpy +// CHECK: llvm.call @hipMemcpy +// CHECK-NOT: cudaMemcpy