Skip to content

Commit

Permalink
[LLVMGPUVectorDistribute][NFC] Refactor vector.contract distribute
Browse files Browse the repository at this point in the history
Currently, vector.contract distribution is implemented as a standalone
distribution closely following vector.multi_reduce. Therefore, we have
to duplicate code/effort when we improve either one.

This commit changes vector.contract just to distribute the "contract"
part of it. Then it creates a new vector.multi_reduce to be
re-distributed with partial reduction semantics. Thus, allowing
the improvements of vector.multi_reduce to be re-used by
vector.contract

Signed-off-by: Manupa Karunaratne <manupa.karunaratne@amd.com>
  • Loading branch information
manupak committed Jan 22, 2025
1 parent 6933c39 commit fcfeaf5
Showing 1 changed file with 122 additions and 100 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -442,13 +442,6 @@ struct DistributeMultiReduction final
}

Type elemTy = srcVector.getType().getElementType();
unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth();
if (elemBitwidth != maxBitsPerShuffle) {
return rewriter.notifyMatchFailure(
multiReduceOp, llvm::formatv("unimplemented: packed shuffle",
elemBitwidth, maxBitsPerShuffle));
}

VectorValue disSrc =
getDistributed(rewriter, srcVector, signature[srcVector]);

Expand Down Expand Up @@ -770,24 +763,18 @@ struct DistributeMultiReduction final
int64_t maxBitsPerShuffle;
};

