Skip to content

Commit

Permalink
feat: allow any ddof value for duckdb var and std (#1858)
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi authored Jan 24, 2025
1 parent 0f38b77 commit 67b1ae4
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 51 deletions.
37 changes: 21 additions & 16 deletions narwhals/_duckdb/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,27 +460,32 @@ def len(self: Self) -> Self:
)

def std(self: Self, ddof: int) -> Self:
if ddof == 1:
func = "stddev_samp"
elif ddof == 0:
func = "stddev_pop"
else:
msg = f"std with ddof {ddof} is not supported in DuckDB"
raise NotImplementedError(msg)
def _std(_input: duckdb.Expression, ddof: int) -> duckdb.Expression:
n_samples = FunctionExpression("count", _input)

return (
FunctionExpression("stddev_pop", _input)
* FunctionExpression("sqrt", n_samples)
/ (FunctionExpression("sqrt", (n_samples - ddof)))
)

return self._from_call(
lambda _input: FunctionExpression(func, _input), "std", returns_scalar=True
_std,
"std",
ddof=ddof,
returns_scalar=True,
)

def var(self: Self, ddof: int) -> Self:
if ddof == 1:
func = "var_samp"
elif ddof == 0:
func = "var_pop"
else:
msg = f"var with ddof {ddof} is not supported in DuckDB"
raise NotImplementedError(msg)
def _var(_input: duckdb.Expression, ddof: int) -> duckdb.Expression:
n_samples = FunctionExpression("count", _input)
return FunctionExpression("var_pop", _input) * n_samples / (n_samples - ddof)

return self._from_call(
lambda _input: FunctionExpression(func, _input), "var", returns_scalar=True
_var,
"var",
ddof=ddof,
returns_scalar=True,
)

def max(self: Self) -> Self:
Expand Down
21 changes: 7 additions & 14 deletions tests/expr_and_series/std_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

from contextlib import nullcontext as does_not_raise

import pytest

import narwhals.stable.v1 as nw
Expand Down Expand Up @@ -34,19 +32,14 @@ def test_std(constructor: Constructor, input_data: dict[str, list[float | None]]
"z_ddof_0": [0.816497],
}
assert_equal_data(result, expected_results)
context = (
pytest.raises(NotImplementedError)
if "duckdb" in str(constructor)
else does_not_raise()

result = df.select(
nw.col("b").std(ddof=2).alias("b_ddof_2"),
)
with context:
result = df.select(
nw.col("b").std(ddof=2).alias("b_ddof_2"),
)
expected_results = {
"b_ddof_2": [1.632993],
}
assert_equal_data(result, expected_results)
expected_results = {
"b_ddof_2": [1.632993],
}
assert_equal_data(result, expected_results)


@pytest.mark.parametrize("input_data", [data, data_with_nulls])
Expand Down
21 changes: 7 additions & 14 deletions tests/expr_and_series/var_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

from contextlib import nullcontext as does_not_raise

import pytest

import narwhals.stable.v1 as nw
Expand Down Expand Up @@ -34,19 +32,14 @@ def test_var(constructor: Constructor, input_data: dict[str, list[float | None]]
"z_ddof_0": [0.6666666666666666],
}
assert_equal_data(result, expected_results)
context = (
pytest.raises(NotImplementedError)
if "duckdb" in str(constructor)
else does_not_raise()

result = df.select(
nw.col("b").var(ddof=2).alias("b_ddof_2"),
)
with context:
result = df.select(
nw.col("b").var(ddof=2).alias("b_ddof_2"),
)
expected_results = {
"b_ddof_2": [2.666666666666667],
}
assert_equal_data(result, expected_results)
expected_results = {
"b_ddof_2": [2.666666666666667],
}
assert_equal_data(result, expected_results)


@pytest.mark.parametrize("input_data", [data, data_with_nulls])
Expand Down
8 changes: 1 addition & 7 deletions tests/group_by_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,7 @@ def test_group_by_depth_1_agg(
("var", 2),
],
)
def test_group_by_depth_1_std_var(
constructor: Constructor, attr: str, ddof: int, request: pytest.FixtureRequest
) -> None:
if "duckdb" in str(constructor) and ddof == 2:
request.applymarker(pytest.mark.xfail)
def test_group_by_depth_1_std_var(constructor: Constructor, attr: str, ddof: int) -> None:
data = {"a": [1, 1, 1, 2, 2, 2], "b": [4, 5, 6, 0, 5, 5]}
_pow = 0.5 if attr == "std" else 1
expected = {
Expand Down Expand Up @@ -398,8 +394,6 @@ def test_all_kind_of_aggs(
# and modin lol https://github.com/modin-project/modin/issues/7414
# and cudf https://github.com/rapidsai/cudf/issues/17649
request.applymarker(pytest.mark.xfail)
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
if "pandas" in str(constructor) and PANDAS_VERSION < (1, 4):
# Bug in old pandas, can't do DataFrameGroupBy[['b', 'b']]
request.applymarker(pytest.mark.xfail)
Expand Down

0 comments on commit 67b1ae4

Please sign in to comment.