Skip to content

Commit 68f4973

Browse files
committed
improve testing
1 parent 2088904 commit 68f4973

File tree

4 files changed

+581
-447
lines changed

4 files changed

+581
-447
lines changed

compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Utils/AMDAIEDmaUtils.cpp

Lines changed: 63 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -17,31 +17,8 @@
1717

1818
namespace mlir::iree_compiler::AMDAIE {
1919

20-
namespace {
20+
namespace detail {
2121

22-
/// Update the strides and offsets of `X` to match the strides of `Y` if it is
23-
/// possible to do so without changing the underlying access pattern of `X`. For
24-
/// example if
25-
///
26-
/// X has access pattern (offset: [5] sizes: [1] strides: [6]) and
27-
/// Y has access pattern (offset: [a] sizes: [b] strides: [3])
28-
///
29-
/// Then the access pattern for X can be changed to have access pattern
30-
/// (offset: [10] sizes: [1] strides: [3]) so that its stride matches Y's.
31-
///
32-
/// For this transformation to be possible in dimension `d` is it necessary that
33-
///
34-
/// 1) the size in dimension `d` of `X` is 1, and
35-
/// 2) the updated offset in `d` of `X` (i.e. offset * strideX / strideY)
36-
/// is an integer.
37-
///
38-
/// As another example, if we have:
39-
///
40-
/// X with access pattern (offset: [4, 8] sizes: [1, 1] strides: [3, 8])
41-
/// Y with access pattern (offset: [a, b] sizes: [d, e] strides: [6, 2])
42-
///
43-
/// then X can be transformed to have access pattern
44-
/// (offset: [2, 32] sizes: [1, 1] strides: [6, 2])
4522
void matchStridesOfUnitDims(MLIRContext *ctx, ArrayRef<OpFoldResult> sizesX,
4623
SmallVector<OpFoldResult> &stridesX,
4724
SmallVector<OpFoldResult> &offsetsX,
@@ -70,33 +47,6 @@ void matchStridesOfUnitDims(MLIRContext *ctx, ArrayRef<OpFoldResult> sizesX,
7047
}
7148
}
7249

73-
/// This function computes the difference between the global offsets of two
74-
/// access patterns. If it is not constant, i.e. if the difference contains
75-
/// an MLIR value which is not a constant, then nullopt is returned.
76-
///
77-
/// This function is useful when determining if the access pattern A, followed
78-
/// by the access pattern B, can be merged into a single access pattern.
79-
///
80-
/// \return global_offset(X) - global_offset(Y).
81-
///
82-
/// Background info: offsets, sizes, and strides define an access pattern into
83-
/// an array, where the i'th element accessed, for 0 <= i < prod_{d<D} sizes[d],
84-
/// is at index
85-
///
86-
/// sum_{d<D} (l(d,i) + offset[d]) * stride[d] (1)
87-
///
88-
/// where l(d,i) is the component of global index `i` in dimension `d`:
89-
///
90-
/// i = sum_{d<D} l(d,i) * size[d] (2)
91-
///
92-
/// Equation (1) can be rewritten with a global offset as
93-
///
94-
/// global_offset + sum_{d<D} l(d,i) * stride[d] (3)
95-
///
96-
/// where the global offset is
97-
///
98-
/// global_offset = sum_{d<D} offset[d] * stride[d].
99-
///
10050
std::optional<int64_t> getGlobalOffsetDifference(
10151
ArrayRef<OpFoldResult> offsetsX, ArrayRef<OpFoldResult> stridesX,
10252
ArrayRef<OpFoldResult> offsetsY, ArrayRef<OpFoldResult> stridesY) {
@@ -108,9 +58,23 @@ std::optional<int64_t> getGlobalOffsetDifference(
10858
"expected same number of offsets for X and Y");
10959

11060
int64_t globalOffsetDifference{0};
61+
62+
// In this function we're computing the constant globalOffsetDifference:
63+
//
64+
// sum_{d} offsetsA[d] * stridesA[d] -
65+
// sum_{d} offsetsB[d] * stridesB[d] .
66+
//
67+
// If all values in offsetsA, offsetsB, stridesA, stridesB are constant,
68+
// this is straightforward. If not, we need all the non-constant terms to
69+
// cancel. In the maps below, we store the terms with non-constants, and then
70+
// check that they've all cancelled at the end. In `valToConst` we store terms
71+
// where one of offset and stride is constant, and the other is not. In
72+
// valPairs, we keep track of all the terms where neither the stride nor the
73+
// offset is constant.
11174
DenseMap<Value, int64_t> valToConst;
75+
DenseMap<std::pair<Value, Value>, int64_t> valPairs;
11276

113-
auto increment = [&](Value v, int64_t signedStride) {
77+
auto incrementValConst = [&](Value v, int64_t signedStride) {
11478
auto iter = valToConst.find(v);
11579
if (iter == valToConst.end()) {
11680
valToConst[v] = signedStride;
@@ -119,43 +83,65 @@ std::optional<int64_t> getGlobalOffsetDifference(
11983
}
12084
};
12185

86+
auto incrementValVal = [&](Value v0, Value v1, int64_t sign) {
87+
std::pair<Value, Value> p0(v0, v1);
88+
auto iter0 = valPairs.find(p0);
89+
if (iter0 != valPairs.end()) {
90+
iter0->second += sign;
91+
return;
92+
}
93+
94+
std::pair<Value, Value> p1(v1, v0);
95+
auto iter1 = valPairs.find(p1);
96+
if (iter1 != valPairs.end()) {
97+
iter1->second += sign;
98+
return;
99+
}
100+
valPairs.insert({p0, sign});
101+
};
102+
103+
// Add the term `offset * stride * sign` to the global offset different,
104+
// triaging the different combinations of constant/non-constant.
122105
auto updateGlobalOffsetDifference = [&](OpFoldResult offset,
123106
OpFoldResult stride, int64_t sign) {
124-
std::optional<int64_t> o = getConstantIntValue(offset);
125-
std::optional<int64_t> s = getConstantIntValue(stride);
126-
if (!o.has_value() && !s.has_value()) {
127-
// The case where both the stride and offset are non-constant can be
128-
// handled, but it'll add more complexity so I'm ignoring for now.
129-
return false;
130-
} else if (o.has_value() && s.has_value()) {
131-
globalOffsetDifference += sign * o.value() * s.value();
132-
} else if (o.has_value()) {
133-
increment(cast<Value>(stride), sign * o.value());
134-
} else if (s.has_value()) {
135-
increment(cast<Value>(offset), sign * s.value());
107+
std::optional<int64_t> cOffset = getConstantIntValue(offset);
108+
std::optional<int64_t> cStride = getConstantIntValue(stride);
109+
Value vOffset = dyn_cast<Value>(offset);
110+
Value vStride = dyn_cast<Value>(stride);
111+
112+
if (!cOffset.has_value() && !cStride.has_value()) {
113+
incrementValVal(vOffset, vStride, sign);
114+
} else if (cOffset.has_value() && cStride.has_value()) {
115+
globalOffsetDifference += sign * cOffset.value() * cStride.value();
116+
} else if (cOffset.has_value()) {
117+
incrementValConst(cast<Value>(stride), sign * cOffset.value());
118+
} else if (cStride.has_value()) {
119+
incrementValConst(cast<Value>(offset), sign * cStride.value());
136120
}
137-
return true;
138121
};
139122

140123
for (uint32_t i = 0; i < offsetsX.size(); ++i) {
141124
// If offsets and strides are the same, the contribution to the global
142125
// offset difference is zero, so we can skip this dimension.
143126
if (offsetsX[i] == offsetsY[i] && stridesX[i] == stridesY[i]) continue;
144-
145-
if (updateGlobalOffsetDifference(offsetsX[i], stridesX[i], 1) == false)
146-
return std::nullopt;
147-
if (updateGlobalOffsetDifference(offsetsY[i], stridesY[i], -1) == false)
148-
return std::nullopt;
127+
updateGlobalOffsetDifference(offsetsX[i], stridesX[i], 1);
128+
updateGlobalOffsetDifference(offsetsY[i], stridesY[i], -1);
149129
}
150130

131+
// The cases where the non-constant terms did not all cancel, and so the
132+
// global offset difference could not be determined to be constant.
151133
for (auto [offset, stride] : valToConst) {
152-
// There is a non-constant offset with a stride that is not zero.
153-
// This means that the global offset difference is not a constant.
154134
if (stride != 0) return std::nullopt;
155135
}
136+
for (auto [valPair, valPairCount] : valPairs) {
137+
if (valPairCount != 0) return std::nullopt;
138+
}
156139

157140
return globalOffsetDifference;
158141
}
142+
} // namespace detail
143+
144+
namespace {
159145

160146
/// Consider 2 access patterns X and Y, where the access pattern for Y has one
161147
/// more dimension than the access pattern for X. This function inserts a
@@ -210,8 +196,8 @@ bool mergeInFirst(MLIRContext *ctx, SmallVector<OpFoldResult> &offsetsA,
210196
// dimensions of size 1, which is being ignored for now).
211197
return false;
212198
}
213-
matchStridesOfUnitDims(ctx, sizesA, stridesA, offsetsA, stridesB);
214-
matchStridesOfUnitDims(ctx, sizesB, stridesB, offsetsB, stridesA);
199+
detail::matchStridesOfUnitDims(ctx, sizesA, stridesA, offsetsA, stridesB);
200+
detail::matchStridesOfUnitDims(ctx, sizesB, stridesB, offsetsB, stridesA);
215201

216202
// Check that strides and sizes are compatible for merging.
217203
if (stridesA != stridesB) return false;
@@ -221,7 +207,7 @@ bool mergeInFirst(MLIRContext *ctx, SmallVector<OpFoldResult> &offsetsA,
221207
}
222208

223209
std::optional<int64_t> maybeOffsetDifference =
224-
getGlobalOffsetDifference(offsetsB, stridesB, offsetsA, stridesA);
210+
detail::getGlobalOffsetDifference(offsetsB, stridesB, offsetsA, stridesA);
225211

226212
// The case where the global offset difference is not constant is difficult to
227213
// handle, unless we can prove that it is non-negative. Leaving this edge case
@@ -278,6 +264,8 @@ LogicalResult combineAccessPatterns(
278264
assert(offsetsB.size() == stridesB.size() &&
279265
"expected same number of source offsets and strides");
280266

267+
// Ensure that OpFoldResults are Attributes when they can be. Specifally
268+
// this will replace arith.constant values with attributes.
281269
auto simplified =
282270
[&](ArrayRef<OpFoldResult> input) -> SmallVector<OpFoldResult> {
283271
SmallVector<OpFoldResult> x(input.begin(), input.end());

compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Utils/AMDAIEDmaUtils.h

Lines changed: 62 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -69,15 +69,68 @@ struct RetrieveScaleAndBias
6969
}
7070
};
7171

72-
/// Check whether two access patterns are equal in value, starting from
73-
/// specified indices.
74-
bool areAccessPatternsEqualFromIndices(ArrayRef<OpFoldResult> offsetsA,
75-
ArrayRef<OpFoldResult> sizesA,
76-
ArrayRef<OpFoldResult> stridesA,
77-
ArrayRef<OpFoldResult> offsetsB,
78-
ArrayRef<OpFoldResult> sizesB,
79-
ArrayRef<OpFoldResult> stridesB,
80-
size_t indexA = 0, size_t indexB = 0);
72+
namespace detail {
73+
74+
/// Update the strides and offsets of `X` to match the strides of `Y` if it is
75+
/// possible to do so without changing the underlying access pattern of `X`. For
76+
/// example if
77+
///
78+
/// X has access pattern (offset: [5] sizes: [1] strides: [6]) and
79+
/// Y has access pattern (offset: [a] sizes: [b] strides: [3])
80+
///
81+
/// Then the access pattern for X can be changed to have access pattern
82+
/// (offset: [10] sizes: [1] strides: [3]) so that its stride matches Y's.
83+
///
84+
/// For this transformation to be possible in dimension `d` is it necessary that
85+
///
86+
/// 1) the size in dimension `d` of `X` is 1, and
87+
/// 2) the updated offset in `d` of `X` (i.e. offset * strideX / strideY)
88+
/// is an integer.
89+
///
90+
/// As another example, if we have:
91+
///
92+
/// X with access pattern (offset: [4, 8] sizes: [1, 1] strides: [3, 8])
93+
/// Y with access pattern (offset: [a, b] sizes: [d, e] strides: [6, 2])
94+
///
95+
/// then X can be transformed to have access pattern
96+
/// (offset: [2, 32] sizes: [1, 1] strides: [6, 2])
97+
void matchStridesOfUnitDims(MLIRContext *ctx, ArrayRef<OpFoldResult> sizesX,
98+
SmallVector<OpFoldResult> &stridesX,
99+
SmallVector<OpFoldResult> &offsetsX,
100+
ArrayRef<OpFoldResult> stridesY);
101+
102+
/// This function computes the difference between the global offsets of two
103+
/// access patterns. If it is not constant, i.e. if the difference contains
104+
/// an MLIR value which is not a constant, then nullopt is returned.
105+
///
106+
/// This function is useful when determining if the access pattern A, followed
107+
/// by the access pattern B, can be merged into a single access pattern.
108+
///
109+
/// \return global_offset(X) - global_offset(Y).
110+
///
111+
/// Background info: offsets, sizes, and strides define an access pattern into
112+
/// an array, where the i'th element accessed, for 0 <= i < prod_{d<D} sizes[d],
113+
/// is at index
114+
///
115+
/// sum_{d<D} (l(d,i) + offset[d]) * stride[d] (1)
116+
///
117+
/// where l(d,i) is the component of global index `i` in dimension `d`:
118+
///
119+
/// i = sum_{d<D} l(d,i) * size[d] (2)
120+
///
121+
/// Equation (1) can be rewritten with a global offset as
122+
///
123+
/// global_offset + sum_{d<D} l(d,i) * stride[d] (3)
124+
///
125+
/// where the global offset is
126+
///
127+
/// global_offset = sum_{d<D} offset[d] * stride[d].
128+
///
129+
std::optional<int64_t> getGlobalOffsetDifference(
130+
ArrayRef<OpFoldResult> offsetsX, ArrayRef<OpFoldResult> stridesX,
131+
ArrayRef<OpFoldResult> offsetsY, ArrayRef<OpFoldResult> stridesY);
132+
133+
} // namespace detail
81134

82135
/// Combine two access patterns into a single one. Assumes that access pattern A
83136
/// belongs to a strided op which is ordered before the strided op B. Takes a

0 commit comments

Comments
 (0)