diff --git a/crates/polars-plan/src/dsl/function_expr/shift_and_fill.rs b/crates/polars-plan/src/dsl/function_expr/shift_and_fill.rs index 3230e96a98d1..ca44fb69c292 100644 --- a/crates/polars-plan/src/dsl/function_expr/shift_and_fill.rs +++ b/crates/polars-plan/src/dsl/function_expr/shift_and_fill.rs @@ -42,20 +42,14 @@ fn shift_and_fill_with_mask(s: &Column, n: i64, fill_value: &Column) -> PolarsRe pub(super) fn shift_and_fill(args: &[Column]) -> PolarsResult { let s = &args[0]; - let n_s = &args[1]; - - polars_ensure!( - n_s.len() == 1, - ComputeError: "n must be a single value." - ); - let n_s = n_s.cast(&DataType::Int64)?; + let n_s = &args[1].cast(&DataType::Int64)?; let n = n_s.i64()?; if let Some(n) = n.get(0) { let logical = s.dtype(); let physical = s.to_physical_repr(); let fill_value_s = &args[2]; - let fill_value = fill_value_s.get(0)?; + let fill_value = fill_value_s.get(0).unwrap(); use DataType::*; match logical { @@ -116,10 +110,6 @@ pub(super) fn shift_and_fill(args: &[Column]) -> PolarsResult { pub fn shift(args: &[Column]) -> PolarsResult { let s = &args[0]; let n_s = &args[1]; - polars_ensure!( - n_s.len() == 1, - ComputeError: "n must be a single value." - ); let n_s = n_s.cast(&DataType::Int64)?; let n = n_s.i64()?; diff --git a/crates/polars-plan/src/plans/conversion/functions.rs b/crates/polars-plan/src/plans/conversion/functions.rs index a2ec9e453dcc..4e0da3434dc1 100644 --- a/crates/polars-plan/src/plans/conversion/functions.rs +++ b/crates/polars-plan/src/plans/conversion/functions.rs @@ -49,6 +49,18 @@ pub(super) fn convert_functions( let e = to_expr_irs(input, arena)?; + // Validate inputs. + match function { + FunctionExpr::Shift => { + polars_ensure!(&e[1].is_scalar(arena), ComputeError: "'n' must be scalar value"); + }, + FunctionExpr::ShiftAndFill => { + polars_ensure!(&e[1].is_scalar(arena), ComputeError: "'n' must be scalar value"); + polars_ensure!(&e[2].is_scalar(arena), ComputeError: "'fill_value' must be scalar value"); + }, + _ => {}, + } + if state.output_name.is_none() { // Handles special case functions like `struct.field`. if let Some(name) = function.output_name() { diff --git a/crates/polars-plan/src/utils.rs b/crates/polars-plan/src/utils.rs index 54350cfaa559..312b2516376b 100644 --- a/crates/polars-plan/src/utils.rs +++ b/crates/polars-plan/src/utils.rs @@ -88,6 +88,7 @@ pub(crate) fn has_leaf_literal(e: &Expr) -> bool { pub(crate) fn all_return_scalar(e: &Expr) -> bool { match e { Expr::Literal(lv) => lv.is_scalar(), + Expr::Cast { expr, .. } => all_return_scalar(expr), Expr::Function { options: opt, .. } => opt.flags.contains(FunctionFlags::RETURNS_SCALAR), Expr::Agg(_) => true, Expr::Column(_) | Expr::Wildcard => false, diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index e899d822d1e5..707c12745056 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -9302,8 +9302,8 @@ def shift(self, n: int = 1, *, fill_value: IntoExpr | None = None) -> DataFrame: Number of indices to shift forward. If a negative value is passed, values are shifted in the opposite direction instead. fill_value - Fill the resulting null values with this value. Accepts expression input. - Non-expression inputs are parsed as literals. + Fill the resulting null values with this value. Accepts scalar expression + input. Non-expression inputs are parsed as literals. Notes ----- diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index c142f7eb67f8..5c4d294a45c7 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -2638,7 +2638,7 @@ def shift( Number of indices to shift forward. If a negative value is passed, values are shifted in the opposite direction instead. fill_value - Fill the resulting null values with this value. + Fill the resulting null values with this scalar value. Notes ----- diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index a69be6a72588..9beda1961ac1 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -5622,8 +5622,8 @@ def shift( Number of indices to shift forward. If a negative value is passed, values are shifted in the opposite direction instead. fill_value - Fill the resulting null values with this value. Accepts expression input. - Non-expression inputs are parsed as literals. + Fill the resulting null values with this value. Accepts scalar expression + input. Non-expression inputs are parsed as literals. Notes ----- diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index 6d4baf5c0ead..e63ec9116bba 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -5436,8 +5436,8 @@ def shift(self, n: int = 1, *, fill_value: IntoExpr | None = None) -> Series: Number of indices to shift forward. If a negative value is passed, values are shifted in the opposite direction instead. fill_value - Fill the resulting null values with this value. Accepts expression input. - Non-expression inputs are parsed as literals. + Fill the resulting null values with this value. Accepts scalar expression + input. Non-expression inputs are parsed as literals. Notes ----- diff --git a/py-polars/tests/unit/operations/test_shift.py b/py-polars/tests/unit/operations/test_shift.py index 62a4bd313563..8073ef886b4f 100644 --- a/py-polars/tests/unit/operations/test_shift.py +++ b/py-polars/tests/unit/operations/test_shift.py @@ -5,6 +5,7 @@ import pytest import polars as pl +from polars.exceptions import ComputeError from polars.testing import assert_frame_equal, assert_series_equal @@ -119,3 +120,54 @@ def test_shift_fill_value_group_logicals() -> None: result = df.select(pl.col("d").shift(fill_value=pl.col("d").max(), n=-1).over("s")) assert result.dtypes == [pl.Date] + + +def test_shift_n_null() -> None: + df = pl.DataFrame({"a": pl.Series([1, 2, 3], dtype=pl.Int32)}) + out = df.shift(None) + expected = pl.DataFrame({"a": pl.Series([None, None, None], dtype=pl.Int32)}) + assert_frame_equal(out, expected) + + out = df.shift(None, fill_value=1) + assert_frame_equal(out, expected) + + out = df.select(pl.col("a").shift(None)) + assert_frame_equal(out, expected) + + out = df.select(pl.col("a").shift(None, fill_value=1)) + assert_frame_equal(out, expected) + + +def test_shift_n_nonscalar() -> None: + df = pl.DataFrame( + { + "a": [1, 2, 3], + "b": [4, 5, 6], + } + ) + + with pytest.raises(ComputeError, match="'n' must be scalar value"): + df.select(pl.col("a").shift(pl.col("b"))) + + with pytest.raises(ComputeError, match="'n' must be scalar value"): + df.select(pl.col("a").shift(pl.col("b"), fill_value=1)) + + +def test_shift_fill_value_nonscalar() -> None: + df = pl.DataFrame( + { + "a": [1, 2, 3], + "b": [4, 5, 6], + } + ) + with pytest.raises( + ComputeError, + match="'fill_value' must be scalar value", + ): + df.shift(1, fill_value=pl.col("b")) + + with pytest.raises( + ComputeError, + match="'fill_value' must be scalar value", + ): + df.select(pl.col("a").shift(1, fill_value=pl.col("b")))