Skip to content

Commit 706b7c8

Browse files
aidoskanapyanovFBruzzesipre-commit-ci[bot]
authored
feat: dask expr cast (#821)
* feat: dask expr cast * replace temporary hack with `.cast` * test cast raises for unknown dtype * remove unused filterwarnings * use walrus for simplicity Co-authored-by: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove duplicated test * add missing function parameter types * skip coverage for some paths * skip coevrage for unknown dtype exception Co-authored-by: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com> * remove redundant modin xfail * empty commit to trigger CI --------- Co-authored-by: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: FBruzzesi <francesco.bruzzesi.93@gmail.com>
1 parent 1b7bd73 commit 706b7c8

File tree

6 files changed

+95
-29
lines changed

6 files changed

+95
-29
lines changed

narwhals/_dask/expr.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from narwhals._dask.utils import add_row_index
1111
from narwhals._dask.utils import maybe_evaluate
12+
from narwhals._dask.utils import reverse_translate_dtype
1213
from narwhals.dependencies import get_dask
1314
from narwhals.utils import generate_unique_token
1415

@@ -17,6 +18,7 @@
1718

1819
from narwhals._dask.dataframe import DaskLazyFrame
1920
from narwhals._dask.namespace import DaskNamespace
21+
from narwhals.dtypes import DType
2022

2123

2224
class DaskExpr:
@@ -654,6 +656,21 @@ def dt(self: Self) -> DaskExprDateTimeNamespace:
654656
def name(self: Self) -> DaskExprNameNamespace:
655657
return DaskExprNameNamespace(self)
656658

659+
def cast(
660+
self: Self,
661+
dtype: DType | type[DType],
662+
) -> Self:
663+
def func(_input: Any, dtype: DType | type[DType]) -> Any:
664+
dtype = reverse_translate_dtype(dtype)
665+
return _input.astype(dtype)
666+
667+
return self._from_call(
668+
func,
669+
"cast",
670+
dtype,
671+
returns_scalar=False,
672+
)
673+
657674

658675
class DaskExprStringNamespace:
659676
def __init__(self, expr: DaskExpr) -> None:

narwhals/_dask/utils.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,14 @@
44
from typing import Any
55

66
from narwhals.dependencies import get_dask_expr
7+
from narwhals.dependencies import get_pandas
8+
from narwhals.dependencies import get_pyarrow
9+
from narwhals.utils import isinstance_or_issubclass
10+
from narwhals.utils import parse_version
711

812
if TYPE_CHECKING:
913
from narwhals._dask.dataframe import DaskLazyFrame
14+
from narwhals.dtypes import DType
1015

1116

1217
def maybe_evaluate(df: DaskLazyFrame, obj: Any) -> Any:
@@ -73,3 +78,47 @@ def parse_exprs_and_named_exprs(
7378
def add_row_index(frame: Any, name: str) -> Any:
7479
frame = frame.assign(**{name: 1})
7580
return frame.assign(**{name: frame[name].cumsum(method="blelloch") - 1})
81+
82+
83+
def reverse_translate_dtype(dtype: DType | type[DType]) -> Any:
84+
from narwhals import dtypes
85+
86+
if isinstance_or_issubclass(dtype, dtypes.Float64):
87+
return "float64"
88+
if isinstance_or_issubclass(dtype, dtypes.Float32):
89+
return "float32"
90+
if isinstance_or_issubclass(dtype, dtypes.Int64):
91+
return "int64"
92+
if isinstance_or_issubclass(dtype, dtypes.Int32):
93+
return "int32"
94+
if isinstance_or_issubclass(dtype, dtypes.Int16):
95+
return "int16"
96+
if isinstance_or_issubclass(dtype, dtypes.Int8):
97+
return "int8"
98+
if isinstance_or_issubclass(dtype, dtypes.UInt64):
99+
return "uint64"
100+
if isinstance_or_issubclass(dtype, dtypes.UInt32):
101+
return "uint32"
102+
if isinstance_or_issubclass(dtype, dtypes.UInt16):
103+
return "uint16"
104+
if isinstance_or_issubclass(dtype, dtypes.UInt8):
105+
return "uint8"
106+
if isinstance_or_issubclass(dtype, dtypes.String):
107+
if (pd := get_pandas()) is not None and parse_version(
108+
pd.__version__
109+
) >= parse_version("2.0.0"):
110+
if get_pyarrow() is not None:
111+
return "string[pyarrow]"
112+
return "string[python]" # pragma: no cover
113+
return "object" # pragma: no cover
114+
if isinstance_or_issubclass(dtype, dtypes.Boolean):
115+
return "bool"
116+
if isinstance_or_issubclass(dtype, dtypes.Categorical):
117+
return "category"
118+
if isinstance_or_issubclass(dtype, dtypes.Datetime):
119+
return "datetime64[us]"
120+
if isinstance_or_issubclass(dtype, dtypes.Duration):
121+
return "timedelta64[ns]"
122+
123+
msg = f"Unknown dtype: {dtype}" # pragma: no cover
124+
raise AssertionError(msg)

tests/expr_and_series/binary_test.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,11 @@
11
from typing import Any
22

3-
import pytest
4-
53
import narwhals.stable.v1 as nw
64
from tests.utils import compare_dicts
75

86

9-
def test_expr_binary(constructor: Any, request: Any) -> None:
7+
def test_expr_binary(constructor: Any) -> None:
108
data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}
11-
if "dask" in str(constructor):
12-
request.applymarker(pytest.mark.xfail)
139
df_raw = constructor(data)
1410
result = nw.from_native(df_raw).with_columns(
1511
a=(1 + 3 * nw.col("a")) * (1 / nw.col("a")),

tests/expr_and_series/cast_test.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,6 @@
4747

4848
@pytest.mark.filterwarnings("ignore:casting period[M] values to int64:FutureWarning")
4949
def test_cast(constructor: Any, request: Any) -> None:
50-
if "dask" in str(constructor):
51-
request.applymarker(pytest.mark.xfail)
5250
if "pyarrow_table_constructor" in str(constructor) and parse_version(
5351
pa.__version__
5452
) <= (15,): # pragma: no cover
@@ -98,17 +96,21 @@ def test_cast(constructor: Any, request: Any) -> None:
9896
assert dict(result.collect_schema()) == expected
9997

10098

101-
def test_cast_series(constructor_eager: Any, request: Any) -> None:
102-
if "pyarrow_table_constructor" in str(constructor_eager) and parse_version(
99+
def test_cast_series(constructor: Any, request: Any) -> None:
100+
if "pyarrow_table_constructor" in str(constructor) and parse_version(
103101
pa.__version__
104102
) <= (15,): # pragma: no cover
105103
request.applymarker(pytest.mark.xfail)
106-
if "modin" in str(constructor_eager):
104+
if "modin" in str(constructor):
107105
# TODO(unassigned): in modin, we end up with `'<U0'` dtype
108106
request.applymarker(pytest.mark.xfail)
109-
df = nw.from_native(constructor_eager(data), eager_only=True).select(
110-
nw.col(key).cast(value) for key, value in schema.items()
107+
df = (
108+
nw.from_native(constructor(data))
109+
.select(nw.col(key).cast(value) for key, value in schema.items())
110+
.lazy()
111+
.collect()
111112
)
113+
112114
expected = {
113115
"a": nw.Int32,
114116
"b": nw.Int16,
@@ -158,3 +160,22 @@ def test_cast_string() -> None:
158160
s = s.cast(nw.String)
159161
result = nw.to_native(s)
160162
assert str(result.dtype) in ("string", "object", "dtype('O')")
163+
164+
165+
def test_cast_raises_for_unknown_dtype(constructor: Any, request: Any) -> None:
166+
if "pyarrow_table_constructor" in str(constructor) and parse_version(
167+
pa.__version__
168+
) <= (15,): # pragma: no cover
169+
request.applymarker(pytest.mark.xfail)
170+
if "polars" in str(constructor):
171+
request.applymarker(pytest.mark.xfail)
172+
173+
df = nw.from_native(constructor(data)).select(
174+
nw.col(key).cast(value) for key, value in schema.items()
175+
)
176+
177+
class Banana:
178+
pass
179+
180+
with pytest.raises(AssertionError, match=r"Unknown dtype"):
181+
df.select(nw.col("a").cast(Banana))

tests/selectors_test.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import pytest
88

99
import narwhals.stable.v1 as nw
10-
from narwhals.dependencies import get_dask_dataframe
1110
from narwhals.selectors import all
1211
from narwhals.selectors import boolean
1312
from narwhals.selectors import by_dtype
@@ -57,8 +56,6 @@ def test_string(constructor: Any, request: Any) -> None:
5756

5857

5958
def test_categorical(request: Any, constructor: Any) -> None:
60-
if "dask" in str(constructor):
61-
request.applymarker(pytest.mark.xfail)
6259
if "pyarrow_table_constructor" in str(constructor) and parse_version(
6360
pa.__version__
6461
) <= (15,): # pragma: no cover
@@ -70,17 +67,6 @@ def test_categorical(request: Any, constructor: Any) -> None:
7067
compare_dicts(result, expected)
7168

7269

73-
@pytest.mark.skipif((get_dask_dataframe() is None), reason="too old for dask")
74-
def test_dask_categorical() -> None:
75-
import dask.dataframe as dd
76-
77-
expected = {"b": ["a", "b", "c"]}
78-
df_raw = dd.from_dict(expected, npartitions=1).astype({"b": "category"})
79-
df = nw.from_native(df_raw)
80-
result = df.select(categorical())
81-
compare_dicts(result, expected)
82-
83-
8470
@pytest.mark.parametrize(
8571
("selector", "expected"),
8672
[

tests/test_group_by.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,6 @@ def test_group_by_multiple_keys(constructor: Any) -> None:
173173

174174

175175
def test_key_with_nulls(constructor: Any, request: Any) -> None:
176-
if "dask" in str(constructor):
177-
request.applymarker(pytest.mark.xfail)
178-
179176
if "modin" in str(constructor):
180177
# TODO(unassigned): Modin flaky here?
181178
request.applymarker(pytest.mark.skip)

0 commit comments

Comments
 (0)