Skip to content

Commit

Permalink
refactor(expr): don't fallback to evaluation by row on error (#14174)
Browse files Browse the repository at this point in the history
Signed-off-by: Runji Wang <wangrunji0408@163.com>
  • Loading branch information
wangrunji0408 authored Dec 28, 2023
1 parent d8ab569 commit ee6c1e9
Show file tree
Hide file tree
Showing 7 changed files with 188 additions and 126 deletions.
45 changes: 43 additions & 2 deletions src/expr/core/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use risingwave_common::array::ArrayError;
use std::fmt::Display;

use risingwave_common::array::{ArrayError, ArrayRef};
use risingwave_common::error::{ErrorCode, RwError};
use risingwave_common::types::DataType;
use risingwave_pb::PbFieldNotFound;
use thiserror::Error;
use thiserror_ext::AsReport;

/// A specialized Result type for expression operations.
pub type Result<T> = std::result::Result<T, ExprError>;
pub type Result<T, E = ExprError> = std::result::Result<T, E>;

pub struct ContextUnavailable(&'static str);

Expand All @@ -39,6 +41,10 @@ impl From<ContextUnavailable> for ExprError {
/// The error type for expression operations.
#[derive(Error, Debug)]
pub enum ExprError {
/// A collection of multiple errors in batch evaluation.
#[error("multiple errors:\n{1}")]
Multiple(ArrayRef, MultiExprError),

// Ideally "Unsupported" errors are caught by frontend. But when the match arms between
// frontend and backend are inconsistent, we do not panic with `unreachable!`.
#[error("Unsupported function: {0}")]
Expand Down Expand Up @@ -135,3 +141,38 @@ impl From<PbFieldNotFound> for ExprError {
))
}
}

/// A collection of multiple errors.
#[derive(Error, Debug)]
pub struct MultiExprError(Box<[ExprError]>);

impl MultiExprError {
/// Returns the first error.
pub fn into_first(self) -> ExprError {
self.0.into_vec().into_iter().next().expect("first error")
}
}

impl Display for MultiExprError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
for (i, e) in self.0.iter().enumerate() {
writeln!(f, "{i}: {e}")?;
}
Ok(())
}
}

impl From<Vec<ExprError>> for MultiExprError {
fn from(v: Vec<ExprError>) -> Self {
Self(v.into_boxed_slice())
}
}

impl IntoIterator for MultiExprError {
type IntoIter = std::vec::IntoIter<ExprError>;
type Item = ExprError;

fn into_iter(self) -> Self::IntoIter {
self.0.into_vec().into_iter()
}
}
19 changes: 7 additions & 12 deletions src/expr/core/src/expr/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use risingwave_pb::expr::ExprNode;

use super::expr_some_all::SomeAllExpression;
use super::expr_udf::UdfExpression;
use super::non_strict::NonStrictNoFallback;
use super::strict::Strict;
use super::wrapper::checked::Checked;
use super::wrapper::non_strict::NonStrict;
use super::wrapper::EvalErrorReport;
Expand All @@ -34,7 +34,8 @@ use crate::{bail, ExprError, Result};

/// Build an expression from protobuf.
pub fn build_from_prost(prost: &ExprNode) -> Result<BoxedExpression> {
ExprBuilder::new_strict().build(prost)
let expr = ExprBuilder::new_strict().build(prost)?;
Ok(Strict::new(expr).boxed())
}

/// Build an expression from protobuf in non-strict mode.
Expand Down Expand Up @@ -76,15 +77,11 @@ where

