Skip to content

Commit

Permalink
[mlir] Guard sccp pass from crashing with different source type (#120…
Browse files Browse the repository at this point in the history
…656)

Vector::BroadCastOp expects the identical element type in folding. It
causes the crash if the different source type is given to the SCCP pass.
We need to guard the pass from crashing if the nonidentical element type
is given, but still compatible. (e.g. index vs integer type)

llvm/llvm-project#120193
  • Loading branch information
Lewuathe authored Dec 25, 2024
1 parent 34f7000 commit d3846ec
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 2 deletions.
12 changes: 10 additions & 2 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2523,8 +2523,16 @@ OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
if (!adaptor.getSource())
return {};
auto vectorType = getResultVectorType();
if (llvm::isa<IntegerAttr, FloatAttr>(adaptor.getSource()))
return DenseElementsAttr::get(vectorType, adaptor.getSource());
if (auto attr = llvm::dyn_cast<IntegerAttr>(adaptor.getSource())) {
if (vectorType.getElementType() != attr.getType())
return {};
return DenseElementsAttr::get(vectorType, attr);
}
if (auto attr = llvm::dyn_cast<FloatAttr>(adaptor.getSource())) {
if (vectorType.getElementType() != attr.getType())
return {};
return DenseElementsAttr::get(vectorType, attr);
}
if (auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
return {};
Expand Down
9 changes: 9 additions & 0 deletions mlir/test/Transforms/sccp.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -246,3 +246,12 @@ func.func @op_with_region() -> (i32) {
^b:
return %1 : i32
}

// CHECK-LABEL: no_crash_with_different_source_type
func.func @no_crash_with_different_source_type() {
// CHECK: llvm.mlir.constant(0 : index) : i64
%0 = llvm.mlir.constant(0 : index) : i64
// CHECK: vector.broadcast %[[CST:.*]] : i64 to vector<128xi64>
%1 = vector.broadcast %0 : i64 to vector<128xi64>
llvm.return
}

0 comments on commit d3846ec

Please sign in to comment.