Skip to content

Commit

Permalink
fix: Fix all-null list aggregations returning Null dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
lukemanley committed Jan 30, 2025
1 parent 0bc6cb8 commit 7f04089
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 27 deletions.
50 changes: 32 additions & 18 deletions crates/polars-ops/src/chunked_array/list/min_max.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,22 @@ pub(super) fn list_min_function(ca: &ListChunked) -> PolarsResult<Series> {
unsafe { out.into_series().from_physical_unchecked(dt) }
})
},
_ => Ok(ca
.try_apply_amortized(|s| {
let s = s.as_ref();
let sc = s.min_reduce()?;
Ok(sc.into_series(s.name().clone()))
})?
.explode()
.unwrap()
.into_series()),
dt => {
let out = ca
.try_apply_amortized(|s| {
let s = s.as_ref();
let sc = s.min_reduce()?;
Ok(sc.into_series(s.name().clone()))
})?
.explode()
.unwrap()
.into_series();

match out.dtype() {
DataType::Null => out.cast(dt),
_ => Ok(out),
}
},
}
}

Expand Down Expand Up @@ -199,15 +206,22 @@ pub(super) fn list_max_function(ca: &ListChunked) -> PolarsResult<Series> {
unsafe { out.into_series().from_physical_unchecked(dt) }
})
},
_ => Ok(ca
.try_apply_amortized(|s| {
let s = s.as_ref();
let sc = s.max_reduce()?;
Ok(sc.into_series(s.name().clone()))
})?
.explode()
.unwrap()
.into_series()),
dt => {
let out = ca
.try_apply_amortized(|s| {
let s = s.as_ref();
let sc = s.max_reduce()?;
Ok(sc.into_series(s.name().clone()))
})?
.explode()
.unwrap()
.into_series();

match out.dtype() {
DataType::Null => out.cast(dt),
_ => Ok(out),
}
},
}
}

Expand Down
25 changes: 16 additions & 9 deletions crates/polars-ops/src/chunked_array/list/sum_mean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,22 @@ pub(super) fn sum_with_nulls(ca: &ListChunked, inner_dtype: &DataType) -> Polars
out.into_series()
},
// slowest sum_as_series path
_ => ca
.try_apply_amortized(|s| {
s.as_ref()
.sum_reduce()
.map(|sc| sc.into_series(PlSmallStr::EMPTY))
})?
.explode()
.unwrap()
.into_series(),
dt => {
let s = ca
.try_apply_amortized(|s| {
s.as_ref()
.sum_reduce()
.map(|sc| sc.into_series(PlSmallStr::EMPTY))
})?
.explode()
.unwrap()
.into_series();

match s.dtype() {
DataType::Null => s.cast(dt)?,
_ => s,
}
},
};
out.rename(ca.name().clone());
Ok(out)
Expand Down
14 changes: 14 additions & 0 deletions py-polars/tests/unit/operations/namespaces/list/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import re
from datetime import date, datetime
from typing import TYPE_CHECKING

import numpy as np
import pytest
Expand All @@ -14,6 +15,9 @@
)
from polars.testing import assert_frame_equal, assert_series_equal

if TYPE_CHECKING:
from polars._typing import PolarsDataType


def test_list_arr_get() -> None:
a = pl.Series("a", [[1, 2, 3], [4, 5], [6, 7, 8, 9]])
Expand Down Expand Up @@ -979,3 +983,13 @@ def test_list_eval_element_schema_19345() -> None:
),
pl.DataFrame({"a": [[1]]}),
)


@pytest.mark.parametrize(
"inner_dtype",
[pl.Int8, pl.Float64, pl.String, pl.Duration],
)
def test_list_agg_all_null(inner_dtype: PolarsDataType) -> None:
s = pl.Series([None, None], dtype=pl.List(inner_dtype))
assert s.list.min().dtype == inner_dtype
assert s.list.max().dtype == inner_dtype

0 comments on commit 7f04089

Please sign in to comment.