Skip to content

Commit 5de0f36

Browse files
committed
add test
1 parent b3f51e5 commit 5de0f36

File tree

1 file changed

+76
-37
lines changed
  • datafusion/physical-expr/src/expressions

1 file changed

+76
-37
lines changed

datafusion/physical-expr/src/expressions/binary.rs

Lines changed: 76 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -358,43 +358,6 @@ impl PhysicalExpr for BinaryExpr {
358358
fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
359359
use arrow::compute::kernels::numeric::*;
360360

361-
fn check_short_circuit(arg: &ColumnarValue, op: &Operator) -> bool {
362-
let data_type = arg.data_type();
363-
match (data_type, op) {
364-
(DataType::Boolean, Operator::And) => {
365-
match arg {
366-
ColumnarValue::Array(array) => {
367-
if let Ok(array) = as_boolean_array(&array) {
368-
return array.false_count() == array.len();
369-
}
370-
}
371-
ColumnarValue::Scalar(scalar) => {
372-
if let ScalarValue::Boolean(Some(value)) = scalar {
373-
return !value;
374-
}
375-
}
376-
}
377-
false
378-
}
379-
(DataType::Boolean, Operator::Or) => {
380-
match arg {
381-
ColumnarValue::Array(array) => {
382-
if let Ok(array) = as_boolean_array(&array) {
383-
return array.true_count() == array.len();
384-
}
385-
}
386-
ColumnarValue::Scalar(scalar) => {
387-
if let ScalarValue::Boolean(Some(value)) = scalar {
388-
return *value;
389-
}
390-
}
391-
}
392-
false
393-
}
394-
_ => false,
395-
}
396-
}
397-
398361
let lhs = self.left.evaluate(batch)?;
399362

400363
// Optimize for short-circuiting `Operator::And` or `Operator::Or` operations and return early.
@@ -848,6 +811,47 @@ impl BinaryExpr {
848811
}
849812
}
850813

814+
/// Check if it meets the short-circuit condition
815+
/// 1. For the `AND` operator, if the `lhs` result all are `false`
816+
/// 2. For the `OR` operator, if the `lhs` result all are `true`
817+
/// 3. Otherwise, it does not meet the short-circuit condition
818+
fn check_short_circuit(arg: &ColumnarValue, op: &Operator) -> bool {
819+
let data_type = arg.data_type();
820+
match (data_type, op) {
821+
(DataType::Boolean, Operator::And) => {
822+
match arg {
823+
ColumnarValue::Array(array) => {
824+
if let Ok(array) = as_boolean_array(&array) {
825+
return array.false_count() == array.len();
826+
}
827+
}
828+
ColumnarValue::Scalar(scalar) => {
829+
if let ScalarValue::Boolean(Some(value)) = scalar {
830+
return !value;
831+
}
832+
}
833+
}
834+
false
835+
}
836+
(DataType::Boolean, Operator::Or) => {
837+
match arg {
838+
ColumnarValue::Array(array) => {
839+
if let Ok(array) = as_boolean_array(&array) {
840+
return array.true_count() == array.len();
841+
}
842+
}
843+
ColumnarValue::Scalar(scalar) => {
844+
if let ScalarValue::Boolean(Some(value)) = scalar {
845+
return *value;
846+
}
847+
}
848+
}
849+
false
850+
}
851+
_ => false,
852+
}
853+
}
854+
851855
fn concat_elements(left: Arc<dyn Array>, right: Arc<dyn Array>) -> Result<ArrayRef> {
852856
Ok(match left.data_type() {
853857
DataType::Utf8 => Arc::new(concat_elements_utf8(
@@ -4875,4 +4879,39 @@ mod tests {
48754879

48764880
Ok(())
48774881
}
4882+
4883+
#[test]
4884+
fn test_check_short_circuit() {
4885+
use crate::planner::logical2physical;
4886+
use datafusion_expr::col as logical_col;
4887+
use datafusion_expr::lit;
4888+
let schema = Arc::new(Schema::new(vec![
4889+
Field::new("a", DataType::Int32, false),
4890+
Field::new("b", DataType::Int32, false),
4891+
]));
4892+
let a_array = Int32Array::from(vec![1, 3, 4, 5, 6]);
4893+
let b_array = Int32Array::from(vec![1, 2, 3, 4, 5]);
4894+
let batch = RecordBatch::try_new(
4895+
Arc::clone(&schema),
4896+
vec![Arc::new(a_array), Arc::new(b_array)],
4897+
)
4898+
.unwrap();
4899+
4900+
// op: AND left: all false
4901+
let left_expr = logical2physical(&logical_col("a").eq(lit(2)), &schema);
4902+
let left_value = left_expr.evaluate(&batch).unwrap();
4903+
assert!(check_short_circuit(&left_value, &Operator::And));
4904+
// op: AND left: not all false
4905+
let left_expr = logical2physical(&logical_col("a").eq(lit(3)), &schema);
4906+
let left_value = left_expr.evaluate(&batch).unwrap();
4907+
assert!(!check_short_circuit(&left_value, &Operator::And));
4908+
// op: OR left: all true
4909+
let left_expr = logical2physical(&logical_col("a").gt(lit(0)), &schema);
4910+
let left_value = left_expr.evaluate(&batch).unwrap();
4911+
assert!(check_short_circuit(&left_value, &Operator::Or));
4912+
// op: OR left: not all true
4913+
let left_expr = logical2physical(&logical_col("a").gt(lit(2)), &schema);
4914+
let left_value = left_expr.evaluate(&batch).unwrap();
4915+
assert!(!check_short_circuit(&left_value, &Operator::Or));
4916+
}
48784917
}

0 commit comments

Comments
 (0)