Skip to content

Commit

Permalink
* add a builder to append to an existing layout
Browse files Browse the repository at this point in the history
* use this to infer partial reduction layout of
  vector.contract

Signed-off-by: Manupa Karunaratne <manupa.karunaratne@amd.com>
  • Loading branch information
manupak committed Jan 22, 2025
1 parent fcfeaf5 commit 4f75304
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 100 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,16 @@ struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {
return rewriter.notifyMatchFailure(
contractOp, "missing nested layout for contraction rhs");
}
NestedLayoutAttr resLayout;
if (auto contractRes = dyn_cast<VectorValue>(contractOp.getResult())) {
resLayout = dyn_cast<NestedLayoutAttr>(signature[contractRes]);
} else {
// Create a zero-d layout because we
// are going to add reduction dims
// back to handle the partial reduction
resLayout = NestedLayoutAttr::get(
contractOp.getContext(), ArrayRef<int64_t>{}, {}, {}, {}, {}, {}, {});
}

Value disLhs = getDistributed(rewriter, contractOp.getLhs(), lhsLayout);
Value disRhs = getDistributed(rewriter, contractOp.getRhs(), rhsLayout);
Expand All @@ -830,18 +840,6 @@ struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {
vector::ContractionOp localContractOp = doDistributedContraction(
rewriter, loc, ctx, contractOp, disLhs, disRhs, localInit);

SmallVector<int64_t> reducedDims;

// Identify the reduction dimension and apply it for subgroup reduction.
for (auto [index, iteratorType] :
llvm::enumerate(contractOp.getIteratorTypes())) {
if (vector::isReductionIterator(iteratorType)) {
auto map = contractOp.getIndexingMapsArray()[0];
int64_t redIdx = *(map.getResultPosition(getAffineDimExpr(index, ctx)));
reducedDims.push_back(redIdx);
}
}

VectorValue localContractValue;
if (accVector) {
localContractValue = dyn_cast<VectorValue>(localContractOp.getResult());
Expand All @@ -853,103 +851,61 @@ struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {

assert(localContractValue && "result should have been a vector");

NestedLayoutAttr resLayout;
if (auto contractRes = dyn_cast<VectorValue>(contractOp.getResult())) {
resLayout = dyn_cast<NestedLayoutAttr>(signature[contractRes]);
} else {
// Create a zero-d layout because we
// are going to add reduction dims
// back to handle the partial reduction
resLayout = NestedLayoutAttr::get(contractOp.getContext(), {}, {}, {}, {},
{}, {}, {});
}

// Shapecast to re-insert reduction dimensions as unit dims.
// We append the result shape with reduction dimensions as
// the fastest changing dimensions.
int64_t opRank = contractOp.getIteratorTypes().size();
SmallVector<int64_t> partialReducedDistributedShape(opRank * 3, 0);
SmallVector<int64_t> resDistShape = resLayout.getDistributedShape();
SmallVector<int64_t> lhsDistShape = lhsLayout.getDistributedShape();
for (int64_t tileGroupIdx : llvm::seq<int64_t>(3)) {
int64_t tileGroupOffset = tileGroupIdx * opRank;
for (int64_t dim : llvm::seq<int64_t>(opRank)) {
if (dim < resLayout.getRank()) {
int64_t resTileGroupOffset = tileGroupIdx * resLayout.getRank();
partialReducedDistributedShape[tileGroupOffset + dim] =
resDistShape[resTileGroupOffset + dim];
} else {
// add the thread local reduced dims at the tail.
partialReducedDistributedShape[tileGroupOffset + dim] = 1;
}
// Identify the reduction dimension and apply it for subgroup reduction.
auto lhsMap = contractOp.getIndexingMapsArray()[0];
SmallVector<int64_t> reductionSubGroupTile;
SmallVector<int64_t> reductionSubGroupStrides;
SmallVector<int64_t> reductionThreadTile;
SmallVector<int64_t> reductionThreadStrides;
SmallVector<int64_t> partialReductionDims;
for (auto [index, iteratorType] :
llvm::enumerate(contractOp.getIteratorTypes())) {
if (vector::isReductionIterator(iteratorType)) {
int64_t redLhsIdx =
*(lhsMap.getResultPosition(getAffineDimExpr(index, ctx)));
partialReductionDims.push_back(resLayout.getRank() +
reductionSubGroupTile.size());
reductionSubGroupTile.push_back(lhsLayout.getSubgroupTile()[redLhsIdx]);
reductionSubGroupStrides.push_back(
lhsLayout.getSubgroupStrides()[redLhsIdx]);
reductionThreadTile.push_back(lhsLayout.getThreadTile()[redLhsIdx]);
reductionThreadStrides.push_back(
lhsLayout.getThreadStrides()[redLhsIdx]);
}
}
SmallVector<int64_t> unitBroadcastTile(reductionThreadTile.size(), 1);

// Manually infer the layout of partial reduction
// We do this by appending the reduction dims on
// subgroup and thread tiles to the layout of the
// result.
IREE::VectorExt::NestedLayoutAttr reductionLayout =
IREE::VectorExt::NestedLayoutAttr::get(
contractOp.getContext(),
/*source=*/resLayout,
/*appendSubGroupLens=*/reductionSubGroupTile,
/*appendBatchLens=*/unitBroadcastTile,
/*appendOuterLens=*/unitBroadcastTile,
/*appendThreadLens=*/reductionThreadTile,
/*appendElementLens=*/unitBroadcastTile,
/*appendSubgroupStrides=*/reductionSubGroupStrides,
/*appendThreadStrides=*/reductionThreadStrides);

VectorType partialReducedDistributedType =
VectorType::get(partialReducedDistributedShape,
VectorType::get(reductionLayout.getDistributedShape(),
localContractValue.getType().getElementType());
Value isoRankLocalReduced = rewriter.create<vector::ShapeCastOp>(
Value shapeCasted = rewriter.create<vector::ShapeCastOp>(
loc, partialReducedDistributedType, localContractValue);

SmallVector<int64_t> partialReductionShape;
partialReductionShape.reserve(lhsLayout.getRank());
if (resVector) {
ArrayRef<int64_t> resVectorShape = resVector.getType().getShape();
partialReductionShape.insert(partialReductionShape.end(),
resVectorShape.begin(),
resVectorShape.end());
}
ArrayRef<int64_t> lhsShape = contractOp.getLhs().getType().getShape();
// Note that partial reduction dimensions are at the end.
SmallVector<int64_t> partialReduceDims;
partialReduceDims.reserve(reducedDims.size());
for (int64_t rDim : reducedDims) {
partialReduceDims.push_back(partialReductionShape.size());
partialReductionShape.push_back(lhsShape[rDim]);
}
VectorType unDistributedType = VectorType::get(
partialReductionShape, localContractValue.getType().getElementType());
VectorType unDistributedType =
VectorType::get(reductionLayout.getUndistributedShape(),
localContractValue.getType().getElementType());
Value undistrLocalReduced = rewriter.create<IREE::VectorExt::ToSIMDOp>(
loc, unDistributedType, isoRankLocalReduced);

// Manually infer the layout of partial reduction
IREE::VectorExt::NestedLayoutAttr reductionLayout;
{
SmallVector<int64_t> subgroupTileLens =
llvm::to_vector(resLayout.getSubgroupTile());
SmallVector<int64_t> batchTileLens =
llvm::to_vector(resLayout.getBatchTile());
SmallVector<int64_t> outerTileLens =
llvm::to_vector(resLayout.getOuterTile());
SmallVector<int64_t> threadTileLens =
llvm::to_vector(resLayout.getThreadTile());
SmallVector<int64_t> elementTileLens =
llvm::to_vector(resLayout.getElementTile());
SmallVector<int64_t> subgroupStrides =
llvm::to_vector(resLayout.getSubgroupStrides());
SmallVector<int64_t> threadStrides =
llvm::to_vector(resLayout.getThreadStrides());
for (int64_t rDim : reducedDims) {
// thread-local reductions have already been carried out.
// What is remaining is reductions across threads and
// subgroups.
subgroupTileLens.push_back(lhsLayout.getSubgroupTile()[rDim]);
batchTileLens.push_back(1);
outerTileLens.push_back(1);
threadTileLens.push_back(lhsLayout.getThreadTile()[rDim]);
elementTileLens.push_back(1);
subgroupStrides.push_back(lhsLayout.getSubgroupStrides()[rDim]);
threadStrides.push_back(lhsLayout.getThreadStrides()[rDim]);
}
reductionLayout = IREE::VectorExt::NestedLayoutAttr::get(
contractOp.getContext(), subgroupTileLens, batchTileLens,
outerTileLens, threadTileLens, elementTileLens, subgroupStrides,
threadStrides);
}
loc, unDistributedType, shapeCasted);

// Create the partial reduction
auto partialReduction = rewriter.create<vector::MultiDimReductionOp>(
loc, contractOp.getKind(), undistrLocalReduced, acc, partialReduceDims);
loc, contractOp.getKind(), undistrLocalReduced, acc,
partialReductionDims);
{
auto unitAttr = UnitAttr::get(rewriter.getContext());
auto reduceAttrs =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,19 @@ SmallVector<int64_t> NestedLayoutAttr::getUndistributedPackedShape() const {
return shape;
}

SmallVector<int64_t> NestedLayoutAttr::getUndistributedShape() const {
int64_t rank = getRank();
SmallVector<int64_t> shape;
shape.reserve(rank);
for (int64_t i : llvm::seq<int64_t>(rank)) {
int64_t expectedDimLen = getSubgroupTile()[i] * getBatchTile()[i] *
getOuterTile()[i] * getThreadTile()[i] *
getElementTile()[i];
shape.push_back(expectedDimLen);
}
return shape;
}

// Gets the rank of the undistributed vector for this layout.
int64_t NestedLayoutAttr::getRank() const {
// The layout requires that all size lists are the same length and match
Expand Down Expand Up @@ -198,6 +211,42 @@ NestedLayoutAttr NestedLayoutAttr::get(
normalizedThreadStrides);
}

static SmallVector<int64_t> appendDims(ArrayRef<int64_t> tileLens,
ArrayRef<int64_t> appendLens) {
SmallVector<int64_t> tileLensResult = llvm::to_vector(tileLens);
tileLensResult.insert(tileLensResult.end(), appendLens.begin(),
appendLens.end());
return tileLensResult;
}

NestedLayoutAttr NestedLayoutAttr::get(MLIRContext *context,
NestedLayoutAttr source,
ArrayRef<int64_t> appendSubGroupLens,
ArrayRef<int64_t> appendBatchLens,
ArrayRef<int64_t> appendOuterLens,
ArrayRef<int64_t> appendThreadLens,
ArrayRef<int64_t> appendElementLens,
ArrayRef<int64_t> appendSubgroupStrides,
ArrayRef<int64_t> appendThreadStrides) {
SmallVector<int64_t> subgroupTile =
appendDims(source.getSubgroupTile(), appendSubGroupLens);
SmallVector<int64_t> batchTile =
appendDims(source.getBatchTile(), appendBatchLens);
SmallVector<int64_t> outerTile =
appendDims(source.getOuterTile(), appendOuterLens);
SmallVector<int64_t> threadTile =
appendDims(source.getThreadTile(), appendThreadLens);
SmallVector<int64_t> elementTile =
appendDims(source.getElementTile(), appendElementLens);
SmallVector<int64_t> subgroupStrides =
appendDims(source.getSubgroupStrides(), appendSubgroupStrides);
SmallVector<int64_t> threadStrides =
appendDims(source.getThreadStrides(), appendThreadStrides);
return NestedLayoutAttr::get(context, subgroupTile, batchTile, outerTile,
threadTile, elementTile, subgroupStrides,
threadStrides);
}

LogicalResult NestedLayoutAttr::verify(
llvm::function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> subgroupTile, ArrayRef<int64_t> batchTile,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,15 @@ def NestedLayoutAttr : IREEVectorExt_Attr<"NestedLayout",
"ArrayRef<int64_t>":$threadTile,
"ArrayRef<int64_t>":$elementTile,
"ArrayRef<int64_t>":$subgroupStrides,
"ArrayRef<int64_t>":$threadStrides)>
"ArrayRef<int64_t>":$threadStrides)>,
AttrBuilder<(ins "NestedLayoutAttr":$source,
"ArrayRef<int64_t>":$appendSubGroupLens,
"ArrayRef<int64_t>":$appendBatchLens,
"ArrayRef<int64_t>":$appendOuterLens,
"ArrayRef<int64_t>":$appendThreadLens,
"ArrayRef<int64_t>":$appendElementLens,
"ArrayRef<int64_t>":$appendSubgroupStrides,
"ArrayRef<int64_t>":$appendThreadStrides)>
];

let extraClassDeclaration = [{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ def VectorLayoutInterface : AttrInterface<"VectorLayoutInterface"> {
/*methodName=*/"project",
/*args=*/(ins "::llvm::ArrayRef<bool>":$droppedDims)
>,
InterfaceMethod<
/*description=*/"Get the expected undistributed shape for the given vector type.",
/*retTy=*/"SmallVector<int64_t>",
/*methodName=*/"getUndistributedShape",
/*args=*/(ins)
>,
InterfaceMethod<
/*description=*/"Get the distributed shape for the given vector type.",
/*retTy=*/"SmallVector<int64_t>",
Expand Down

0 comments on commit 4f75304

Please sign in to comment.