17
17
18
18
namespace mlir ::iree_compiler::AMDAIE {
19
19
20
- namespace {
20
+ namespace detail {
21
21
22
- // / Update the strides and offsets of `X` to match the strides of `Y` if it is
23
- // / possible to do so without changing the underlying access pattern of `X`. For
24
- // / example if
25
- // /
26
- // / X has access pattern (offset: [5] sizes: [1] strides: [6]) and
27
- // / Y has access pattern (offset: [a] sizes: [b] strides: [3])
28
- // /
29
- // / Then the access pattern for X can be changed to have access pattern
30
- // / (offset: [10] sizes: [1] strides: [3]) so that its stride matches Y's.
31
- // /
32
- // / For this transformation to be possible in dimension `d` is it necessary that
33
- // /
34
- // / 1) the size in dimension `d` of `X` is 1, and
35
- // / 2) the updated offset in `d` of `X` (i.e. offset * strideX / strideY)
36
- // / is an integer.
37
- // /
38
- // / As another example, if we have:
39
- // /
40
- // / X with access pattern (offset: [4, 8] sizes: [1, 1] strides: [3, 8])
41
- // / Y with access pattern (offset: [a, b] sizes: [d, e] strides: [6, 2])
42
- // /
43
- // / then X can be transformed to have access pattern
44
- // / (offset: [2, 32] sizes: [1, 1] strides: [6, 2])
45
22
void matchStridesOfUnitDims (MLIRContext *ctx, ArrayRef<OpFoldResult> sizesX,
46
23
SmallVector<OpFoldResult> &stridesX,
47
24
SmallVector<OpFoldResult> &offsetsX,
@@ -70,33 +47,6 @@ void matchStridesOfUnitDims(MLIRContext *ctx, ArrayRef<OpFoldResult> sizesX,
70
47
}
71
48
}
72
49
73
- // / This function computes the difference between the global offsets of two
74
- // / access patterns. If it is not constant, i.e. if the difference contains
75
- // / an MLIR value which is not a constant, then nullopt is returned.
76
- // /
77
- // / This function is useful when determining if the access pattern A, followed
78
- // / by the access pattern B, can be merged into a single access pattern.
79
- // /
80
- // / \return global_offset(X) - global_offset(Y).
81
- // /
82
- // / Background info: offsets, sizes, and strides define an access pattern into
83
- // / an array, where the i'th element accessed, for 0 <= i < prod_{d<D} sizes[d],
84
- // / is at index
85
- // /
86
- // / sum_{d<D} (l(d,i) + offset[d]) * stride[d] (1)
87
- // /
88
- // / where l(d,i) is the component of global index `i` in dimension `d`:
89
- // /
90
- // / i = sum_{d<D} l(d,i) * size[d] (2)
91
- // /
92
- // / Equation (1) can be rewritten with a global offset as
93
- // /
94
- // / global_offset + sum_{d<D} l(d,i) * stride[d] (3)
95
- // /
96
- // / where the global offset is
97
- // /
98
- // / global_offset = sum_{d<D} offset[d] * stride[d].
99
- // /
100
50
std::optional<int64_t > getGlobalOffsetDifference (
101
51
ArrayRef<OpFoldResult> offsetsX, ArrayRef<OpFoldResult> stridesX,
102
52
ArrayRef<OpFoldResult> offsetsY, ArrayRef<OpFoldResult> stridesY) {
@@ -108,9 +58,23 @@ std::optional<int64_t> getGlobalOffsetDifference(
108
58
" expected same number of offsets for X and Y" );
109
59
110
60
int64_t globalOffsetDifference{0 };
61
+
62
+ // In this function we're computing the constant globalOffsetDifference:
63
+ //
64
+ // sum_{d} offsetsA[d] * stridesA[d] -
65
+ // sum_{d} offsetsB[d] * stridesB[d] .
66
+ //
67
+ // If all values in offsetsA, offsetsB, stridesA, stridesB are constant,
68
+ // this is straightforward. If not, we need all the non-constant terms to
69
+ // cancel. In the maps below, we store the terms with non-constants, and then
70
+ // check that they've all cancelled at the end. In `valToConst` we store terms
71
+ // where one of offset and stride is constant, and the other is not. In
72
+ // valPairs, we keep track of all the terms where neither the stride nor the
73
+ // offset is constant.
111
74
DenseMap<Value, int64_t > valToConst;
75
+ DenseMap<std::pair<Value, Value>, int64_t > valPairs;
112
76
113
- auto increment = [&](Value v, int64_t signedStride) {
77
+ auto incrementValConst = [&](Value v, int64_t signedStride) {
114
78
auto iter = valToConst.find (v);
115
79
if (iter == valToConst.end ()) {
116
80
valToConst[v] = signedStride;
@@ -119,43 +83,65 @@ std::optional<int64_t> getGlobalOffsetDifference(
119
83
}
120
84
};
121
85
86
+ auto incrementValVal = [&](Value v0, Value v1, int64_t sign) {
87
+ std::pair<Value, Value> p0 (v0, v1);
88
+ auto iter0 = valPairs.find (p0);
89
+ if (iter0 != valPairs.end ()) {
90
+ iter0->second += sign;
91
+ return ;
92
+ }
93
+
94
+ std::pair<Value, Value> p1 (v1, v0);
95
+ auto iter1 = valPairs.find (p1);
96
+ if (iter1 != valPairs.end ()) {
97
+ iter1->second += sign;
98
+ return ;
99
+ }
100
+ valPairs.insert ({p0, sign});
101
+ };
102
+
103
+ // Add the term `offset * stride * sign` to the global offset different,
104
+ // triaging the different combinations of constant/non-constant.
122
105
auto updateGlobalOffsetDifference = [&](OpFoldResult offset,
123
106
OpFoldResult stride, int64_t sign) {
124
- std::optional<int64_t > o = getConstantIntValue (offset);
125
- std::optional<int64_t > s = getConstantIntValue (stride);
126
- if (!o.has_value () && !s.has_value ()) {
127
- // The case where both the stride and offset are non-constant can be
128
- // handled, but it'll add more complexity so I'm ignoring for now.
129
- return false ;
130
- } else if (o.has_value () && s.has_value ()) {
131
- globalOffsetDifference += sign * o.value () * s.value ();
132
- } else if (o.has_value ()) {
133
- increment (cast<Value>(stride), sign * o.value ());
134
- } else if (s.has_value ()) {
135
- increment (cast<Value>(offset), sign * s.value ());
107
+ std::optional<int64_t > cOffset = getConstantIntValue (offset);
108
+ std::optional<int64_t > cStride = getConstantIntValue (stride);
109
+ Value vOffset = dyn_cast<Value>(offset);
110
+ Value vStride = dyn_cast<Value>(stride);
111
+
112
+ if (!cOffset.has_value () && !cStride.has_value ()) {
113
+ incrementValVal (vOffset, vStride, sign);
114
+ } else if (cOffset.has_value () && cStride.has_value ()) {
115
+ globalOffsetDifference += sign * cOffset.value () * cStride.value ();
116
+ } else if (cOffset.has_value ()) {
117
+ incrementValConst (cast<Value>(stride), sign * cOffset.value ());
118
+ } else if (cStride.has_value ()) {
119
+ incrementValConst (cast<Value>(offset), sign * cStride.value ());
136
120
}
137
- return true ;
138
121
};
139
122
140
123
for (uint32_t i = 0 ; i < offsetsX.size (); ++i) {
141
124
// If offsets and strides are the same, the contribution to the global
142
125
// offset difference is zero, so we can skip this dimension.
143
126
if (offsetsX[i] == offsetsY[i] && stridesX[i] == stridesY[i]) continue ;
144
-
145
- if (updateGlobalOffsetDifference (offsetsX[i], stridesX[i], 1 ) == false )
146
- return std::nullopt;
147
- if (updateGlobalOffsetDifference (offsetsY[i], stridesY[i], -1 ) == false )
148
- return std::nullopt;
127
+ updateGlobalOffsetDifference (offsetsX[i], stridesX[i], 1 );
128
+ updateGlobalOffsetDifference (offsetsY[i], stridesY[i], -1 );
149
129
}
150
130
131
+ // The cases where the non-constant terms did not all cancel, and so the
132
+ // global offset difference could not be determined to be constant.
151
133
for (auto [offset, stride] : valToConst) {
152
- // There is a non-constant offset with a stride that is not zero.
153
- // This means that the global offset difference is not a constant.
154
134
if (stride != 0 ) return std::nullopt;
155
135
}
136
+ for (auto [valPair, valPairCount] : valPairs) {
137
+ if (valPairCount != 0 ) return std::nullopt;
138
+ }
156
139
157
140
return globalOffsetDifference;
158
141
}
142
+ } // namespace detail
143
+
144
+ namespace {
159
145
160
146
// / Consider 2 access patterns X and Y, where the access pattern for Y has one
161
147
// / more dimension than the access pattern for X. This function inserts a
@@ -210,8 +196,8 @@ bool mergeInFirst(MLIRContext *ctx, SmallVector<OpFoldResult> &offsetsA,
210
196
// dimensions of size 1, which is being ignored for now).
211
197
return false ;
212
198
}
213
- matchStridesOfUnitDims (ctx, sizesA, stridesA, offsetsA, stridesB);
214
- matchStridesOfUnitDims (ctx, sizesB, stridesB, offsetsB, stridesA);
199
+ detail:: matchStridesOfUnitDims (ctx, sizesA, stridesA, offsetsA, stridesB);
200
+ detail:: matchStridesOfUnitDims (ctx, sizesB, stridesB, offsetsB, stridesA);
215
201
216
202
// Check that strides and sizes are compatible for merging.
217
203
if (stridesA != stridesB) return false ;
@@ -221,7 +207,7 @@ bool mergeInFirst(MLIRContext *ctx, SmallVector<OpFoldResult> &offsetsA,
221
207
}
222
208
223
209
std::optional<int64_t > maybeOffsetDifference =
224
- getGlobalOffsetDifference (offsetsB, stridesB, offsetsA, stridesA);
210
+ detail:: getGlobalOffsetDifference (offsetsB, stridesB, offsetsA, stridesA);
225
211
226
212
// The case where the global offset difference is not constant is difficult to
227
213
// handle, unless we can prove that it is non-negative. Leaving this edge case
@@ -278,6 +264,8 @@ LogicalResult combineAccessPatterns(
278
264
assert (offsetsB.size () == stridesB.size () &&
279
265
" expected same number of source offsets and strides" );
280
266
267
+ // Ensure that OpFoldResults are Attributes when they can be. Specifally
268
+ // this will replace arith.constant values with attributes.
281
269
auto simplified =
282
270
[&](ArrayRef<OpFoldResult> input) -> SmallVector<OpFoldResult> {
283
271
SmallVector<OpFoldResult> x (input.begin (), input.end ());
0 commit comments