/// Attach wrappers to an expression.
#[expect(clippy::let_and_return)]
fn wrap(&self, expr: impl Expression + 'static, no_fallback: bool) -> BoxedExpression {
fn wrap(&self, expr: impl Expression + 'static) -> BoxedExpression {
let checked = Checked(expr);

let may_non_strict = if let Some(error_report) = &self.error_report {
if no_fallback {
NonStrictNoFallback::new(checked, error_report.clone()).boxed()
} else {
NonStrict::new(checked, error_report.clone()).boxed()
}
NonStrict::new(checked, error_report.clone()).boxed()
} else {
checked.boxed()
};
Expand All @@ -95,9 +92,7 @@ where
/// Build an expression with `build_inner` and attach some wrappers.
fn build(&self, prost: &ExprNode) -> Result<BoxedExpression> {
let expr = self.build_inner(prost)?;
// no fallback to row-based evaluation for UDF
let no_fallback = matches!(prost.get_rex_node().unwrap(), RexNode::Udf(_));
Ok(self.wrap(expr, no_fallback))
Ok(self.wrap(expr))
}

/// Build an expression from protobuf.
Expand Down Expand Up @@ -216,7 +211,7 @@ pub fn build_func_non_strict(
error_report: impl EvalErrorReport + 'static,
) -> Result<NonStrictExpression> {
let expr = build_func(func, ret_type, children)?;
let wrapped = NonStrictExpression(ExprBuilder::new_non_strict(error_report).wrap(expr, false));
let wrapped = NonStrictExpression(ExprBuilder::new_non_strict(error_report).wrap(expr));

Ok(wrapped)
}
Expand Down
1 change: 1 addition & 0 deletions src/expr/core/src/expr/wrapper/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@

pub(crate) mod checked;
pub(crate) mod non_strict;
pub(crate) mod strict;

pub use non_strict::{EvalErrorReport, LogReport};
120 changes: 20 additions & 100 deletions src/expr/core/src/expr/wrapper/non_strict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use async_trait::async_trait;
use auto_impl::auto_impl;
use risingwave_common::array::{ArrayRef, DataChunk};
use risingwave_common::log::LogSuppresser;
use risingwave_common::row::{OwnedRow, Row};
use risingwave_common::row::OwnedRow;
use risingwave_common::types::{DataType, Datum};
use thiserror_ext::AsReport;

Expand Down Expand Up @@ -60,8 +60,7 @@ impl EvalErrorReport for LogReport {
}

/// A wrapper of [`Expression`] that evaluates in a non-strict way. Basically...
/// - When an error occurs during chunk-level evaluation, recompute in row-based execution and pad
/// with NULL for each failed row.
/// - When an error occurs during chunk-level evaluation, pad with NULL for each failed row.
/// - Report all error occurred during row-level evaluation to the [`EvalErrorReport`].
pub(crate) struct NonStrict<E, R> {
inner: E,
Expand All @@ -88,31 +87,6 @@ where
pub fn new(inner: E, report: R) -> Self {
Self { inner, report }
}

/// Evaluate expression in row-based execution with `eval_row_infallible`.
async fn eval_chunk_infallible_by_row(&self, input: &DataChunk) -> ArrayRef {
let mut array_builder = self.return_type().create_array_builder(input.capacity());
for row in input.rows_with_holes() {
if let Some(row) = row {
let datum = self.eval_row_infallible(&row.into_owned_row()).await; // TODO: use `Row` trait
array_builder.append(&datum);
} else {
array_builder.append_null();
}
}
array_builder.finish().into()
}

/// Evaluate expression on a single row, report error and return NULL if failed.
async fn eval_row_infallible(&self, input: &OwnedRow) -> Datum {
match self.inner.eval_row(input).await {
Ok(datum) => datum,
Err(error) => {
self.report.report(error);
None // NULL
}
}
}
}

// TODO: avoid the overhead of extra boxing.
Expand All @@ -129,75 +103,14 @@ where
async fn eval(&self, input: &DataChunk) -> Result<ArrayRef> {
Ok(match self.inner.eval(input).await {
Ok(array) => array,
Err(_e) => self.eval_chunk_infallible_by_row(input).await,
})
}

async fn eval_v2(&self, input: &DataChunk) -> Result<ValueImpl> {
Ok(match self.inner.eval_v2(input).await {
Ok(value) => value,
Err(_e) => self.eval_chunk_infallible_by_row(input).await.into(),
})
}

async fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
Ok(self.eval_row_infallible(input).await)
}

fn eval_const(&self) -> Result<Datum> {
self.inner.eval_const() // do not handle error
}

fn input_ref_index(&self) -> Option<usize> {
self.inner.input_ref_index()
}
}

/// Similar to [`NonStrict`] wrapper, but does not fallback to row-based evaluation when an error occurs.
pub(crate) struct NonStrictNoFallback<E, R> {
inner: E,
report: R,
}

