Skip to content

Commit

Permalink
additional tests, cosmetics
Browse files Browse the repository at this point in the history
  • Loading branch information
newling committed Jan 24, 2025
1 parent 09f350d commit 55cab2a
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

namespace mlir::iree_compiler::AMDAIE {

// Functions in namespace `detail` are not intended to be used outside of this
// file, but are exposed in the .h file for testing purposes.
namespace detail {

void matchStridesOfUnitDims(MLIRContext *ctx, ArrayRef<OpFoldResult> sizesX,
Expand Down Expand Up @@ -57,8 +59,6 @@ std::optional<int64_t> getGlobalOffsetDifference(
assert(offsetsX.size() == offsetsY.size() &&
"expected same number of offsets for X and Y");

int64_t globalOffsetDifference{0};

// In this function we're computing the constant globalOffsetDifference:
//
// sum_{d} offsetsA[d] * stridesA[d] -
Expand All @@ -67,14 +67,12 @@ std::optional<int64_t> getGlobalOffsetDifference(
// If all values in offsetsA, offsetsB, stridesA, stridesB are constant,
// this is straightforward. If not, we need all the non-constant terms to
// cancel. In the maps below, we store the terms with non-constants, and then
// check that they've all cancelled at the end. In `valToConst` we store terms
// where one of offset and stride is constant, and the other is not. In
// valPairs, we keep track of all the terms where neither the stride nor the
// offset is constant.
// check at the end of this function that they've all cancelled. In
// `valToConst` we store terms where one of offset and stride is constant, and
// the other is not. In `valPairs`, we keep track of all the terms where
// neither the stride nor the offset is constant.
DenseMap<Value, int64_t> valToConst;
DenseMap<std::pair<Value, Value>, int64_t> valPairs;

auto incrementValConst = [&](Value v, int64_t signedStride) {
auto incrementValToConst = [&](Value v, int64_t signedStride) {
auto iter = valToConst.find(v);
if (iter == valToConst.end()) {
valToConst[v] = signedStride;
Expand All @@ -83,14 +81,14 @@ std::optional<int64_t> getGlobalOffsetDifference(
}
};

auto incrementValVal = [&](Value v0, Value v1, int64_t sign) {
DenseMap<std::pair<Value, Value>, int64_t> valPairs;
auto incrementValPairs = [&](Value v0, Value v1, int64_t sign) {
std::pair<Value, Value> p0(v0, v1);
auto iter0 = valPairs.find(p0);
if (iter0 != valPairs.end()) {
iter0->second += sign;
return;
}

std::pair<Value, Value> p1(v1, v0);
auto iter1 = valPairs.find(p1);
if (iter1 != valPairs.end()) {
Expand All @@ -100,6 +98,8 @@ std::optional<int64_t> getGlobalOffsetDifference(
valPairs.insert({p0, sign});
};

int64_t globalOffsetDifference{0};

// Add the term `offset * stride * sign` to the global offset different,
// triaging the different combinations of constant/non-constant.
auto updateGlobalOffsetDifference = [&](OpFoldResult offset,
Expand All @@ -110,31 +110,28 @@ std::optional<int64_t> getGlobalOffsetDifference(
Value vStride = dyn_cast<Value>(stride);

if (!cOffset.has_value() && !cStride.has_value()) {
incrementValVal(vOffset, vStride, sign);
incrementValPairs(vOffset, vStride, sign);
} else if (cOffset.has_value() && cStride.has_value()) {
globalOffsetDifference += sign * cOffset.value() * cStride.value();
} else if (cOffset.has_value()) {
incrementValConst(cast<Value>(stride), sign * cOffset.value());
incrementValToConst(cast<Value>(stride), sign * cOffset.value());
} else if (cStride.has_value()) {
incrementValConst(cast<Value>(offset), sign * cStride.value());
incrementValToConst(cast<Value>(offset), sign * cStride.value());
}
};

for (uint32_t i = 0; i < offsetsX.size(); ++i) {
// If offsets and strides are the same, the contribution to the global
// offset difference is zero, so we can skip this dimension.
if (offsetsX[i] == offsetsY[i] && stridesX[i] == stridesY[i]) continue;
updateGlobalOffsetDifference(offsetsX[i], stridesX[i], 1);
updateGlobalOffsetDifference(offsetsY[i], stridesY[i], -1);
}

// The cases where the non-constant terms did not all cancel, and so the
// global offset difference could not be determined to be constant.
for (auto [offset, stride] : valToConst) {
if (stride != 0) return std::nullopt;
if (llvm::any_of(valToConst, [](auto x) { return x.second != 0; })) {
return std::nullopt;
}
for (auto [valPair, valPairCount] : valPairs) {
if (valPairCount != 0) return std::nullopt;
if (llvm::any_of(valPairs, [](auto x) { return x.second != 0; })) {
return std::nullopt;
}

return globalOffsetDifference;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,13 @@ void matchStridesOfUnitDims(MLIRContext *ctx, ArrayRef<OpFoldResult> sizesX,
/// This function is useful when determining if the access pattern A, followed
/// by the access pattern B, can be merged into a single access pattern.
///
/// The reason for this API design, as opposed to a more intuitive design of
/// having a function to compute the global offset difference for a single
/// access pattern, say `getGlobalOffset`, and then computing the difference
/// as `getGlobalOffset(A) - getGlobalOffset(B)`, is that the global offset
/// difference might be constant while each individual offsets are not, and
/// determining that the non-constants cancel is easier with this API.
///
/// \return global_offset(X) - global_offset(Y).
///
/// Background info: offsets, sizes, and strides define an access pattern into
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -686,9 +686,10 @@ module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb}


// CHECK-LABEL: @with_complex_variable_offset
// CHECK: amdaie.npu.circular_dma_cpy_nd
// CHECK-SAME: [0, 0] [2, 1000] [0, 1]
// CHECK-SAME: [0, %arg0, %arg1, %arg2] [2, 10, 10, 10] [0, 400, 20, 1])
// CHECK: amdaie.npu.circular_dma_cpy_nd
// CHECK-SAME: [0, 0] [2, 1000] [0, 1]
// CHECK-SAME: [0, %arg0, %arg1, %arg2] [2, 10, 10, 10] [0, 400, 20, 1])
// CHECK-NOT: amdaie.npu.circular_dma_cpy_nd
#executable_target_amdaie_xclbin_fb = #hal.executable.target<"amd-aie", "amdaie-xclbin-fb", {target_device = "npu1_4col", ukernels = "none"}>
module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb} {
func.func @with_complex_variable_offset(%arg0: index, %arg1 : index, %arg2 : index, %arg3: !amdaie.logicalobjectfifo<memref<bf16>>) {
Expand All @@ -708,9 +709,10 @@ module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb}


// CHECK-LABEL: @with_mixed_offsets
// CHECK: amdaie.npu.circular_dma_cpy_nd
// CHECK-SAME: [0, 0] [2, 1000] [0, 1]
// CHECK-SAME: [0, 0, %arg1, %arg2] [2, 10, 10, 10] [400000, 400, %arg0, 1])
// CHECK: amdaie.npu.circular_dma_cpy_nd
// CHECK-SAME: [0, 0] [2, 1000] [0, 1]
// CHECK-SAME: [0, 0, %arg1, %arg2] [2, 10, 10, 10] [400000, 400, %arg0, 1])
// CHECK-NOT: amdaie.npu.circular_dma_cpy_nd
#executable_target_amdaie_xclbin_fb = #hal.executable.target<"amd-aie", "amdaie-xclbin-fb", {target_device = "npu1_4col", ukernels = "none"}>
module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb} {
func.func @with_mixed_offsets(%arg0: index, %arg1 : index, %arg2 : index, %arg3: !amdaie.logicalobjectfifo<memref<bf16>>) {
Expand All @@ -725,3 +727,45 @@ module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb}
return
}
}

// -----

// CHECK-LABEL: @with_nonconst_offset_difference
// CHECK: amdaie.npu.circular_dma_cpy_nd
// CHECK-NEXT: amdaie.npu.circular_dma_cpy_nd
// CHECK-NEXT: amdaie.end
#executable_target_amdaie_xclbin_fb = #hal.executable.target<"amd-aie", "amdaie-xclbin-fb", {target_device = "npu1_4col", ukernels = "none"}>
module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb} {
func.func @with_nonconst_offset_difference(%arg0: index, %arg1 : index, %arg3: !amdaie.logicalobjectfifo<memref<bf16>>) {
amdaie.workgroup {
%0 = amdaie.connection(%arg3, %arg3) : (!amdaie.logicalobjectfifo<memref<bf16>>, !amdaie.logicalobjectfifo<memref<bf16>>)
amdaie.controlcode {
%1 = amdaie.npu.circular_dma_cpy_nd %0([0] [1000] [1], [%arg0, 0, 0] [1, 1, 10] [1, 1, 1])
%2 = amdaie.npu.circular_dma_cpy_nd %0([0] [1000] [1], [0, %arg1, 0] [1, 1, 10] [1, 1, 1])
amdaie.end
}
}
return
}
}

// -----

// CHECK-LABEL: @with_nonconst_offset_product_difference
// CHECK: amdaie.npu.circular_dma_cpy_nd
// CHECK-NEXT: amdaie.npu.circular_dma_cpy_nd
// CHECK-NEXT: amdaie.end
#executable_target_amdaie_xclbin_fb = #hal.executable.target<"amd-aie", "amdaie-xclbin-fb", {target_device = "npu1_4col", ukernels = "none"}>
module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb} {
func.func @with_nonconst_offset_product_difference(%arg0: index, %arg1 : index, %arg3: !amdaie.logicalobjectfifo<memref<bf16>>) {
amdaie.workgroup {
%0 = amdaie.connection(%arg3, %arg3) : (!amdaie.logicalobjectfifo<memref<bf16>>, !amdaie.logicalobjectfifo<memref<bf16>>)
amdaie.controlcode {
%1 = amdaie.npu.circular_dma_cpy_nd %0([0] [1000] [1], [0, %arg0, 0] [1, 1, 10] [1, %arg0, 1])
%2 = amdaie.npu.circular_dma_cpy_nd %0([0] [1000] [1], [0, %arg1, 0] [1, 1, 10] [1, %arg0, 1])
amdaie.end
}
}
return
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,30 @@ module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb}
return
}
}

// -----

// TODO(newling) add another canonicalizer to make
// (1) bring stride 0 to the first (outermost) dimension
// (2) for size-1 dimensions, make the stride constant if possible,
// with these the source side would be
// [0, %arg0, %arg1, %arg2] [2, 1, 1, 1000] [0, 10, 5, 1]
// CHECK-LABEL: @with_different_indices_sliced
// CHECK: amdaie.npu.circular_dma_cpy_nd
// CHECK-SAME: [0, 0] [2, 1000] [0, 1]
// CHECK-SAME: [%arg0, 0, 5, %arg2] [1, 2, 1, 1000] [10, 0, %arg1, 1])
// CHECK-NOT: amdaie.npu.circular_dma_cpy_nd
#executable_target_amdaie_xclbin_fb = #hal.executable.target<"amd-aie", "amdaie-xclbin-fb", {target_device = "npu1_4col", ukernels = "none"}>
module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb} {
func.func @with_different_indices_sliced(%arg0: index, %arg1 : index, %arg2 : index, %arg3: !amdaie.logicalobjectfifo<memref<bf16>>) {
amdaie.workgroup {
%0 = amdaie.connection(%arg3, %arg3) : (!amdaie.logicalobjectfifo<memref<bf16>>, !amdaie.logicalobjectfifo<memref<bf16>>)
amdaie.controlcode {
%1 = amdaie.npu.circular_dma_cpy_nd %0([0] [1000] [1], [%arg0, 0, 5, %arg2] [1, 1, 1, 1000] [10, 10, %arg1, 1])
%2 = amdaie.npu.circular_dma_cpy_nd %0([0] [1000] [1], [0, %arg0, 5, %arg2] [1, 1, 1, 1000] [10, 10, %arg1, 1])
amdaie.end
}
}
return
}
}

0 comments on commit 55cab2a

Please sign in to comment.