diff --git a/compiler/plugins/target/ROCM/ROCMTarget.cpp b/compiler/plugins/target/ROCM/ROCMTarget.cpp index 8558f0b4f394..49e8256b10d9 100644 --- a/compiler/plugins/target/ROCM/ROCMTarget.cpp +++ b/compiler/plugins/target/ROCM/ROCMTarget.cpp @@ -175,9 +175,10 @@ static StringRef getABI(IREE::HAL::ExecutableTargetAttr targetAttr) { static void dumpModuleToPath(StringRef path, StringRef baseName, StringRef suffix, StringRef extension, - llvm::Module &module) { + llvm::Module &module, StringRef header = {}) { llvm::SmallVector data; llvm::raw_svector_ostream ostream(data); + ostream << header; module.print(ostream, nullptr); dumpDataToPath(path, baseName, suffix, extension, StringRef(data.data(), data.size())); @@ -295,7 +296,8 @@ class ROCMTargetBackend final : public TargetBackend { static void optimizeModule(llvm::Module &module, llvm::TargetMachine &targetMachine, ArrayRef passPlugins, - bool slpVectorization) { + bool slpVectorization, + std::string &outPassesString) { llvm::LoopAnalysisManager lam; llvm::FunctionAnalysisManager fam; llvm::CGSCCAnalysisManager cgam; @@ -336,7 +338,11 @@ class ROCMTargetBackend final : public TargetBackend { mpm.addPass(llvm::VerifierPass()); mpm.addPass(pb.buildPerModuleDefaultPipeline(ol)); mpm.addPass(llvm::VerifierPass()); - + llvm::raw_string_ostream os(outPassesString); + mpm.printPipeline(os, [&pic](StringRef className) { + auto passName = pic.getPassNameForClassName(className); + return passName.empty() ? className : passName; + }); mpm.run(module, mam); } @@ -566,12 +572,19 @@ class ROCMTargetBackend final : public TargetBackend { } // Run LLVM optimization passes. + std::string passesString; optimizeModule(*llvmModule, *targetMachine, options.passPlugins, - options.slpVectorization); + options.slpVectorization, passesString); if (!serializationOptions.dumpIntermediatesPath.empty()) { + std::string header = llvm::formatv(R"TXT( +; To reproduce the .optimized.ll from the .linked.ll, run: +; opt --passes='{}' + +)TXT", + passesString); dumpModuleToPath(serializationOptions.dumpIntermediatesPath, serializationOptions.dumpBaseName, variantOp.getName(), - ".optimized.ll", *llvmModule); + ".optimized.ll", *llvmModule, header); } if (failed(validateFinalizedModule(variantOp, *llvmModule))) {