Skip to content

Commit

Permalink
Merge branch 'main' into perf/codspeed-ci
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi committed Aug 19, 2024
2 parents eeb7347 + 676731c commit 38007fd
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 42 deletions.
47 changes: 46 additions & 1 deletion narwhals/_dask/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,9 @@ def fill_null(self, value: Any) -> DaskExpr:
)

def clip(
self: Self, lower_bound: Any | None = None, upper_bound: Any | None = None
self: Self,
lower_bound: Any | None = None,
upper_bound: Any | None = None,
) -> Self:
return self._from_call(
lambda _input, _lower, _upper: _input.clip(lower=_lower, upper=_upper),
Expand Down Expand Up @@ -798,6 +800,49 @@ def ordinal_day(self) -> DaskExpr:
returns_scalar=False,
)

def to_string(self, format: str) -> DaskExpr: # noqa: A002
return self._expr._from_call(
lambda _input, _format: _input.dt.strftime(_format),
"strftime",
format.replace("%.f", ".%f"),
returns_scalar=False,
)

def total_minutes(self) -> DaskExpr:
return self._expr._from_call(
lambda _input: _input.dt.total_seconds() // 60,
"total_minutes",
returns_scalar=False,
)

def total_seconds(self) -> DaskExpr:
return self._expr._from_call(
lambda _input: _input.dt.total_seconds() // 1,
"total_seconds",
returns_scalar=False,
)

def total_milliseconds(self) -> DaskExpr:
return self._expr._from_call(
lambda _input: _input.dt.total_seconds() * 1000 // 1,
"total_milliseconds",
returns_scalar=False,
)

def total_microseconds(self) -> DaskExpr:
return self._expr._from_call(
lambda _input: _input.dt.total_seconds() * 1_000_000 // 1,
"total_microseconds",
returns_scalar=False,
)

def total_nanoseconds(self) -> DaskExpr:
return self._expr._from_call(
lambda _input: _input.dt.total_seconds() * 1_000_000_000 // 1,
"total_nanoseconds",
returns_scalar=False,
)


class DaskExprNameNamespace:
def __init__(self: Self, expr: DaskExpr) -> None:
Expand Down
4 changes: 1 addition & 3 deletions tests/expr_and_series/dt/datetime_duration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ def test_duration_attributes(
expected_b: list[int],
expected_c: list[int],
) -> None:
if "dask_lazy" in str(constructor) or (
parse_version(pd.__version__) < (2, 2) and "pandas_pyarrow" in str(constructor)
):
if parse_version(pd.__version__) < (2, 2) and "pandas_pyarrow" in str(constructor):
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor(data))
Expand Down
140 changes: 102 additions & 38 deletions tests/expr_and_series/dt/to_string_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

import narwhals.stable.v1 as nw
from tests.utils import compare_dicts
from tests.utils import is_windows

