From d3846eca2061e6e9a8d654551153f7362c27b59a Mon Sep 17 00:00:00 2001 From: Kai Sasaki Date: Wed, 25 Dec 2024 12:19:52 +0900 Subject: [PATCH] [mlir] Guard sccp pass from crashing with different source type (#120656) Vector::BroadCastOp expects the identical element type in folding. It causes the crash if the different source type is given to the SCCP pass. We need to guard the pass from crashing if the nonidentical element type is given, but still compatible. (e.g. index vs integer type) https://github.com/llvm/llvm-project/issues/120193 --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 12 ++++++++++-- mlir/test/Transforms/sccp.mlir | 9 +++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 491b5f44b722b..ae1cf95732336 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2523,8 +2523,16 @@ OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) { if (!adaptor.getSource()) return {}; auto vectorType = getResultVectorType(); - if (llvm::isa(adaptor.getSource())) - return DenseElementsAttr::get(vectorType, adaptor.getSource()); + if (auto attr = llvm::dyn_cast(adaptor.getSource())) { + if (vectorType.getElementType() != attr.getType()) + return {}; + return DenseElementsAttr::get(vectorType, attr); + } + if (auto attr = llvm::dyn_cast(adaptor.getSource())) { + if (vectorType.getElementType() != attr.getType()) + return {}; + return DenseElementsAttr::get(vectorType, attr); + } if (auto attr = llvm::dyn_cast(adaptor.getSource())) return DenseElementsAttr::get(vectorType, attr.getSplatValue()); return {}; diff --git a/mlir/test/Transforms/sccp.mlir b/mlir/test/Transforms/sccp.mlir index dcae052c29c24..c78c8594c0ba5 100644 --- a/mlir/test/Transforms/sccp.mlir +++ b/mlir/test/Transforms/sccp.mlir @@ -246,3 +246,12 @@ func.func @op_with_region() -> (i32) { ^b: return %1 : i32 } + +// CHECK-LABEL: no_crash_with_different_source_type +func.func @no_crash_with_different_source_type() { + // CHECK: llvm.mlir.constant(0 : index) : i64 + %0 = llvm.mlir.constant(0 : index) : i64 + // CHECK: vector.broadcast %[[CST:.*]] : i64 to vector<128xi64> + %1 = vector.broadcast %0 : i64 to vector<128xi64> + llvm.return +}