@@ -297,54 +297,83 @@ struct LowerTritonPass
297297 continue ;
298298 }
299299
300- // int32_t threadsPerWarp = 32;
301- // if (innerMod->hasAttrOfType<IntegerAttr>("ttg.threads_per_warp")) {
302- // threadsPerWarp =
303- // innerMod->getAttrOfType<IntegerAttr>("ttg.threads_per_warp")
304- // .getInt();
300+ // remove divisibility attributes from the module before lowering to PTX
301+ // auto funcOpInterface = dyn_cast<FunctionOpInterface>(
302+ // symbolTable.lookupNearestSymbolFrom(ttCallOp,
303+ // ttCallOp.getFnAttr()));
304+
305+ // if (!funcOpInterface) {
306+ // innerMod->emitError("Failed to find function '") << ttCallOp.getFn()
307+ // <<
308+ // "' in module";
309+ // anyFailed = true;
310+ // continue;
305311 // }
306312
307- auto ptxOrError =
308- cuda::LLVMToPTX (innerMod, computeCapability, libdeviceDir);
309- if (!ptxOrError.ok ()) {
310- innerMod->emitError (ptxOrError.status ().message ());
311- anyFailed = true ;
312- continue ;
313- }
313+ // mlir::StringAttr divAttrName =
314+ // builder.getStringAttr("tt.divisibility"); for (size_t i = 0; i <
315+ // ttCallOp.getInputs().size(); ++i) {
316+ // funcOpInterface.removeArgAttr(i, divAttrName);
317+ // }
314318
315- auto ptx = ptxOrError.value ();
316- llvm::errs () << " Compilation result: " << ptx << " \n " ;
319+ // auto ptxOrError =
320+ // cuda::LLVMToPTX(innerMod, computeCapability, libdeviceDir);
321+ // if (!ptxOrError.ok()) {
322+ // innerMod->emitError(ptxOrError.status().message());
323+ // anyFailed = true;
324+ // continue;
325+ // }
326+
327+ // auto ptx = ptxOrError.value();
328+ // llvm::errs() << "Compilation result: " << ptx << "\n";
329+
330+ int32_t threadsPerWarp = 32 ;
331+ if (innerMod->hasAttrOfType <IntegerAttr>(" ttg.threads_per_warp" )) {
332+ threadsPerWarp =
333+ innerMod->getAttrOfType <IntegerAttr>(" ttg.threads_per_warp" )
334+ .getInt ();
335+ }
317336
318337 builder.setInsertionPoint (ttCallOp);
319338
320- // auto sharedMemSizeAttr =
321- // innerMod->getAttrOfType<IntegerAttr>("ttg.shared");
322- // auto sharedMemSize = sharedMemSizeAttr.getInt();
323- // auto shmemOpType = ttCallOp.getGridx().getType();
324- // auto shmemOp = stablehlo::ConstantOp::create(
325- // builder, ttCallOp.getLoc(), shmemOpType,
326- // cast<ElementsAttr>(makeAttr(shmemOpType, sharedMemSize)));
327-
328- // auto blockX = stablehlo::ConstantOp::create(
329- // builder, ttCallOp.getLoc(), shmemOpType,
330- // cast<ElementsAttr>(makeAttr(shmemOpType, threadsPerWarp *
331- // numWarps)));
332- // auto blockYZ = stablehlo::ConstantOp::create(
333- // builder, ttCallOp.getLoc(), shmemOpType,
334- // cast<ElementsAttr>(makeAttr(shmemOpType, 1)));
335-
336- // auto kernelCallOp = enzymexla::KernelCallOp::create(
337- // builder, ttCallOp.getLoc(), ttCallOp.getResultTypes(),
338- // ttCallOp.getFn(), ttCallOp.getGridx(), ttCallOp.getGridy(),
339- // ttCallOp.getGridz(), blockX, blockYZ, blockYZ, shmemOp,
340- // ttCallOp.getClusterx(), ttCallOp.getClustery(),
341- // ttCallOp.getClusterz(), ttCallOp.getInputs(),
342- // ttCallOp.getBackendConfigAttr(), ttCallOp.getOperandLayoutsAttr(),
343- // ttCallOp.getResultLayoutsAttr(), ttCallOp.getArgAttrsAttr(),
344- // ttCallOp.getResAttrsAttr(), ttCallOp.getOutputOperandAliasesAttr(),
345- // ttCallOp.getXlaSideEffectFreeAttr());
346- // ttCallOp.replaceAllUsesWith(kernelCallOp);
347- // ttCallOp.erase();
339+ auto sharedMemSizeAttr =
340+ innerMod->getAttrOfType <IntegerAttr>(" ttg.shared" );
341+ auto sharedMemSize = sharedMemSizeAttr.getInt ();
342+ auto shmemOpType = ttCallOp.getGridx ().getType ();
343+ auto shmemOp = stablehlo::ConstantOp::create (
344+ builder, ttCallOp.getLoc (), shmemOpType,
345+ cast<ElementsAttr>(makeAttr (shmemOpType, sharedMemSize)));
346+
347+ auto blockX = stablehlo::ConstantOp::create (
348+ builder, ttCallOp.getLoc (), shmemOpType,
349+ cast<ElementsAttr>(makeAttr (shmemOpType, threadsPerWarp * numWarps)));
350+ auto blockYZ = stablehlo::ConstantOp::create (
351+ builder, ttCallOp.getLoc (), shmemOpType,
352+ cast<ElementsAttr>(makeAttr (shmemOpType, 1 )));
353+
354+ SmallVector<mlir::Value> newInputs (ttCallOp.getInputs ().begin (),
355+ ttCallOp.getInputs ().end ());
356+ // we don't use the next 2 inputs
357+ auto scratchSpace = stablehlo::ConstantOp::create (
358+ builder, ttCallOp.getLoc (),
359+ RankedTensorType::get ({}, builder.getI8Type ()),
360+ cast<ElementsAttr>(
361+ makeAttr (RankedTensorType::get ({}, builder.getI8Type ()), 0 )));
362+ newInputs.push_back (scratchSpace);
363+ newInputs.push_back (scratchSpace);
364+
365+ auto kernelCallOp = enzymexla::KernelCallOp::create (
366+ builder, ttCallOp.getLoc (), ttCallOp.getResultTypes (),
367+ ttCallOp.getFn (), ttCallOp.getGridx (), ttCallOp.getGridy (),
368+ ttCallOp.getGridz (), blockX, blockYZ, blockYZ, shmemOp,
369+ ttCallOp.getClusterx (), ttCallOp.getClustery (),
370+ ttCallOp.getClusterz (), newInputs, ttCallOp.getBackendConfigAttr (),
371+ ttCallOp.getOperandLayoutsAttr (), ttCallOp.getResultLayoutsAttr (),
372+ ttCallOp.getArgAttrsAttr (), ttCallOp.getResAttrsAttr (),
373+ ttCallOp.getOutputOperandAliasesAttr (),
374+ ttCallOp.getXlaSideEffectFreeAttr ());
375+ ttCallOp.replaceAllUsesWith (kernelCallOp);
376+ ttCallOp.erase ();
348377 }
349378
350379 if (anyFailed) {
0 commit comments