-
Notifications
You must be signed in to change notification settings - Fork 645
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LLVMGPUVectorDistribute] Refactor vector.contract distribute #19631
[LLVMGPUVectorDistribute] Refactor vector.contract distribute #19631
Conversation
4379322
to
41bb886
Compare
Value isoRankLocalReduced = rewriter.create<vector::ShapeCastOp>( | ||
loc, partialReducedDistributedType, localContractValue); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we instead do a broadcast + transpose? broadcast + transpose is prefered for unit dimensions because it is easier to unroll / reason about
SmallVector<int64_t> subgroupTileLens = | ||
llvm::to_vector(resLayout.getSubgroupTile()); | ||
SmallVector<int64_t> batchTileLens = | ||
llvm::to_vector(resLayout.getBatchTile()); | ||
SmallVector<int64_t> outerTileLens = | ||
llvm::to_vector(resLayout.getOuterTile()); | ||
SmallVector<int64_t> threadTileLens = | ||
llvm::to_vector(resLayout.getThreadTile()); | ||
SmallVector<int64_t> elementTileLens = | ||
llvm::to_vector(resLayout.getElementTile()); | ||
SmallVector<int64_t> subgroupStrides = | ||
llvm::to_vector(resLayout.getSubgroupStrides()); | ||
SmallVector<int64_t> threadStrides = | ||
llvm::to_vector(resLayout.getThreadStrides()); | ||
for (int64_t rDim : reducedDims) { | ||
// thread-local reductions have already been carried out. | ||
// What is remaining is reductions across threads and | ||
// subgroups. | ||
subgroupTileLens.push_back(lhsLayout.getSubgroupTile()[rDim]); | ||
batchTileLens.push_back(1); | ||
outerTileLens.push_back(1); | ||
threadTileLens.push_back(lhsLayout.getThreadTile()[rDim]); | ||
elementTileLens.push_back(1); | ||
subgroupStrides.push_back(lhsLayout.getSubgroupStrides()[rDim]); | ||
threadStrides.push_back(lhsLayout.getThreadStrides()[rDim]); | ||
} | ||
reductionLayout = IREE::VectorExt::NestedLayoutAttr::get( | ||
contractOp.getContext(), subgroupTileLens, batchTileLens, | ||
outerTileLens, threadTileLens, elementTileLens, subgroupStrides, | ||
threadStrides); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a "broadcast" function to the layout which takes a dimension and sizes for a single dimension? This would make it easier to write these transformations.
@Groverkss spent few cycles here today... I ve simplified the implementation a bit by adding a builder for NestedLayoutAttr that can take an existing layout and append few tiling dims. This API is sufficient to manually infer the partial reduction layout. Then rest of rewrite can be derived from it altogether -- when you have some time PTAL. |
de22597
to
abaad34
Compare
0211656
to
4f75304
Compare
Currently, vector.contract distribution is implemented as a standalone distribution closely following vector.multi_reduce. Therefore, we have to duplicate code/effort when we improve either one. This commit changes vector.contract just to distribute the "contract" part of it. Then it creates a new vector.multi_reduce to be re-distributed with partial reduction semantics. Thus, allowing the improvements of vector.multi_reduce to be re-used by vector.contract Signed-off-by: Manupa Karunaratne <manupa.karunaratne@amd.com>
* use this to infer partial reduction layout of vector.contract Signed-off-by: Manupa Karunaratne <manupa.karunaratne@amd.com>
Currently, vector.contract distribution is implemented as a standalone distribution closely following vector.multi_reduce. Therefore, we have to duplicate code/effort when we improve either one.
This commit changes vector.contract just to distribute the "contract" part of it. Then it creates a new vector.multi_reduce to be re-distributed with partial reduction semantics. Thus, allowing the improvements of vector.multi_reduce to be re-used by vector.contract
closes : #19620