From 7f040897605d9c2992b961eb24592ddae3e255d9 Mon Sep 17 00:00:00 2001 From: Luke Manley Date: Wed, 29 Jan 2025 20:42:32 -0500 Subject: [PATCH] fix: Fix all-null list aggregations returning Null dtype --- .../src/chunked_array/list/min_max.rs | 50 ++++++++++++------- .../src/chunked_array/list/sum_mean.rs | 25 ++++++---- .../operations/namespaces/list/test_list.py | 14 ++++++ 3 files changed, 62 insertions(+), 27 deletions(-) diff --git a/crates/polars-ops/src/chunked_array/list/min_max.rs b/crates/polars-ops/src/chunked_array/list/min_max.rs index ee76b74e7579..f626781f1147 100644 --- a/crates/polars-ops/src/chunked_array/list/min_max.rs +++ b/crates/polars-ops/src/chunked_array/list/min_max.rs @@ -89,15 +89,22 @@ pub(super) fn list_min_function(ca: &ListChunked) -> PolarsResult { unsafe { out.into_series().from_physical_unchecked(dt) } }) }, - _ => Ok(ca - .try_apply_amortized(|s| { - let s = s.as_ref(); - let sc = s.min_reduce()?; - Ok(sc.into_series(s.name().clone())) - })? - .explode() - .unwrap() - .into_series()), + dt => { + let out = ca + .try_apply_amortized(|s| { + let s = s.as_ref(); + let sc = s.min_reduce()?; + Ok(sc.into_series(s.name().clone())) + })? + .explode() + .unwrap() + .into_series(); + + match out.dtype() { + DataType::Null => out.cast(dt), + _ => Ok(out), + } + }, } } @@ -199,15 +206,22 @@ pub(super) fn list_max_function(ca: &ListChunked) -> PolarsResult { unsafe { out.into_series().from_physical_unchecked(dt) } }) }, - _ => Ok(ca - .try_apply_amortized(|s| { - let s = s.as_ref(); - let sc = s.max_reduce()?; - Ok(sc.into_series(s.name().clone())) - })? - .explode() - .unwrap() - .into_series()), + dt => { + let out = ca + .try_apply_amortized(|s| { + let s = s.as_ref(); + let sc = s.max_reduce()?; + Ok(sc.into_series(s.name().clone())) + })? + .explode() + .unwrap() + .into_series(); + + match out.dtype() { + DataType::Null => out.cast(dt), + _ => Ok(out), + } + }, } } diff --git a/crates/polars-ops/src/chunked_array/list/sum_mean.rs b/crates/polars-ops/src/chunked_array/list/sum_mean.rs index fb26cb17872e..054fe647598c 100644 --- a/crates/polars-ops/src/chunked_array/list/sum_mean.rs +++ b/crates/polars-ops/src/chunked_array/list/sum_mean.rs @@ -106,15 +106,22 @@ pub(super) fn sum_with_nulls(ca: &ListChunked, inner_dtype: &DataType) -> Polars out.into_series() }, // slowest sum_as_series path - _ => ca - .try_apply_amortized(|s| { - s.as_ref() - .sum_reduce() - .map(|sc| sc.into_series(PlSmallStr::EMPTY)) - })? - .explode() - .unwrap() - .into_series(), + dt => { + let s = ca + .try_apply_amortized(|s| { + s.as_ref() + .sum_reduce() + .map(|sc| sc.into_series(PlSmallStr::EMPTY)) + })? + .explode() + .unwrap() + .into_series(); + + match s.dtype() { + DataType::Null => s.cast(dt)?, + _ => s, + } + }, }; out.rename(ca.name().clone()); Ok(out) diff --git a/py-polars/tests/unit/operations/namespaces/list/test_list.py b/py-polars/tests/unit/operations/namespaces/list/test_list.py index bed606e7c7ef..2d45f40b931c 100644 --- a/py-polars/tests/unit/operations/namespaces/list/test_list.py +++ b/py-polars/tests/unit/operations/namespaces/list/test_list.py @@ -2,6 +2,7 @@ import re from datetime import date, datetime +from typing import TYPE_CHECKING import numpy as np import pytest @@ -14,6 +15,9 @@ ) from polars.testing import assert_frame_equal, assert_series_equal +if TYPE_CHECKING: + from polars._typing import PolarsDataType + def test_list_arr_get() -> None: a = pl.Series("a", [[1, 2, 3], [4, 5], [6, 7, 8, 9]]) @@ -979,3 +983,13 @@ def test_list_eval_element_schema_19345() -> None: ), pl.DataFrame({"a": [[1]]}), ) + + +@pytest.mark.parametrize( + "inner_dtype", + [pl.Int8, pl.Float64, pl.String, pl.Duration], +) +def test_list_agg_all_null(inner_dtype: PolarsDataType) -> None: + s = pl.Series([None, None], dtype=pl.List(inner_dtype)) + assert s.list.min().dtype == inner_dtype + assert s.list.max().dtype == inner_dtype