Skip to content

Commit

Permalink
[mlir] Fix the wrong computation of dynamic strides for lowering Allo…
Browse files Browse the repository at this point in the history
…cOp to LLVM

Leftover change from before the MLIR merge, reviewed at accepted at
tensorflow/mlir#338.
  • Loading branch information
tungld authored and arichardson committed Jan 20, 2020
2 parents 90b5566 + e5957ac commit 0b8b528
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
9 changes: 5 additions & 4 deletions mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1054,14 +1054,15 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
// Iterate strides in reverse order, compute runningStride and strideValues.
auto nStrides = strides.size();
SmallVector<Value, 4> strideValues(nStrides, nullptr);
for (auto indexedStride : llvm::enumerate(llvm::reverse(strides))) {
int64_t index = nStrides - 1 - indexedStride.index();
for (unsigned i = 0; i < nStrides; ++i) {
int64_t index = nStrides - 1 - i;
if (strides[index] == MemRefType::getDynamicStrideOrOffset())
// Identity layout map is enforced in the match function, so we compute:
// `runningStride *= sizes[index]`
// `runningStride *= sizes[index + 1]`
runningStride =
runningStride
? rewriter.create<LLVM::MulOp>(loc, runningStride, sizes[index])
? rewriter.create<LLVM::MulOp>(loc, runningStride,
sizes[index + 1])
: createIndexConstant(rewriter, loc, 1);
else
runningStride = createIndexConstant(rewriter, loc, strides[index]);
Expand Down
6 changes: 3 additions & 3 deletions mlir/test/Conversion/StandardToLLVM/convert-memref-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ func @mixed_alloc(%arg0: index, %arg1: index) -> memref<?x42x?xf32> {
// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK-NEXT: llvm.insertvalue %[[off]], %{{.*}}[2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
// CHECK-NEXT: %[[st2:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK-NEXT: %[[st1:.*]] = llvm.mul %{{.*}}, %[[c42]] : !llvm.i64
// CHECK-NEXT: %[[st0:.*]] = llvm.mul %{{.*}}, %[[M]] : !llvm.i64
// CHECK-NEXT: %[[st1:.*]] = llvm.mul %{{.*}}, %[[N]] : !llvm.i64
// CHECK-NEXT: %[[st0:.*]] = llvm.mul %{{.*}}, %[[c42]] : !llvm.i64
// CHECK-NEXT: llvm.insertvalue %[[M]], %{{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
// CHECK-NEXT: llvm.insertvalue %[[st0]], %{{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
// CHECK-NEXT: llvm.insertvalue %[[c42]], %{{.*}}[3, 1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
Expand Down Expand Up @@ -142,7 +142,7 @@ func @dynamic_alloc(%arg0: index, %arg1: index) -> memref<?x?xf32> {
// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK-NEXT: llvm.insertvalue %[[off]], %{{.*}}[2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK-NEXT: %[[st0:.*]] = llvm.mul %{{.*}}, %[[M]] : !llvm.i64
// CHECK-NEXT: %[[st0:.*]] = llvm.mul %{{.*}}, %[[N]] : !llvm.i64
// CHECK-NEXT: llvm.insertvalue %[[M]], %{{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK-NEXT: llvm.insertvalue %[[st0]], %{{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK-NEXT: llvm.insertvalue %[[N]], %{{.*}}[3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
Expand Down

0 comments on commit 0b8b528

Please sign in to comment.