From 66b17e939e3629b6f36471d10e18675891d53b82 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Wed, 8 Jan 2025 13:55:20 +0000 Subject: [PATCH] fixup --- narwhals/_duckdb/expr.py | 24 ++++++++++++++++++++++++ tests/expr_and_series/n_unique_test.py | 6 +----- tests/expr_and_series/unary_test.py | 6 +----- tests/group_by_test.py | 8 +------- tpch/execute.py | 2 +- 5 files changed, 28 insertions(+), 18 deletions(-) diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index 4515cbba1..297d9e64a 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -408,6 +408,30 @@ def sum(self) -> Self: lambda _input: FunctionExpression("sum", _input), "sum", returns_scalar=True ) + def n_unique(self) -> Self: + from duckdb import CaseExpression + from duckdb import ConstantExpression + from duckdb import FunctionExpression + + def func(_input: duckdb.Expression) -> duckdb.Expression: + return ( + FunctionExpression( + "array_unique", FunctionExpression("array_agg", _input) + ) + + FunctionExpression( + "max", + CaseExpression( + condition=_input.isnotnull(), value=ConstantExpression(0) + ).otherwise(ConstantExpression(1)), + ) + ).alias("result") + + return self._from_call( + func, + "n_unique", + returns_scalar=True, + ) + def count(self) -> Self: from duckdb import FunctionExpression diff --git a/tests/expr_and_series/n_unique_test.py b/tests/expr_and_series/n_unique_test.py index d8e4d9b77..90bffb04b 100644 --- a/tests/expr_and_series/n_unique_test.py +++ b/tests/expr_and_series/n_unique_test.py @@ -1,7 +1,5 @@ from __future__ import annotations -import pytest - import narwhals.stable.v1 as nw from tests.utils import Constructor from tests.utils import ConstructorEager @@ -13,9 +11,7 @@ } -def test_n_unique(constructor: Constructor, request: pytest.FixtureRequest) -> None: - if "duckdb" in str(constructor): - request.applymarker(pytest.mark.xfail) +def test_n_unique(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select(nw.all().n_unique()) expected = {"a": [3], "b": [4]} diff --git a/tests/expr_and_series/unary_test.py b/tests/expr_and_series/unary_test.py index 9ee38a230..f3e01d80f 100644 --- a/tests/expr_and_series/unary_test.py +++ b/tests/expr_and_series/unary_test.py @@ -126,11 +126,7 @@ def test_unary_two_elements_series(constructor_eager: ConstructorEager) -> None: assert_equal_data(result, expected) -def test_unary_one_element( - constructor: Constructor, request: pytest.FixtureRequest -) -> None: - if "duckdb" in str(constructor): - request.applymarker(pytest.mark.xfail) +def test_unary_one_element(constructor: Constructor) -> None: data = {"a": [1], "b": [2], "c": [None]} # Dask runs into a divide by zero RuntimeWarning for 1 element skew. context = ( diff --git a/tests/group_by_test.py b/tests/group_by_test.py index 0dd6d8a10..c854da453 100644 --- a/tests/group_by_test.py +++ b/tests/group_by_test.py @@ -115,8 +115,6 @@ def test_group_by_depth_1_agg( expected: dict[str, list[int | float]], request: pytest.FixtureRequest, ) -> None: - if "duckdb" in str(constructor) and attr == "n_unique": - request.applymarker(pytest.mark.xfail) if "pandas_pyarrow" in str(constructor) and attr == "var" and PANDAS_VERSION < (2, 1): # Known issue with variance calculation in pandas 2.0.x with pyarrow backend in groupby operations" request.applymarker(pytest.mark.xfail) @@ -166,11 +164,7 @@ def test_group_by_median(constructor: Constructor) -> None: assert_equal_data(result, expected) -def test_group_by_n_unique_w_missing( - constructor: Constructor, request: pytest.FixtureRequest -) -> None: - if "duckdb" in str(constructor): - request.applymarker(pytest.mark.xfail) +def test_group_by_n_unique_w_missing(constructor: Constructor) -> None: data = {"a": [1, 1, 2], "b": [4, None, 5], "c": [None, None, 7], "d": [1, 1, 3]} result = ( nw.from_native(constructor(data)) diff --git a/tpch/execute.py b/tpch/execute.py index 1f3823ced..f2f3041df 100644 --- a/tpch/execute.py +++ b/tpch/execute.py @@ -40,7 +40,7 @@ "dask": lambda x: x.compute(), } -DUCKDB_XFAILS = ["q11", "q14", "q15", "q16", "q18", "q22"] +DUCKDB_XFAILS = ["q11", "q14", "q15", "q18", "q22"] QUERY_DATA_PATH_MAP = { "q1": (LINEITEM_PATH,),