-
Notifications
You must be signed in to change notification settings - Fork 31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LoweringStrategy] Use a more general method to fetch input dims and sizes #1090
Conversation
4ff5e29
to
bba961e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! A few comments to address.
compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/KernelDispatch.cpp
Outdated
Show resolved
Hide resolved
compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/KernelDispatch.cpp
Outdated
Show resolved
Hide resolved
compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/KernelDispatch.cpp
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great improvement!
Mostly nits, but a question about using ContractionDimensions's inferred batch dimension(s).
compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/KernelDispatch.cpp
Outdated
Show resolved
Hide resolved
compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/KernelDispatch.cpp
Outdated
Show resolved
Hide resolved
@@ -842,6 +902,8 @@ static LogicalResult setRootConfig(mlir::FunctionOpInterface entryPointFn, | |||
uint32_t numCols) { | |||
assert(!getLoweringConfig<IREE::Codegen::LoweringConfigAttr>(genericOp) && | |||
"expected lowering_config is not set"); | |||
unsigned numLoops = genericOp.getNumLoops(); | |||
assert(numLoops <= 7 && "expected input number of loops no more than 7"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these asserts needed, where is this assumption of <= 7 used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Kind of. I was thinking to restrict the contraction ops to only include 2D inputs like matmul ops or 4D inputs like mmt4d (or their batch versions) because the current pack strategy is specially designed for those ops. However, I think there's barely input ops with more than 7 loops in real application. I can delete this assert for now.
SmallVector<unsigned, 2> nDims = contractionDims.n; | ||
SmallVector<unsigned, 2> kDims = contractionDims.k; | ||
if (mDims.empty() || nDims.empty() || kDims.empty()) { | ||
return linalgOp.emitOpError("failed to fetch m/n/k dims."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When might this be empty after you've already not received failure from linalg::inferContractionDims ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not fully sure about this, but I've seen this check everytime this function is called in IREE, e.g. https://github.com/iree-org/iree/blob/624a9fae4f8dd03becfaae69110411c84c4ed6e8/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPackToIntrinsics.cpp#L51.
I think even if linalg::inferContractionDims
returns success, it doesn't guarantee all the m, n, k are non-empty. We want to make sure all these dimensions are non-empty in order to proceed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At the link above, it's doing the check just before calling back() on each of the vectors, so at a point where it knows that the intrinsic it's targeting has all 3. Your implicit assumption here is that you're doing some kind of matmul, but maybe we'll want to reuse this for matvec or something where there is no n
dimension.
Perhaps add a comment, change the class name, or move the checks later.
@@ -84,6 +84,52 @@ FailureOr<std::array<uint32_t, 3>> getPackedSize(linalg::LinalgOp linalgOp, | |||
return instructionSize; | |||
} | |||
|
|||
struct InputDimsAndSizes { | |||
SmallVector<unsigned, 2> mDims; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Personal preference: Use SmallVector<T>
instead of SmallVector<T, 2>
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer to use the same type as in upstream. In addition it seems not work to use SmallVector when calling
SmallVector<unsigned> mDims = contractionDims.m;
It has error: no viable conversion from 'SmallVector<[...], 2>' to 'SmallVector<[...], (default) 12>'
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I didn't realize this is the way upstream does it, that's fine then.
FWIW you could do
SmallVector mDims {contractionDims.m.begin(), contractionDims.m.end()};
compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/KernelDispatch.cpp
Outdated
Show resolved
Hide resolved
compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/KernelDispatch.cpp
Show resolved
Hide resolved
compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/KernelDispatch.cpp
Outdated
Show resolved
Hide resolved
a33465c
to
b46089c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/KernelDispatch.cpp
Outdated
Show resolved
Hide resolved
Nice! |
tileSizeLevel1.insert(tileSizeLevel1.begin(), 0); | ||
tileSizeLevel2.insert(tileSizeLevel2.begin(), 0); | ||
tileSizeLevel3.insert(tileSizeLevel3.begin(), 0); | ||
tileSizeLevel0[batchDims[0]] = 1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could assert that batchDims is not empty, like the others.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
SmallVector<unsigned, 2> nDims = contractionDims.n; | ||
SmallVector<unsigned, 2> kDims = contractionDims.k; | ||
if (mDims.empty() || nDims.empty() || kDims.empty()) { | ||
return linalgOp.emitOpError("failed to fetch m/n/k dims."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At the link above, it's doing the check just before calling back() on each of the vectors, so at a point where it knows that the intrinsic it's targeting has all 3. Your implicit assumption here is that you're doing some kind of matmul, but maybe we'll want to reuse this for matvec or something where there is no n
dimension.
Perhaps add a comment, change the class name, or move the checks later.
b46089c
to
1226376
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
There is a bug in the existing codes to get M/N/K size from a matmul-like op, i.e.,
Apparently, K shouldn't be
lhsShape[1]
if it's amatmul-tranpose-a
op.It's hard to infer the shape if the input matmul-like ops are transposed and in
linalg.generic
form. Or even the input has a higher number of dimensions such as mmt4d ops.In addition, the indexing maps of matmul-like
linalg.generic
ops can be transposed during dispatch generation by default because ofTransposeGenericOpsPass
.Example:
parallel parallel reduction parallel parallel reduction
will becomeparallel parallel parallel parallel reduction reduction
after dispatch generation.So if we still put the pack size/tile size at the dim for the former indexing map, it will generate wrong size for tiling.
This PR uses the upstream method
linalg::inferContractionDims
to infer the dim indices of M/N/K for all contraction ops.