Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Add scalar checks for n and fill_value parameters in shift #21292

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

mcrumiller
Copy link
Contributor

@mcrumiller mcrumiller commented Feb 16, 2025

Closes #21280.

New behavior:

import polars as pl

df = pl.DataFrame({
    "x": [1, 2, 3],
    "y": [4, 5, 6]
})

# This used to succeed but silently use y.get(0).
df.shift(2, fill_value=pl.col("y"))
# polars.exceptions.ComputeError: 'fill_value' must be scalar value

df.shift(2, fill_value=pl.col("y").first())  # Ok
# shape: (3, 2)
# ┌─────┬─────┐
# │ x   ┆ y   │
# │ --- ┆ --- │
# │ i64 ┆ i64 │
# ╞═════╪═════╡
# │ 4   ┆ 4   │
# │ 4   ┆ 4   │
# │ 1   ┆ 4   │
# └─────┴─────┘

# Expressions are not in the function signature, but they work. We can still verify.
df.shift(pl.col("y"), fill_value=1)
# polars.exceptions.ComputeError: 'n' must be scalar value

@github-actions github-actions bot added fix Bug fix python Related to Python Polars rust Related to Rust Polars labels Feb 16, 2025
Copy link

codecov bot commented Feb 16, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 79.92%. Comparing base (4e286d8) to head (b408e6a).
Report is 3 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main   #21292   +/-   ##
=======================================
  Coverage   79.92%   79.92%           
=======================================
  Files        1596     1596           
  Lines      228745   228745           
  Branches     2616     2616           
=======================================
+ Hits       182825   182832    +7     
+ Misses      45321    45314    -7     
  Partials      599      599           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@@ -55,7 +55,11 @@ pub(super) fn shift_and_fill(args: &[Column]) -> PolarsResult<Column> {
let logical = s.dtype();
let physical = s.to_physical_repr();
let fill_value_s = &args[2];
let fill_value = fill_value_s.get(0)?;
polars_ensure!(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree about this, but I think we should catch this at the IR. I believe we have more is_scalar checks in place there.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ritchie46, was just looking into this actually. Will follow up shortly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Ritchie not sure this is possible actually. For reference, I put the check in crates/polars-plan/src/plans/conversion/functions.rs, and used all_returns_scalar(expr) on the argument in question--is this the way to do it?

The problem is that when fill_value is a datetime with a time zone, it gets replaced by a replace_time_zone() call, which doesn't return scalar. I can't see any way to ensure before evaluation that we have a scalar when this path is taken.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried updating the replace_time_zone to set its returns_scalar value to be the input's return_scalar value, but I'm getting a 'returns_scalar' and 'elementwise' are mutually exclusive. I can see why this was thought to be true but perhaps they shouldn't be mutually exclusive. For example, if we can show that x is a scalar, and f doesn't change the length, then f(x) should also return scalar even if f is elementwise.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may open a separate PR for this to see what you think.

Copy link
Collaborator

@nameexhaustion nameexhaustion Feb 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need to remove the mutual exclusion

For this PR the function you likely need to call is ExprIR::is_scalar (or is_scalar_ae) -

pub fn is_scalar(&self, expr_arena: &Arena<AExpr>) -> bool {

You should be able to call it to check the input expr at some point during the conversion.

Copy link
Contributor Author

@mcrumiller mcrumiller Feb 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nameexhaustion If you drill down into is_scalar you'll find that the scalar properties come directly from the DSL's Expr.is_scalar. In my example of using dt.round, we have:

ExprIR.is_scalar
is_scalar_ae
AExpr::Function.options.flags.contains(FunctionFlags::RETURNS_SCALAR)

The FunctionFlags::RETURNS_SCALAR in that last call there are set directly in map_many_private. In the particular case of dt.round, this flag is explicitly set to false in round:

pub fn round(self, every: Expr) -> Expr {
self.0.map_many_private(
FunctionExpr::TemporalExpr(TemporalFunction::Round),
&[every],
false,
None,
)
}

Instead of explicitly setting to false, it should be set to true if and only if all of its inputs are scalar, and this should be easily be checked via recursion. However, it can't, because the function is also elementwise.

In general, the issue is that certain function expressions return False for returns_scalar even if we theoretically can say for certain that they do indeed return a scalar (i.e. if a function doesn't change the length and its input is scalar).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, the issue is that certain function expressions return False for returns_scalar even if we theoretically can say for certain that they do indeed return a scalar

There is already a branch in is_scalar_ae that should correctly identify this case -

} else if options.is_elementwise()
|| !options.flags.contains(FunctionFlags::CHANGES_LENGTH)
{
input.iter().all(|e| e.is_scalar(expr_arena))
} else {

I expect it should work if you call it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nameexhaustion I completely missed that branch, you're right, my apologies. Let me look into why this didn't work for dates...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nameexhaustion I think I got it working, thanks a bunch for your help.

@mcrumiller mcrumiller marked this pull request as draft February 17, 2025 12:57
@mcrumiller mcrumiller force-pushed the shift-scalar branch 2 times, most recently from d78645e to 6628d1c Compare February 19, 2025 14:38
@mcrumiller mcrumiller changed the title fix: Add scalar check for fill_value in shift fix: Add scalar checks for n and fill_value parameters in shift Feb 19, 2025
@mcrumiller mcrumiller force-pushed the shift-scalar branch 2 times, most recently from 70fe4fb to 4e0dab7 Compare February 19, 2025 15:16
polars_ensure!(
n_s.len() == 1,
ComputeError: "n must be a single value."
);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed this check here since we now check in the IR.

@mcrumiller mcrumiller marked this pull request as ready for review February 19, 2025 15:44
@mcrumiller
Copy link
Contributor Author

@ritchie46 I've moved the check. Let me know if you think this is a good spot for it. If we have other functions that require argument checks, this section can be switched to a match.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
fix Bug fix python Related to Python Polars rust Related to Rust Polars
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add a scalar check for fill_value in shift()?
3 participants