Skip to content

Commit

Permalink
Add scalar check for shift
Browse files Browse the repository at this point in the history
  • Loading branch information
mcrumiller committed Feb 19, 2025
1 parent 4e286d8 commit 6628d1c
Show file tree
Hide file tree
Showing 8 changed files with 74 additions and 19 deletions.
14 changes: 2 additions & 12 deletions crates/polars-plan/src/dsl/function_expr/shift_and_fill.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,20 +42,14 @@ fn shift_and_fill_with_mask(s: &Column, n: i64, fill_value: &Column) -> PolarsRe

pub(super) fn shift_and_fill(args: &[Column]) -> PolarsResult<Column> {
let s = &args[0];
let n_s = &args[1];

polars_ensure!(
n_s.len() == 1,
ComputeError: "n must be a single value."
);
let n_s = n_s.cast(&DataType::Int64)?;
let n_s = &args[1].cast(&DataType::Int64)?;
let n = n_s.i64()?;

if let Some(n) = n.get(0) {
let logical = s.dtype();
let physical = s.to_physical_repr();
let fill_value_s = &args[2];
let fill_value = fill_value_s.get(0)?;
let fill_value = fill_value_s.get(0).unwrap();

use DataType::*;
match logical {
Expand Down Expand Up @@ -116,10 +110,6 @@ pub(super) fn shift_and_fill(args: &[Column]) -> PolarsResult<Column> {
pub fn shift(args: &[Column]) -> PolarsResult<Column> {
let s = &args[0];
let n_s = &args[1];
polars_ensure!(
n_s.len() == 1,
ComputeError: "n must be a single value."
);

let n_s = n_s.cast(&DataType::Int64)?;
let n = n_s.i64()?;
Expand Down
12 changes: 12 additions & 0 deletions crates/polars-plan/src/plans/conversion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,18 @@ pub(super) fn convert_functions(

let e = to_expr_irs(input, arena)?;

// Validate inputs.
match function {
FunctionExpr::Shift => {
polars_ensure!(&e[1].is_scalar(arena), ComputeError: "'n' must be scalar value");
},
FunctionExpr::ShiftAndFill => {
polars_ensure!(&e[1].is_scalar(arena), ComputeError: "'n' must be scalar value");
polars_ensure!(&e[2].is_scalar(arena), ComputeError: "'fill_value' must be scalar value");
},
_ => {},
}

if state.output_name.is_none() {
// Handles special case functions like `struct.field`.
if let Some(name) = function.output_name() {
Expand Down
1 change: 1 addition & 0 deletions crates/polars-plan/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ pub(crate) fn has_leaf_literal(e: &Expr) -> bool {
pub(crate) fn all_return_scalar(e: &Expr) -> bool {
match e {
Expr::Literal(lv) => lv.is_scalar(),
Expr::Cast { expr, .. } => all_return_scalar(expr),
Expr::Function { options: opt, .. } => opt.flags.contains(FunctionFlags::RETURNS_SCALAR),
Expr::Agg(_) => true,
Expr::Column(_) | Expr::Wildcard => false,
Expand Down
4 changes: 2 additions & 2 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -9302,8 +9302,8 @@ def shift(self, n: int = 1, *, fill_value: IntoExpr | None = None) -> DataFrame:
Number of indices to shift forward. If a negative value is passed, values
are shifted in the opposite direction instead.
fill_value
Fill the resulting null values with this value. Accepts expression input.
Non-expression inputs are parsed as literals.
Fill the resulting null values with this value. Accepts scalar expression
input. Non-expression inputs are parsed as literals.
Notes
-----
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2638,7 +2638,7 @@ def shift(
Number of indices to shift forward. If a negative value is passed, values
are shifted in the opposite direction instead.
fill_value
Fill the resulting null values with this value.
Fill the resulting null values with this scalar value.
Notes
-----
Expand Down
4 changes: 2 additions & 2 deletions py-polars/polars/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -5622,8 +5622,8 @@ def shift(
Number of indices to shift forward. If a negative value is passed, values
are shifted in the opposite direction instead.
fill_value
Fill the resulting null values with this value. Accepts expression input.
Non-expression inputs are parsed as literals.
Fill the resulting null values with this value. Accepts scalar expression
input. Non-expression inputs are parsed as literals.
Notes
-----
Expand Down
4 changes: 2 additions & 2 deletions py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -5436,8 +5436,8 @@ def shift(self, n: int = 1, *, fill_value: IntoExpr | None = None) -> Series:
Number of indices to shift forward. If a negative value is passed, values
are shifted in the opposite direction instead.
fill_value
Fill the resulting null values with this value. Accepts expression input.
Non-expression inputs are parsed as literals.
Fill the resulting null values with this value. Accepts scalar expression
input. Non-expression inputs are parsed as literals.
Notes
-----
Expand Down
52 changes: 52 additions & 0 deletions py-polars/tests/unit/operations/test_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

import polars as pl
from polars.exceptions import ComputeError
from polars.testing import assert_frame_equal, assert_series_equal


Expand Down Expand Up @@ -119,3 +120,54 @@ def test_shift_fill_value_group_logicals() -> None:
result = df.select(pl.col("d").shift(fill_value=pl.col("d").max(), n=-1).over("s"))

assert result.dtypes == [pl.Date]


def test_shift_n_null() -> None:
df = pl.DataFrame({"a": pl.Series([1, 2, 3], dtype=pl.Int32)})
out = df.shift(None)
expected = pl.DataFrame({"a": pl.Series([None, None, None], dtype=pl.Int32)})
assert_frame_equal(out, expected)

out = df.shift(None, fill_value=1)
assert_frame_equal(out, expected)

out = df.select(pl.col("a").shift(None))
assert_frame_equal(out, expected)

out = df.select(pl.col("a").shift(None, fill_value=1))
assert_frame_equal(out, expected)


def test_shift_n_nonscalar() -> None:
df = pl.DataFrame(
{
"a": [1, 2, 3],
"b": [4, 5, 6],
}
)

with pytest.raises(ComputeError, match="'n' must be scalar value"):
df.select(pl.col("a").shift(pl.col("b")))

with pytest.raises(ComputeError, match="'n' must be scalar value"):
df.select(pl.col("a").shift(pl.col("b"), fill_value=1))


def test_shift_fill_value_nonscalar() -> None:
df = pl.DataFrame(
{
"a": [1, 2, 3],
"b": [4, 5, 6],
}
)
with pytest.raises(
ComputeError,
match="'fill_value' must be scalar value",
):
df.shift(1, fill_value=pl.col("b"))

with pytest.raises(
ComputeError,
match="'fill_value' must be scalar value",
):
df.select(pl.col("a").shift(1, fill_value=pl.col("b")))

0 comments on commit 6628d1c

Please sign in to comment.