From d3846eca2061e6e9a8d654551153f7362c27b59a Mon Sep 17 00:00:00 2001
From: Kai Sasaki <lewuathe@gmail.com>
Date: Wed, 25 Dec 2024 12:19:52 +0900
Subject: [PATCH] [mlir] Guard sccp pass from crashing with different source
 type (#120656)

Vector::BroadCastOp expects the identical element type in folding. It
causes the crash if the different source type is given to the SCCP pass.
We need to guard the pass from crashing if the nonidentical element type
is given, but still compatible. (e.g. index vs integer type)

https://github.com/llvm/llvm-project/issues/120193
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 12 ++++++++++--
 mlir/test/Transforms/sccp.mlir           |  9 +++++++++
 2 files changed, 19 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 491b5f44b722b..ae1cf95732336 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2523,8 +2523,16 @@ OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
   if (!adaptor.getSource())
     return {};
   auto vectorType = getResultVectorType();
-  if (llvm::isa<IntegerAttr, FloatAttr>(adaptor.getSource()))
-    return DenseElementsAttr::get(vectorType, adaptor.getSource());
+  if (auto attr = llvm::dyn_cast<IntegerAttr>(adaptor.getSource())) {
+    if (vectorType.getElementType() != attr.getType())
+      return {};
+    return DenseElementsAttr::get(vectorType, attr);
+  }
+  if (auto attr = llvm::dyn_cast<FloatAttr>(adaptor.getSource())) {
+    if (vectorType.getElementType() != attr.getType())
+      return {};
+    return DenseElementsAttr::get(vectorType, attr);
+  }
   if (auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
     return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
   return {};
diff --git a/mlir/test/Transforms/sccp.mlir b/mlir/test/Transforms/sccp.mlir
index dcae052c29c24..c78c8594c0ba5 100644
--- a/mlir/test/Transforms/sccp.mlir
+++ b/mlir/test/Transforms/sccp.mlir
@@ -246,3 +246,12 @@ func.func @op_with_region() -> (i32) {
 ^b:
   return %1 : i32
 }
+
+// CHECK-LABEL: no_crash_with_different_source_type
+func.func @no_crash_with_different_source_type() {
+  // CHECK: llvm.mlir.constant(0 : index) : i64
+  %0 = llvm.mlir.constant(0 : index) : i64
+  // CHECK: vector.broadcast %[[CST:.*]] : i64 to vector<128xi64>
+  %1 = vector.broadcast %0 : i64 to vector<128xi64>
+  llvm.return
+}