@@ -358,43 +358,6 @@ impl PhysicalExpr for BinaryExpr {
358358 fn evaluate ( & self , batch : & RecordBatch ) -> Result < ColumnarValue > {
359359 use arrow:: compute:: kernels:: numeric:: * ;
360360
361- fn check_short_circuit ( arg : & ColumnarValue , op : & Operator ) -> bool {
362- let data_type = arg. data_type ( ) ;
363- match ( data_type, op) {
364- ( DataType :: Boolean , Operator :: And ) => {
365- match arg {
366- ColumnarValue :: Array ( array) => {
367- if let Ok ( array) = as_boolean_array ( & array) {
368- return array. false_count ( ) == array. len ( ) ;
369- }
370- }
371- ColumnarValue :: Scalar ( scalar) => {
372- if let ScalarValue :: Boolean ( Some ( value) ) = scalar {
373- return !value;
374- }
375- }
376- }
377- false
378- }
379- ( DataType :: Boolean , Operator :: Or ) => {
380- match arg {
381- ColumnarValue :: Array ( array) => {
382- if let Ok ( array) = as_boolean_array ( & array) {
383- return array. true_count ( ) == array. len ( ) ;
384- }
385- }
386- ColumnarValue :: Scalar ( scalar) => {
387- if let ScalarValue :: Boolean ( Some ( value) ) = scalar {
388- return * value;
389- }
390- }
391- }
392- false
393- }
394- _ => false ,
395- }
396- }
397-
398361 let lhs = self . left . evaluate ( batch) ?;
399362
400363 // Optimize for short-circuiting `Operator::And` or `Operator::Or` operations and return early.
@@ -848,6 +811,47 @@ impl BinaryExpr {
848811 }
849812}
850813
814+ /// Check if it meets the short-circuit condition
815+ /// 1. For the `AND` operator, if the `lhs` result all are `false`
816+ /// 2. For the `OR` operator, if the `lhs` result all are `true`
817+ /// 3. Otherwise, it does not meet the short-circuit condition
818+ fn check_short_circuit ( arg : & ColumnarValue , op : & Operator ) -> bool {
819+ let data_type = arg. data_type ( ) ;
820+ match ( data_type, op) {
821+ ( DataType :: Boolean , Operator :: And ) => {
822+ match arg {
823+ ColumnarValue :: Array ( array) => {
824+ if let Ok ( array) = as_boolean_array ( & array) {
825+ return array. false_count ( ) == array. len ( ) ;
826+ }
827+ }
828+ ColumnarValue :: Scalar ( scalar) => {
829+ if let ScalarValue :: Boolean ( Some ( value) ) = scalar {
830+ return !value;
831+ }
832+ }
833+ }
834+ false
835+ }
836+ ( DataType :: Boolean , Operator :: Or ) => {
837+ match arg {
838+ ColumnarValue :: Array ( array) => {
839+ if let Ok ( array) = as_boolean_array ( & array) {
840+ return array. true_count ( ) == array. len ( ) ;
841+ }
842+ }
843+ ColumnarValue :: Scalar ( scalar) => {
844+ if let ScalarValue :: Boolean ( Some ( value) ) = scalar {
845+ return * value;
846+ }
847+ }
848+ }
849+ false
850+ }
851+ _ => false ,
852+ }
853+ }
854+
851855fn concat_elements ( left : Arc < dyn Array > , right : Arc < dyn Array > ) -> Result < ArrayRef > {
852856 Ok ( match left. data_type ( ) {
853857 DataType :: Utf8 => Arc :: new ( concat_elements_utf8 (
@@ -4875,4 +4879,39 @@ mod tests {
48754879
48764880 Ok ( ( ) )
48774881 }
4882+
4883+ #[ test]
4884+ fn test_check_short_circuit ( ) {
4885+ use crate :: planner:: logical2physical;
4886+ use datafusion_expr:: col as logical_col;
4887+ use datafusion_expr:: lit;
4888+ let schema = Arc :: new ( Schema :: new ( vec ! [
4889+ Field :: new( "a" , DataType :: Int32 , false ) ,
4890+ Field :: new( "b" , DataType :: Int32 , false ) ,
4891+ ] ) ) ;
4892+ let a_array = Int32Array :: from ( vec ! [ 1 , 3 , 4 , 5 , 6 ] ) ;
4893+ let b_array = Int32Array :: from ( vec ! [ 1 , 2 , 3 , 4 , 5 ] ) ;
4894+ let batch = RecordBatch :: try_new (
4895+ Arc :: clone ( & schema) ,
4896+ vec ! [ Arc :: new( a_array) , Arc :: new( b_array) ] ,
4897+ )
4898+ . unwrap ( ) ;
4899+
4900+ // op: AND left: all false
4901+ let left_expr = logical2physical ( & logical_col ( "a" ) . eq ( lit ( 2 ) ) , & schema) ;
4902+ let left_value = left_expr. evaluate ( & batch) . unwrap ( ) ;
4903+ assert ! ( check_short_circuit( & left_value, & Operator :: And ) ) ;
4904+ // op: AND left: not all false
4905+ let left_expr = logical2physical ( & logical_col ( "a" ) . eq ( lit ( 3 ) ) , & schema) ;
4906+ let left_value = left_expr. evaluate ( & batch) . unwrap ( ) ;
4907+ assert ! ( !check_short_circuit( & left_value, & Operator :: And ) ) ;
4908+ // op: OR left: all true
4909+ let left_expr = logical2physical ( & logical_col ( "a" ) . gt ( lit ( 0 ) ) , & schema) ;
4910+ let left_value = left_expr. evaluate ( & batch) . unwrap ( ) ;
4911+ assert ! ( check_short_circuit( & left_value, & Operator :: Or ) ) ;
4912+ // op: OR left: not all true
4913+ let left_expr = logical2physical ( & logical_col ( "a" ) . gt ( lit ( 2 ) ) , & schema) ;
4914+ let left_value = left_expr. evaluate ( & batch) . unwrap ( ) ;
4915+ assert ! ( !check_short_circuit( & left_value, & Operator :: Or ) ) ;
4916+ }
48784917}
0 commit comments