data = {
Expand All @@ -17,31 +18,71 @@


@pytest.mark.parametrize(
"fmt", ["%Y-%m-%d", "%Y-%m-%d %H:%M:%S", "%Y/%m/%d %H:%M:%S", "%G-W%V-%u", "%G-W%V"]
"fmt",
[
"%Y-%m-%d",
"%Y-%m-%d %H:%M:%S",
"%Y/%m/%d %H:%M:%S",
"%G-W%V-%u",
"%G-W%V",
],
)
@pytest.mark.skipif(is_windows(), reason="pyarrow breaking on windows")
def test_dt_to_string(constructor_eager: Any, fmt: str) -> None:
def test_dt_to_string_series(constructor_eager: Any, fmt: str) -> None:
input_frame = nw.from_native(constructor_eager(data), eager_only=True)
input_series = input_frame["a"]

expected_col = [datetime.strftime(d, fmt) for d in data["a"]]

result = input_series.dt.to_string(fmt).to_list()
result = {"a": input_series.dt.to_string(fmt)}

if any(
x in str(constructor_eager) for x in ["pandas_pyarrow", "pyarrow_table", "modin"]
):
# PyArrow differs from other libraries, in that %S also shows
# the fraction of a second.
result = [x[: x.find(".")] if "." in x else x for x in result]
assert result == expected_col
result = input_frame.select(nw.col("a").dt.to_string(fmt))["a"].to_list()
if any(
x in str(constructor_eager) for x in ["pandas_pyarrow", "pyarrow_table", "modin"]
):
result = {"a": input_series.dt.to_string(fmt).str.replace(r"\.\d+$", "")}

compare_dicts(result, {"a": expected_col})


@pytest.mark.parametrize(
"fmt",
[
"%Y-%m-%d",
"%Y-%m-%d %H:%M:%S",
"%Y/%m/%d %H:%M:%S",
"%G-W%V-%u",
"%G-W%V",
],
)
@pytest.mark.skipif(is_windows(), reason="pyarrow breaking on windows")
def test_dt_to_string_expr(constructor: Any, fmt: str) -> None:
input_frame = nw.from_native(constructor(data))

expected_col = [datetime.strftime(d, fmt) for d in data["a"]]

result = input_frame.select(nw.col("a").dt.to_string(fmt).alias("b"))
if any(x in str(constructor) for x in ["pandas_pyarrow", "pyarrow_table", "modin"]):
# PyArrow differs from other libraries, in that %S also shows
# the fraction of a second.
result = [x[: x.find(".")] if "." in x else x for x in result]
assert result == expected_col
result = input_frame.select(
nw.col("a").dt.to_string(fmt).str.replace(r"\.\d+$", "").alias("b")
)
compare_dicts(result, {"b": expected_col})


def _clean_string(result: str) -> str:
# rstrip '0' to remove trailing zeros, as different libraries handle this differently
# if there's then a trailing `.`, remove that too.
if "." in result:
result = result.rstrip("0").rstrip(".")
return result


def _clean_string_expr(e: Any) -> Any:
# Same as `_clean_string` but for Expr
return e.str.replace_all(r"0+$", "").str.replace_all(r"\.$", "")


@pytest.mark.parametrize(
Expand All @@ -50,20 +91,16 @@ def test_dt_to_string(constructor_eager: Any, fmt: str) -> None:
(datetime(2020, 1, 9), "2020-01-09T00:00:00.000000"),
(datetime(2020, 1, 9, 12, 34, 56), "2020-01-09T12:34:56.000000"),
(datetime(2020, 1, 9, 12, 34, 56, 123), "2020-01-09T12:34:56.000123"),
(datetime(2020, 1, 9, 12, 34, 56, 123456), "2020-01-09T12:34:56.123456"),
(
datetime(2020, 1, 9, 12, 34, 56, 123456),
"2020-01-09T12:34:56.123456",
),
],
)
@pytest.mark.skipif(is_windows(), reason="pyarrow breaking on windows")
def test_dt_to_string_iso_local_datetime(
def test_dt_to_string_iso_local_datetime_series(
constructor_eager: Any, data: datetime, expected: str
) -> None:
def _clean_string(result: str) -> str:
# rstrip '0' to remove trailing zeros, as different libraries handle this differently
# if there's then a trailing `.`, remove that too.
if "." in result:
result = result.rstrip("0").rstrip(".")
return result

df = constructor_eager({"a": [data]})
result = (
nw.from_native(df, eager_only=True)["a"]
Expand All @@ -72,34 +109,50 @@ def _clean_string(result: str) -> str:
)
assert _clean_string(result) == _clean_string(expected)

result = (
nw.from_native(df, eager_only=True)
.select(nw.col("a").dt.to_string("%Y-%m-%dT%H:%M:%S.%f"))["a"]
.to_list()[0]
)
assert _clean_string(result) == _clean_string(expected)

result = (
nw.from_native(df, eager_only=True)["a"]
.dt.to_string("%Y-%m-%dT%H:%M:%S%.f")
.to_list()[0]
)
assert _clean_string(result) == _clean_string(expected)

result = (
nw.from_native(df, eager_only=True)
.select(nw.col("a").dt.to_string("%Y-%m-%dT%H:%M:%S%.f"))["a"]
.to_list()[0]

@pytest.mark.parametrize(
("data", "expected"),
[
(datetime(2020, 1, 9, 12, 34, 56), "2020-01-09T12:34:56.000000"),
(datetime(2020, 1, 9, 12, 34, 56, 123), "2020-01-09T12:34:56.000123"),
(
datetime(2020, 1, 9, 12, 34, 56, 123456),
"2020-01-09T12:34:56.123456",
),
],
)
@pytest.mark.skipif(is_windows(), reason="pyarrow breaking on windows")
def test_dt_to_string_iso_local_datetime_expr(
request: Any, constructor: Any, data: datetime, expected: str
) -> None:
if "modin" in str(constructor):
request.applymarker(pytest.mark.xfail)
df = constructor({"a": [data]})

result = nw.from_native(df).with_columns(
_clean_string_expr(nw.col("a").dt.to_string("%Y-%m-%dT%H:%M:%S.%f")).alias("b")
)
assert _clean_string(result) == _clean_string(expected)
compare_dicts(result, {"a": [data], "b": [_clean_string(expected)]})

result = nw.from_native(df).with_columns(
_clean_string_expr(nw.col("a").dt.to_string("%Y-%m-%dT%H:%M:%S%.f")).alias("b")
)
compare_dicts(result, {"a": [data], "b": [_clean_string(expected)]})


@pytest.mark.parametrize(
("data", "expected"),
[(datetime(2020, 1, 9), "2020-01-09")],
)
@pytest.mark.skipif(is_windows(), reason="pyarrow breaking on windows")
def test_dt_to_string_iso_local_date(
def test_dt_to_string_iso_local_date_series(
constructor_eager: Any, data: datetime, expected: str
) -> None:
df = constructor_eager({"a": [data]})
Expand All @@ -108,9 +161,20 @@ def test_dt_to_string_iso_local_date(
)
assert result == expected

result = (
nw.from_native(df, eager_only=True)
.select(b=nw.col("a").dt.to_string("%Y-%m-%d"))["b"]
.to_list()[0]

@pytest.mark.parametrize(
("data", "expected"),
[(datetime(2020, 1, 9), "2020-01-09")],
)
@pytest.mark.skipif(is_windows(), reason="pyarrow breaking on windows")
def test_dt_to_string_iso_local_date_expr(
request: Any, constructor: Any, data: datetime, expected: str
) -> None:
if "modin" in str(constructor):
request.applymarker(pytest.mark.xfail)

df = constructor({"a": [data]})
result = nw.from_native(df).with_columns(
nw.col("a").dt.to_string("%Y-%m-%d").alias("b")
)
assert result == expected
compare_dicts(result, {"a": [data], "b": [expected]})

0 comments on commit 38007fd

Please sign in to comment.