Skip to content

Commit

Permalink
feat: add LazyFrame.unpivot for spark and duckdb (#1890)
Browse files Browse the repository at this point in the history
* feat: add LazyFrame.unpivot for spark and duckdb

* keep 'on' parsing in complaint

* parse names only
  • Loading branch information
FBruzzesi authored Feb 2, 2025
1 parent 8ca9422 commit d48b8a3
Show file tree
Hide file tree
Showing 9 changed files with 94 additions and 44 deletions.
9 changes: 3 additions & 6 deletions narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,12 +766,11 @@ def unpivot(
self: Self,
on: str | list[str] | None,
index: str | list[str] | None,
variable_name: str | None,
value_name: str | None,
variable_name: str,
value_name: str,
) -> Self:
native_frame = self._native_frame
variable_name = variable_name if variable_name is not None else "variable"
value_name = value_name if value_name is not None else "value"
n_rows = len(self)

index_: list[str] = (
[] if index is None else [index] if isinstance(index, str) else index
Expand All @@ -784,8 +783,6 @@ def unpivot(
else on
)

n_rows = len(self)

promote_kwargs = (
{"promote_options": "permissive"}
if self._backend_version >= (14, 0, 0)
Expand Down
8 changes: 4 additions & 4 deletions narwhals/_dask/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,14 +419,14 @@ def unpivot(
self: Self,
on: str | list[str] | None,
index: str | list[str] | None,
variable_name: str | None,
value_name: str | None,
variable_name: str,
value_name: str,
) -> Self:
return self._from_native_frame(
self._native_frame.melt(
id_vars=index,
value_vars=on,
var_name=variable_name if variable_name is not None else "variable",
value_name=value_name if value_name is not None else "value",
var_name=variable_name,
value_name=value_name,
)
)
49 changes: 45 additions & 4 deletions narwhals/_duckdb/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,8 +373,6 @@ def collect_schema(self: Self) -> dict[str, DType]:

def unique(self: Self, subset: Sequence[str] | None, keep: str) -> Self:
if subset is not None:
import duckdb

rel = self._native_frame
# Sanitise input
if any(x not in rel.columns for x in subset):
Expand Down Expand Up @@ -423,10 +421,53 @@ def sort(
return self._from_native_frame(result)

def drop_nulls(self: Self, subset: list[str] | None) -> Self:
import duckdb

rel = self._native_frame
subset_ = subset if subset is not None else rel.columns
keep_condition = " and ".join(f'"{col}" is not null' for col in subset_)
query = f"select * from rel where {keep_condition}" # noqa: S608
return self._from_native_frame(duckdb.sql(query))

def unpivot(
self: Self,
on: str | list[str] | None,
index: str | list[str] | None,
variable_name: str,
value_name: str,
) -> Self:
index_: list[str] = (
[] if index is None else [index] if isinstance(index, str) else index
)
on_: list[str] = (
[c for c in self.columns if c not in index_]
if on is None
else [on]
if isinstance(on, str)
else on
)

if variable_name == "":
msg = "`variable_name` cannot be empty string for duckdb backend."
raise NotImplementedError(msg)

if value_name == "":
msg = "`value_name` cannot be empty string for duckdb backend."
raise NotImplementedError(msg)

cols_to_select = ", ".join(
f'"{col}"' for col in [*index_, variable_name, value_name]
)
unpivot_on = ", ".join(f'"{col}"' for col in on_)

rel = self._native_frame # noqa: F841
query = f"""
with unpivot_cte as (
unpivot rel
on {unpivot_on}
into
name {variable_name}
value {value_name}
)
select {cols_to_select}
from unpivot_cte;
""" # noqa: S608
return self._from_native_frame(duckdb.sql(query))
8 changes: 4 additions & 4 deletions narwhals/_pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1059,15 +1059,15 @@ def unpivot(
self: Self,
on: str | list[str] | None,
index: str | list[str] | None,
variable_name: str | None,
value_name: str | None,
variable_name: str,
value_name: str,
) -> Self:
return self._from_native_frame(
self._native_frame.melt(
id_vars=index,
value_vars=on,
var_name=variable_name if variable_name is not None else "variable",
value_name=value_name if value_name is not None else "value",
var_name=variable_name,
value_name=value_name,
)
)

Expand Down
8 changes: 4 additions & 4 deletions narwhals/_polars/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,8 +312,8 @@ def unpivot(
self: Self,
on: str | list[str] | None,
index: str | list[str] | None,
variable_name: str | None,
value_name: str | None,
variable_name: str,
value_name: str,
) -> Self:
if self._backend_version < (1, 0, 0):
return self._from_native_frame(
Expand Down Expand Up @@ -508,8 +508,8 @@ def unpivot(
self: Self,
on: str | list[str] | None,
index: str | list[str] | None,
variable_name: str | None,
value_name: str | None,
variable_name: str,
value_name: str,
) -> Self:
if self._backend_version < (1, 0, 0):
return self._from_native_frame(
Expand Down
16 changes: 16 additions & 0 deletions narwhals/_spark_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,3 +409,19 @@ def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Sel
]
)
)

def unpivot(
self: Self,
on: str | list[str] | None,
index: str | list[str] | None,
variable_name: str,
value_name: str,
) -> Self:
return self._from_native_frame(
self._native_frame.unpivot(
ids=index,
values=on,
variableColumnName=variable_name,
valueColumnName=value_name,
)
)
3 changes: 3 additions & 0 deletions narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,9 @@ def unpivot(
variable_name: str | None,
value_name: str | None,
) -> Self:
variable_name = variable_name if variable_name is not None else "variable"
value_name = value_name if value_name is not None else "value"

return self._from_compliant_dataframe(
self._compliant_frame.unpivot(
on=on,
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def pytest_generate_tests(metafunc: pytest.Metafunc) -> None:
)
elif "constructor" in metafunc.fixturenames:
if (
any(x in str(metafunc.module) for x in ("unpivot", "from_dict", "from_numpy"))
any(x in str(metafunc.module) for x in ("from_dict", "from_numpy"))
and LAZY_CONSTRUCTORS["duckdb"] in constructors
):
constructors.remove(LAZY_CONSTRUCTORS["duckdb"])
Expand Down
35 changes: 14 additions & 21 deletions tests/frame/unpivot_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from contextlib import nullcontext as does_not_raise
from typing import TYPE_CHECKING
from typing import Any

Expand Down Expand Up @@ -37,14 +38,10 @@
[("b", expected_b_only), (["b", "c"], expected_b_c), (None, expected_b_c)],
)
def test_unpivot_on(
request: pytest.FixtureRequest,
constructor: Constructor,
on: str | list[str] | None,
expected: dict[str, list[float]],
) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor(data))
result = df.unpivot(on=on, index=["a"]).sort("variable", "a")
assert_equal_data(result, expected)
Expand All @@ -59,28 +56,26 @@ def test_unpivot_on(
],
)
def test_unpivot_var_value_names(
request: pytest.FixtureRequest,
constructor: Constructor,
variable_name: str | None,
value_name: str | None,
) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor(data))
result = df.unpivot(
on=["b", "c"], index=["a"], variable_name=variable_name, value_name=value_name
context = (
pytest.raises(NotImplementedError)
if ("duckdb" in str(constructor) and any([variable_name == "", value_name == ""]))
else does_not_raise()
)

assert result.collect_schema().names()[-2:] == [variable_name, value_name]
with context:
df = nw.from_native(constructor(data))
result = df.unpivot(
on=["b", "c"], index=["a"], variable_name=variable_name, value_name=value_name
)

assert result.collect_schema().names()[-2:] == [variable_name, value_name]

def test_unpivot_default_var_value_names(
request: pytest.FixtureRequest, constructor: Constructor
) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)

def test_unpivot_default_var_value_names(constructor: Constructor) -> None:
df = nw.from_native(constructor(data))
result = df.unpivot(on=["b", "c"], index=["a"])

Expand All @@ -102,10 +97,8 @@ def test_unpivot_mixed_types(
data: dict[str, Any],
expected_dtypes: list[DType],
) -> None:
if (
"cudf" in str(constructor)
or "pyspark" in str(constructor)
or ("pyarrow_table" in str(constructor) and PYARROW_VERSION < (14, 0, 0))
if "cudf" in str(constructor) or (
"pyarrow_table" in str(constructor) and PYARROW_VERSION < (14, 0, 0)
):
request.applymarker(pytest.mark.xfail)

Expand Down

0 comments on commit d48b8a3

Please sign in to comment.