|
16 | 16 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
17 | 17 | #include "mlir/Dialect/LLVMIR/NVVMDialect.h" |
18 | 18 | #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" |
19 | | -#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" |
20 | | -#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" |
21 | | -#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" |
22 | 19 | #include "nvidia/include/Dialect/NVGPU/IR/Dialect.h" |
23 | 20 | #include "nvidia/include/Dialect/NVWS/IR/Dialect.h" |
24 | 21 | #include "proton/Dialect/include/Dialect/Proton/IR/Dialect.h" |
|
36 | 33 |
|
37 | 34 | #include "src/enzyme_ad/jax/Utils.h" |
38 | 35 |
|
39 | | -#include "absl/status/status.h" |
40 | | -#include "absl/status/statusor.h" |
41 | | -#include "absl/strings/str_replace.h" |
42 | | - |
43 | | -#include "mlir/ExecutionEngine/OptUtils.h" |
44 | | -#include "llvm/ADT/STLExtras.h" |
45 | | -#include "llvm/IR/LLVMContext.h" |
46 | | -#include "llvm/IR/LegacyPassManager.h" |
47 | | -#include "llvm/IR/Module.h" |
48 | | -#include "llvm/IRReader/IRReader.h" |
49 | | -#include "llvm/Linker/Linker.h" |
50 | | -#include "llvm/MC/TargetRegistry.h" |
51 | | -#include "llvm/Support/CodeGen.h" |
52 | | -#include "llvm/Support/LogicalResult.h" |
53 | | -#include "llvm/Support/SourceMgr.h" |
54 | | -#include "llvm/Support/TargetSelect.h" |
55 | | -#include "llvm/Support/raw_ostream.h" |
56 | | -#include "llvm/Target/TargetMachine.h" |
57 | | -#include "llvm/Target/TargetOptions.h" |
58 | | -#include "llvm/TargetParser/Triple.h" |
59 | | - |
60 | 36 | #include "llvm/ADT/DenseMap.h" |
61 | 37 | #include "llvm/ADT/SmallVector.h" |
62 | 38 |
|
@@ -100,134 +76,6 @@ void collectTritonKernels( |
100 | 76 | return; |
101 | 77 | } |
102 | 78 |
|
103 | | -namespace cuda { |
104 | | - |
105 | | -namespace fs = std::filesystem; |
106 | | - |
107 | | -absl::StatusOr<std::unique_ptr<llvm::TargetMachine>> |
108 | | -CreateTargetMachine(llvm::Module *module, absl::string_view arch_name, |
109 | | - bool enable_fp_fusion, absl::string_view features) { |
110 | | - // Based on createTargetMachine() in triton/python/src/llvm.cc |
111 | | - std::string error; |
112 | | - const auto *target = |
113 | | - llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error); |
114 | | - if (target == nullptr) { |
115 | | - return absl::InternalError( |
116 | | - absl::StrFormat("Failed to lookup LLVM target based on triple %s: %s", |
117 | | - module->getTargetTriple().str(), error)); |
118 | | - } |
119 | | - llvm::TargetOptions opt; |
120 | | - if (enable_fp_fusion) { |
121 | | - opt.AllowFPOpFusion = llvm::FPOpFusion::Fast; |
122 | | - } |
123 | | - opt.NoInfsFPMath = false; |
124 | | - opt.NoNaNsFPMath = true; |
125 | | - opt.TrapUnreachable = true; |
126 | | - opt.MCOptions.AsmVerbose = true; |
127 | | - opt.MCOptions.PreserveAsmComments = true; |
128 | | - return std::unique_ptr<llvm::TargetMachine>(target->createTargetMachine( |
129 | | - module->getTargetTriple(), arch_name, features, opt, llvm::Reloc::PIC_, |
130 | | - std::nullopt, llvm::CodeGenOptLevel::Aggressive)); |
131 | | -} |
132 | | - |
133 | | -absl::Status LinkLibdevice(llvm::Module *module, std::string libdevice_dir) { |
134 | | - auto libdevice_path = (fs::path(libdevice_dir) / "libdevice.10.bc").string(); |
135 | | - |
136 | | - llvm::LLVMContext &ctx = module->getContext(); |
137 | | - llvm::SMDiagnostic err; |
138 | | - std::unique_ptr<llvm::Module> libdevice_module = |
139 | | - llvm::parseIRFile(libdevice_path, err, ctx); |
140 | | - if (!libdevice_module) { |
141 | | - return absl::InternalError( |
142 | | - absl::StrFormat("Failed to parse libdevice IR file at %s: %s", |
143 | | - libdevice_path, err.getMessage())); |
144 | | - } |
145 | | - |
146 | | - llvm::Linker linker(*module); |
147 | | - if (linker.linkInModule(std::move(libdevice_module), |
148 | | - llvm::Linker::Flags::LinkOnlyNeeded)) { |
149 | | - return absl::InternalError("Failed to link libdevice"); |
150 | | - } |
151 | | - |
152 | | - return absl::OkStatus(); |
153 | | -} |
154 | | - |
155 | | -absl::StatusOr<std::string> LLVMToPTX(mlir::ModuleOp module, |
156 | | - absl::string_view arch_name, |
157 | | - std::string libdevice_dir) { |
158 | | - // Based on translateLLVMIRToASM() in triton/python/src/llvm.cc |
159 | | - mlir::DialectRegistry registry; |
160 | | - mlir::registerBuiltinDialectTranslation(registry); |
161 | | - mlir::registerLLVMDialectTranslation(registry); |
162 | | - mlir::registerNVVMDialectTranslation(registry); |
163 | | - module.getContext()->appendDialectRegistry(registry); |
164 | | - |
165 | | - llvm::LLVMContext llvmContext; |
166 | | - std::unique_ptr<llvm::Module> llvmModule = |
167 | | - mlir::translateModuleToLLVMIR(module, llvmContext); |
168 | | - if (!llvmModule) { |
169 | | - return absl::InternalError("Failed to emit LLVM IR"); |
170 | | - } |
171 | | - |
172 | | - auto cc = absl::StrReplaceAll(arch_name, {{".", ""}}); // "8.0" -> "80" |
173 | | - auto proc = absl::StrCat("sm_", cc, cc == "90" ? "a" : ""); |
174 | | - // We cap the ISA at 8.4 to align with Triton. |
175 | | - // See get_features() in triton/third_party/nvidia/backend/compiler.py. |
176 | | - auto features = cc >= "84" ? "+ptx84" : "+ptx" + cc; |
177 | | - llvmModule->setTargetTriple(llvm::Triple("nvptx64-nvidia-cuda")); |
178 | | - static absl::once_flag init_target_once; |
179 | | - absl::call_once(init_target_once, []() { |
180 | | - LLVMInitializeNVPTXTarget(); |
181 | | - LLVMInitializeNVPTXTargetInfo(); |
182 | | - LLVMInitializeNVPTXTargetMC(); |
183 | | - LLVMInitializeNVPTXAsmPrinter(); |
184 | | - }); |
185 | | - |
186 | | - auto machineOrStatus = |
187 | | - CreateTargetMachine(llvmModule.get(), proc, |
188 | | - /*enable_fp_fusion=*/false, features); |
189 | | - if (!machineOrStatus.ok()) { |
190 | | - return machineOrStatus.status(); |
191 | | - } |
192 | | - auto machine = std::move(machineOrStatus.value()); |
193 | | - |
194 | | - llvmModule->setDataLayout(machine->createDataLayout()); |
195 | | - |
196 | | - auto needsLibdevice = |
197 | | - llvm::any_of(llvmModule->functions(), [](const auto &f) { |
198 | | - return !f.isIntrinsic() && f.isDeclaration() && |
199 | | - f.getName().starts_with("__nv_"); |
200 | | - }); |
201 | | - if (needsLibdevice) { |
202 | | - auto linkStatus = LinkLibdevice(llvmModule.get(), libdevice_dir); |
203 | | - if (!linkStatus.ok()) { |
204 | | - return linkStatus; |
205 | | - } |
206 | | - } |
207 | | - |
208 | | - auto transformer = mlir::makeOptimizingTransformer( |
209 | | - /*optLevel=*/3, /*sizeLevel=*/0, /*targetMachine=*/machine.get()); |
210 | | - if (auto error = transformer(llvmModule.get()); error) { |
211 | | - return absl::InternalError("Failed to optimize LLVM IR"); |
212 | | - } |
213 | | - |
214 | | - std::string result; |
215 | | - { |
216 | | - llvm::raw_string_ostream stream(result); |
217 | | - llvm::buffer_ostream bstream(stream); |
218 | | - llvm::legacy::PassManager pm; |
219 | | - machine->addPassesToEmitFile(pm, bstream, nullptr, |
220 | | - llvm::CodeGenFileType::AssemblyFile, |
221 | | - /*DisableVerify=*/false); |
222 | | - if (!pm.run(*llvmModule)) { |
223 | | - return absl::InternalError("Failed to compile LLVM IR to PTX"); |
224 | | - } |
225 | | - } |
226 | | - return result; |
227 | | -} |
228 | | - |
229 | | -} // namespace cuda |
230 | | - |
231 | 79 | struct LowerTritonPass |
232 | 80 | : public mlir::enzyme::impl::LowerTritonPassBase<LowerTritonPass> { |
233 | 81 | using Base::Base; |
@@ -297,36 +145,6 @@ struct LowerTritonPass |
297 | 145 | continue; |
298 | 146 | } |
299 | 147 |
|
300 | | - // remove divisibility attributes from the module before lowering to PTX |
301 | | - // auto funcOpInterface = dyn_cast<FunctionOpInterface>( |
302 | | - // symbolTable.lookupNearestSymbolFrom(ttCallOp, |
303 | | - // ttCallOp.getFnAttr())); |
304 | | - |
305 | | - // if (!funcOpInterface) { |
306 | | - // innerMod->emitError("Failed to find function '") << ttCallOp.getFn() |
307 | | - // << |
308 | | - // "' in module"; |
309 | | - // anyFailed = true; |
310 | | - // continue; |
311 | | - // } |
312 | | - |
313 | | - // mlir::StringAttr divAttrName = |
314 | | - // builder.getStringAttr("tt.divisibility"); for (size_t i = 0; i < |
315 | | - // ttCallOp.getInputs().size(); ++i) { |
316 | | - // funcOpInterface.removeArgAttr(i, divAttrName); |
317 | | - // } |
318 | | - |
319 | | - // auto ptxOrError = |
320 | | - // cuda::LLVMToPTX(innerMod, computeCapability, libdeviceDir); |
321 | | - // if (!ptxOrError.ok()) { |
322 | | - // innerMod->emitError(ptxOrError.status().message()); |
323 | | - // anyFailed = true; |
324 | | - // continue; |
325 | | - // } |
326 | | - |
327 | | - // auto ptx = ptxOrError.value(); |
328 | | - // llvm::errs() << "Compilation result: " << ptx << "\n"; |
329 | | - |
330 | 148 | int32_t threadsPerWarp = 32; |
331 | 149 | if (innerMod->hasAttrOfType<IntegerAttr>("ttg.threads_per_warp")) { |
332 | 150 | threadsPerWarp = |
|
0 commit comments