From 939f6e1cd4fadc8d43cd9c5239d36c05cd135f85 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Wed, 8 Jan 2025 16:56:38 +0000 Subject: [PATCH 1/2] feat: semi join in duckdb --- narwhals/_duckdb/dataframe.py | 17 ++++++++++------- tests/frame/join_test.py | 3 --- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/narwhals/_duckdb/dataframe.py b/narwhals/_duckdb/dataframe.py index 76ff68ae0..50ea79bdf 100644 --- a/narwhals/_duckdb/dataframe.py +++ b/narwhals/_duckdb/dataframe.py @@ -225,7 +225,7 @@ def join( if isinstance(right_on, str): right_on = [right_on] - if how not in ("inner", "left"): + if how not in ("inner", "left", "semi"): msg = "Only inner and left join is implemented for DuckDB" raise NotImplementedError(msg) @@ -242,12 +242,15 @@ def join( other._native_frame.set_alias("rhs"), condition=condition, how=how ) - select = [f"lhs.{x}" for x in self._native_frame.columns] - for col in other._native_frame.columns: - if col in self._native_frame.columns and col not in right_on: - select.append(f"rhs.{col} as {col}{suffix}") - elif col not in right_on: - select.append(col) + if how in ("inner", "left"): + select = [f"lhs.{x}" for x in self._native_frame.columns] + for col in other._native_frame.columns: + if col in self._native_frame.columns and col not in right_on: + select.append(f"rhs.{col} as {col}{suffix}") + elif col not in right_on: + select.append(col) + elif how == "semi": + select = [f"lhs.{x}" for x in self._native_frame.columns] res = rel.select(", ".join(select)).set_alias(original_alias) return self._from_native_frame(res) diff --git a/tests/frame/join_test.py b/tests/frame/join_test.py index 7332cb254..242696394 100644 --- a/tests/frame/join_test.py +++ b/tests/frame/join_test.py @@ -206,10 +206,7 @@ def test_semi_join( join_key: list[str], filter_expr: nw.Expr, expected: dict[str, list[Any]], - request: pytest.FixtureRequest, ) -> None: - if "duckdb" in str(constructor): - request.applymarker(pytest.mark.xfail) data = {"antananarivo": [1, 3, 2], "bob": [4, 4, 6], "zorro": [7.0, 8, 9]} df = nw.from_native(constructor(data)) other = df.filter(filter_expr) From 9d6d370793a58a6f4a0c6c4338ecef848a06ab55 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Wed, 8 Jan 2025 16:59:06 +0000 Subject: [PATCH 2/2] fixup --- tpch/execute.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tpch/execute.py b/tpch/execute.py index 1f3823ced..427c8cf5b 100644 --- a/tpch/execute.py +++ b/tpch/execute.py @@ -40,7 +40,7 @@ "dask": lambda x: x.compute(), } -DUCKDB_XFAILS = ["q11", "q14", "q15", "q16", "q18", "q22"] +DUCKDB_XFAILS = ["q11", "q14", "q15", "q16", "q22"] QUERY_DATA_PATH_MAP = { "q1": (LINEITEM_PATH,),