From ee6c1e9430e6b74b645c5cdd9ee44f6c7a419624 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Thu, 28 Dec 2023 12:01:35 +0800 Subject: [PATCH] refactor(expr): don't fallback to evaluation by row on error (#14174) Signed-off-by: Runji Wang --- src/expr/core/src/error.rs | 45 ++++++- src/expr/core/src/expr/build.rs | 19 ++- src/expr/core/src/expr/wrapper/mod.rs | 1 + src/expr/core/src/expr/wrapper/non_strict.rs | 120 ++++--------------- src/expr/core/src/expr/wrapper/strict.rs | 83 +++++++++++++ src/expr/impl/src/scalar/arithmetic_op.rs | 4 +- src/expr/macro/src/gen.rs | 42 +++++-- 7 files changed, 188 insertions(+), 126 deletions(-) create mode 100644 src/expr/core/src/expr/wrapper/strict.rs diff --git a/src/expr/core/src/error.rs b/src/expr/core/src/error.rs index 63e43180e6d0..851ba673a863 100644 --- a/src/expr/core/src/error.rs +++ b/src/expr/core/src/error.rs @@ -12,7 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -use risingwave_common::array::ArrayError; +use std::fmt::Display; + +use risingwave_common::array::{ArrayError, ArrayRef}; use risingwave_common::error::{ErrorCode, RwError}; use risingwave_common::types::DataType; use risingwave_pb::PbFieldNotFound; @@ -20,7 +22,7 @@ use thiserror::Error; use thiserror_ext::AsReport; /// A specialized Result type for expression operations. -pub type Result = std::result::Result; +pub type Result = std::result::Result; pub struct ContextUnavailable(&'static str); @@ -39,6 +41,10 @@ impl From for ExprError { /// The error type for expression operations. #[derive(Error, Debug)] pub enum ExprError { + /// A collection of multiple errors in batch evaluation. + #[error("multiple errors:\n{1}")] + Multiple(ArrayRef, MultiExprError), + // Ideally "Unsupported" errors are caught by frontend. But when the match arms between // frontend and backend are inconsistent, we do not panic with `unreachable!`. #[error("Unsupported function: {0}")] @@ -135,3 +141,38 @@ impl From for ExprError { )) } } + +/// A collection of multiple errors. +#[derive(Error, Debug)] +pub struct MultiExprError(Box<[ExprError]>); + +impl MultiExprError { + /// Returns the first error. + pub fn into_first(self) -> ExprError { + self.0.into_vec().into_iter().next().expect("first error") + } +} + +impl Display for MultiExprError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + for (i, e) in self.0.iter().enumerate() { + writeln!(f, "{i}: {e}")?; + } + Ok(()) + } +} + +impl From> for MultiExprError { + fn from(v: Vec) -> Self { + Self(v.into_boxed_slice()) + } +} + +impl IntoIterator for MultiExprError { + type IntoIter = std::vec::IntoIter; + type Item = ExprError; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_vec().into_iter() + } +} diff --git a/src/expr/core/src/expr/build.rs b/src/expr/core/src/expr/build.rs index 1e40022fbe17..5b08f2173cec 100644 --- a/src/expr/core/src/expr/build.rs +++ b/src/expr/core/src/expr/build.rs @@ -21,7 +21,7 @@ use risingwave_pb::expr::ExprNode; use super::expr_some_all::SomeAllExpression; use super::expr_udf::UdfExpression; -use super::non_strict::NonStrictNoFallback; +use super::strict::Strict; use super::wrapper::checked::Checked; use super::wrapper::non_strict::NonStrict; use super::wrapper::EvalErrorReport; @@ -34,7 +34,8 @@ use crate::{bail, ExprError, Result}; /// Build an expression from protobuf. pub fn build_from_prost(prost: &ExprNode) -> Result { - ExprBuilder::new_strict().build(prost) + let expr = ExprBuilder::new_strict().build(prost)?; + Ok(Strict::new(expr).boxed()) } /// Build an expression from protobuf in non-strict mode. @@ -76,15 +77,11 @@ where /// Attach wrappers to an expression. #[expect(clippy::let_and_return)] - fn wrap(&self, expr: impl Expression + 'static, no_fallback: bool) -> BoxedExpression { + fn wrap(&self, expr: impl Expression + 'static) -> BoxedExpression { let checked = Checked(expr); let may_non_strict = if let Some(error_report) = &self.error_report { - if no_fallback { - NonStrictNoFallback::new(checked, error_report.clone()).boxed() - } else { - NonStrict::new(checked, error_report.clone()).boxed() - } + NonStrict::new(checked, error_report.clone()).boxed() } else { checked.boxed() }; @@ -95,9 +92,7 @@ where /// Build an expression with `build_inner` and attach some wrappers. fn build(&self, prost: &ExprNode) -> Result { let expr = self.build_inner(prost)?; - // no fallback to row-based evaluation for UDF - let no_fallback = matches!(prost.get_rex_node().unwrap(), RexNode::Udf(_)); - Ok(self.wrap(expr, no_fallback)) + Ok(self.wrap(expr)) } /// Build an expression from protobuf. @@ -216,7 +211,7 @@ pub fn build_func_non_strict( error_report: impl EvalErrorReport + 'static, ) -> Result { let expr = build_func(func, ret_type, children)?; - let wrapped = NonStrictExpression(ExprBuilder::new_non_strict(error_report).wrap(expr, false)); + let wrapped = NonStrictExpression(ExprBuilder::new_non_strict(error_report).wrap(expr)); Ok(wrapped) } diff --git a/src/expr/core/src/expr/wrapper/mod.rs b/src/expr/core/src/expr/wrapper/mod.rs index 16988a050ad8..b3864876059e 100644 --- a/src/expr/core/src/expr/wrapper/mod.rs +++ b/src/expr/core/src/expr/wrapper/mod.rs @@ -14,5 +14,6 @@ pub(crate) mod checked; pub(crate) mod non_strict; +pub(crate) mod strict; pub use non_strict::{EvalErrorReport, LogReport}; diff --git a/src/expr/core/src/expr/wrapper/non_strict.rs b/src/expr/core/src/expr/wrapper/non_strict.rs index e1ed69a3e359..fa819d810814 100644 --- a/src/expr/core/src/expr/wrapper/non_strict.rs +++ b/src/expr/core/src/expr/wrapper/non_strict.rs @@ -18,7 +18,7 @@ use async_trait::async_trait; use auto_impl::auto_impl; use risingwave_common::array::{ArrayRef, DataChunk}; use risingwave_common::log::LogSuppresser; -use risingwave_common::row::{OwnedRow, Row}; +use risingwave_common::row::OwnedRow; use risingwave_common::types::{DataType, Datum}; use thiserror_ext::AsReport; @@ -60,8 +60,7 @@ impl EvalErrorReport for LogReport { } /// A wrapper of [`Expression`] that evaluates in a non-strict way. Basically... -/// - When an error occurs during chunk-level evaluation, recompute in row-based execution and pad -/// with NULL for each failed row. +/// - When an error occurs during chunk-level evaluation, pad with NULL for each failed row. /// - Report all error occurred during row-level evaluation to the [`EvalErrorReport`]. pub(crate) struct NonStrict { inner: E, @@ -88,31 +87,6 @@ where pub fn new(inner: E, report: R) -> Self { Self { inner, report } } - - /// Evaluate expression in row-based execution with `eval_row_infallible`. - async fn eval_chunk_infallible_by_row(&self, input: &DataChunk) -> ArrayRef { - let mut array_builder = self.return_type().create_array_builder(input.capacity()); - for row in input.rows_with_holes() { - if let Some(row) = row { - let datum = self.eval_row_infallible(&row.into_owned_row()).await; // TODO: use `Row` trait - array_builder.append(&datum); - } else { - array_builder.append_null(); - } - } - array_builder.finish().into() - } - - /// Evaluate expression on a single row, report error and return NULL if failed. - async fn eval_row_infallible(&self, input: &OwnedRow) -> Datum { - match self.inner.eval_row(input).await { - Ok(datum) => datum, - Err(error) => { - self.report.report(error); - None // NULL - } - } - } } // TODO: avoid the overhead of extra boxing. @@ -129,75 +103,14 @@ where async fn eval(&self, input: &DataChunk) -> Result { Ok(match self.inner.eval(input).await { Ok(array) => array, - Err(_e) => self.eval_chunk_infallible_by_row(input).await, - }) - } - - async fn eval_v2(&self, input: &DataChunk) -> Result { - Ok(match self.inner.eval_v2(input).await { - Ok(value) => value, - Err(_e) => self.eval_chunk_infallible_by_row(input).await.into(), - }) - } - - async fn eval_row(&self, input: &OwnedRow) -> Result { - Ok(self.eval_row_infallible(input).await) - } - - fn eval_const(&self) -> Result { - self.inner.eval_const() // do not handle error - } - - fn input_ref_index(&self) -> Option { - self.inner.input_ref_index() - } -} - -/// Similar to [`NonStrict`] wrapper, but does not fallback to row-based evaluation when an error occurs. -pub(crate) struct NonStrictNoFallback { - inner: E, - report: R, -} - -impl std::fmt::Debug for NonStrictNoFallback -where - E: std::fmt::Debug, -{ - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("NonStrictNoFallback") - .field("inner", &self.inner) - .field("report", &std::any::type_name::()) - .finish() - } -} - -impl NonStrictNoFallback -where - E: Expression, - R: EvalErrorReport, -{ - pub fn new(inner: E, report: R) -> Self { - Self { inner, report } - } -} - -// TODO: avoid the overhead of extra boxing. -#[async_trait] -impl Expression for NonStrictNoFallback -where - E: Expression, - R: EvalErrorReport, -{ - fn return_type(&self) -> DataType { - self.inner.return_type() - } - - async fn eval(&self, input: &DataChunk) -> Result { - Ok(match self.inner.eval(input).await { - Ok(array) => array, - Err(error) => { - self.report.report(error); - // no fallback and return NULL for each row + Err(ExprError::Multiple(array, errors)) => { + for error in errors { + self.report.report(error); + } + array + } + Err(e) => { + self.report.report(e); let mut builder = self.return_type().create_array_builder(input.capacity()); builder.append_n_null(input.capacity()); builder.finish().into() @@ -207,9 +120,15 @@ where async fn eval_v2(&self, input: &DataChunk) -> Result { Ok(match self.inner.eval_v2(input).await { - Ok(value) => value, - Err(error) => { - self.report.report(error); + Ok(array) => array, + Err(ExprError::Multiple(array, errors)) => { + for error in errors { + self.report.report(error); + } + array.into() + } + Err(e) => { + self.report.report(e); ValueImpl::Scalar { value: None, capacity: input.capacity(), @@ -218,6 +137,7 @@ where }) } + /// Evaluate expression on a single row, report error and return NULL if failed. async fn eval_row(&self, input: &OwnedRow) -> Result { Ok(match self.inner.eval_row(input).await { Ok(datum) => datum, diff --git a/src/expr/core/src/expr/wrapper/strict.rs b/src/expr/core/src/expr/wrapper/strict.rs new file mode 100644 index 000000000000..5eab4beecd76 --- /dev/null +++ b/src/expr/core/src/expr/wrapper/strict.rs @@ -0,0 +1,83 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use risingwave_common::array::{ArrayRef, DataChunk}; +use risingwave_common::row::OwnedRow; +use risingwave_common::types::{DataType, Datum}; + +use crate::error::Result; +use crate::expr::{Expression, ValueImpl}; +use crate::ExprError; + +/// A wrapper of [`Expression`] that only keeps the first error if multiple errors are returned. +pub(crate) struct Strict { + inner: E, +} + +impl std::fmt::Debug for Strict +where + E: std::fmt::Debug, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Strict") + .field("inner", &self.inner) + .finish() + } +} + +impl Strict +where + E: Expression, +{ + pub fn new(inner: E) -> Self { + Self { inner } + } +} + +#[async_trait] +impl Expression for Strict +where + E: Expression, +{ + fn return_type(&self) -> DataType { + self.inner.return_type() + } + + async fn eval(&self, input: &DataChunk) -> Result { + match self.inner.eval(input).await { + Err(ExprError::Multiple(_, errors)) => Err(errors.into_first()), + res => res, + } + } + + async fn eval_v2(&self, input: &DataChunk) -> Result { + match self.inner.eval_v2(input).await { + Err(ExprError::Multiple(_, errors)) => Err(errors.into_first()), + res => res, + } + } + + async fn eval_row(&self, input: &OwnedRow) -> Result { + self.inner.eval_row(input).await + } + + fn eval_const(&self) -> Result { + self.inner.eval_const() + } + + fn input_ref_index(&self) -> Option { + self.inner.input_ref_index() + } +} diff --git a/src/expr/impl/src/scalar/arithmetic_op.rs b/src/expr/impl/src/scalar/arithmetic_op.rs index f12bf6dc5e64..88187f65d8c8 100644 --- a/src/expr/impl/src/scalar/arithmetic_op.rs +++ b/src/expr/impl/src/scalar/arithmetic_op.rs @@ -154,8 +154,8 @@ where } #[function("abs(decimal) -> decimal")] -pub fn decimal_abs(decimal: Decimal) -> Result { - Ok(Decimal::abs(&decimal)) +pub fn decimal_abs(decimal: Decimal) -> Decimal { + Decimal::abs(&decimal) } fn err_pow_zero_negative() -> ExprError { diff --git a/src/expr/macro/src/gen.rs b/src/expr/macro/src/gen.rs index 2ecac59994e8..89b1fa199180 100644 --- a/src/expr/macro/src/gen.rs +++ b/src/expr/macro/src/gen.rs @@ -326,8 +326,18 @@ impl FunctionAttr { _ if self.ret == "void" => quote! { { #output; Option::::None } }, ReturnTypeKind::T => quote! { Some(#output) }, ReturnTypeKind::Option => output, - ReturnTypeKind::Result => quote! { Some(#output?) }, - ReturnTypeKind::ResultOption => quote! { #output? }, + ReturnTypeKind::Result => quote! { + match #output { + Ok(x) => Some(x), + Err(e) => { errors.push(e); None } + } + }, + ReturnTypeKind::ResultOption => quote! { + match #output { + Ok(x) => x, + Err(e) => { errors.push(e); None } + } + }, }; // if user function accepts non-option arguments, we assume the function // returns null on null input, so we need to unwrap the inputs before calling. @@ -382,7 +392,7 @@ impl FunctionAttr { let fn_name = format_ident!("{}", batch_fn); quote! { let c = #fn_name(#(#arrays),*); - Ok(Arc::new(c.into())) + Arc::new(c.into()) } } else if (types::is_primitive(&self.ret) || self.ret == "boolean") && user_fn.is_pure() @@ -396,14 +406,14 @@ impl FunctionAttr { std::iter::repeat_with(|| #fn_name()).take(input.capacity()) Bitmap::ones(input.capacity()), ); - Ok(Arc::new(c.into())) + Arc::new(c.into()) }, 1 => quote! { let c = #ret_array_type::from_iter_bitmap( a0.raw_iter().map(|a| #fn_name(a)), a0.null_bitmap().clone() ); - Ok(Arc::new(c.into())) + Arc::new(c.into()) }, 2 => quote! { // allow using `zip` for performance @@ -414,7 +424,7 @@ impl FunctionAttr { .map(|(a, b)| #fn_name #generic(a, b)), a0.null_bitmap() & a1.null_bitmap(), ); - Ok(Arc::new(c.into())) + Arc::new(c.into()) }, n => todo!("SIMD optimization for {n} arguments"), } @@ -449,7 +459,7 @@ impl FunctionAttr { #append_output } } - Ok(Arc::new(builder.finish().into())) + Arc::new(builder.finish().into()) } }; @@ -465,7 +475,7 @@ impl FunctionAttr { use risingwave_common::util::iter_util::ZipEqFast; use risingwave_expr::expr::{Context, BoxedExpression}; - use risingwave_expr::Result; + use risingwave_expr::{ExprError, Result}; use risingwave_expr::codegen::*; #check_children @@ -492,7 +502,13 @@ impl FunctionAttr { let #arrays: &#arg_arrays = #array_refs.as_ref().into(); )* #eval_variadic - #eval + let mut errors = vec![]; + let array = { #eval }; + if errors.is_empty() { + Ok(array) + } else { + Err(ExprError::Multiple(array, errors.into())) + } } async fn eval_row(&self, input: &OwnedRow) -> Result { #( @@ -500,7 +516,13 @@ impl FunctionAttr { let #inputs: Option<#arg_types> = #datums.as_ref().map(|s| s.as_scalar_ref_impl().try_into().unwrap()); )* #eval_row_variadic - Ok(#row_output) + let mut errors: Vec = vec![]; + let output = #row_output; + if let Some(err) = errors.into_iter().next() { + Err(err.into()) + } else { + Ok(output) + } } }