/// The lowering for Contract is performed in three steps (similar to above
/// multi_reduction):
/// 1. Local Contract: Each thread performs operations on its locally
/// distributed elements.
/// 2. Subgroup Reduction: Threads in each subgroup reduce the results from
/// step 1 across threads using a subgroup reduction if distribution occurs
/// along the reduction dimension.
/// 3. Accumulator Reduction: Each thread combines its intermediate results
/// with its held accumulator.
///
/// Currently, reduction across multiple warps is not supported.
/// The distribution of contract is performed by doing a local contraction where
/// each thread performs operations on its locally distributed elements. Then,
/// the resulting vector is interpreted in undistributed domain. The said
/// undistributed vector is a partial reduction when contraction has been
/// performed only thread locally. Therefore, a to-be-distributed
/// vector.multi_reduce
////is added to complete the contraction.
struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {
using OpDistributionPattern::OpDistributionPattern;

DistributeContract(MLIRContext *context, int64_t subgroupSize,
int64_t maxBitsPerShuffle, int64_t benefit = 1)
: OpDistributionPattern(context, benefit), subgroupSize(subgroupSize),
maxBitsPerShuffle(maxBitsPerShuffle) {}
DistributeContract(MLIRContext *context, int64_t benefit = 1)
: OpDistributionPattern(context, benefit) {}

LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
DistributionSignature &signature,
Expand Down Expand Up @@ -838,19 +825,20 @@ struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {
Location loc = contractOp.getLoc();

// Step 1: local contraction
Value localInit = getCombiningIdentityValue(
loc, rewriter, contractOp.getKind(), disAcc.getType());
vector::ContractionOp localContractOp = doDistributedContraction(
rewriter, loc, ctx, contractOp, disLhs, disRhs, disAcc);
rewriter, loc, ctx, contractOp, disLhs, disRhs, localInit);

int64_t rank = lhsLayout.getRank();
SmallVector<bool> reducedDims(rank, false);
SmallVector<int64_t> reducedDims;

// Identify the reduction dimension and apply it for subgroup reduction.
for (auto [index, iteratorType] :
llvm::enumerate(contractOp.getIteratorTypes())) {
if (vector::isReductionIterator(iteratorType)) {
auto map = contractOp.getIndexingMapsArray()[0];
int64_t redIdx = *(map.getResultPosition(getAffineDimExpr(index, ctx)));
reducedDims[redIdx] = true;
reducedDims.push_back(redIdx);
}
}

Expand All @@ -865,46 +853,121 @@ struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {

assert(localContractValue && "result should have been a vector");

// Flatten the locally result value.
VectorType shaped = localContractValue.getType();
int64_t numElements = shaped.getNumElements();
SmallVector<int64_t> flatShape(1, numElements);
VectorType flatVecType = VectorType::get(flatShape, accElemTy);
VectorValue flat = rewriter.create<vector::ShapeCastOp>(loc, flatVecType,
localContractValue);

// Step 2: Do subgroup reduction.
FailureOr<VectorValue> threadReduced = doThreadReduction(
rewriter, lhsLayout, flat, contractOp.getKind(), reducedDims);
if (failed(threadReduced)) {
return failure();
NestedLayoutAttr resLayout;
if (auto contractRes = dyn_cast<VectorValue>(contractOp.getResult())) {
resLayout = dyn_cast<NestedLayoutAttr>(signature[contractRes]);
} else {
// Create a zero-d layout because we
// are going to add reduction dims
// back to handle the partial reduction
resLayout = NestedLayoutAttr::get(contractOp.getContext(), {}, {}, {}, {},
{}, {}, {});
}

// Do reduction against accumulator, which needs to be done after thread
// reduction.
VectorValue unflattened = rewriter.create<vector::ShapeCastOp>(
loc, shaped, threadReduced.value());

if (!accVector) {
disAcc = rewriter.create<vector::BroadcastOp>(loc, shaped, disAcc);
// Shapecast to re-insert reduction dimensions as unit dims.
// We append the result shape with reduction dimensions as
// the fastest changing dimensions.
int64_t opRank = contractOp.getIteratorTypes().size();
SmallVector<int64_t> partialReducedDistributedShape(opRank * 3, 0);
SmallVector<int64_t> resDistShape = resLayout.getDistributedShape();
SmallVector<int64_t> lhsDistShape = lhsLayout.getDistributedShape();
for (int64_t tileGroupIdx : llvm::seq<int64_t>(3)) {
int64_t tileGroupOffset = tileGroupIdx * opRank;
for (int64_t dim : llvm::seq<int64_t>(opRank)) {
if (dim < resLayout.getRank()) {
int64_t resTileGroupOffset = tileGroupIdx * resLayout.getRank();
partialReducedDistributedShape[tileGroupOffset + dim] =
resDistShape[resTileGroupOffset + dim];
} else {
// add the thread local reduced dims at the tail.
partialReducedDistributedShape[tileGroupOffset + dim] = 1;
}
}
}

// Step 3: Accumulator Reduction
Value accReduction = vector::makeArithReduction(
rewriter, loc, contractOp.getKind(), unflattened, disAcc);
auto accReduced = dyn_cast<VectorValue>(accReduction);
if (!accReduced) {
return failure();
}
VectorType partialReducedDistributedType =
VectorType::get(partialReducedDistributedShape,
localContractValue.getType().getElementType());
Value isoRankLocalReduced = rewriter.create<vector::ShapeCastOp>(
loc, partialReducedDistributedType, localContractValue);

SmallVector<int64_t> partialReductionShape;
partialReductionShape.reserve(lhsLayout.getRank());
if (resVector) {
replaceOpWithDistributedValues(rewriter, contractOp, accReduced);
} else {
Value accReducedVal = rewriter.create<vector::ExtractOp>(
loc, accReduction, SmallVector<int64_t>{0});
replaceOpWithDistributedValues(rewriter, contractOp, accReducedVal);
ArrayRef<int64_t> resVectorShape = resVector.getType().getShape();
partialReductionShape.insert(partialReductionShape.end(),
resVectorShape.begin(),
resVectorShape.end());
}
ArrayRef<int64_t> lhsShape = contractOp.getLhs().getType().getShape();
// Note that partial reduction dimensions are at the end.
SmallVector<int64_t> partialReduceDims;
partialReduceDims.reserve(reducedDims.size());
for (int64_t rDim : reducedDims) {
partialReduceDims.push_back(partialReductionShape.size());
partialReductionShape.push_back(lhsShape[rDim]);
}
VectorType unDistributedType = VectorType::get(
partialReductionShape, localContractValue.getType().getElementType());
Value undistrLocalReduced = rewriter.create<IREE::VectorExt::ToSIMDOp>(
loc, unDistributedType, isoRankLocalReduced);

// Manually infer the layout of partial reduction
IREE::VectorExt::NestedLayoutAttr reductionLayout;
{
SmallVector<int64_t> subgroupTileLens =
llvm::to_vector(resLayout.getSubgroupTile());
SmallVector<int64_t> batchTileLens =
llvm::to_vector(resLayout.getBatchTile());
SmallVector<int64_t> outerTileLens =
llvm::to_vector(resLayout.getOuterTile());
SmallVector<int64_t> threadTileLens =
llvm::to_vector(resLayout.getThreadTile());
SmallVector<int64_t> elementTileLens =
llvm::to_vector(resLayout.getElementTile());
SmallVector<int64_t> subgroupStrides =
llvm::to_vector(resLayout.getSubgroupStrides());
SmallVector<int64_t> threadStrides =
llvm::to_vector(resLayout.getThreadStrides());
for (int64_t rDim : reducedDims) {
// thread-local reductions have already been carried out.
// What is remaining is reductions across threads and
// subgroups.
subgroupTileLens.push_back(lhsLayout.getSubgroupTile()[rDim]);
batchTileLens.push_back(1);
outerTileLens.push_back(1);
threadTileLens.push_back(lhsLayout.getThreadTile()[rDim]);
elementTileLens.push_back(1);
subgroupStrides.push_back(lhsLayout.getSubgroupStrides()[rDim]);
threadStrides.push_back(lhsLayout.getThreadStrides()[rDim]);
}
reductionLayout = IREE::VectorExt::NestedLayoutAttr::get(
contractOp.getContext(), subgroupTileLens, batchTileLens,
outerTileLens, threadTileLens, elementTileLens, subgroupStrides,
threadStrides);
}

// Create the partial reduction
auto partialReduction = rewriter.create<vector::MultiDimReductionOp>(
loc, contractOp.getKind(), undistrLocalReduced, acc, partialReduceDims);
{
auto unitAttr = UnitAttr::get(rewriter.getContext());
auto reduceAttrs =
SmallVector<Attribute>(partialReduction->getNumOperands(), unitAttr);
reduceAttrs[0] = reductionLayout;
ArrayAttr reduceResultsAttr =
ArrayAttr::get(rewriter.getContext(), {unitAttr});
if (auto dstLayout =
dyn_cast_or_null<NestedLayoutAttr>(signature[resVector])) {
reduceAttrs[1] = dstLayout;
reduceResultsAttr = ArrayAttr::get(rewriter.getContext(), {dstLayout});
}
ArrayAttr reduceOperandsAttr =
ArrayAttr::get(rewriter.getContext(), reduceAttrs);
setSignatureForRedistribution(rewriter, partialReduction.getOperation(),
reduceOperandsAttr, reduceResultsAttr);
}
rewriter.replaceOp(contractOp, partialReduction);
return success();
}

Expand Down Expand Up @@ -954,46 +1017,6 @@ struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {

return localContractOp;
}

FailureOr<VectorValue> doThreadReduction(RewriterBase &rewriter,
NestedLayoutAttr layout,
VectorValue flat,
vector::CombiningKind kind,
ArrayRef<bool> reductionMask) const {
VectorType flatVecType = flat.getType();
int64_t numElements = flatVecType.getNumElements();
Location loc = flat.getLoc();

auto constOp = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(flatVecType));
auto res = llvm::cast<VectorValue>(constOp.getResult());

for (unsigned i = 0; i < numElements; ++i) {
Value extracted = rewriter.create<vector::ExtractOp>(loc, flat, i);

// Reduce across all reduction dimensions 1-by-1.
for (unsigned i = 0, e = reductionMask.size(); i != e; ++i) {
if (reductionMask[i]) {
int64_t offset = getShuffleOffset(layout, i);
int64_t width = getShuffleWidth(layout, i);
assert(offset <= std::numeric_limits<uint32_t>::max() &&
width <= std::numeric_limits<uint32_t>::max());

extracted = rewriter.create<gpu::SubgroupReduceOp>(
loc, extracted, combiningKindToAllReduce(kind),
/*uniform=*/false, /*cluster_size=*/width,
/*cluster_stride=*/offset);
}
}

res = rewriter.create<vector::InsertOp>(loc, extracted, res, i);
}

return res;
}

int64_t subgroupSize;
int64_t maxBitsPerShuffle;
};

struct DistributeTranspose final : OpDistributionPattern<vector::TransposeOp> {
Expand Down Expand Up @@ -1344,8 +1367,7 @@ void populateGPUDistributeNestedLayoutAttrPatterns(RewritePatternSet &patterns,
patterns.add<DistributeBroadcast, DistributeTranspose>(patterns.getContext());
patterns.add<DistributeMultiReduction>(patterns.getContext(), subgroupSize,
maxBitsPerShuffle);
patterns.add<DistributeContract>(patterns.getContext(), subgroupSize,
maxBitsPerShuffle);
patterns.add<DistributeContract>(patterns.getContext());
patterns.add<DistributeBatchOuterToLayoutConversions>(patterns.getContext());
patterns.add<DistributeStep>(patterns.getContext(), threadId, subgroupSize);
}
Expand Down

0 comments on commit fcfeaf5

Please sign in to comment.