Skip to content

Commit bb01ca6

Browse files
committed
feat: allow any ddof value for duckdb var and std
1 parent 0f38b77 commit bb01ca6

File tree

4 files changed

+36
-51
lines changed

4 files changed

+36
-51
lines changed

narwhals/_duckdb/expr.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -460,27 +460,32 @@ def len(self: Self) -> Self:
460460
)
461461

462462
def std(self: Self, ddof: int) -> Self:
463-
if ddof == 1:
464-
func = "stddev_samp"
465-
elif ddof == 0:
466-
func = "stddev_pop"
467-
else:
468-
msg = f"std with ddof {ddof} is not supported in DuckDB"
469-
raise NotImplementedError(msg)
463+
def _std(_input: duckdb.Expression, ddof: int) -> duckdb.Expression:
464+
n_samples = FunctionExpression("count", _input)
465+
466+
return (
467+
FunctionExpression("stddev_pop", _input)
468+
* FunctionExpression("sqrt", n_samples)
469+
/ (FunctionExpression("sqrt", (n_samples - ddof)))
470+
)
471+
470472
return self._from_call(
471-
lambda _input: FunctionExpression(func, _input), "std", returns_scalar=True
473+
_std,
474+
"std",
475+
ddof=ddof,
476+
returns_scalar=True,
472477
)
473478

474479
def var(self: Self, ddof: int) -> Self:
475-
if ddof == 1:
476-
func = "var_samp"
477-
elif ddof == 0:
478-
func = "var_pop"
479-
else:
480-
msg = f"var with ddof {ddof} is not supported in DuckDB"
481-
raise NotImplementedError(msg)
480+
def _var(_input: duckdb.Expression, ddof: int) -> duckdb.Expression:
481+
n_samples = FunctionExpression("count", _input)
482+
return FunctionExpression("var_pop", _input) * n_samples / (n_samples - ddof)
483+
482484
return self._from_call(
483-
lambda _input: FunctionExpression(func, _input), "var", returns_scalar=True
485+
_var,
486+
"var",
487+
ddof=ddof,
488+
returns_scalar=True,
484489
)
485490

486491
def max(self: Self) -> Self:

tests/expr_and_series/std_test.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from __future__ import annotations
22

3-
from contextlib import nullcontext as does_not_raise
4-
53
import pytest
64

75
import narwhals.stable.v1 as nw
@@ -34,19 +32,14 @@ def test_std(constructor: Constructor, input_data: dict[str, list[float | None]]
3432
"z_ddof_0": [0.816497],
3533
}
3634
assert_equal_data(result, expected_results)
37-
context = (
38-
pytest.raises(NotImplementedError)
39-
if "duckdb" in str(constructor)
40-
else does_not_raise()
35+
36+
result = df.select(
37+
nw.col("b").std(ddof=2).alias("b_ddof_2"),
4138
)
42-
with context:
43-
result = df.select(
44-
nw.col("b").std(ddof=2).alias("b_ddof_2"),
45-
)
46-
expected_results = {
47-
"b_ddof_2": [1.632993],
48-
}
49-
assert_equal_data(result, expected_results)
39+
expected_results = {
40+
"b_ddof_2": [1.632993],
41+
}
42+
assert_equal_data(result, expected_results)
5043

5144

5245
@pytest.mark.parametrize("input_data", [data, data_with_nulls])

tests/expr_and_series/var_test.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from __future__ import annotations
22

3-
from contextlib import nullcontext as does_not_raise
4-
53
import pytest
64

75
import narwhals.stable.v1 as nw
@@ -34,19 +32,14 @@ def test_var(constructor: Constructor, input_data: dict[str, list[float | None]]
3432
"z_ddof_0": [0.6666666666666666],
3533
}
3634
assert_equal_data(result, expected_results)
37-
context = (
38-
pytest.raises(NotImplementedError)
39-
if "duckdb" in str(constructor)
40-
else does_not_raise()
35+
36+
result = df.select(
37+
nw.col("b").var(ddof=2).alias("b_ddof_2"),
4138
)
42-
with context:
43-
result = df.select(
44-
nw.col("b").var(ddof=2).alias("b_ddof_2"),
45-
)
46-
expected_results = {
47-
"b_ddof_2": [2.666666666666667],
48-
}
49-
assert_equal_data(result, expected_results)
39+
expected_results = {
40+
"b_ddof_2": [2.666666666666667],
41+
}
42+
assert_equal_data(result, expected_results)
5043

5144

5245
@pytest.mark.parametrize("input_data", [data, data_with_nulls])

tests/group_by_test.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -137,11 +137,7 @@ def test_group_by_depth_1_agg(
137137
("var", 2),
138138
],
139139
)
140-
def test_group_by_depth_1_std_var(
141-
constructor: Constructor, attr: str, ddof: int, request: pytest.FixtureRequest
142-
) -> None:
143-
if "duckdb" in str(constructor) and ddof == 2:
144-
request.applymarker(pytest.mark.xfail)
140+
def test_group_by_depth_1_std_var(constructor: Constructor, attr: str, ddof: int) -> None:
145141
data = {"a": [1, 1, 1, 2, 2, 2], "b": [4, 5, 6, 0, 5, 5]}
146142
_pow = 0.5 if attr == "std" else 1
147143
expected = {
@@ -398,8 +394,6 @@ def test_all_kind_of_aggs(
398394
# and modin lol https://github.com/modin-project/modin/issues/7414
399395
# and cudf https://github.com/rapidsai/cudf/issues/17649
400396
request.applymarker(pytest.mark.xfail)
401-
if "duckdb" in str(constructor):
402-
request.applymarker(pytest.mark.xfail)
403397
if "pandas" in str(constructor) and PANDAS_VERSION < (1, 4):
404398
# Bug in old pandas, can't do DataFrameGroupBy[['b', 'b']]
405399
request.applymarker(pytest.mark.xfail)

0 commit comments

Comments
 (0)