Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoweringStrategy] Use a more general method to fetch input dims and sizes #1090

Merged
merged 3 commits into from
Feb 11, 2025

Conversation

yzhang93
Copy link
Contributor

@yzhang93 yzhang93 commented Feb 8, 2025

There is a bug in the existing codes to get M/N/K size from a matmul-like op, i.e.,

const uint64_t M = initShape[0];
const uint64_t N = initShape[1];
const uint64_t K = lhsShape[1];

Apparently, K shouldn't be lhsShape[1] if it's a matmul-tranpose-a op.
It's hard to infer the shape if the input matmul-like ops are transposed and in linalg.generic form. Or even the input has a higher number of dimensions such as mmt4d ops.

In addition, the indexing maps of matmul-like linalg.generic ops can be transposed during dispatch generation by default because of TransposeGenericOpsPass.
Example:
parallel parallel reduction parallel parallel reduction will become
parallel parallel parallel parallel reduction reduction after dispatch generation.
So if we still put the pack size/tile size at the dim for the former indexing map, it will generate wrong size for tiling.

This PR uses the upstream method linalg::inferContractionDims to infer the dim indices of M/N/K for all contraction ops.

Copy link
Contributor

@Abhishek-Varma Abhishek-Varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! A few comments to address.

Copy link
Contributor

@newling newling left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great improvement!

Mostly nits, but a question about using ContractionDimensions's inferred batch dimension(s).

@@ -842,6 +902,8 @@ static LogicalResult setRootConfig(mlir::FunctionOpInterface entryPointFn,
uint32_t numCols) {
assert(!getLoweringConfig<IREE::Codegen::LoweringConfigAttr>(genericOp) &&
"expected lowering_config is not set");
unsigned numLoops = genericOp.getNumLoops();
assert(numLoops <= 7 && "expected input number of loops no more than 7");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these asserts needed, where is this assumption of <= 7 used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kind of. I was thinking to restrict the contraction ops to only include 2D inputs like matmul ops or 4D inputs like mmt4d (or their batch versions) because the current pack strategy is specially designed for those ops. However, I think there's barely input ops with more than 7 loops in real application. I can delete this assert for now.

SmallVector<unsigned, 2> nDims = contractionDims.n;
SmallVector<unsigned, 2> kDims = contractionDims.k;
if (mDims.empty() || nDims.empty() || kDims.empty()) {
return linalgOp.emitOpError("failed to fetch m/n/k dims.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When might this be empty after you've already not received failure from linalg::inferContractionDims ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not fully sure about this, but I've seen this check everytime this function is called in IREE, e.g. https://github.com/iree-org/iree/blob/624a9fae4f8dd03becfaae69110411c84c4ed6e8/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPackToIntrinsics.cpp#L51.

I think even if linalg::inferContractionDims returns success, it doesn't guarantee all the m, n, k are non-empty. We want to make sure all these dimensions are non-empty in order to proceed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the link above, it's doing the check just before calling back() on each of the vectors, so at a point where it knows that the intrinsic it's targeting has all 3. Your implicit assumption here is that you're doing some kind of matmul, but maybe we'll want to reuse this for matvec or something where there is no n dimension.

Perhaps add a comment, change the class name, or move the checks later.

@@ -84,6 +84,52 @@ FailureOr<std::array<uint32_t, 3>> getPackedSize(linalg::LinalgOp linalgOp,
return instructionSize;
}

struct InputDimsAndSizes {
SmallVector<unsigned, 2> mDims;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personal preference: Use SmallVector<T> instead of SmallVector<T, 2>.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer to use the same type as in upstream. In addition it seems not work to use SmallVector when calling

SmallVector<unsigned> mDims = contractionDims.m;

It has error: no viable conversion from 'SmallVector<[...], 2>' to 'SmallVector<[...], (default) 12>'

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I didn't realize this is the way upstream does it, that's fine then.

FWIW you could do

SmallVector mDims {contractionDims.m.begin(), contractionDims.m.end()};

@yzhang93 yzhang93 force-pushed the refactor_get_shapes_dims branch 2 times, most recently from a33465c to b46089c Compare February 10, 2025 22:08
Copy link
Collaborator

@jtuyls jtuyls left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@MaheshRavishankar
Copy link
Collaborator

Nice!

tileSizeLevel1.insert(tileSizeLevel1.begin(), 0);
tileSizeLevel2.insert(tileSizeLevel2.begin(), 0);
tileSizeLevel3.insert(tileSizeLevel3.begin(), 0);
tileSizeLevel0[batchDims[0]] = 1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could assert that batchDims is not empty, like the others.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

SmallVector<unsigned, 2> nDims = contractionDims.n;
SmallVector<unsigned, 2> kDims = contractionDims.k;
if (mDims.empty() || nDims.empty() || kDims.empty()) {
return linalgOp.emitOpError("failed to fetch m/n/k dims.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the link above, it's doing the check just before calling back() on each of the vectors, so at a point where it knows that the intrinsic it's targeting has all 3. Your implicit assumption here is that you're doing some kind of matmul, but maybe we'll want to reuse this for matvec or something where there is no n dimension.

Perhaps add a comment, change the class name, or move the checks later.

@yzhang93 yzhang93 force-pushed the refactor_get_shapes_dims branch from b46089c to 1226376 Compare February 11, 2025 01:19
Copy link
Contributor

@Abhishek-Varma Abhishek-Varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@yzhang93 yzhang93 merged commit 1f8ec3e into nod-ai:main Feb 11, 2025
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants