Skip to content

Commit fd81660

Browse files
committed
feat: lower to kernel_call
1 parent 5115d4f commit fd81660

File tree

1 file changed

+71
-42
lines changed

1 file changed

+71
-42
lines changed

src/enzyme_ad/jax/Passes/LowerTriton.cpp

Lines changed: 71 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -297,54 +297,83 @@ struct LowerTritonPass
297297
continue;
298298
}
299299

300-
// int32_t threadsPerWarp = 32;
301-
// if (innerMod->hasAttrOfType<IntegerAttr>("ttg.threads_per_warp")) {
302-
// threadsPerWarp =
303-
// innerMod->getAttrOfType<IntegerAttr>("ttg.threads_per_warp")
304-
// .getInt();
300+
// remove divisibility attributes from the module before lowering to PTX
301+
// auto funcOpInterface = dyn_cast<FunctionOpInterface>(
302+
// symbolTable.lookupNearestSymbolFrom(ttCallOp,
303+
// ttCallOp.getFnAttr()));
304+
305+
// if (!funcOpInterface) {
306+
// innerMod->emitError("Failed to find function '") << ttCallOp.getFn()
307+
// <<
308+
// "' in module";
309+
// anyFailed = true;
310+
// continue;
305311
// }
306312

307-
auto ptxOrError =
308-
cuda::LLVMToPTX(innerMod, computeCapability, libdeviceDir);
309-
if (!ptxOrError.ok()) {
310-
innerMod->emitError(ptxOrError.status().message());
311-
anyFailed = true;
312-
continue;
313-
}
313+
// mlir::StringAttr divAttrName =
314+
// builder.getStringAttr("tt.divisibility"); for (size_t i = 0; i <
315+
// ttCallOp.getInputs().size(); ++i) {
316+
// funcOpInterface.removeArgAttr(i, divAttrName);
317+
// }
314318

315-
auto ptx = ptxOrError.value();
316-
llvm::errs() << "Compilation result: " << ptx << "\n";
319+
// auto ptxOrError =
320+
// cuda::LLVMToPTX(innerMod, computeCapability, libdeviceDir);
321+
// if (!ptxOrError.ok()) {
322+
// innerMod->emitError(ptxOrError.status().message());
323+
// anyFailed = true;
324+
// continue;
325+
// }
326+
327+
// auto ptx = ptxOrError.value();
328+
// llvm::errs() << "Compilation result: " << ptx << "\n";
329+
330+
int32_t threadsPerWarp = 32;
331+
if (innerMod->hasAttrOfType<IntegerAttr>("ttg.threads_per_warp")) {
332+
threadsPerWarp =
333+
innerMod->getAttrOfType<IntegerAttr>("ttg.threads_per_warp")
334+
.getInt();
335+
}
317336

318337
builder.setInsertionPoint(ttCallOp);
319338

320-
// auto sharedMemSizeAttr =
321-
// innerMod->getAttrOfType<IntegerAttr>("ttg.shared");
322-
// auto sharedMemSize = sharedMemSizeAttr.getInt();
323-
// auto shmemOpType = ttCallOp.getGridx().getType();
324-
// auto shmemOp = stablehlo::ConstantOp::create(
325-
// builder, ttCallOp.getLoc(), shmemOpType,
326-
// cast<ElementsAttr>(makeAttr(shmemOpType, sharedMemSize)));
327-
328-
// auto blockX = stablehlo::ConstantOp::create(
329-
// builder, ttCallOp.getLoc(), shmemOpType,
330-
// cast<ElementsAttr>(makeAttr(shmemOpType, threadsPerWarp *
331-
// numWarps)));
332-
// auto blockYZ = stablehlo::ConstantOp::create(
333-
// builder, ttCallOp.getLoc(), shmemOpType,
334-
// cast<ElementsAttr>(makeAttr(shmemOpType, 1)));
335-
336-
// auto kernelCallOp = enzymexla::KernelCallOp::create(
337-
// builder, ttCallOp.getLoc(), ttCallOp.getResultTypes(),
338-
// ttCallOp.getFn(), ttCallOp.getGridx(), ttCallOp.getGridy(),
339-
// ttCallOp.getGridz(), blockX, blockYZ, blockYZ, shmemOp,
340-
// ttCallOp.getClusterx(), ttCallOp.getClustery(),
341-
// ttCallOp.getClusterz(), ttCallOp.getInputs(),
342-
// ttCallOp.getBackendConfigAttr(), ttCallOp.getOperandLayoutsAttr(),
343-
// ttCallOp.getResultLayoutsAttr(), ttCallOp.getArgAttrsAttr(),
344-
// ttCallOp.getResAttrsAttr(), ttCallOp.getOutputOperandAliasesAttr(),
345-
// ttCallOp.getXlaSideEffectFreeAttr());
346-
// ttCallOp.replaceAllUsesWith(kernelCallOp);
347-
// ttCallOp.erase();
339+
auto sharedMemSizeAttr =
340+
innerMod->getAttrOfType<IntegerAttr>("ttg.shared");
341+
auto sharedMemSize = sharedMemSizeAttr.getInt();
342+
auto shmemOpType = ttCallOp.getGridx().getType();
343+
auto shmemOp = stablehlo::ConstantOp::create(
344+
builder, ttCallOp.getLoc(), shmemOpType,
345+
cast<ElementsAttr>(makeAttr(shmemOpType, sharedMemSize)));
346+
347+
auto blockX = stablehlo::ConstantOp::create(
348+
builder, ttCallOp.getLoc(), shmemOpType,
349+
cast<ElementsAttr>(makeAttr(shmemOpType, threadsPerWarp * numWarps)));
350+
auto blockYZ = stablehlo::ConstantOp::create(
351+
builder, ttCallOp.getLoc(), shmemOpType,
352+
cast<ElementsAttr>(makeAttr(shmemOpType, 1)));
353+
354+
SmallVector<mlir::Value> newInputs(ttCallOp.getInputs().begin(),
355+
ttCallOp.getInputs().end());
356+
// we don't use the next 2 inputs
357+
auto scratchSpace = stablehlo::ConstantOp::create(
358+
builder, ttCallOp.getLoc(),
359+
RankedTensorType::get({}, builder.getI8Type()),
360+
cast<ElementsAttr>(
361+
makeAttr(RankedTensorType::get({}, builder.getI8Type()), 0)));
362+
newInputs.push_back(scratchSpace);
363+
newInputs.push_back(scratchSpace);
364+
365+
auto kernelCallOp = enzymexla::KernelCallOp::create(
366+
builder, ttCallOp.getLoc(), ttCallOp.getResultTypes(),
367+
ttCallOp.getFn(), ttCallOp.getGridx(), ttCallOp.getGridy(),
368+
ttCallOp.getGridz(), blockX, blockYZ, blockYZ, shmemOp,
369+
ttCallOp.getClusterx(), ttCallOp.getClustery(),
370+
ttCallOp.getClusterz(), newInputs, ttCallOp.getBackendConfigAttr(),
371+
ttCallOp.getOperandLayoutsAttr(), ttCallOp.getResultLayoutsAttr(),
372+
ttCallOp.getArgAttrsAttr(), ttCallOp.getResAttrsAttr(),
373+
ttCallOp.getOutputOperandAliasesAttr(),
374+
ttCallOp.getXlaSideEffectFreeAttr());
375+
ttCallOp.replaceAllUsesWith(kernelCallOp);
376+
ttCallOp.erase();
348377
}
349378

350379
if (anyFailed) {

0 commit comments

Comments
 (0)