Skip to content

Commit

Permalink
find BlockArgument from gemm output going through all view-like opera…
Browse files Browse the repository at this point in the history
…tions (#1690)
  • Loading branch information
dhernandez0 committed Oct 29, 2024
1 parent 04d29b9 commit 28e2301
Showing 1 changed file with 19 additions and 1 deletion.
20 changes: 19 additions & 1 deletion mlir/lib/Dialect/Rock/utility/loweringUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,23 @@ FailureOr<memref::AllocOp> mlir::rock::findMemrefAlloc(Value value) {
return findAlloc<memref::AllocOp>(value);
}

FailureOr<BlockArgument> findBlockArgument(Value value) {
auto maybeBlockArg = dyn_cast_or_null<BlockArgument>(value);
while (!maybeBlockArg) {
// Keep going until the operation that defines the value is a
// view-like operation
if (auto viewOp =
dyn_cast_or_null<ViewLikeOpInterface>(value.getDefiningOp())) {
value = viewOp.getViewSource();
} else {
return failure();
}
maybeBlockArg = dyn_cast_or_null<BlockArgument>(value);
}

return maybeBlockArg;
}

std::optional<int64_t> mlir::rock::computeConstDiff(Value l, Value u) {
IntegerAttr clb, cub;
if (matchPattern(l, m_Constant(&clb)) && matchPattern(u, m_Constant(&cub))) {
Expand Down Expand Up @@ -826,9 +843,10 @@ mlir::rock::traceGemmOutputToArgs(Value matC, func::FuncOp func,
auto funcArgs = func.getArguments();
// check if matC is a kernel argument
for (auto arg : funcArgs) {
if (matC == arg)
if (findBlockArgument(matC) == arg)
args.push_back(arg);
}
assert(args.empty() || args.size() == 1);
if (!args.empty())
return args;

Expand Down

0 comments on commit 28e2301

Please sign in to comment.