From aa9b4d4b80b5327374b92a3849f1cf0f0cf224d6 Mon Sep 17 00:00:00 2001 From: snonk Date: Wed, 26 Nov 2025 09:57:25 -0600 Subject: [PATCH 1/2] change calling convention --- .../jax/Passes/LowerEnzymeXLABLAS.cpp | 280 +++++++++++++++++- 1 file changed, 279 insertions(+), 1 deletion(-) diff --git a/src/enzyme_ad/jax/Passes/LowerEnzymeXLABLAS.cpp b/src/enzyme_ad/jax/Passes/LowerEnzymeXLABLAS.cpp index 0b2979408..670fbfbf3 100644 --- a/src/enzyme_ad/jax/Passes/LowerEnzymeXLABLAS.cpp +++ b/src/enzyme_ad/jax/Passes/LowerEnzymeXLABLAS.cpp @@ -29,6 +29,284 @@ using namespace mlir::enzyme; using namespace mlir::enzymexla; using namespace mlir::stablehlo; +struct SymmOpLowering : public OpRewritePattern { + + using OpRewritePattern::OpRewritePattern; + + std::string backend; + int64_t blasIntWidth; + SymmOpLowering(std::string backend, int64_t blasIntWidth, + MLIRContext *context, PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), backend(backend), + blasIntWidth(blasIntWidth) {} + + LogicalResult matchAndRewrite(enzymexla::SymmOp op, + PatternRewriter &rewriter) const override { + if (backend == "cpu") + return matchAndRewriteCPU(op, rewriter); + + // else if (backend == "cuda") + // return matchAndRewriteCUDA(op, rewriter); + + // else if (backend == "tpu") + // return matchAndRewriteTPU(op, rewriter); + + else + return rewriter.notifyMatchFailure(op, "Unknown backend: \"" + backend + + "\""); + } + + LogicalResult matchAndRewriteCPU(enzymexla::SymmOp op, + PatternRewriter &rewriter) const { + llvm::errs() << "1\n"; + + auto ctx = op->getContext(); + LLVMTypeConverter typeConverter(ctx); + llvm::errs() << "2\n"; + + Value a = op.getOperand(0); + Value b = op.getOperand(1); + Value c = op.getOperand(2); + Value alpha_value = op.getAlpha(); + Value beta_value = op.getBeta(); + auto side_value = op.getSide() == enzymexla::LapackSide::left ? 'L' : 'R'; + auto uplo_value = op.getUplo() == enzymexla::LapackUplo::L ? 'L' : 'U'; + llvm::errs() << "3\n"; + + auto aType = cast(a.getType()); + auto bType = cast(b.getType()); + auto cType = cast(c.getType()); + llvm::errs() << "4\n"; + if (!aType || !bType || !cType) +{ llvm::errs() << "operand types not ranked tensor types\n"; + return rewriter.notifyMatchFailure(op, "operand types not ranked tensor types"); + +} if (!aType.hasRank() || !bType.hasRank() || !cType.hasRank()) +{ llvm::errs() << "expected ranked tensor types\n"; + return rewriter.notifyMatchFailure(op, "expected ranked tensor types"); +} + if (aType.getRank() != 2 || bType.getRank() > 2 || cType.getRank() > 2) +{ llvm::errs() << "only 2D matrices supported for symm\n"; + return rewriter.notifyMatchFailure(op, "only 2D matrices supported for symm"); +} + + llvm::errs() << "passed type checks\n"; + + Type elementType = aType.getElementType(); + auto blasIntType = rewriter.getIntegerType(blasIntWidth); + auto intType = RankedTensorType::get({}, blasIntType); + auto uint8Type = RankedTensorType::get({}, rewriter.getIntegerType(8, false)); + auto llvmIntType = typeConverter.convertType(blasIntType); + auto llvmElmType = typeConverter.convertType(elementType); + auto llvmPtrType = LLVM::LLVMPointerType::get(ctx); + auto llvmVoidType = LLVM::LLVMVoidType::get(ctx); + + llvm::errs() << "5\n"; + + std::string blasFn; + if (auto prefix = lapackPrecisionPrefix(elementType)) { + blasFn = "enzymexla_blas_" + *prefix + "symm_"; + } else { + op->emitOpError() << "Unsupported element type: " << elementType; + return rewriter.notifyMatchFailure(op, "unsupported element type"); + } + std::string blasFnWrapper = blasFn + "wrapper"; + llvm::errs() << "6\n"; + + auto moduleOp = op->getParentOfType(); + + if (!moduleOp.lookupSymbol(blasFn)) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + auto funcType = + LLVM::LLVMFunctionType::get(llvmVoidType, + { + llvmPtrType, // side + llvmPtrType, // uplo + llvmPtrType, // m + llvmPtrType, // n + llvmPtrType, // alpha + llvmPtrType, // A + llvmPtrType, // lda + llvmPtrType, // B + llvmPtrType, // ldb + llvmPtrType, // beta + llvmPtrType, // C + llvmPtrType, // ldc + llvmIntType, + llvmIntType + }, + false); + rewriter.create(op.getLoc(), blasFn, funcType, + LLVM::Linkage::External); + } + + llvm::errs() << "7\n"; + + + if (!moduleOp.lookupSymbol(blasFnWrapper)) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + auto funcType = LLVM::LLVMFunctionType::get( + llvmVoidType, + { + llvmPtrType, // side + llvmPtrType, // uplo + llvmPtrType, // m + llvmPtrType, // n + llvmPtrType, // alpha + llvmPtrType, // A + llvmPtrType, // lda + llvmPtrType, // B + llvmPtrType, // ldb + llvmPtrType, // beta + llvmPtrType, // C + llvmPtrType, // ldc + }, + false); + + llvm::errs() << "8\n"; + + + auto funcOp = + LLVM::LLVMFuncOp::create(rewriter, op.getLoc(), blasFnWrapper, + funcType, LLVM::Linkage::Private); + rewriter.setInsertionPointToStart(funcOp.addEntryBlock(rewriter)); + + SmallVector args(funcOp.getArguments().begin(), + funcOp.getArguments().end()); + auto const1 = + LLVM::ConstantOp::create(rewriter, op.getLoc(), llvmIntType, + rewriter.getIntegerAttr(llvmIntType, 1)); + args.push_back(const1); + args.push_back(const1); + + auto callOp = LLVM::CallOp::create(rewriter, op.getLoc(), TypeRange{}, + SymbolRefAttr::get(ctx, blasFn), args); + LLVM::ReturnOp::create(rewriter, op.getLoc(), ValueRange{}); + } + + llvm::errs() << "9\n"; + + + static int64_t fn_counter = 0; + blasFnWrapper += "_" + std::to_string(fn_counter++); + + SmallVector isColMajorArr(12, true); + SmallVector operandRanks = {0, 0, 0, 0, 0, 2, 0, 2, 0, 0, 2, 0}; + SmallVector outputRanks = {2}; + auto operandLayouts = getSHLOLayout(rewriter, operandRanks, isColMajorArr, 2); + auto resultLayouts = getSHLOLayout(rewriter, outputRanks, isColMajorArr, 2); + llvm::errs() << "12323\n"; + + + SmallVector aliases; + aliases.push_back(stablehlo::OutputOperandAliasAttr::get(ctx, {}, 10, {})); /*C*/ + + func::FuncOp shloFunc; + + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + SmallVector argTypes = { + op.getA().getType(), // A + op.getB().getType(), // B + op.getC().getType(), // C + op.getAlpha().getType(), // alpha + op.getBeta().getType(), // beta + }; + SmallVector retTypes = {op.getC().getType()}; + + auto calleeType = rewriter.getFunctionType(argTypes, retTypes); + auto shloFunc = func::FuncOp::create(rewriter, op.getLoc(), blasFnWrapper, calleeType); + shloFunc.setPrivate(); + + auto &entryBlock = *shloFunc.addEntryBlock(); + rewriter.setInsertionPointToStart(&entryBlock); + llvm::errs() << "10\n"; + + + auto A = entryBlock.getArgument(0); + auto B = entryBlock.getArgument(1); + auto C = entryBlock.getArgument(2); + auto alpha = entryBlock.getArgument(3); + auto beta = entryBlock.getArgument(4); + + auto side = rewriter.create( + op.getLoc(), uint8Type, + cast(makeAttr(uint8Type, side_value))); + auto uplo = rewriter.create( + op.getLoc(), uint8Type, + cast(makeAttr(uint8Type, uplo_value))); + + auto lda = stablehlo::ConvertOp::create( + rewriter, op.getLoc(), intType, + stablehlo::GetDimensionSizeOp::create(rewriter, op.getLoc(), A, 0)); + auto ldb = stablehlo::ConvertOp::create( + rewriter, op.getLoc(), intType, + stablehlo::GetDimensionSizeOp::create(rewriter, op.getLoc(), B, 0)); + auto ldc = stablehlo::ConvertOp::create( + rewriter, op.getLoc(), intType, + stablehlo::GetDimensionSizeOp::create(rewriter, op.getLoc(), C, 0)); + auto mSize = ldc; + auto nSize = stablehlo::ConvertOp::create( + rewriter, op.getLoc(), intType, + stablehlo::GetDimensionSizeOp::create(rewriter, op.getLoc(), C, 1)); + llvm::errs() << "11\n"; + + + auto jitCall = enzymexla::JITCallOp::create( + rewriter, op.getLoc(), TypeRange{op.getC().getType()}, + mlir::FlatSymbolRefAttr::get(ctx, blasFnWrapper), // TODO CHECK blasFnWrapper vs fn + ValueRange{side, uplo, mSize, nSize, alpha, A, lda, B, ldb, beta, C, ldc}, + rewriter.getStringAttr(""), + /*operand_layouts=*/operandLayouts, + /*result_layouts=*/resultLayouts, + /*arg_attrs=*/nullptr, + /*res_attrs=*/nullptr, + /*output_operand_aliases=*/rewriter.getArrayAttr(aliases), + /*xla_side_effect_free=*/rewriter.getUnitAttr()); + + func::ReturnOp::create(rewriter, op.getLoc(), + ValueRange{jitCall.getResult(0)}); // could be empty? + } + llvm::errs() << "12\n"; + + assert(op.getA() && "A is null"); + assert(op.getB() && "B is null"); + assert(op.getC() && "C is null"); + assert(op.getAlpha() && "alpha is null"); + assert(op.getBeta() && "beta is null"); + + moduleOp.verify(); + + auto callOp = func::CallOp::create( + rewriter, op.getLoc(), shloFunc, + ValueRange{op.getA(), op.getB(), op.getC(), op.getAlpha(), op.getBeta()}); + llvm::errs() << "13\n"; + + + auto result = callOp.getResult(0); + llvm::errs() << "14\n"; + + rewriter.replaceAllUsesWith(op.getResult(), result); + // rewriter.eraseOp(op); // remove? + + return success(); + } + + LogicalResult matchAndRewriteCUDA(enzymexla::SymmOp op, + PatternRewriter &rewriter) const { + return failure(); + } + LogicalResult matchAndRewriteTPU(enzymexla::SymmOp op, + PatternRewriter &rewriter) const { + return failure(); + } +}; + struct SyrkOpLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -325,7 +603,7 @@ struct LowerEnzymeXLABLASPass auto context = getOperation()->getContext(); RewritePatternSet patterns(context); - patterns.add(backend, blasIntWidth, context); + patterns.add(backend, blasIntWidth, context); GreedyRewriteConfig config; if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), From 90f15e01bd89f6f4b2c0da4b64e20b50198de4cc Mon Sep 17 00:00:00 2001 From: snonk Date: Wed, 3 Dec 2025 01:38:35 -0600 Subject: [PATCH 2/2] resolve errors --- .../jax/Passes/LowerEnzymeXLABLAS.cpp | 216 ++++++++---------- 1 file changed, 91 insertions(+), 125 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/LowerEnzymeXLABLAS.cpp b/src/enzyme_ad/jax/Passes/LowerEnzymeXLABLAS.cpp index 670fbfbf3..fd0e96895 100644 --- a/src/enzyme_ad/jax/Passes/LowerEnzymeXLABLAS.cpp +++ b/src/enzyme_ad/jax/Passes/LowerEnzymeXLABLAS.cpp @@ -36,7 +36,7 @@ struct SymmOpLowering : public OpRewritePattern { std::string backend; int64_t blasIntWidth; SymmOpLowering(std::string backend, int64_t blasIntWidth, - MLIRContext *context, PatternBenefit benefit = 1) + MLIRContext *context, PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), backend(backend), blasIntWidth(blasIntWidth) {} @@ -57,13 +57,11 @@ struct SymmOpLowering : public OpRewritePattern { } LogicalResult matchAndRewriteCPU(enzymexla::SymmOp op, - PatternRewriter &rewriter) const { - llvm::errs() << "1\n"; + PatternRewriter &rewriter) const { auto ctx = op->getContext(); LLVMTypeConverter typeConverter(ctx); - llvm::errs() << "2\n"; - + Value a = op.getOperand(0); Value b = op.getOperand(1); Value c = op.getOperand(2); @@ -71,38 +69,30 @@ struct SymmOpLowering : public OpRewritePattern { Value beta_value = op.getBeta(); auto side_value = op.getSide() == enzymexla::LapackSide::left ? 'L' : 'R'; auto uplo_value = op.getUplo() == enzymexla::LapackUplo::L ? 'L' : 'U'; - llvm::errs() << "3\n"; - + auto aType = cast(a.getType()); auto bType = cast(b.getType()); auto cType = cast(c.getType()); - llvm::errs() << "4\n"; if (!aType || !bType || !cType) -{ llvm::errs() << "operand types not ranked tensor types\n"; - return rewriter.notifyMatchFailure(op, "operand types not ranked tensor types"); + return rewriter.notifyMatchFailure( + op, "operand types not ranked tensor types"); -} if (!aType.hasRank() || !bType.hasRank() || !cType.hasRank()) -{ llvm::errs() << "expected ranked tensor types\n"; + if (!aType.hasRank() || !bType.hasRank() || !cType.hasRank()) return rewriter.notifyMatchFailure(op, "expected ranked tensor types"); -} - if (aType.getRank() != 2 || bType.getRank() > 2 || cType.getRank() > 2) -{ llvm::errs() << "only 2D matrices supported for symm\n"; - return rewriter.notifyMatchFailure(op, "only 2D matrices supported for symm"); -} - llvm::errs() << "passed type checks\n"; + if (aType.getRank() != 2 || bType.getRank() > 2 || cType.getRank() > 2) + return rewriter.notifyMatchFailure(op, + "only 2D matrices supported for symm"); Type elementType = aType.getElementType(); auto blasIntType = rewriter.getIntegerType(blasIntWidth); auto intType = RankedTensorType::get({}, blasIntType); - auto uint8Type = RankedTensorType::get({}, rewriter.getIntegerType(8, false)); + auto uint8Type = + RankedTensorType::get({}, rewriter.getIntegerType(8, false)); auto llvmIntType = typeConverter.convertType(blasIntType); - auto llvmElmType = typeConverter.convertType(elementType); auto llvmPtrType = LLVM::LLVMPointerType::get(ctx); auto llvmVoidType = LLVM::LLVMVoidType::get(ctx); - llvm::errs() << "5\n"; - std::string blasFn; if (auto prefix = lapackPrecisionPrefix(elementType)) { blasFn = "enzymexla_blas_" + *prefix + "symm_"; @@ -111,63 +101,51 @@ struct SymmOpLowering : public OpRewritePattern { return rewriter.notifyMatchFailure(op, "unsupported element type"); } std::string blasFnWrapper = blasFn + "wrapper"; - llvm::errs() << "6\n"; auto moduleOp = op->getParentOfType(); if (!moduleOp.lookupSymbol(blasFn)) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(moduleOp.getBody()); - auto funcType = - LLVM::LLVMFunctionType::get(llvmVoidType, - { - llvmPtrType, // side - llvmPtrType, // uplo - llvmPtrType, // m - llvmPtrType, // n - llvmPtrType, // alpha - llvmPtrType, // A - llvmPtrType, // lda - llvmPtrType, // B - llvmPtrType, // ldb - llvmPtrType, // beta - llvmPtrType, // C - llvmPtrType, // ldc - llvmIntType, - llvmIntType - }, - false); + auto funcType = LLVM::LLVMFunctionType::get(llvmVoidType, + {llvmPtrType, // side + llvmPtrType, // uplo + llvmPtrType, // m + llvmPtrType, // n + llvmPtrType, // alpha + llvmPtrType, // A + llvmPtrType, // lda + llvmPtrType, // B + llvmPtrType, // ldb + llvmPtrType, // beta + llvmPtrType, // C + llvmPtrType, // ldc + llvmIntType, llvmIntType}, + false); rewriter.create(op.getLoc(), blasFn, funcType, LLVM::Linkage::External); } - llvm::errs() << "7\n"; - - if (!moduleOp.lookupSymbol(blasFnWrapper)) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(moduleOp.getBody()); - auto funcType = LLVM::LLVMFunctionType::get( - llvmVoidType, - { - llvmPtrType, // side - llvmPtrType, // uplo - llvmPtrType, // m - llvmPtrType, // n - llvmPtrType, // alpha - llvmPtrType, // A - llvmPtrType, // lda - llvmPtrType, // B - llvmPtrType, // ldb - llvmPtrType, // beta - llvmPtrType, // C - llvmPtrType, // ldc - }, - false); - - llvm::errs() << "8\n"; - + auto funcType = LLVM::LLVMFunctionType::get(llvmVoidType, + { + llvmPtrType, // side + llvmPtrType, // uplo + llvmPtrType, // m + llvmPtrType, // n + llvmPtrType, // alpha + llvmPtrType, // A + llvmPtrType, // lda + llvmPtrType, // B + llvmPtrType, // ldb + llvmPtrType, // beta + llvmPtrType, // C + llvmPtrType, // ldc + }, + false); auto funcOp = LLVM::LLVMFuncOp::create(rewriter, op.getLoc(), blasFnWrapper, @@ -187,46 +165,43 @@ struct SymmOpLowering : public OpRewritePattern { LLVM::ReturnOp::create(rewriter, op.getLoc(), ValueRange{}); } - llvm::errs() << "9\n"; - - static int64_t fn_counter = 0; - blasFnWrapper += "_" + std::to_string(fn_counter++); + std::string funcFnName = blasFnWrapper + "_" + std::to_string(fn_counter++); SmallVector isColMajorArr(12, true); - SmallVector operandRanks = {0, 0, 0, 0, 0, 2, 0, 2, 0, 0, 2, 0}; + SmallVector operandRanks = { + 0, 0, 0, 0, 0, 2, 0, op.getB().getType().getRank(), 0, 0, 2, 0}; SmallVector outputRanks = {2}; - auto operandLayouts = getSHLOLayout(rewriter, operandRanks, isColMajorArr, 2); + auto operandLayouts = + getSHLOLayout(rewriter, operandRanks, isColMajorArr, 2); auto resultLayouts = getSHLOLayout(rewriter, outputRanks, isColMajorArr, 2); - llvm::errs() << "12323\n"; - SmallVector aliases; - aliases.push_back(stablehlo::OutputOperandAliasAttr::get(ctx, {}, 10, {})); /*C*/ + aliases.push_back( + stablehlo::OutputOperandAliasAttr::get(ctx, {}, 10, {})); /*C*/ func::FuncOp shloFunc; - + { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(moduleOp.getBody()); SmallVector argTypes = { - op.getA().getType(), // A - op.getB().getType(), // B - op.getC().getType(), // C - op.getAlpha().getType(), // alpha - op.getBeta().getType(), // beta + op.getA().getType(), // A + op.getB().getType(), // B + op.getC().getType(), // C + op.getAlpha().getType(), // alpha + op.getBeta().getType(), // beta }; SmallVector retTypes = {op.getC().getType()}; auto calleeType = rewriter.getFunctionType(argTypes, retTypes); - auto shloFunc = func::FuncOp::create(rewriter, op.getLoc(), blasFnWrapper, calleeType); + shloFunc = + func::FuncOp::create(rewriter, op.getLoc(), funcFnName, calleeType); shloFunc.setPrivate(); auto &entryBlock = *shloFunc.addEntryBlock(); rewriter.setInsertionPointToStart(&entryBlock); - llvm::errs() << "10\n"; - auto A = entryBlock.getArgument(0); auto B = entryBlock.getArgument(1); @@ -234,62 +209,52 @@ struct SymmOpLowering : public OpRewritePattern { auto alpha = entryBlock.getArgument(3); auto beta = entryBlock.getArgument(4); - auto side = rewriter.create( + auto side = rewriter.create( op.getLoc(), uint8Type, cast(makeAttr(uint8Type, side_value))); - auto uplo = rewriter.create( + auto uplo = rewriter.create( op.getLoc(), uint8Type, cast(makeAttr(uint8Type, uplo_value))); - + auto lda = stablehlo::ConvertOp::create( - rewriter, op.getLoc(), intType, - stablehlo::GetDimensionSizeOp::create(rewriter, op.getLoc(), A, 0)); + rewriter, op.getLoc(), intType, + stablehlo::GetDimensionSizeOp::create(rewriter, op.getLoc(), A, 0)); auto ldb = stablehlo::ConvertOp::create( - rewriter, op.getLoc(), intType, - stablehlo::GetDimensionSizeOp::create(rewriter, op.getLoc(), B, 0)); + rewriter, op.getLoc(), intType, + stablehlo::GetDimensionSizeOp::create(rewriter, op.getLoc(), B, 0)); auto ldc = stablehlo::ConvertOp::create( - rewriter, op.getLoc(), intType, - stablehlo::GetDimensionSizeOp::create(rewriter, op.getLoc(), C, 0)); + rewriter, op.getLoc(), intType, + stablehlo::GetDimensionSizeOp::create(rewriter, op.getLoc(), C, 0)); auto mSize = ldc; auto nSize = stablehlo::ConvertOp::create( - rewriter, op.getLoc(), intType, - stablehlo::GetDimensionSizeOp::create(rewriter, op.getLoc(), C, 1)); - llvm::errs() << "11\n"; + rewriter, op.getLoc(), intType, + stablehlo::GetDimensionSizeOp::create(rewriter, op.getLoc(), C, 1)); - auto jitCall = enzymexla::JITCallOp::create( - rewriter, op.getLoc(), TypeRange{op.getC().getType()}, - mlir::FlatSymbolRefAttr::get(ctx, blasFnWrapper), // TODO CHECK blasFnWrapper vs fn - ValueRange{side, uplo, mSize, nSize, alpha, A, lda, B, ldb, beta, C, ldc}, - rewriter.getStringAttr(""), - /*operand_layouts=*/operandLayouts, - /*result_layouts=*/resultLayouts, - /*arg_attrs=*/nullptr, - /*res_attrs=*/nullptr, - /*output_operand_aliases=*/rewriter.getArrayAttr(aliases), - /*xla_side_effect_free=*/rewriter.getUnitAttr()); + rewriter, op.getLoc(), TypeRange{op.getC().getType()}, + mlir::FlatSymbolRefAttr::get( + ctx, blasFnWrapper), // TODO CHECK blasFnWrapper vs fn + ValueRange{side, uplo, mSize, nSize, alpha, A, lda, B, ldb, beta, C, + ldc}, + rewriter.getStringAttr(""), + /*operand_layouts=*/operandLayouts, + /*result_layouts=*/resultLayouts, + /*arg_attrs=*/nullptr, + /*res_attrs=*/nullptr, + /*output_operand_aliases=*/rewriter.getArrayAttr(aliases), + /*xla_side_effect_free=*/rewriter.getUnitAttr()); - func::ReturnOp::create(rewriter, op.getLoc(), - ValueRange{jitCall.getResult(0)}); // could be empty? + func::ReturnOp::create( + rewriter, op.getLoc(), + ValueRange{jitCall.getResult(0)}); // could be empty? } - llvm::errs() << "12\n"; - - assert(op.getA() && "A is null"); - assert(op.getB() && "B is null"); - assert(op.getC() && "C is null"); - assert(op.getAlpha() && "alpha is null"); - assert(op.getBeta() && "beta is null"); - - moduleOp.verify(); - - auto callOp = func::CallOp::create( - rewriter, op.getLoc(), shloFunc, - ValueRange{op.getA(), op.getB(), op.getC(), op.getAlpha(), op.getBeta()}); - llvm::errs() << "13\n"; + auto callOp = + func::CallOp::create(rewriter, op.getLoc(), shloFunc, + ValueRange{op.getA(), op.getB(), op.getC(), + op.getAlpha(), op.getBeta()}); auto result = callOp.getResult(0); - llvm::errs() << "14\n"; rewriter.replaceAllUsesWith(op.getResult(), result); // rewriter.eraseOp(op); // remove? @@ -313,7 +278,7 @@ struct SyrkOpLowering : public OpRewritePattern { SyrkOpLowering(std::string backend, int64_t blasIntWidth, MLIRContext *context, PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), backend(backend), - blasIntWidth(blasIntWidth){}; + blasIntWidth(blasIntWidth) {}; LogicalResult matchAndRewrite(enzymexla::SyrkOp op, PatternRewriter &rewriter) const override { @@ -603,7 +568,8 @@ struct LowerEnzymeXLABLASPass auto context = getOperation()->getContext(); RewritePatternSet patterns(context); - patterns.add(backend, blasIntWidth, context); + patterns.add(backend, blasIntWidth, + context); GreedyRewriteConfig config; if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),