impl<E, R> std::fmt::Debug for NonStrictNoFallback<E, R>
where
E: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("NonStrictNoFallback")
.field("inner", &self.inner)
.field("report", &std::any::type_name::<R>())
.finish()
}
}

impl<E, R> NonStrictNoFallback<E, R>
where
E: Expression,
R: EvalErrorReport,
{
pub fn new(inner: E, report: R) -> Self {
Self { inner, report }
}
}

// TODO: avoid the overhead of extra boxing.
#[async_trait]
impl<E, R> Expression for NonStrictNoFallback<E, R>
where
E: Expression,
R: EvalErrorReport,
{
fn return_type(&self) -> DataType {
self.inner.return_type()
}

async fn eval(&self, input: &DataChunk) -> Result<ArrayRef> {
Ok(match self.inner.eval(input).await {
Ok(array) => array,
Err(error) => {
self.report.report(error);
// no fallback and return NULL for each row
Err(ExprError::Multiple(array, errors)) => {
for error in errors {
self.report.report(error);
}
array
}
Err(e) => {
self.report.report(e);
let mut builder = self.return_type().create_array_builder(input.capacity());
builder.append_n_null(input.capacity());
builder.finish().into()
Expand All @@ -207,9 +120,15 @@ where

async fn eval_v2(&self, input: &DataChunk) -> Result<ValueImpl> {
Ok(match self.inner.eval_v2(input).await {
Ok(value) => value,
Err(error) => {
self.report.report(error);
Ok(array) => array,
Err(ExprError::Multiple(array, errors)) => {
for error in errors {
self.report.report(error);
}
array.into()
}
Err(e) => {
self.report.report(e);
ValueImpl::Scalar {
value: None,
capacity: input.capacity(),
Expand All @@ -218,6 +137,7 @@ where
})
}

/// Evaluate expression on a single row, report error and return NULL if failed.
async fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
Ok(match self.inner.eval_row(input).await {
Ok(datum) => datum,
Expand Down
83 changes: 83 additions & 0 deletions src/expr/core/src/expr/wrapper/strict.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Copyright 2023 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use async_trait::async_trait;
use risingwave_common::array::{ArrayRef, DataChunk};
use risingwave_common::row::OwnedRow;
use risingwave_common::types::{DataType, Datum};

use crate::error::Result;
use crate::expr::{Expression, ValueImpl};
use crate::ExprError;

/// A wrapper of [`Expression`] that only keeps the first error if multiple errors are returned.
pub(crate) struct Strict<E> {
inner: E,
}

impl<E> std::fmt::Debug for Strict<E>
where
E: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Strict")
.field("inner", &self.inner)
.finish()
}
}

impl<E> Strict<E>
where
E: Expression,
{
pub fn new(inner: E) -> Self {
Self { inner }
}
}

#[async_trait]
impl<E> Expression for Strict<E>
where
E: Expression,
{
fn return_type(&self) -> DataType {
self.inner.return_type()
}

async fn eval(&self, input: &DataChunk) -> Result<ArrayRef> {
match self.inner.eval(input).await {
Err(ExprError::Multiple(_, errors)) => Err(errors.into_first()),
res => res,
}
}

async fn eval_v2(&self, input: &DataChunk) -> Result<ValueImpl> {
match self.inner.eval_v2(input).await {
Err(ExprError::Multiple(_, errors)) => Err(errors.into_first()),
res => res,
}
}

async fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
self.inner.eval_row(input).await
}

fn eval_const(&self) -> Result<Datum> {
self.inner.eval_const()
}

fn input_ref_index(&self) -> Option<usize> {
self.inner.input_ref_index()
}
}
4 changes: 2 additions & 2 deletions src/expr/impl/src/scalar/arithmetic_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ where
}

#[function("abs(decimal) -> decimal")]
pub fn decimal_abs(decimal: Decimal) -> Result<Decimal> {
Ok(Decimal::abs(&decimal))
pub fn decimal_abs(decimal: Decimal) -> Decimal {
Decimal::abs(&decimal)
}

fn err_pow_zero_negative() -> ExprError {
Expand Down
Loading

0 comments on commit ee6c1e9

Please sign in to comment.