Skip to content

Commit e873f6d

Browse files
committed
chore: cleanup
1 parent 57938eb commit e873f6d

File tree

3 files changed

+0
-194
lines changed

3 files changed

+0
-194
lines changed

src/enzyme_ad/jax/BUILD

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -929,8 +929,6 @@ cc_library(
929929
"@llvm-project//mlir:ControlFlowToLLVM",
930930
"@llvm-project//mlir:ControlFlowToSCF",
931931
"@llvm-project//mlir:DLTIDialect",
932-
"@llvm-project//mlir:ExecutionEngine",
933-
"@llvm-project//mlir:ExecutionEngineUtils",
934932
"@llvm-project//mlir:FromLLVMIRTranslation",
935933
"@llvm-project//mlir:FromLLVMIRTranslationRegistration",
936934
"@llvm-project//mlir:FuncDialect",
@@ -949,7 +947,6 @@ cc_library(
949947
"@llvm-project//mlir:InliningUtils",
950948
"@llvm-project//mlir:LLVMCommonConversion",
951949
"@llvm-project//mlir:LLVMDialect",
952-
"@llvm-project//mlir:LLVMToLLVMIRTranslation",
953950
"@llvm-project//mlir:LinalgTransforms",
954951
"@llvm-project//mlir:MathDialect",
955952
"@llvm-project//mlir:MathToLLVM",
@@ -962,9 +959,7 @@ cc_library(
962959
"@llvm-project//mlir:NVGPUDialect",
963960
"@llvm-project//mlir:NVGPUToNVVM",
964961
"@llvm-project//mlir:NVVMDialect",
965-
"@llvm-project//mlir:NVVMTarget",
966962
"@llvm-project//mlir:NVVMToLLVM",
967-
"@llvm-project//mlir:NVVMToLLVMIRTranslation",
968963
"@llvm-project//mlir:OpenMPDialect",
969964
"@llvm-project//mlir:OpenMPToLLVM",
970965
"@llvm-project//mlir:Parser",
@@ -979,7 +974,6 @@ cc_library(
979974
"@llvm-project//mlir:SCFUtils",
980975
"@llvm-project//mlir:SideEffectInterfaces",
981976
"@llvm-project//mlir:Support",
982-
"@llvm-project//mlir:TargetLLVM",
983977
"@llvm-project//mlir:TensorDialect",
984978
"@llvm-project//mlir:ToLLVMIRTranslation",
985979
"@llvm-project//mlir:ToLLVMIRTranslationRegistration",

src/enzyme_ad/jax/Passes/LowerTriton.cpp

Lines changed: 0 additions & 182 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,6 @@
1616
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1717
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
1818
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
19-
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
20-
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
21-
#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"
2219
#include "nvidia/include/Dialect/NVGPU/IR/Dialect.h"
2320
#include "nvidia/include/Dialect/NVWS/IR/Dialect.h"
2421
#include "proton/Dialect/include/Dialect/Proton/IR/Dialect.h"
@@ -36,27 +33,6 @@
3633

3734
#include "src/enzyme_ad/jax/Utils.h"
3835

39-
#include "absl/status/status.h"
40-
#include "absl/status/statusor.h"
41-
#include "absl/strings/str_replace.h"
42-
43-
#include "mlir/ExecutionEngine/OptUtils.h"
44-
#include "llvm/ADT/STLExtras.h"
45-
#include "llvm/IR/LLVMContext.h"
46-
#include "llvm/IR/LegacyPassManager.h"
47-
#include "llvm/IR/Module.h"
48-
#include "llvm/IRReader/IRReader.h"
49-
#include "llvm/Linker/Linker.h"
50-
#include "llvm/MC/TargetRegistry.h"
51-
#include "llvm/Support/CodeGen.h"
52-
#include "llvm/Support/LogicalResult.h"
53-
#include "llvm/Support/SourceMgr.h"
54-
#include "llvm/Support/TargetSelect.h"
55-
#include "llvm/Support/raw_ostream.h"
56-
#include "llvm/Target/TargetMachine.h"
57-
#include "llvm/Target/TargetOptions.h"
58-
#include "llvm/TargetParser/Triple.h"
59-
6036
#include "llvm/ADT/DenseMap.h"
6137
#include "llvm/ADT/SmallVector.h"
6238

@@ -100,134 +76,6 @@ void collectTritonKernels(
10076
return;
10177
}
10278

103-
namespace cuda {
104-
105-
namespace fs = std::filesystem;
106-
107-
absl::StatusOr<std::unique_ptr<llvm::TargetMachine>>
108-
CreateTargetMachine(llvm::Module *module, absl::string_view arch_name,
109-
bool enable_fp_fusion, absl::string_view features) {
110-
// Based on createTargetMachine() in triton/python/src/llvm.cc
111-
std::string error;
112-
const auto *target =
113-
llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
114-
if (target == nullptr) {
115-
return absl::InternalError(
116-
absl::StrFormat("Failed to lookup LLVM target based on triple %s: %s",
117-
module->getTargetTriple().str(), error));
118-
}
119-
llvm::TargetOptions opt;
120-
if (enable_fp_fusion) {
121-
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
122-
}
123-
opt.NoInfsFPMath = false;
124-
opt.NoNaNsFPMath = true;
125-
opt.TrapUnreachable = true;
126-
opt.MCOptions.AsmVerbose = true;
127-
opt.MCOptions.PreserveAsmComments = true;
128-
return std::unique_ptr<llvm::TargetMachine>(target->createTargetMachine(
129-
module->getTargetTriple(), arch_name, features, opt, llvm::Reloc::PIC_,
130-
std::nullopt, llvm::CodeGenOptLevel::Aggressive));
131-
}
132-
133-
absl::Status LinkLibdevice(llvm::Module *module, std::string libdevice_dir) {
134-
auto libdevice_path = (fs::path(libdevice_dir) / "libdevice.10.bc").string();
135-
136-
llvm::LLVMContext &ctx = module->getContext();
137-
llvm::SMDiagnostic err;
138-
std::unique_ptr<llvm::Module> libdevice_module =
139-
llvm::parseIRFile(libdevice_path, err, ctx);
140-
if (!libdevice_module) {
141-
return absl::InternalError(
142-
absl::StrFormat("Failed to parse libdevice IR file at %s: %s",
143-
libdevice_path, err.getMessage()));
144-
}
145-
146-
llvm::Linker linker(*module);
147-
if (linker.linkInModule(std::move(libdevice_module),
148-
llvm::Linker::Flags::LinkOnlyNeeded)) {
149-
return absl::InternalError("Failed to link libdevice");
150-
}
151-
152-
return absl::OkStatus();
153-
}
154-
155-
absl::StatusOr<std::string> LLVMToPTX(mlir::ModuleOp module,
156-
absl::string_view arch_name,
157-
std::string libdevice_dir) {
158-
// Based on translateLLVMIRToASM() in triton/python/src/llvm.cc
159-
mlir::DialectRegistry registry;
160-
mlir::registerBuiltinDialectTranslation(registry);
161-
mlir::registerLLVMDialectTranslation(registry);
162-
mlir::registerNVVMDialectTranslation(registry);
163-
module.getContext()->appendDialectRegistry(registry);
164-
165-
llvm::LLVMContext llvmContext;
166-
std::unique_ptr<llvm::Module> llvmModule =
167-
mlir::translateModuleToLLVMIR(module, llvmContext);
168-
if (!llvmModule) {
169-
return absl::InternalError("Failed to emit LLVM IR");
170-
}
171-
172-
auto cc = absl::StrReplaceAll(arch_name, {{".", ""}}); // "8.0" -> "80"
173-
auto proc = absl::StrCat("sm_", cc, cc == "90" ? "a" : "");
174-
// We cap the ISA at 8.4 to align with Triton.
175-
// See get_features() in triton/third_party/nvidia/backend/compiler.py.
176-
auto features = cc >= "84" ? "+ptx84" : "+ptx" + cc;
177-
llvmModule->setTargetTriple(llvm::Triple("nvptx64-nvidia-cuda"));
178-
static absl::once_flag init_target_once;
179-
absl::call_once(init_target_once, []() {
180-
LLVMInitializeNVPTXTarget();
181-
LLVMInitializeNVPTXTargetInfo();
182-
LLVMInitializeNVPTXTargetMC();
183-
LLVMInitializeNVPTXAsmPrinter();
184-
});
185-
186-
auto machineOrStatus =
187-
CreateTargetMachine(llvmModule.get(), proc,
188-
/*enable_fp_fusion=*/false, features);
189-
if (!machineOrStatus.ok()) {
190-
return machineOrStatus.status();
191-
}
192-
auto machine = std::move(machineOrStatus.value());
193-
194-
llvmModule->setDataLayout(machine->createDataLayout());
195-
196-
auto needsLibdevice =
197-
llvm::any_of(llvmModule->functions(), [](const auto &f) {
198-
return !f.isIntrinsic() && f.isDeclaration() &&
199-
f.getName().starts_with("__nv_");
200-
});
201-
if (needsLibdevice) {
202-
auto linkStatus = LinkLibdevice(llvmModule.get(), libdevice_dir);
203-
if (!linkStatus.ok()) {
204-
return linkStatus;
205-
}
206-
}
207-
208-
auto transformer = mlir::makeOptimizingTransformer(
209-
/*optLevel=*/3, /*sizeLevel=*/0, /*targetMachine=*/machine.get());
210-
if (auto error = transformer(llvmModule.get()); error) {
211-
return absl::InternalError("Failed to optimize LLVM IR");
212-
}
213-
214-
std::string result;
215-
{
216-
llvm::raw_string_ostream stream(result);
217-
llvm::buffer_ostream bstream(stream);
218-
llvm::legacy::PassManager pm;
219-
machine->addPassesToEmitFile(pm, bstream, nullptr,
220-
llvm::CodeGenFileType::AssemblyFile,
221-
/*DisableVerify=*/false);
222-
if (!pm.run(*llvmModule)) {
223-
return absl::InternalError("Failed to compile LLVM IR to PTX");
224-
}
225-
}
226-
return result;
227-
}
228-
229-
} // namespace cuda
230-
23179
struct LowerTritonPass
23280
: public mlir::enzyme::impl::LowerTritonPassBase<LowerTritonPass> {
23381
using Base::Base;
@@ -297,36 +145,6 @@ struct LowerTritonPass
297145
continue;
298146
}
299147

300-
// remove divisibility attributes from the module before lowering to PTX
301-
// auto funcOpInterface = dyn_cast<FunctionOpInterface>(
302-
// symbolTable.lookupNearestSymbolFrom(ttCallOp,
303-
// ttCallOp.getFnAttr()));
304-
305-
// if (!funcOpInterface) {
306-
// innerMod->emitError("Failed to find function '") << ttCallOp.getFn()
307-
// <<
308-
// "' in module";
309-
// anyFailed = true;
310-
// continue;
311-
// }
312-
313-
// mlir::StringAttr divAttrName =
314-
// builder.getStringAttr("tt.divisibility"); for (size_t i = 0; i <
315-
// ttCallOp.getInputs().size(); ++i) {
316-
// funcOpInterface.removeArgAttr(i, divAttrName);
317-
// }
318-
319-
// auto ptxOrError =
320-
// cuda::LLVMToPTX(innerMod, computeCapability, libdeviceDir);
321-
// if (!ptxOrError.ok()) {
322-
// innerMod->emitError(ptxOrError.status().message());
323-
// anyFailed = true;
324-
// continue;
325-
// }
326-
327-
// auto ptx = ptxOrError.value();
328-
// llvm::errs() << "Compilation result: " << ptx << "\n";
329-
330148
int32_t threadsPerWarp = 32;
331149
if (innerMod->hasAttrOfType<IntegerAttr>("ttg.threads_per_warp")) {
332150
threadsPerWarp =

src/enzyme_ad/jax/Passes/Passes.td

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,12 +1096,6 @@ def LowerTritonPass : Pass<"lower-triton", "mlir::ModuleOp"> {
10961096
"ROCDL::ROCDLDialect",
10971097
];
10981098
let options = [
1099-
Option<
1100-
/*C++ variable name=*/"libdeviceDir",
1101-
/*CLI argument=*/"libdevice_dir",
1102-
/*type=*/"std::string",
1103-
/*default=*/"\"\"",
1104-
/*description=*/"Path to the libdevice directory">,
11051099
Option<
11061100
/*C++ variable name=*/"backend",
11071101
/*CLI argument=*/"backend",

0 commit comments

Comments
 (0)