diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 6e3d73c69..176a79259 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -383,11 +383,10 @@ def join_asof( *, left_on: str | None, right_on: str | None, - on: str | None, - by_left: str | list[str] | None, - by_right: str | list[str] | None, - by: str | list[str] | None, + by_left: list[str] | None, + by_right: list[str] | None, strategy: Literal["backward", "forward", "nearest"], + suffix: str, ) -> Self: msg = "join_asof is not yet supported on PyArrow tables" # pragma: no cover raise NotImplementedError(msg) diff --git a/narwhals/_dask/dataframe.py b/narwhals/_dask/dataframe.py index 165014d8d..3213a0eaa 100644 --- a/narwhals/_dask/dataframe.py +++ b/narwhals/_dask/dataframe.py @@ -334,13 +334,12 @@ def join_asof( self: Self, other: Self, *, - left_on: str | None = None, - right_on: str | None = None, - on: str | None = None, - by_left: str | list[str] | None = None, - by_right: str | list[str] | None = None, - by: str | list[str] | None = None, - strategy: Literal["backward", "forward", "nearest"] = "backward", + left_on: str | None, + right_on: str | None, + by_left: list[str] | None, + by_right: list[str] | None, + strategy: Literal["backward", "forward", "nearest"], + suffix: str, ) -> Self: plx = self.__native_namespace__() return self._from_native_frame( @@ -349,12 +348,10 @@ def join_asof( other._native_frame, left_on=left_on, right_on=right_on, - on=on, left_by=by_left, right_by=by_right, - by=by, direction=strategy, - suffixes=("", "_right"), + suffixes=("", suffix), ), ) diff --git a/narwhals/_duckdb/dataframe.py b/narwhals/_duckdb/dataframe.py index 00325b280..7ee11a12a 100644 --- a/narwhals/_duckdb/dataframe.py +++ b/narwhals/_duckdb/dataframe.py @@ -6,6 +6,7 @@ from typing import Literal from typing import Sequence +import duckdb from duckdb import ColumnExpression from narwhals._duckdb.utils import native_to_narwhals_dtype @@ -22,7 +23,6 @@ if TYPE_CHECKING: from types import ModuleType - import duckdb import pandas as pd import pyarrow as pa from typing_extensions import Self @@ -260,6 +260,51 @@ def join( res = rel.select(", ".join(select)).set_alias(original_alias) return self._from_native_frame(res) + def join_asof( + self: Self, + other: Self, + *, + left_on: str | None, + right_on: str | None, + by_left: list[str] | None, + by_right: list[str] | None, + strategy: Literal["backward", "forward", "nearest"], + suffix: str, + ) -> Self: + lhs = self._native_frame + rhs = other._native_frame + conditions = [] + if by_left is not None and by_right is not None: + conditions += [ + f'lhs."{left}" = rhs."{right}"' for left, right in zip(by_left, by_right) + ] + else: + by_left = by_right = [] + if strategy == "backward": + conditions += [f'lhs."{left_on}" >= rhs."{right_on}"'] + elif strategy == "forward": + conditions += [f'lhs."{left_on}" <= rhs."{right_on}"'] + else: + msg = "Only 'backward' and 'forward' strategies are currently supported for DuckDB" + raise NotImplementedError(msg) + condition = " and ".join(conditions) + select = ["lhs.*"] + for col in rhs.columns: + if col in lhs.columns and ( + right_on is None or col not in [right_on, *by_right] + ): + select.append(f'rhs."{col}" as "{col}{suffix}"') + elif right_on is None or col not in [right_on, *by_right]: + select.append(col) + query = f""" + SELECT {",".join(select)} + FROM lhs + ASOF LEFT JOIN rhs + ON {condition} + """ # noqa: S608 + res = duckdb.sql(query) + return self._from_native_frame(res) + def collect_schema(self: Self) -> dict[str, DType]: return { column_name: native_to_narwhals_dtype(str(duckdb_dtype), self._version) diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index 0aec90bd7..809b3de05 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -660,13 +660,12 @@ def join_asof( self: Self, other: Self, *, - left_on: str | None = None, - right_on: str | None = None, - on: str | None = None, - by_left: str | list[str] | None = None, - by_right: str | list[str] | None = None, - by: str | list[str] | None = None, - strategy: Literal["backward", "forward", "nearest"] = "backward", + left_on: str | None, + right_on: str | None, + by_left: list[str] | None, + by_right: list[str] | None, + strategy: Literal["backward", "forward", "nearest"], + suffix: str, ) -> Self: plx = self.__native_namespace__() return self._from_native_frame( @@ -675,12 +674,10 @@ def join_asof( other._native_frame, left_on=left_on, right_on=right_on, - on=on, left_by=by_left, right_by=by_right, - by=by, direction=strategy, - suffixes=("", "_right"), + suffixes=("", suffix), ), ) diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 786fb5c71..d02d589c2 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -277,6 +277,7 @@ def join_asof( by_right: str | list[str] | None = None, by: str | list[str] | None = None, strategy: Literal["backward", "forward", "nearest"] = "backward", + suffix: str = "_right", ) -> Self: _supported_strategies = ("backward", "forward", "nearest") @@ -302,16 +303,13 @@ def join_asof( msg = "If `by` is specified, `by_left` and `by_right` should be None." raise ValueError(msg) if on is not None: - return self._from_compliant_dataframe( - self._compliant_frame.join_asof( - self._extract_compliant(other), - on=on, - by_left=by_left, - by_right=by_right, - by=by, - strategy=strategy, - ) - ) + left_on = right_on = on + if by is not None: + by_left = by_right = by + if isinstance(by_left, str): + by_left = [by_left] + if isinstance(by_right, str): + by_right = [by_right] return self._from_compliant_dataframe( self._compliant_frame.join_asof( self._extract_compliant(other), @@ -319,8 +317,8 @@ def join_asof( right_on=right_on, by_left=by_left, by_right=by_right, - by=by, strategy=strategy, + suffix=suffix, ) ) @@ -2748,6 +2746,7 @@ def join_asof( by_right: str | list[str] | None = None, by: str | list[str] | None = None, strategy: Literal["backward", "forward", "nearest"] = "backward", + suffix: str = "_right", ) -> Self: """Perform an asof join. @@ -2764,6 +2763,7 @@ def join_asof( by_right: join on these columns before doing asof join. by: join on these columns before doing asof join. strategy: Join strategy. The default is "backward". + suffix: Suffix to append to columns with a duplicate name. * *backward*: selects the last row in the right DataFrame whose "on" key is less than or equal to the left's key. * *forward*: selects the first row in the right DataFrame whose "on" key is greater than or equal to the left's key. @@ -2924,6 +2924,7 @@ def join_asof( by_right=by_right, by=by, strategy=strategy, + suffix=suffix, ) # --- descriptive --- @@ -5030,6 +5031,7 @@ def join_asof( by_right: str | list[str] | None = None, by: str | list[str] | None = None, strategy: Literal["backward", "forward", "nearest"] = "backward", + suffix: str = "_right", ) -> Self: """Perform an asof join. @@ -5058,6 +5060,8 @@ def join_asof( * *forward*: selects the first row in the right DataFrame whose "on" key is greater than or equal to the left's key. * *nearest*: search selects the last row in the right DataFrame whose value is nearest to the left's key. + suffix: Suffix to append to columns with a duplicate name. + Returns: A new joined LazyFrame. @@ -5224,6 +5228,7 @@ def join_asof( by_right=by_right, by=by, strategy=strategy, + suffix=suffix, ) def clone(self: Self) -> Self: diff --git a/tests/frame/join_test.py b/tests/frame/join_test.py index d0e276606..88b5ab678 100644 --- a/tests/frame/join_test.py +++ b/tests/frame/join_test.py @@ -356,11 +356,44 @@ def test_join_keys_exceptions(constructor: Constructor, how: str) -> None: df.join(df, how=how, on="antananarivo", right_on="antananarivo") # type: ignore[arg-type] +@pytest.mark.parametrize( + ("strategy", "expected"), + [ + ( + "backward", + { + "antananarivo": [1, 5, 10], + "val": ["a", "b", "c"], + "val_right": [1, 3, 7], + }, + ), + ( + "forward", + { + "antananarivo": [1, 5, 10], + "val": ["a", "b", "c"], + "val_right": [1, 6, None], + }, + ), + ( + "nearest", + { + "antananarivo": [1, 5, 10], + "val": ["a", "b", "c"], + "val_right": [1, 6, 7], + }, + ), + ], +) def test_joinasof_numeric( constructor: Constructor, request: pytest.FixtureRequest, + strategy: Literal["backward", "forward", "nearest"], + expected: dict[str, list[Any]], ) -> None: - if any(x in str(constructor) for x in ("pyarrow_table", "cudf", "duckdb", "pyspark")): + if any(x in str(constructor) for x in ("pyarrow_table", "cudf", "pyspark")): + request.applymarker(pytest.mark.xfail) + if "duckdb" in str(constructor) and strategy == "nearest": request.applymarker(pytest.mark.xfail) if PANDAS_VERSION < (2, 1) and ( ("pandas_pyarrow" in str(constructor)) or ("pandas_nullable" in str(constructor)) @@ -372,54 +405,67 @@ def test_joinasof_numeric( df_right = nw.from_native( constructor({"antananarivo": [1, 2, 3, 6, 7], "val": [1, 2, 3, 6, 7]}) ).sort("antananarivo") - result_backward = df.join_asof( + result = df.join_asof( df_right, # type: ignore[arg-type] left_on="antananarivo", right_on="antananarivo", + strategy=strategy, ) - result_forward = df.join_asof( - df_right, # type: ignore[arg-type] - left_on="antananarivo", - right_on="antananarivo", - strategy="forward", - ) - result_nearest = df.join_asof( - df_right, # type: ignore[arg-type] - left_on="antananarivo", - right_on="antananarivo", - strategy="nearest", - ) - result_backward_on = df.join_asof(df_right, on="antananarivo") # type: ignore[arg-type] - result_forward_on = df.join_asof(df_right, on="antananarivo", strategy="forward") # type: ignore[arg-type] - result_nearest_on = df.join_asof(df_right, on="antananarivo", strategy="nearest") # type: ignore[arg-type] - expected_backward = { - "antananarivo": [1, 5, 10], - "val": ["a", "b", "c"], - "val_right": [1, 3, 7], - } - expected_forward = { - "antananarivo": [1, 5, 10], - "val": ["a", "b", "c"], - "val_right": [1, 6, None], - } - expected_nearest = { - "antananarivo": [1, 5, 10], - "val": ["a", "b", "c"], - "val_right": [1, 6, 7], - } - assert_equal_data(result_backward, expected_backward) - assert_equal_data(result_forward, expected_forward) - assert_equal_data(result_nearest, expected_nearest) - assert_equal_data(result_backward_on, expected_backward) - assert_equal_data(result_forward_on, expected_forward) - assert_equal_data(result_nearest_on, expected_nearest) + result_on = df.join_asof(df_right, on="antananarivo", strategy=strategy) # type: ignore[arg-type] + assert_equal_data(result.sort(by="antananarivo"), expected) + assert_equal_data(result_on.sort(by="antananarivo"), expected) +@pytest.mark.parametrize( + ("strategy", "expected"), + [ + ( + "backward", + { + "datetime": [ + datetime(2016, 3, 1), + datetime(2018, 8, 1), + datetime(2019, 1, 1), + ], + "population": [82.19, 82.66, 83.12], + "gdp": [4164, 4566, 4696], + }, + ), + ( + "forward", + { + "datetime": [ + datetime(2016, 3, 1), + datetime(2018, 8, 1), + datetime(2019, 1, 1), + ], + "population": [82.19, 82.66, 83.12], + "gdp": [4411, 4696, 4696], + }, + ), + ( + "nearest", + { + "datetime": [ + datetime(2016, 3, 1), + datetime(2018, 8, 1), + datetime(2019, 1, 1), + ], + "population": [82.19, 82.66, 83.12], + "gdp": [4164, 4696, 4696], + }, + ), + ], +) def test_joinasof_time( constructor: Constructor, request: pytest.FixtureRequest, + strategy: Literal["backward", "forward", "nearest"], + expected: dict[str, list[Any]], ) -> None: - if any(x in str(constructor) for x in ("pyarrow_table", "cudf", "duckdb", "pyspark")): + if any(x in str(constructor) for x in ("pyarrow_table", "cudf", "pyspark")): + request.applymarker(pytest.mark.xfail) + if "duckdb" in str(constructor) and strategy == "nearest": request.applymarker(pytest.mark.xfail) if PANDAS_VERSION < (2, 1) and ("pandas_pyarrow" in str(constructor)): request.applymarker(pytest.mark.xfail) @@ -449,58 +495,22 @@ def test_joinasof_time( } ) ).sort("datetime") - result_backward = df.join_asof(df_right, left_on="datetime", right_on="datetime") # type: ignore[arg-type] - result_forward = df.join_asof( - df_right, # type: ignore[arg-type] - left_on="datetime", - right_on="datetime", - strategy="forward", - ) - result_nearest = df.join_asof( + result = df.join_asof( df_right, # type: ignore[arg-type] left_on="datetime", right_on="datetime", - strategy="nearest", - ) - result_backward_on = df.join_asof(df_right, on="datetime") # type: ignore[arg-type] - result_forward_on = df.join_asof( - df_right, # type: ignore[arg-type] - on="datetime", - strategy="forward", - ) - result_nearest_on = df.join_asof( - df_right, # type: ignore[arg-type] - on="datetime", - strategy="nearest", + strategy=strategy, ) - expected_backward = { - "datetime": [datetime(2016, 3, 1), datetime(2018, 8, 1), datetime(2019, 1, 1)], - "population": [82.19, 82.66, 83.12], - "gdp": [4164, 4566, 4696], - } - expected_forward = { - "datetime": [datetime(2016, 3, 1), datetime(2018, 8, 1), datetime(2019, 1, 1)], - "population": [82.19, 82.66, 83.12], - "gdp": [4411, 4696, 4696], - } - expected_nearest = { - "datetime": [datetime(2016, 3, 1), datetime(2018, 8, 1), datetime(2019, 1, 1)], - "population": [82.19, 82.66, 83.12], - "gdp": [4164, 4696, 4696], - } - assert_equal_data(result_backward, expected_backward) - assert_equal_data(result_forward, expected_forward) - assert_equal_data(result_nearest, expected_nearest) - assert_equal_data(result_backward_on, expected_backward) - assert_equal_data(result_forward_on, expected_forward) - assert_equal_data(result_nearest_on, expected_nearest) + result_on = df.join_asof(df_right, on="datetime", strategy=strategy) # type: ignore[arg-type] + assert_equal_data(result.sort(by="datetime"), expected) + assert_equal_data(result_on.sort(by="datetime"), expected) def test_joinasof_by( constructor: Constructor, request: pytest.FixtureRequest, ) -> None: - if any(x in str(constructor) for x in ("pyarrow_table", "cudf", "duckdb", "pyspark")): + if any(x in str(constructor) for x in ("pyarrow_table", "cudf", "pyspark")): request.applymarker(pytest.mark.xfail) if PANDAS_VERSION < (2, 1) and ( ("pandas_pyarrow" in str(constructor)) or ("pandas_nullable" in str(constructor)) @@ -528,8 +538,38 @@ def test_joinasof_by( "c": [9, 2, 1, 1], "d": [1, 3, None, 4], } - assert_equal_data(result, expected) - assert_equal_data(result_by, expected) + assert_equal_data(result.sort(by="antananarivo"), expected) + assert_equal_data(result_by.sort(by="antananarivo"), expected) + + +def test_joinasof_suffix( + constructor: Constructor, + request: pytest.FixtureRequest, +) -> None: + if any(x in str(constructor) for x in ("pyarrow_table", "cudf", "pyspark")): + request.applymarker(pytest.mark.xfail) + if PANDAS_VERSION < (2, 1) and ( + ("pandas_pyarrow" in str(constructor)) or ("pandas_nullable" in str(constructor)) + ): + request.applymarker(pytest.mark.xfail) + df = nw.from_native( + constructor({"antananarivo": [1, 5, 10], "val": ["a", "b", "c"]}) + ).sort("antananarivo") + df_right = nw.from_native( + constructor({"antananarivo": [1, 2, 3, 6, 7], "val": [1, 2, 3, 6, 7]}) + ).sort("antananarivo") + result = df.join_asof( + df_right, # type: ignore[arg-type] + left_on="antananarivo", + right_on="antananarivo", + suffix="_y", + ) + expected = { + "antananarivo": [1, 5, 10], + "val": ["a", "b", "c"], + "val_y": [1, 3, 7], + } + assert_equal_data(result.sort(by="antananarivo"), expected) @pytest.mark.parametrize("strategy", ["back", "furthest"])