From aa48faae5ebaf36d0efd39cc1b0267a4502ee51a Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Mon, 6 Jan 2025 10:41:53 +0000 Subject: [PATCH] feat: Implement partial "lazy" support for DuckDB (even with this PR, DuckDB support is work-in-progress!) (#1725) --- README.md | 3 +- docs/backcompat.md | 4 + docs/basics/dataframe_conversion.md | 16 +- docs/extending.md | 7 +- narwhals/_arrow/dataframe.py | 4 + narwhals/_dask/dataframe.py | 4 + narwhals/_duckdb/dataframe.py | 329 +++++--- narwhals/_duckdb/expr.py | 767 ++++++++++++++++++ narwhals/_duckdb/group_by.py | 57 ++ narwhals/_duckdb/namespace.py | 205 +++++ narwhals/_duckdb/series.py | 2 +- narwhals/_duckdb/typing.py | 16 + narwhals/_duckdb/utils.py | 213 +++++ narwhals/_pandas_like/dataframe.py | 4 + narwhals/functions.py | 3 + narwhals/translate.py | 19 +- pyproject.toml | 1 + tests/conftest.py | 23 +- tests/expr_and_series/all_horizontal_test.py | 2 + tests/expr_and_series/arithmetic_test.py | 6 +- tests/expr_and_series/cast_test.py | 18 +- tests/expr_and_series/concat_str_test.py | 8 +- .../expr_and_series/convert_time_zone_test.py | 2 + tests/expr_and_series/cum_count_test.py | 2 + tests/expr_and_series/cum_max_test.py | 2 + tests/expr_and_series/cum_min_test.py | 2 + tests/expr_and_series/cum_prod_test.py | 2 + tests/expr_and_series/cum_sum_test.py | 2 + tests/expr_and_series/diff_test.py | 2 + .../dt/datetime_attributes_test.py | 3 + .../dt/datetime_duration_test.py | 2 + tests/expr_and_series/dt/timestamp_test.py | 8 + tests/expr_and_series/dt/to_string_test.py | 20 +- tests/expr_and_series/fill_null_test.py | 12 +- tests/expr_and_series/is_duplicated_test.py | 14 +- tests/expr_and_series/is_finite_test.py | 4 +- .../expr_and_series/is_first_distinct_test.py | 8 +- .../expr_and_series/is_last_distinct_test.py | 8 +- tests/expr_and_series/is_nan_test.py | 8 +- tests/expr_and_series/is_unique_test.py | 12 +- tests/expr_and_series/lit_test.py | 7 + tests/expr_and_series/mean_horizontal_test.py | 10 +- tests/expr_and_series/median_test.py | 11 +- tests/expr_and_series/n_unique_test.py | 6 +- .../expr_and_series/name/to_uppercase_test.py | 16 +- tests/expr_and_series/nth_test.py | 2 + tests/expr_and_series/null_count_test.py | 8 +- tests/expr_and_series/over_test.py | 24 +- tests/expr_and_series/quantile_test.py | 5 +- tests/expr_and_series/reduction_test.py | 22 +- tests/expr_and_series/replace_strict_test.py | 6 + .../expr_and_series/replace_time_zone_test.py | 3 + tests/expr_and_series/shift_test.py | 5 +- tests/expr_and_series/std_test.py | 21 +- tests/expr_and_series/str/len_chars_test.py | 6 +- tests/expr_and_series/str/replace_test.py | 8 +- tests/expr_and_series/str/to_datetime_test.py | 12 +- .../str/to_uppercase_to_lowercase_test.py | 2 + tests/expr_and_series/sum_horizontal_test.py | 14 +- tests/expr_and_series/unary_test.py | 16 +- tests/expr_and_series/var_test.py | 21 +- tests/expr_and_series/when_test.py | 44 +- tests/frame/add_test.py | 6 +- tests/frame/clone_test.py | 2 + tests/frame/concat_test.py | 12 +- tests/frame/drop_nulls_test.py | 11 +- tests/frame/explode_test.py | 8 +- tests/frame/filter_test.py | 6 +- tests/frame/gather_every_test.py | 6 +- tests/frame/join_test.py | 22 +- tests/frame/select_test.py | 12 +- tests/frame/unique_test.py | 17 +- tests/frame/unpivot_test.py | 4 +- tests/frame/with_columns_test.py | 2 + tests/frame/with_row_index_test.py | 6 +- tests/group_by_test.py | 24 +- tests/selectors_test.py | 31 +- tests/stable_api_test.py | 6 +- tests/utils.py | 8 +- tpch/execute.py | 3 +- utils/import_check.py | 2 + 81 files changed, 2064 insertions(+), 217 deletions(-) create mode 100644 narwhals/_duckdb/expr.py create mode 100644 narwhals/_duckdb/group_by.py create mode 100644 narwhals/_duckdb/namespace.py create mode 100644 narwhals/_duckdb/typing.py create mode 100644 narwhals/_duckdb/utils.py diff --git a/README.md b/README.md index bb024c6c2..eee90ebd9 100644 --- a/README.md +++ b/README.md @@ -14,8 +14,7 @@ Extremely lightweight and extensible compatibility layer between dataframe libraries! - **Full API support**: cuDF, Modin, pandas, Polars, PyArrow -- **Lazy-only support**: Dask -- **Interchange-level support**: DuckDB, Ibis, Vaex, anything which implements the DataFrame Interchange Protocol +- **Lazy-only support**: Dask. Work in progress: DuckDB, Ibis, PySpark. Seamlessly support all, without depending on any! diff --git a/docs/backcompat.md b/docs/backcompat.md index 55b927fd8..b2d312e0a 100644 --- a/docs/backcompat.md +++ b/docs/backcompat.md @@ -111,6 +111,10 @@ before making any change. ### After `stable.v1` + +- Since Narwhals 1.21, passing a `DuckDBPyRelation` to `from_native` returns a `LazyFrame`. In + `narwhals.stable.v1`, it returns a `DataFrame` with `level='interchange'`. + - Since Narwhals 1.15, `Series` is generic in the native Series, meaning that you can write: ```python diff --git a/docs/basics/dataframe_conversion.md b/docs/basics/dataframe_conversion.md index 690f5d093..a4753a033 100644 --- a/docs/basics/dataframe_conversion.md +++ b/docs/basics/dataframe_conversion.md @@ -14,6 +14,7 @@ To illustrate, we create dataframes in various formats: ```python exec="1" source="above" session="conversion" import narwhals as nw from narwhals.typing import IntoDataFrame +from typing import Any import duckdb import polars as pl @@ -45,11 +46,15 @@ print(df_to_pandas(df_polars)) ### Via PyCapsule Interface -Similarly, if your library uses Polars internally, you can convert any user-supplied dataframe to Polars format using Narwhals. +Similarly, if your library uses Polars internally, you can convert any user-supplied dataframe +which implements `__arrow_c_stream__`: ```python exec="1" source="above" session="conversion" result="python" -def df_to_polars(df: IntoDataFrame) -> pl.DataFrame: - return nw.from_arrow(nw.from_native(df), native_namespace=pl).to_native() +def df_to_polars(df_native: Any) -> pl.DataFrame: + if hasattr(df_native, "__arrow_c_stream__"): + return nw.from_arrow(df_native, native_namespace=pl).to_native() + msg = f"Expected object which implements '__arrow_c_stream__' got: {type(df)}" + raise TypeError(msg) print(df_to_polars(df_duckdb)) # You can only execute this line of code once. @@ -66,8 +71,9 @@ If you need to ingest the same dataframe multiple times, then you may want to go This may be less efficient than the PyCapsule approach above (and always requires PyArrow!), but is more forgiving: ```python exec="1" source="above" session="conversion" result="python" -def df_to_polars(df: IntoDataFrame) -> pl.DataFrame: - return pl.DataFrame(nw.from_native(df).to_arrow()) +def df_to_polars(df_native: IntoDataFrame) -> pl.DataFrame: + df = nw.from_native(df_native).lazy().collect() + return pl.DataFrame(nw.from_native(df, eager_only=True).to_arrow()) df_duckdb = duckdb.sql("SELECT * FROM df_polars") diff --git a/docs/extending.md b/docs/extending.md index 2a8953987..588e234f4 100644 --- a/docs/extending.md +++ b/docs/extending.md @@ -15,17 +15,16 @@ Currently, Narwhals has **full API** support for the following libraries: It also has **lazy-only** support for [Dask](https://github.com/dask/dask), and **interchange** support for [DuckDB](https://github.com/duckdb/duckdb) and [Ibis](https://github.com/ibis-project/ibis). +We are working towards full "lazy-only" support for DuckDB, Ibis, and PySpark. + ### Levels of support Narwhals comes with three levels of support: - **Full API support**: cuDF, Modin, pandas, Polars, PyArrow -- **Lazy-only support**: Dask +- **Lazy-only support**: Dask. Work in progress: DuckDB, Ibis, PySpark. - **Interchange-level support**: DuckDB, Ibis, Vaex, anything which implements the DataFrame Interchange Protocol -The lazy-only layer is a major item on our 2025 roadmap, and hope to be able to bring libraries currently in -the "interchange" level into that one. - Libraries for which we have full support can benefit from the whole [Narwhals API](./api-reference/index.md). diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 9e5ce0621..f4ad2912e 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -16,6 +16,7 @@ from narwhals._arrow.utils import validate_dataframe_comparand from narwhals._expression_parsing import evaluate_into_exprs from narwhals.dependencies import is_numpy_array +from narwhals.exceptions import ColumnNotFoundError from narwhals.utils import Implementation from narwhals.utils import flatten from narwhals.utils import generate_temporary_column_name @@ -669,6 +670,9 @@ def unique( import pyarrow.compute as pc df = self._native_frame + if subset is not None and any(x not in self.columns for x in subset): + msg = f"Column(s) {subset} not found in {self.columns}" + raise ColumnNotFoundError(msg) subset = subset or self.columns if keep in {"any", "first", "last"}: diff --git a/narwhals/_dask/dataframe.py b/narwhals/_dask/dataframe.py index 5e652a937..16053d69a 100644 --- a/narwhals/_dask/dataframe.py +++ b/narwhals/_dask/dataframe.py @@ -11,6 +11,7 @@ from narwhals._dask.utils import parse_exprs_and_named_exprs from narwhals._pandas_like.utils import native_to_narwhals_dtype from narwhals._pandas_like.utils import select_columns_by_name +from narwhals.exceptions import ColumnNotFoundError from narwhals.typing import CompliantLazyFrame from narwhals.utils import Implementation from narwhals.utils import flatten @@ -197,6 +198,9 @@ def unique( *, keep: Literal["any", "none"] = "any", ) -> Self: + if subset is not None and any(x not in self.columns for x in subset): + msg = f"Column(s) {subset} not found in {self.columns}" + raise ColumnNotFoundError(msg) native_frame = self._native_frame if keep == "none": subset = subset or self.columns diff --git a/narwhals/_duckdb/dataframe.py b/narwhals/_duckdb/dataframe.py index 73dd055ca..76ff68ae0 100644 --- a/narwhals/_duckdb/dataframe.py +++ b/narwhals/_duckdb/dataframe.py @@ -1,105 +1,74 @@ from __future__ import annotations -import re -from functools import lru_cache +from itertools import chain from typing import TYPE_CHECKING from typing import Any +from typing import Iterable +from typing import Literal +from typing import Sequence +from narwhals._duckdb.utils import native_to_narwhals_dtype +from narwhals._duckdb.utils import parse_exprs_and_named_exprs from narwhals.dependencies import get_duckdb +from narwhals.exceptions import ColumnNotFoundError from narwhals.utils import Implementation -from narwhals.utils import import_dtypes_module +from narwhals.utils import Version +from narwhals.utils import flatten +from narwhals.utils import generate_temporary_column_name +from narwhals.utils import parse_columns_to_drop from narwhals.utils import parse_version from narwhals.utils import validate_backend_version if TYPE_CHECKING: from types import ModuleType + import duckdb import pandas as pd import pyarrow as pa from typing_extensions import Self + from narwhals._duckdb.expr import DuckDBExpr + from narwhals._duckdb.group_by import DuckDBGroupBy + from narwhals._duckdb.namespace import DuckDBNamespace from narwhals._duckdb.series import DuckDBInterchangeSeries from narwhals.dtypes import DType - from narwhals.utils import Version - - -@lru_cache(maxsize=16) -def native_to_narwhals_dtype(duckdb_dtype: str, version: Version) -> DType: - dtypes = import_dtypes_module(version) - if duckdb_dtype == "HUGEINT": - return dtypes.Int128() - if duckdb_dtype == "BIGINT": - return dtypes.Int64() - if duckdb_dtype == "INTEGER": - return dtypes.Int32() - if duckdb_dtype == "SMALLINT": - return dtypes.Int16() - if duckdb_dtype == "TINYINT": - return dtypes.Int8() - if duckdb_dtype == "UHUGEINT": - return dtypes.UInt128() - if duckdb_dtype == "UBIGINT": - return dtypes.UInt64() - if duckdb_dtype == "UINTEGER": - return dtypes.UInt32() - if duckdb_dtype == "USMALLINT": - return dtypes.UInt16() - if duckdb_dtype == "UTINYINT": - return dtypes.UInt8() - if duckdb_dtype == "DOUBLE": - return dtypes.Float64() - if duckdb_dtype == "FLOAT": - return dtypes.Float32() - if duckdb_dtype == "VARCHAR": - return dtypes.String() - if duckdb_dtype == "DATE": - return dtypes.Date() - if duckdb_dtype == "TIMESTAMP": - return dtypes.Datetime() - if duckdb_dtype == "BOOLEAN": - return dtypes.Boolean() - if duckdb_dtype == "INTERVAL": - return dtypes.Duration() - if duckdb_dtype.startswith("STRUCT"): - matchstruc_ = re.findall(r"(\w+)\s+(\w+)", duckdb_dtype) - return dtypes.Struct( - [ - dtypes.Field( - matchstruc_[i][0], - native_to_narwhals_dtype(matchstruc_[i][1], version), - ) - for i in range(len(matchstruc_)) - ] - ) - if match_ := re.match(r"(.*)\[\]$", duckdb_dtype): - return dtypes.List(native_to_narwhals_dtype(match_.group(1), version)) - if match_ := re.match(r"(\w+)\[(\d+)\]", duckdb_dtype): - return dtypes.Array( - native_to_narwhals_dtype(match_.group(1), version), - int(match_.group(2)), - ) - if duckdb_dtype.startswith("DECIMAL("): - return dtypes.Decimal() - return dtypes.Unknown() # pragma: no cover -class DuckDBInterchangeFrame: +class DuckDBLazyFrame: _implementation = Implementation.DUCKDB def __init__( - self, df: Any, *, backend_version: tuple[int, ...], version: Version + self, + df: duckdb.DuckDBPyRelation, + *, + backend_version: tuple[int, ...], + version: Version, ) -> None: - self._native_frame = df + self._native_frame: duckdb.DuckDBPyRelation = df self._version = version self._backend_version = backend_version validate_backend_version(self._implementation, self._backend_version) - def __narwhals_dataframe__(self) -> Any: + def __narwhals_dataframe__(self) -> Any: # pragma: no cover + # Keep around for backcompat. + if self._version is not Version.V1: + msg = "__narwhals_dataframe__ is not implemented for DuckDBLazyFrame" + raise AttributeError(msg) + return self + + def __narwhals_lazyframe__(self) -> Any: return self def __native_namespace__(self: Self) -> ModuleType: return get_duckdb() # type: ignore[no-any-return] + def __narwhals_namespace__(self) -> DuckDBNamespace: + from narwhals._duckdb.namespace import DuckDBNamespace + + return DuckDBNamespace( + backend_version=self._backend_version, version=self._version + ) + def __getitem__(self, item: str) -> DuckDBInterchangeSeries: from narwhals._duckdb.series import DuckDBInterchangeSeries @@ -107,42 +76,101 @@ def __getitem__(self, item: str) -> DuckDBInterchangeSeries: self._native_frame.select(item), version=self._version ) + def collect(self) -> Any: + try: + import pyarrow as pa # ignore-banned-import + except ModuleNotFoundError as exc: # pragma: no cover + msg = "PyArrow>=11.0.0 is required to collect `LazyFrame` backed by DuckDcollect `LazyFrame` backed by DuckDB" + raise ModuleNotFoundError(msg) from exc + + from narwhals._arrow.dataframe import ArrowDataFrame + + return ArrowDataFrame( + native_dataframe=self._native_frame.arrow(), + backend_version=parse_version(pa.__version__), + version=self._version, + ) + + def head(self, n: int) -> Self: + return self._from_native_frame(self._native_frame.limit(n)) + def select( self: Self, *exprs: Any, **named_exprs: Any, ) -> Self: - if named_exprs or not all(isinstance(x, str) for x in exprs): # pragma: no cover - msg = ( - "`select`-ing not by name is not supported for DuckDB backend.\n\n" - "If you would like to see this kind of object better supported in " - "Narwhals, please open a feature request " - "at https://github.com/narwhals-dev/narwhals/issues." + new_columns_map = parse_exprs_and_named_exprs(self, *exprs, **named_exprs) + if not new_columns_map: + # TODO(marco): return empty relation with 0 columns? + return self._from_native_frame(self._native_frame.limit(0)) + + if all(getattr(x, "_returns_scalar", False) for x in exprs) and all( + getattr(x, "_returns_scalar", False) for x in named_exprs.values() + ): + return self._from_native_frame( + self._native_frame.aggregate( + [val.alias(col) for col, val in new_columns_map.items()] + ) ) - raise NotImplementedError(msg) - return self._from_native_frame(self._native_frame.select(*exprs)) + return self._from_native_frame( + self._native_frame.select( + *(val.alias(col) for col, val in new_columns_map.items()) + ) + ) - def __getattr__(self, attr: str) -> Any: - if attr == "schema": - return { - column_name: native_to_narwhals_dtype(str(duckdb_dtype), self._version) - for column_name, duckdb_dtype in zip( - self._native_frame.columns, self._native_frame.types - ) - } - elif attr == "columns": - return self._native_frame.columns - - msg = ( # pragma: no cover - f"Attribute {attr} is not supported for metadata-only dataframes.\n\n" - "If you would like to see this kind of object better supported in " - "Narwhals, please open a feature request " - "at https://github.com/narwhals-dev/narwhals/issues." + def drop(self: Self, columns: list[str], strict: bool) -> Self: # noqa: FBT001 + columns_to_drop = parse_columns_to_drop( + compliant_frame=self, columns=columns, strict=strict + ) + selection = (col for col in self.columns if col not in columns_to_drop) + return self._from_native_frame(self._native_frame.select(*selection)) + + def lazy(self) -> Self: + return self + + def with_columns( + self: Self, + *exprs: Any, + **named_exprs: Any, + ) -> Self: + from duckdb import ColumnExpression + + new_columns_map = parse_exprs_and_named_exprs(self, *exprs, **named_exprs) + result = [] + for col in self._native_frame.columns: + if col in new_columns_map: + result.append(new_columns_map.pop(col).alias(col)) + else: + result.append(ColumnExpression(col)) + for col, value in new_columns_map.items(): + result.append(value.alias(col)) + return self._from_native_frame(self._native_frame.select(*result)) + + def filter(self, *predicates: DuckDBExpr, **constraints: Any) -> Self: + plx = self.__narwhals_namespace__() + expr = plx.all_horizontal( + *chain(predicates, (plx.col(name) == v for name, v in constraints.items())) ) - raise NotImplementedError(msg) # pragma: no cover + # `[0]` is safe as all_horizontal's expression only returns a single column + mask = expr._call(self)[0] + return self._from_native_frame(self._native_frame.filter(mask)) + + @property + def schema(self) -> dict[str, DType]: + return { + column_name: native_to_narwhals_dtype(str(duckdb_dtype), self._version) + for column_name, duckdb_dtype in zip( + self._native_frame.columns, self._native_frame.types + ) + } + + @property + def columns(self) -> list[str]: + return self._native_frame.columns # type: ignore[no-any-return] def to_pandas(self: Self) -> pd.DataFrame: + # only if version is v1, keep around for backcompat import pandas as pd # ignore-banned-import() if parse_version(pd.__version__) >= parse_version("1.0.0"): @@ -152,6 +180,7 @@ def to_pandas(self: Self) -> pd.DataFrame: raise NotImplementedError(msg) def to_arrow(self: Self) -> pa.Table: + # only if version is v1, keep around for backcompat return self._native_frame.arrow() def _change_version(self: Self, version: Version) -> Self: @@ -161,9 +190,68 @@ def _change_version(self: Self, version: Version) -> Self: def _from_native_frame(self: Self, df: Any) -> Self: return self.__class__( - df, version=self._version, backend_version=self._backend_version + df, backend_version=self._backend_version, version=self._version + ) + + def group_by(self: Self, *keys: str, drop_null_keys: bool) -> DuckDBGroupBy: + from narwhals._duckdb.group_by import DuckDBGroupBy + + if drop_null_keys: + msg = "todo" + raise NotImplementedError(msg) + + return DuckDBGroupBy( + compliant_frame=self, keys=list(keys), drop_null_keys=drop_null_keys + ) + + def rename(self: Self, mapping: dict[str, str]) -> Self: + df = self._native_frame + selection = [ + f"{col} as {mapping[col]}" if col in mapping else col for col in df.columns + ] + return self._from_native_frame(df.select(", ".join(selection))) + + def join( + self: Self, + other: Self, + *, + how: Literal["left", "inner", "outer", "cross", "anti", "semi"] = "inner", + left_on: str | list[str] | None, + right_on: str | list[str] | None, + suffix: str, + ) -> Self: + if isinstance(left_on, str): + left_on = [left_on] + if isinstance(right_on, str): + right_on = [right_on] + + if how not in ("inner", "left"): + msg = "Only inner and left join is implemented for DuckDB" + raise NotImplementedError(msg) + + # help mypy + assert left_on is not None # noqa: S101 + assert right_on is not None # noqa: S101 + + conditions = [ + f"lhs.{left} = rhs.{right}" for left, right in zip(left_on, right_on) + ] + original_alias = self._native_frame.alias + condition = " and ".join(conditions) + rel = self._native_frame.set_alias("lhs").join( + other._native_frame.set_alias("rhs"), condition=condition, how=how ) + select = [f"lhs.{x}" for x in self._native_frame.columns] + for col in other._native_frame.columns: + if col in self._native_frame.columns and col not in right_on: + select.append(f"rhs.{col} as {col}{suffix}") + elif col not in right_on: + select.append(col) + + res = rel.select(", ".join(select)).set_alias(original_alias) + return self._from_native_frame(res) + def collect_schema(self) -> dict[str, DType]: return { column_name: native_to_narwhals_dtype(str(duckdb_dtype), self._version) @@ -171,3 +259,56 @@ def collect_schema(self) -> dict[str, DType]: self._native_frame.columns, self._native_frame.types ) } + + def unique(self, subset: Sequence[str] | None, keep: str) -> Self: + if subset is not None: + import duckdb + + rel = self._native_frame + # Sanitise input + if any(x not in rel.columns for x in subset): + msg = f"Columns {set(subset).difference(rel.columns)} not found in {rel.columns}." + raise ColumnNotFoundError(msg) + idx_name = f'"{generate_temporary_column_name(8, rel.columns)}"' + count_name = ( + f'"{generate_temporary_column_name(8, [*rel.columns, idx_name])}"' + ) + if keep == "none": + keep_condition = f"where {count_name}=1" + else: + keep_condition = f"where {idx_name}=1" + query = f""" + with cte as ( + select *, + row_number() over (partition by {",".join(subset)}) as {idx_name}, + count(*) over (partition by {",".join(subset)}) as {count_name} + from rel + ) + select * exclude ({idx_name}, {count_name}) from cte {keep_condition} + """ # noqa: S608 + return self._from_native_frame(duckdb.sql(query)) + return self._from_native_frame(self._native_frame.unique(", ".join(self.columns))) + + def sort( + self: Self, + by: str | Iterable[str], + *more_by: str, + descending: bool | Sequence[bool] = False, + nulls_last: bool = False, + ) -> Self: + flat_by = flatten([*flatten([by]), *more_by]) + if isinstance(descending, bool): + descending = [descending] * len(flat_by) + descending_str = ["desc" if x else "" for x in descending] + + result = self._native_frame.order( + ",".join( + ( + f"{col} {desc} nulls last" + if nulls_last + else f"{col} {desc} nulls first" + for col, desc in zip(flat_by, descending_str) + ) + ) + ) + return self._from_native_frame(result) diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py new file mode 100644 index 000000000..3956e919d --- /dev/null +++ b/narwhals/_duckdb/expr.py @@ -0,0 +1,767 @@ +from __future__ import annotations + +import functools +from typing import TYPE_CHECKING +from typing import Any +from typing import Callable +from typing import Literal +from typing import NoReturn +from typing import Sequence + +from narwhals._duckdb.utils import binary_operation_returns_scalar +from narwhals._duckdb.utils import get_column_name +from narwhals._duckdb.utils import maybe_evaluate +from narwhals._duckdb.utils import narwhals_to_native_dtype +from narwhals._expression_parsing import infer_new_root_output_names +from narwhals.typing import CompliantExpr +from narwhals.utils import Implementation + +if TYPE_CHECKING: + import duckdb + from typing_extensions import Self + + from narwhals._duckdb.dataframe import DuckDBLazyFrame + from narwhals._duckdb.namespace import DuckDBNamespace + from narwhals.dtypes import DType + from narwhals.utils import Version + + +class DuckDBExpr(CompliantExpr["duckdb.Expression"]): + _implementation = Implementation.DUCKDB + + def __init__( + self, + call: Callable[[DuckDBLazyFrame], list[duckdb.Expression]], + *, + depth: int, + function_name: str, + root_names: list[str] | None, + output_names: list[str] | None, + # Whether the expression is a length-1 Column resulting from + # a reduction, such as `nw.col('a').sum()` + returns_scalar: bool, + backend_version: tuple[int, ...], + version: Version, + kwargs: dict[str, Any], + ) -> None: + self._call = call + self._depth = depth + self._function_name = function_name + self._root_names = root_names + self._output_names = output_names + self._returns_scalar = returns_scalar + self._backend_version = backend_version + self._version = version + self._kwargs = kwargs + + def __call__(self, df: DuckDBLazyFrame) -> Sequence[duckdb.Expression]: + return self._call(df) + + def __narwhals_expr__(self) -> None: ... + + def __narwhals_namespace__(self) -> DuckDBNamespace: # pragma: no cover + # Unused, just for compatibility with PandasLikeExpr + from narwhals._duckdb.namespace import DuckDBNamespace + + return DuckDBNamespace( + backend_version=self._backend_version, version=self._version + ) + + @classmethod + def from_column_names( + cls: type[Self], + *column_names: str, + backend_version: tuple[int, ...], + version: Version, + ) -> Self: + def func(_: DuckDBLazyFrame) -> list[duckdb.Expression]: + from duckdb import ColumnExpression + + return [ColumnExpression(col_name) for col_name in column_names] + + return cls( + func, + depth=0, + function_name="col", + root_names=list(column_names), + output_names=list(column_names), + returns_scalar=False, + backend_version=backend_version, + version=version, + kwargs={}, + ) + + def _from_call( + self, + call: Callable[..., duckdb.Expression], + expr_name: str, + *, + returns_scalar: bool, + **kwargs: Any, + ) -> Self: + def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: + results = [] + inputs = self._call(df) + _kwargs = {key: maybe_evaluate(df, value) for key, value in kwargs.items()} + for _input in inputs: + input_col_name = get_column_name( + df, _input, returns_scalar=self._returns_scalar + ) + if self._returns_scalar: + # TODO(marco): once WindowExpression is supported, then + # we may need to call it with `over(1)` here, + # depending on the context? + pass + + column_result = call(_input, **_kwargs) + column_result = column_result.alias(input_col_name) + if returns_scalar: + # TODO(marco): once WindowExpression is supported, then + # we may need to call it with `over(1)` here, + # depending on the context? + pass + results.append(column_result) + return results + + root_names, output_names = infer_new_root_output_names(self, **kwargs) + + return self.__class__( + func, + depth=self._depth + 1, + function_name=f"{self._function_name}->{expr_name}", + root_names=root_names, + output_names=output_names, + returns_scalar=returns_scalar, + backend_version=self._backend_version, + version=self._version, + kwargs=kwargs, + ) + + def __and__(self, other: DuckDBExpr) -> Self: + return self._from_call( + lambda _input, other: _input & other, + "__and__", + other=other, + returns_scalar=binary_operation_returns_scalar(self, other), + ) + + def __or__(self, other: DuckDBExpr) -> Self: + return self._from_call( + lambda _input, other: _input | other, + "__or__", + other=other, + returns_scalar=binary_operation_returns_scalar(self, other), + ) + + def __add__(self, other: DuckDBExpr) -> Self: + return self._from_call( + lambda _input, other: _input + other, + "__add__", + other=other, + returns_scalar=binary_operation_returns_scalar(self, other), + ) + + def __truediv__(self, other: DuckDBExpr) -> Self: + return self._from_call( + lambda _input, other: _input / other, + "__truediv__", + other=other, + returns_scalar=binary_operation_returns_scalar(self, other), + ) + + def __floordiv__(self, other: DuckDBExpr) -> Self: + return self._from_call( + lambda _input, other: _input.__floordiv__(other), + "__floordiv__", + other=other, + returns_scalar=binary_operation_returns_scalar(self, other), + ) + + def __mod__(self, other: DuckDBExpr) -> Self: + return self._from_call( + lambda _input, other: _input.__mod__(other), + "__mod__", + other=other, + returns_scalar=binary_operation_returns_scalar(self, other), + ) + + def __sub__(self, other: DuckDBExpr) -> Self: + return self._from_call( + lambda _input, other: _input - other, + "__sub__", + other=other, + returns_scalar=binary_operation_returns_scalar(self, other), + ) + + def __mul__(self, other: DuckDBExpr) -> Self: + return self._from_call( + lambda _input, other: _input * other, + "__mul__", + other=other, + returns_scalar=binary_operation_returns_scalar(self, other), + ) + + def __pow__(self, other: DuckDBExpr) -> Self: + return self._from_call( + lambda _input, other: _input**other, + "__pow__", + other=other, + returns_scalar=binary_operation_returns_scalar(self, other), + ) + + def __lt__(self, other: DuckDBExpr) -> Self: + return self._from_call( + lambda _input, other: _input < other, + "__lt__", + other=other, + returns_scalar=binary_operation_returns_scalar(self, other), + ) + + def __gt__(self, other: DuckDBExpr) -> Self: + return self._from_call( + lambda _input, other: _input > other, + "__gt__", + other=other, + returns_scalar=binary_operation_returns_scalar(self, other), + ) + + def __le__(self, other: DuckDBExpr) -> Self: + return self._from_call( + lambda _input, other: _input <= other, + "__le__", + other=other, + returns_scalar=binary_operation_returns_scalar(self, other), + ) + + def __ge__(self, other: DuckDBExpr) -> Self: + return self._from_call( + lambda _input, other: _input >= other, + "__ge__", + other=other, + returns_scalar=binary_operation_returns_scalar(self, other), + ) + + def __eq__(self, other: DuckDBExpr) -> Self: # type: ignore[override] + return self._from_call( + lambda _input, other: _input == other, + "__eq__", + other=other, + returns_scalar=binary_operation_returns_scalar(self, other), + ) + + def __ne__(self, other: DuckDBExpr) -> Self: # type: ignore[override] + return self._from_call( + lambda _input, other: _input != other, + "__ne__", + other=other, + returns_scalar=binary_operation_returns_scalar(self, other), + ) + + def __invert__(self) -> Self: + return self._from_call( + lambda _input: ~_input, + "__invert__", + returns_scalar=self._returns_scalar, + ) + + def alias(self, name: str) -> Self: + def _alias(df: DuckDBLazyFrame) -> list[duckdb.Expression]: + return [col.alias(name) for col in self._call(df)] + + # Define this one manually, so that we can + # override `output_names` and not increase depth + return self.__class__( + _alias, + depth=self._depth, + function_name=self._function_name, + root_names=self._root_names, + output_names=[name], + returns_scalar=self._returns_scalar, + backend_version=self._backend_version, + version=self._version, + kwargs={**self._kwargs, "name": name}, + ) + + def abs(self) -> Self: + from duckdb import FunctionExpression + + return self._from_call( + lambda _input: FunctionExpression("abs", _input), + "abs", + returns_scalar=self._returns_scalar, + ) + + def mean(self) -> Self: + from duckdb import FunctionExpression + + return self._from_call( + lambda _input: FunctionExpression("mean", _input), + "mean", + returns_scalar=True, + ) + + def skew(self) -> Self: + from duckdb import FunctionExpression + + return self._from_call( + lambda _input: FunctionExpression("skewness", _input), + "skew", + returns_scalar=True, + ) + + def median(self) -> Self: + from duckdb import FunctionExpression + + return self._from_call( + lambda _input: FunctionExpression("median", _input), + "median", + returns_scalar=True, + ) + + def all(self) -> Self: + from duckdb import FunctionExpression + + return self._from_call( + lambda _input: FunctionExpression("bool_and", _input), + "all", + returns_scalar=True, + ) + + def any(self) -> Self: + from duckdb import FunctionExpression + + return self._from_call( + lambda _input: FunctionExpression("bool_or", _input), + "any", + returns_scalar=True, + ) + + def quantile( + self, + quantile: float, + interpolation: Literal["nearest", "higher", "lower", "midpoint", "linear"], + ) -> Self: + from duckdb import ConstantExpression + from duckdb import FunctionExpression + + def func(_input: duckdb.Expression) -> duckdb.Expression: + if interpolation == "linear": + return FunctionExpression( + "quantile_cont", _input, ConstantExpression(quantile) + ) + msg = "Only linear interpolation methods are supported for DuckDB quantile." + raise NotImplementedError(msg) + + return self._from_call( + func, + "quantile", + returns_scalar=True, + ) + + def clip(self, lower_bound: Any, upper_bound: Any) -> Self: + from duckdb import FunctionExpression + + def func( + _input: duckdb.Expression, lower_bound: Any, upper_bound: Any + ) -> duckdb.Expression: + return FunctionExpression( + "greatest", + FunctionExpression("least", _input, upper_bound), + lower_bound, + ) + + return self._from_call( + func, + "clip", + lower_bound=lower_bound, + upper_bound=upper_bound, + returns_scalar=self._returns_scalar, + ) + + def is_between( + self, + lower_bound: Any, + upper_bound: Any, + closed: Literal["left", "right", "none", "both"], + ) -> Self: + def func( + _input: duckdb.Expression, lower_bound: Any, upper_bound: Any + ) -> duckdb.Expression: + if closed == "left": + return (_input >= lower_bound) & (_input < upper_bound) + elif closed == "right": + return (_input > lower_bound) & (_input <= upper_bound) + elif closed == "none": + return (_input > lower_bound) & (_input < upper_bound) + return (_input >= lower_bound) & (_input <= upper_bound) + + return self._from_call( + func, + "is_between", + lower_bound=lower_bound, + upper_bound=upper_bound, + returns_scalar=self._returns_scalar, + ) + + def sum(self) -> Self: + from duckdb import FunctionExpression + + return self._from_call( + lambda _input: FunctionExpression("sum", _input), + "sum", + returns_scalar=True, + ) + + def count(self) -> Self: + from duckdb import FunctionExpression + + return self._from_call( + lambda _input: FunctionExpression("count", _input), + "count", + returns_scalar=True, + ) + + def len(self) -> Self: + from duckdb import FunctionExpression + + return self._from_call( + lambda _input: FunctionExpression("count"), + "len", + returns_scalar=True, + ) + + def std(self, ddof: int) -> Self: + from duckdb import FunctionExpression + + if ddof == 1: + func = "stddev_samp" + elif ddof == 0: + func = "stddev_pop" + else: + msg = f"std with ddof {ddof} is not supported in DuckDB" + raise NotImplementedError(msg) + return self._from_call( + lambda _input: FunctionExpression(func, _input), + "std", + returns_scalar=True, + ) + + def var(self, ddof: int) -> Self: + from duckdb import FunctionExpression + + if ddof == 1: + func = "var_samp" + elif ddof == 0: + func = "var_pop" + else: + msg = f"var with ddof {ddof} is not supported in DuckDB" + raise NotImplementedError(msg) + return self._from_call( + lambda _input: FunctionExpression(func, _input), + "var", + returns_scalar=True, + ) + + def max(self) -> Self: + from duckdb import FunctionExpression + + return self._from_call( + lambda _input: FunctionExpression("max", _input), + "max", + returns_scalar=True, + ) + + def min(self) -> Self: + from duckdb import FunctionExpression + + return self._from_call( + lambda _input: FunctionExpression("min", _input), + "min", + returns_scalar=True, + ) + + def is_null(self) -> Self: + return self._from_call( + lambda _input: _input.isnull(), + "is_null", + returns_scalar=self._returns_scalar, + ) + + def is_in(self, other: Sequence[Any]) -> Self: + from duckdb import ConstantExpression + + return self._from_call( + lambda _input: functools.reduce( + lambda x, y: x | _input.isin(ConstantExpression(y)), + other[1:], + _input.isin(ConstantExpression(other[0])), + ), + "is_in", + returns_scalar=self._returns_scalar, + ) + + def round(self, decimals: int) -> Self: + from duckdb import ConstantExpression + from duckdb import FunctionExpression + + return self._from_call( + lambda _input: FunctionExpression( + "round", _input, ConstantExpression(decimals) + ), + "round", + returns_scalar=self._returns_scalar, + ) + + def fill_null(self, value: Any, strategy: Any, limit: int | None) -> Self: + from duckdb import CoalesceOperator + from duckdb import ConstantExpression + + if strategy is not None: + msg = "todo" + raise NotImplementedError(msg) + + return self._from_call( + lambda _input: CoalesceOperator(_input, ConstantExpression(value)), + "fill_null", + returns_scalar=self._returns_scalar, + ) + + def cast( + self: Self, + dtype: DType | type[DType], + ) -> Self: + def func(_input: Any, dtype: DType | type[DType]) -> Any: + native_dtype = narwhals_to_native_dtype(dtype, self._version) + return _input.cast(native_dtype) + + return self._from_call( + func, + "cast", + dtype=dtype, + returns_scalar=self._returns_scalar, + ) + + @property + def str(self: Self) -> DuckDBExprStringNamespace: + return DuckDBExprStringNamespace(self) + + @property + def dt(self: Self) -> DuckDBExprDateTimeNamespace: + return DuckDBExprDateTimeNamespace(self) + + +class DuckDBExprStringNamespace: + def __init__(self, expr: DuckDBExpr) -> None: + self._compliant_expr = expr + + def starts_with(self, prefix: str) -> DuckDBExpr: + from duckdb import ConstantExpression + from duckdb import FunctionExpression + + return self._compliant_expr._from_call( + lambda _input: FunctionExpression( + "starts_with", _input, ConstantExpression(prefix) + ), + "starts_with", + returns_scalar=self._compliant_expr._returns_scalar, + ) + + def ends_with(self, suffix: str) -> DuckDBExpr: + from duckdb import ConstantExpression + from duckdb import FunctionExpression + + return self._compliant_expr._from_call( + lambda _input: FunctionExpression( + "ends_with", _input, ConstantExpression(suffix) + ), + "ends_with", + returns_scalar=self._compliant_expr._returns_scalar, + ) + + def contains(self, pattern: str, *, literal: bool) -> DuckDBExpr: + from duckdb import ConstantExpression + from duckdb import FunctionExpression + + def func(_input: duckdb.Expression) -> duckdb.Expression: + if literal: + return FunctionExpression("contains", _input, ConstantExpression(pattern)) + return FunctionExpression( + "regexp_matches", _input, ConstantExpression(pattern) + ) + + return self._compliant_expr._from_call( + func, + "contains", + returns_scalar=self._compliant_expr._returns_scalar, + ) + + def slice(self, offset: int, length: int) -> DuckDBExpr: + from duckdb import ConstantExpression + from duckdb import FunctionExpression + + def func(_input: duckdb.Expression) -> duckdb.Expression: + return FunctionExpression( + "array_slice", + _input, + ConstantExpression(offset + 1) + if offset >= 0 + else FunctionExpression("length", _input) + offset + 1, + FunctionExpression("length", _input) + if length is None + else ConstantExpression(length) + offset, + ) + + return self._compliant_expr._from_call( + func, + "slice", + returns_scalar=self._compliant_expr._returns_scalar, + ) + + def to_lowercase(self) -> DuckDBExpr: + from duckdb import FunctionExpression + + return self._compliant_expr._from_call( + lambda _input: FunctionExpression("lower", _input), + "to_lowercase", + returns_scalar=self._compliant_expr._returns_scalar, + ) + + def to_uppercase(self) -> DuckDBExpr: + from duckdb import FunctionExpression + + return self._compliant_expr._from_call( + lambda _input: FunctionExpression("upper", _input), + "to_uppercase", + returns_scalar=self._compliant_expr._returns_scalar, + ) + + def strip_chars(self, characters: str | None) -> DuckDBExpr: + import string + + from duckdb import ConstantExpression + from duckdb import FunctionExpression + + return self._compliant_expr._from_call( + lambda _input: FunctionExpression( + "trim", + _input, + ConstantExpression( + string.whitespace if characters is None else characters + ), + ), + "strip_chars", + returns_scalar=self._compliant_expr._returns_scalar, + ) + + def replace_all( + self, pattern: str, value: str, *, literal: bool = False + ) -> DuckDBExpr: + from duckdb import ConstantExpression + from duckdb import FunctionExpression + + if literal is False: + msg = "`replace_all` for DuckDB currently only supports `literal=True`." + raise NotImplementedError(msg) + return self._compliant_expr._from_call( + lambda _input: FunctionExpression( + "replace", + _input, + ConstantExpression(pattern), + ConstantExpression(value), + ), + "replace_all", + returns_scalar=self._compliant_expr._returns_scalar, + ) + + def replace(self, pattern: str, value: str, *, literal: bool, n: int) -> NoReturn: + msg = "`replace` is currently not supported for DuckDB" + raise NotImplementedError(msg) + + +class DuckDBExprDateTimeNamespace: + def __init__(self, expr: DuckDBExpr) -> None: + self._compliant_expr = expr + + def year(self) -> DuckDBExpr: + from duckdb import FunctionExpression + + return self._compliant_expr._from_call( + lambda _input: FunctionExpression("year", _input), + "year", + returns_scalar=self._compliant_expr._returns_scalar, + ) + + def month(self) -> DuckDBExpr: + from duckdb import FunctionExpression + + return self._compliant_expr._from_call( + lambda _input: FunctionExpression("month", _input), + "month", + returns_scalar=self._compliant_expr._returns_scalar, + ) + + def day(self) -> DuckDBExpr: + from duckdb import FunctionExpression + + return self._compliant_expr._from_call( + lambda _input: FunctionExpression("day", _input), + "day", + returns_scalar=self._compliant_expr._returns_scalar, + ) + + def hour(self) -> DuckDBExpr: + from duckdb import FunctionExpression + + return self._compliant_expr._from_call( + lambda _input: FunctionExpression("hour", _input), + "hour", + returns_scalar=self._compliant_expr._returns_scalar, + ) + + def minute(self) -> DuckDBExpr: + from duckdb import FunctionExpression + + return self._compliant_expr._from_call( + lambda _input: FunctionExpression("minute", _input), + "minute", + returns_scalar=self._compliant_expr._returns_scalar, + ) + + def second(self) -> DuckDBExpr: + from duckdb import FunctionExpression + + return self._compliant_expr._from_call( + lambda _input: FunctionExpression("second", _input), + "second", + returns_scalar=self._compliant_expr._returns_scalar, + ) + + def millisecond(self) -> DuckDBExpr: + from duckdb import FunctionExpression + + return self._compliant_expr._from_call( + lambda _input: FunctionExpression("millisecond", _input) + - FunctionExpression("second", _input) * 1_000, + "millisecond", + returns_scalar=self._compliant_expr._returns_scalar, + ) + + def microsecond(self) -> DuckDBExpr: + from duckdb import FunctionExpression + + return self._compliant_expr._from_call( + lambda _input: FunctionExpression("microsecond", _input) + - FunctionExpression("second", _input) * 1_000_000, + "microsecond", + returns_scalar=self._compliant_expr._returns_scalar, + ) + + def nanosecond(self) -> DuckDBExpr: + from duckdb import FunctionExpression + + return self._compliant_expr._from_call( + lambda _input: FunctionExpression("nanosecond", _input) + - FunctionExpression("second", _input) * 1_000_000_000, + "nanosecond", + returns_scalar=self._compliant_expr._returns_scalar, + ) diff --git a/narwhals/_duckdb/group_by.py b/narwhals/_duckdb/group_by.py new file mode 100644 index 000000000..0b312ff03 --- /dev/null +++ b/narwhals/_duckdb/group_by.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +from copy import copy +from typing import TYPE_CHECKING + +from narwhals._expression_parsing import parse_into_exprs + +if TYPE_CHECKING: + from narwhals._duckdb.dataframe import DuckDBLazyFrame + from narwhals._duckdb.typing import IntoDuckDBExpr + + +class DuckDBGroupBy: + def __init__( + self, + compliant_frame: DuckDBLazyFrame, + keys: list[str], + drop_null_keys: bool, # noqa: FBT001 + ) -> None: + self._compliant_frame = compliant_frame + self._keys = keys + + def agg( + self, + *aggs: IntoDuckDBExpr, + **named_aggs: IntoDuckDBExpr, + ) -> DuckDBLazyFrame: + exprs = parse_into_exprs( + *aggs, + namespace=self._compliant_frame.__narwhals_namespace__(), + **named_aggs, + ) + output_names: list[str] = copy(self._keys) + for expr in exprs: + if expr._output_names is None: # pragma: no cover + msg = ( + "Anonymous expressions are not supported in group_by.agg.\n" + "Instead of `nw.all()`, try using a named expression, such as " + "`nw.col('a', 'b')`\n" + ) + raise ValueError(msg) + + output_names.extend(expr._output_names) + + agg_columns = [ + *self._keys, + *(x for expr in exprs for x in expr(self._compliant_frame)), + ] + try: + return self._compliant_frame._from_native_frame( + self._compliant_frame._native_frame.aggregate( + agg_columns, group_expr=",".join(self._keys) + ) + ) + except ValueError as exc: # pragma: no cover + msg = "Failed to aggregated - does your aggregation function return a scalar?" + raise RuntimeError(msg) from exc diff --git a/narwhals/_duckdb/namespace.py b/narwhals/_duckdb/namespace.py new file mode 100644 index 000000000..bcd7eff6d --- /dev/null +++ b/narwhals/_duckdb/namespace.py @@ -0,0 +1,205 @@ +from __future__ import annotations + +import functools +import operator +from functools import reduce +from typing import TYPE_CHECKING +from typing import Any +from typing import Literal +from typing import Sequence + +from narwhals._duckdb.expr import DuckDBExpr +from narwhals._duckdb.utils import narwhals_to_native_dtype +from narwhals._expression_parsing import combine_root_names +from narwhals._expression_parsing import parse_into_exprs +from narwhals._expression_parsing import reduce_output_names +from narwhals.typing import CompliantNamespace + +if TYPE_CHECKING: + import duckdb + + from narwhals._duckdb.dataframe import DuckDBLazyFrame + from narwhals._duckdb.typing import IntoDuckDBExpr + from narwhals.dtypes import DType + from narwhals.utils import Version + + +def get_column_name(df: DuckDBLazyFrame, column: duckdb.Expression) -> str: + return str(df._native_frame.select(column).columns[0]) + + +class DuckDBNamespace(CompliantNamespace["duckdb.Expression"]): + def __init__(self, *, backend_version: tuple[int, ...], version: Version) -> None: + self._backend_version = backend_version + self._version = version + + def all(self) -> DuckDBExpr: + def _all(df: DuckDBLazyFrame) -> list[duckdb.Expression]: + from duckdb import ColumnExpression + + return [ColumnExpression(col_name) for col_name in df.columns] + + return DuckDBExpr( + call=_all, + depth=0, + function_name="all", + root_names=None, + output_names=None, + returns_scalar=False, + backend_version=self._backend_version, + version=self._version, + kwargs={}, + ) + + def concat( + self, + items: Sequence[DuckDBLazyFrame], + *, + how: Literal["horizontal", "vertical", "diagonal"], + ) -> DuckDBLazyFrame: + if how == "horizontal": + msg = "horizontal concat not supported for duckdb. Please join instead" + raise TypeError(msg) + if how == "diagonal": + msg = "Not implemented yet" + raise NotImplementedError(msg) + first = items[0] + schema = first.schema + if how == "vertical" and not all(x.schema == schema for x in items[1:]): + msg = "inputs should all have the same schema" + raise TypeError(msg) + res = functools.reduce( + lambda x, y: x.union(y), (item._native_frame for item in items) + ) + return first._from_native_frame(res) + + def all_horizontal(self, *exprs: IntoDuckDBExpr) -> DuckDBExpr: + parsed_exprs = parse_into_exprs(*exprs, namespace=self) + + def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: + cols = [c for _expr in parsed_exprs for c in _expr(df)] + col_name = get_column_name(df, cols[0]) + return [reduce(operator.and_, cols).alias(col_name)] + + return DuckDBExpr( + call=func, + depth=max(x._depth for x in parsed_exprs) + 1, + function_name="all_horizontal", + root_names=combine_root_names(parsed_exprs), + output_names=reduce_output_names(parsed_exprs), + returns_scalar=False, + backend_version=self._backend_version, + version=self._version, + kwargs={"exprs": exprs}, + ) + + def any_horizontal(self, *exprs: IntoDuckDBExpr) -> DuckDBExpr: + parsed_exprs = parse_into_exprs(*exprs, namespace=self) + + def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: + cols = [c for _expr in parsed_exprs for c in _expr(df)] + col_name = get_column_name(df, cols[0]) + return [reduce(operator.or_, cols).alias(col_name)] + + return DuckDBExpr( + call=func, + depth=max(x._depth for x in parsed_exprs) + 1, + function_name="or_horizontal", + root_names=combine_root_names(parsed_exprs), + output_names=reduce_output_names(parsed_exprs), + returns_scalar=False, + backend_version=self._backend_version, + version=self._version, + kwargs={"exprs": exprs}, + ) + + def max_horizontal(self, *exprs: IntoDuckDBExpr) -> DuckDBExpr: + from duckdb import FunctionExpression + + parsed_exprs = parse_into_exprs(*exprs, namespace=self) + + def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: + cols = [c for _expr in parsed_exprs for c in _expr(df)] + col_name = get_column_name(df, cols[0]) + return [FunctionExpression("greatest", *cols).alias(col_name)] + + return DuckDBExpr( + call=func, + depth=max(x._depth for x in parsed_exprs) + 1, + function_name="max_horizontal", + root_names=combine_root_names(parsed_exprs), + output_names=reduce_output_names(parsed_exprs), + returns_scalar=False, + backend_version=self._backend_version, + version=self._version, + kwargs={"exprs": exprs}, + ) + + def min_horizontal(self, *exprs: IntoDuckDBExpr) -> DuckDBExpr: + from duckdb import FunctionExpression + + parsed_exprs = parse_into_exprs(*exprs, namespace=self) + + def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: + cols = [c for _expr in parsed_exprs for c in _expr(df)] + col_name = get_column_name(df, cols[0]) + return [FunctionExpression("least", *cols).alias(col_name)] + + return DuckDBExpr( + call=func, + depth=max(x._depth for x in parsed_exprs) + 1, + function_name="min_horizontal", + root_names=combine_root_names(parsed_exprs), + output_names=reduce_output_names(parsed_exprs), + returns_scalar=False, + backend_version=self._backend_version, + version=self._version, + kwargs={"exprs": exprs}, + ) + + def col(self, *column_names: str) -> DuckDBExpr: + return DuckDBExpr.from_column_names( + *column_names, backend_version=self._backend_version, version=self._version + ) + + def lit(self, value: Any, dtype: DType | None) -> DuckDBExpr: + from duckdb import ConstantExpression + + def func(_df: DuckDBLazyFrame) -> list[duckdb.Expression]: + if dtype is not None: + return [ + ConstantExpression(value) + .cast(narwhals_to_native_dtype(dtype, version=self._version)) + .alias("literal") + ] + return [ConstantExpression(value).alias("literal")] + + return DuckDBExpr( + func, + depth=0, + function_name="lit", + root_names=None, + output_names=["literal"], + returns_scalar=True, + backend_version=self._backend_version, + version=self._version, + kwargs={}, + ) + + def len(self) -> DuckDBExpr: + def func(_df: DuckDBLazyFrame) -> list[duckdb.Expression]: + from duckdb import FunctionExpression + + return [FunctionExpression("count").alias("len")] + + return DuckDBExpr( + call=func, + depth=0, + function_name="len", + root_names=None, + output_names=["len"], + returns_scalar=True, + backend_version=self._backend_version, + version=self._version, + kwargs={}, + ) diff --git a/narwhals/_duckdb/series.py b/narwhals/_duckdb/series.py index dc7485e98..bec9e0e08 100644 --- a/narwhals/_duckdb/series.py +++ b/narwhals/_duckdb/series.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING from typing import Any -from narwhals._duckdb.dataframe import native_to_narwhals_dtype +from narwhals._duckdb.utils import native_to_narwhals_dtype from narwhals.dependencies import get_duckdb if TYPE_CHECKING: diff --git a/narwhals/_duckdb/typing.py b/narwhals/_duckdb/typing.py new file mode 100644 index 000000000..65d1ba3a7 --- /dev/null +++ b/narwhals/_duckdb/typing.py @@ -0,0 +1,16 @@ +from __future__ import annotations # pragma: no cover + +from typing import TYPE_CHECKING # pragma: no cover +from typing import Union # pragma: no cover + +if TYPE_CHECKING: + import sys + + if sys.version_info >= (3, 10): + from typing import TypeAlias + else: + from typing_extensions import TypeAlias + + from narwhals._duckdb.expr import DuckDBExpr + + IntoDuckDBExpr: TypeAlias = Union[DuckDBExpr, str] diff --git a/narwhals/_duckdb/utils.py b/narwhals/_duckdb/utils.py new file mode 100644 index 000000000..abac2e158 --- /dev/null +++ b/narwhals/_duckdb/utils.py @@ -0,0 +1,213 @@ +from __future__ import annotations + +import re +from functools import lru_cache +from typing import TYPE_CHECKING +from typing import Any + +from narwhals.dtypes import DType +from narwhals.exceptions import InvalidIntoExprError +from narwhals.utils import import_dtypes_module +from narwhals.utils import isinstance_or_issubclass + +if TYPE_CHECKING: + import duckdb + + from narwhals._duckdb.dataframe import DuckDBLazyFrame + from narwhals._duckdb.expr import DuckDBExpr + from narwhals._duckdb.typing import IntoDuckDBExpr + from narwhals.utils import Version + + +def get_column_name( + df: DuckDBLazyFrame, column: duckdb.Expression, *, returns_scalar: bool +) -> str: + if returns_scalar: + return str(df._native_frame.aggregate([column]).columns[0]) + return str(df._native_frame.select(column).columns[0]) + + +def maybe_evaluate(df: DuckDBLazyFrame, obj: Any) -> Any: + import duckdb + + from narwhals._duckdb.expr import DuckDBExpr + + if isinstance(obj, DuckDBExpr): + column_results = obj._call(df) + if len(column_results) != 1: # pragma: no cover + msg = "Multi-output expressions (e.g. `nw.all()` or `nw.col('a', 'b')`) not supported in this context" + raise NotImplementedError(msg) + column_result = column_results[0] + if obj._returns_scalar: + msg = "Reductions are not yet supported for DuckDB, at least until they implement duckdb.WindowExpression" + raise NotImplementedError(msg) + return column_result + if isinstance_or_issubclass(obj, DType): + return obj + return duckdb.ConstantExpression(obj) + + +def parse_exprs_and_named_exprs( + df: DuckDBLazyFrame, + *exprs: IntoDuckDBExpr, + **named_exprs: IntoDuckDBExpr, +) -> dict[str, duckdb.Expression]: + result_columns: dict[str, list[duckdb.Expression]] = {} + for expr in exprs: + column_list = _columns_from_expr(df, expr) + if isinstance(expr, str): # pragma: no cover + output_names = [expr] + elif expr._output_names is None: + output_names = [ + get_column_name(df, col, returns_scalar=expr._returns_scalar) + for col in column_list + ] + else: + output_names = expr._output_names + result_columns.update(zip(output_names, column_list)) + for col_alias, expr in named_exprs.items(): + columns_list = _columns_from_expr(df, expr) + if len(columns_list) != 1: # pragma: no cover + msg = "Named expressions must return a single column" + raise AssertionError(msg) + result_columns[col_alias] = columns_list[0] + return result_columns + + +def _columns_from_expr( + df: DuckDBLazyFrame, expr: IntoDuckDBExpr +) -> list[duckdb.Expression]: + if isinstance(expr, str): # pragma: no cover + from duckdb import ColumnExpression + + return [ColumnExpression(expr)] + elif hasattr(expr, "__narwhals_expr__"): + col_output_list = expr._call(df) + if expr._output_names is not None and ( + len(col_output_list) != len(expr._output_names) + ): # pragma: no cover + msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues" + raise AssertionError(msg) + return col_output_list + else: + raise InvalidIntoExprError.from_invalid_type(type(expr)) + + +@lru_cache(maxsize=16) +def native_to_narwhals_dtype(duckdb_dtype: str, version: Version) -> DType: + dtypes = import_dtypes_module(version) + if duckdb_dtype == "HUGEINT": + return dtypes.Int128() + if duckdb_dtype == "BIGINT": + return dtypes.Int64() + if duckdb_dtype == "INTEGER": + return dtypes.Int32() + if duckdb_dtype == "SMALLINT": + return dtypes.Int16() + if duckdb_dtype == "TINYINT": + return dtypes.Int8() + if duckdb_dtype == "UHUGEINT": + return dtypes.UInt128() + if duckdb_dtype == "UBIGINT": + return dtypes.UInt64() + if duckdb_dtype == "UINTEGER": + return dtypes.UInt32() + if duckdb_dtype == "USMALLINT": + return dtypes.UInt16() + if duckdb_dtype == "UTINYINT": + return dtypes.UInt8() + if duckdb_dtype == "DOUBLE": + return dtypes.Float64() + if duckdb_dtype == "FLOAT": + return dtypes.Float32() + if duckdb_dtype == "VARCHAR": + return dtypes.String() + if duckdb_dtype == "DATE": + return dtypes.Date() + if duckdb_dtype == "TIMESTAMP": + return dtypes.Datetime() + if duckdb_dtype == "BOOLEAN": + return dtypes.Boolean() + if duckdb_dtype == "INTERVAL": + return dtypes.Duration() + if duckdb_dtype.startswith("STRUCT"): + matchstruc_ = re.findall(r"(\w+)\s+(\w+)", duckdb_dtype) + return dtypes.Struct( + [ + dtypes.Field( + matchstruc_[i][0], + native_to_narwhals_dtype(matchstruc_[i][1], version), + ) + for i in range(len(matchstruc_)) + ] + ) + if match_ := re.match(r"(.*)\[\]$", duckdb_dtype): + return dtypes.List(native_to_narwhals_dtype(match_.group(1), version)) + if match_ := re.match(r"(\w+)\[(\d+)\]", duckdb_dtype): + return dtypes.Array( + native_to_narwhals_dtype(match_.group(1), version), + int(match_.group(2)), + ) + if duckdb_dtype.startswith("DECIMAL("): + return dtypes.Decimal() + return dtypes.Unknown() # pragma: no cover + + +def narwhals_to_native_dtype(dtype: DType | type[DType], version: Version) -> str: + dtypes = import_dtypes_module(version) + if isinstance_or_issubclass(dtype, dtypes.Float64): + return "FLOAT" + if isinstance_or_issubclass(dtype, dtypes.Float32): + return "DOUBLE" + if isinstance_or_issubclass(dtype, dtypes.Int64): + return "BIGINT" + if isinstance_or_issubclass(dtype, dtypes.Int32): + return "INT" + if isinstance_or_issubclass(dtype, dtypes.Int16): + return "SMALLINT" + if isinstance_or_issubclass(dtype, dtypes.Int8): + return "TINYINT" + if isinstance_or_issubclass(dtype, dtypes.UInt64): + return "UBIGINT" + if isinstance_or_issubclass(dtype, dtypes.UInt32): + return "UINT" + if isinstance_or_issubclass(dtype, dtypes.UInt16): # pragma: no cover + return "USMALLINT" + if isinstance_or_issubclass(dtype, dtypes.UInt8): # pragma: no cover + return "UTINYINT" + if isinstance_or_issubclass(dtype, dtypes.String): + return "VARCHAR" + if isinstance_or_issubclass(dtype, dtypes.Boolean): # pragma: no cover + return "BOOLEAN" + if isinstance_or_issubclass(dtype, dtypes.Categorical): + msg = "Categorical not supported by DuckDB" + raise NotImplementedError(msg) + if isinstance_or_issubclass(dtype, dtypes.Datetime): + _time_unit = getattr(dtype, "time_unit", "us") + _time_zone = getattr(dtype, "time_zone", None) + msg = "todo" + raise NotImplementedError(msg) + if isinstance_or_issubclass(dtype, dtypes.Duration): # pragma: no cover + _time_unit = getattr(dtype, "time_unit", "us") + msg = "todo" + raise NotImplementedError(msg) + if isinstance_or_issubclass(dtype, dtypes.Date): # pragma: no cover + return "DATE" + if isinstance_or_issubclass(dtype, dtypes.List): + msg = "todo" + raise NotImplementedError(msg) + if isinstance_or_issubclass(dtype, dtypes.Struct): # pragma: no cover + msg = "todo" + raise NotImplementedError(msg) + if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover + msg = "todo" + raise NotImplementedError(msg) + msg = f"Unknown dtype: {dtype}" # pragma: no cover + raise AssertionError(msg) + + +def binary_operation_returns_scalar(lhs: DuckDBExpr, rhs: DuckDBExpr | Any) -> bool: + # If `rhs` is a DuckDBExpr, we look at `_returns_scalar`. If it isn't, + # it means that it was a scalar (e.g. nw.col('a') + 1), and so we default + # to `True`. + return lhs._returns_scalar and getattr(rhs, "_returns_scalar", True) diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index 293f5cefe..e11c02710 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -20,6 +20,7 @@ from narwhals._pandas_like.utils import select_columns_by_name from narwhals._pandas_like.utils import validate_dataframe_comparand from narwhals.dependencies import is_numpy_array +from narwhals.exceptions import ColumnNotFoundError from narwhals.utils import Implementation from narwhals.utils import flatten from narwhals.utils import generate_temporary_column_name @@ -694,6 +695,9 @@ def unique( # The param `maintain_order` is only here for compatibility with the Polars API # and has no effect on the output. mapped_keep = {"none": False, "any": "first"}.get(keep, keep) + if subset is not None and any(x not in self.columns for x in subset): + msg = f"Column(s) {subset} not found in {self.columns}" + raise ColumnNotFoundError(msg) return self._from_native_frame( self._native_frame.drop_duplicates(subset=subset, keep=mapped_keep) ) diff --git a/narwhals/functions.py b/narwhals/functions.py index 75cd9000e..ed167fb0d 100644 --- a/narwhals/functions.py +++ b/narwhals/functions.py @@ -1102,6 +1102,7 @@ def _scan_csv_impl( Implementation.MODIN, Implementation.CUDF, Implementation.DASK, + Implementation.DUCKDB, ): native_frame = native_namespace.read_csv(source, **kwargs) elif implementation is Implementation.PYARROW: @@ -1190,6 +1191,7 @@ def _read_parquet_impl( Implementation.PANDAS, Implementation.MODIN, Implementation.CUDF, + Implementation.DUCKDB, ): native_frame = native_namespace.read_parquet(source, **kwargs) elif implementation is Implementation.PYARROW: @@ -1273,6 +1275,7 @@ def _scan_parquet_impl( Implementation.MODIN, Implementation.CUDF, Implementation.DASK, + Implementation.DUCKDB, ): native_frame = native_namespace.read_parquet(source, **kwargs) elif implementation is Implementation.PYARROW: diff --git a/narwhals/translate.py b/narwhals/translate.py index 77c83b548..8d0805a26 100644 --- a/narwhals/translate.py +++ b/narwhals/translate.py @@ -698,13 +698,13 @@ def _from_native_impl( # noqa: PLR0915 # DuckDB elif is_duckdb_relation(native_object): - from narwhals._duckdb.dataframe import DuckDBInterchangeFrame + from narwhals._duckdb.dataframe import DuckDBLazyFrame if eager_only or series_only: # pragma: no cover if not pass_through: msg = ( "Cannot only use `series_only=True` or `eager_only=False` " - "with DuckDB Relation" + "with DuckDBPyRelation" ) else: return native_object @@ -712,11 +712,18 @@ def _from_native_impl( # noqa: PLR0915 import duckdb # ignore-banned-import backend_version = parse_version(duckdb.__version__) - return DataFrame( - DuckDBInterchangeFrame( - native_object, version=version, backend_version=backend_version + if version is Version.V1: + return DataFrame( + DuckDBLazyFrame( + native_object, backend_version=backend_version, version=version + ), + level="interchange", + ) + return LazyFrame( + DuckDBLazyFrame( + native_object, backend_version=backend_version, version=version ), - level="interchange", + level="full", ) # Ibis diff --git a/pyproject.toml b/pyproject.toml index daa21c3ee..6c33c09bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -116,6 +116,7 @@ lint.ignore = [ "E501", "FIX", "ISC001", + "PD003", "PD010", "PD901", # This is a auxiliary library so dataframe variables have no concrete business meaning "PLR0911", diff --git a/tests/conftest.py b/tests/conftest.py index 28fbc7610..dee762705 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,10 +13,7 @@ import pytest if TYPE_CHECKING: - from narwhals.typing import IntoDataFrame - from narwhals.typing import IntoFrame - -if TYPE_CHECKING: + import duckdb from pyspark.sql import SparkSession from narwhals.typing import IntoDataFrame @@ -109,6 +106,13 @@ def polars_lazy_constructor(obj: Any) -> pl.LazyFrame: return pl.LazyFrame(obj) +def duckdb_lazy_constructor(obj: Any) -> duckdb.DuckDBPyRelation: + import duckdb + + _df = pl.LazyFrame(obj) + return duckdb.table("_df") + + def dask_lazy_p1_constructor(obj: Any) -> IntoFrame: # pragma: no cover import dask.dataframe as dd @@ -168,6 +172,7 @@ def spark_session() -> Generator[SparkSession, None, None]: # pragma: no cover LAZY_CONSTRUCTORS: dict[str, Callable[[Any], IntoFrame]] = { "dask": dask_lazy_p2_constructor, "polars[lazy]": polars_lazy_constructor, + "duckdb": duckdb_lazy_constructor, } GPU_CONSTRUCTORS: dict[str, Callable[[Any], IntoFrame]] = {"cudf": cudf_constructor} @@ -207,4 +212,14 @@ def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: "constructor_eager", eager_constructors, ids=eager_constructors_ids ) elif "constructor" in metafunc.fixturenames: + if ( + any( + x in str(metafunc.module) + for x in ("list", "name", "unpivot", "from_dict", "from_numpy", "tail") + ) + and LAZY_CONSTRUCTORS["duckdb"] in constructors + ): + # TODO(unassigned): list and name namespaces still need implementing for duckdb + constructors.remove(LAZY_CONSTRUCTORS["duckdb"]) + constructors_ids.remove("duckdb") metafunc.parametrize("constructor", constructors, ids=constructors_ids) diff --git a/tests/expr_and_series/all_horizontal_test.py b/tests/expr_and_series/all_horizontal_test.py index 706c42baf..6eb98c3a3 100644 --- a/tests/expr_and_series/all_horizontal_test.py +++ b/tests/expr_and_series/all_horizontal_test.py @@ -57,6 +57,8 @@ def test_allh_nth( ) -> None: if "polars" in str(constructor) and POLARS_VERSION < (1, 0): request.applymarker(pytest.mark.xfail) + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) data = { "a": [False, False, True], "b": [False, True, True], diff --git a/tests/expr_and_series/arithmetic_test.py b/tests/expr_and_series/arithmetic_test.py index cd82a945e..aec586c62 100644 --- a/tests/expr_and_series/arithmetic_test.py +++ b/tests/expr_and_series/arithmetic_test.py @@ -38,6 +38,8 @@ def test_arithmetic_expr( constructor: Constructor, request: pytest.FixtureRequest, ) -> None: + if "duckdb" in str(constructor) and attr == "__floordiv__": + request.applymarker(pytest.mark.xfail) if attr == "__mod__" and any( x in str(constructor) for x in ["pandas_pyarrow", "modin_pyarrow"] ): @@ -244,7 +246,9 @@ def test_arithmetic_expr_left_literal( constructor: Constructor, request: pytest.FixtureRequest, ) -> None: - if "dask" in str(constructor) and DASK_VERSION < (2024, 10): + if ("duckdb" in str(constructor) and attr == "__floordiv__") or ( + "dask" in str(constructor) and DASK_VERSION < (2024, 10) + ): request.applymarker(pytest.mark.xfail) if attr == "__mod__" and any( x in str(constructor) for x in ["pandas_pyarrow", "modin_pyarrow"] diff --git a/tests/expr_and_series/cast_test.py b/tests/expr_and_series/cast_test.py index e956dd455..b6ce43573 100644 --- a/tests/expr_and_series/cast_test.py +++ b/tests/expr_and_series/cast_test.py @@ -13,6 +13,7 @@ from tests.utils import PANDAS_VERSION from tests.utils import PYARROW_VERSION from tests.utils import Constructor +from tests.utils import ConstructorEager from tests.utils import assert_equal_data from tests.utils import is_windows @@ -59,6 +60,8 @@ def test_cast( constructor: Constructor, request: pytest.FixtureRequest, ) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) if "pyarrow_table_constructor" in str(constructor) and PYARROW_VERSION <= ( 15, ): # pragma: no cover @@ -109,18 +112,18 @@ def test_cast( def test_cast_series( - constructor: Constructor, + constructor_eager: ConstructorEager, request: pytest.FixtureRequest, ) -> None: - if "pyarrow_table_constructor" in str(constructor) and PYARROW_VERSION <= ( + if "pyarrow_table_constructor" in str(constructor_eager) and PYARROW_VERSION <= ( 15, ): # pragma: no cover request.applymarker(pytest.mark.xfail) - if "modin_constructor" in str(constructor): + if "modin_constructor" in str(constructor_eager): # TODO(unassigned): in modin, we end up with `' None: def test_cast_raises_for_unknown_dtype( constructor: Constructor, request: pytest.FixtureRequest ) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) if "pyarrow_table" in str(constructor) and PYARROW_VERSION < (15,): # Unsupported cast from string to dictionary using function cast_dictionary request.applymarker(pytest.mark.xfail) @@ -196,6 +201,7 @@ def test_cast_datetime_tz_aware( ) -> None: if ( "dask" in str(constructor) + or "duckdb" in str(constructor) or "cudf" in str(constructor) # https://github.com/rapidsai/cudf/issues/16973 or ("pyarrow_table" in str(constructor) and is_windows()) ): @@ -222,7 +228,9 @@ def test_cast_datetime_tz_aware( def test_cast_struct(request: pytest.FixtureRequest, constructor: Constructor) -> None: - if any(backend in str(constructor) for backend in ("dask", "modin", "cudf")): + if any( + backend in str(constructor) for backend in ("dask", "modin", "cudf", "duckdb") + ): request.applymarker(pytest.mark.xfail) if "pandas" in str(constructor) and PANDAS_VERSION < (2, 2): diff --git a/tests/expr_and_series/concat_str_test.py b/tests/expr_and_series/concat_str_test.py index 26366d2f2..7c9f259ba 100644 --- a/tests/expr_and_series/concat_str_test.py +++ b/tests/expr_and_series/concat_str_test.py @@ -21,8 +21,14 @@ ], ) def test_concat_str( - constructor: Constructor, *, ignore_nulls: bool, expected: list[str] + constructor: Constructor, + *, + ignore_nulls: bool, + expected: list[str], + request: pytest.FixtureRequest, ) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = ( df.select( diff --git a/tests/expr_and_series/convert_time_zone_test.py b/tests/expr_and_series/convert_time_zone_test.py index aa4235549..6b3cf5b41 100644 --- a/tests/expr_and_series/convert_time_zone_test.py +++ b/tests/expr_and_series/convert_time_zone_test.py @@ -28,6 +28,7 @@ def test_convert_time_zone( or ("pandas_pyarrow" in str(constructor) and PANDAS_VERSION < (2, 1)) or ("modin_pyarrow" in str(constructor) and PANDAS_VERSION < (2, 1)) or ("cudf" in str(constructor)) + or ("duckdb" in str(constructor)) ): request.applymarker(pytest.mark.xfail) data = { @@ -84,6 +85,7 @@ def test_convert_time_zone_from_none( or ("modin_pyarrow" in str(constructor) and PANDAS_VERSION < (2, 1)) or ("pyarrow_table" in str(constructor) and PYARROW_VERSION < (12,)) or ("cudf" in str(constructor)) + or ("duckdb" in str(constructor)) ): request.applymarker(pytest.mark.xfail) if "polars" in str(constructor) and POLARS_VERSION < (0, 20, 7): diff --git a/tests/expr_and_series/cum_count_test.py b/tests/expr_and_series/cum_count_test.py index 6ddf6c991..1a2377f34 100644 --- a/tests/expr_and_series/cum_count_test.py +++ b/tests/expr_and_series/cum_count_test.py @@ -21,6 +21,8 @@ def test_cum_count_expr( ) -> None: if "dask" in str(constructor) and reverse: request.applymarker(pytest.mark.xfail) + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) name = "reverse_cum_count" if reverse else "cum_count" df = nw.from_native(constructor(data)) diff --git a/tests/expr_and_series/cum_max_test.py b/tests/expr_and_series/cum_max_test.py index 054537d34..22b7c73fa 100644 --- a/tests/expr_and_series/cum_max_test.py +++ b/tests/expr_and_series/cum_max_test.py @@ -23,6 +23,8 @@ def test_cum_max_expr( ) -> None: if "dask" in str(constructor) and reverse: request.applymarker(pytest.mark.xfail) + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) if PYARROW_VERSION < (13, 0, 0) and "pyarrow_table" in str(constructor): request.applymarker(pytest.mark.xfail) diff --git a/tests/expr_and_series/cum_min_test.py b/tests/expr_and_series/cum_min_test.py index bb92f5b9d..b34672219 100644 --- a/tests/expr_and_series/cum_min_test.py +++ b/tests/expr_and_series/cum_min_test.py @@ -23,6 +23,8 @@ def test_cum_min_expr( ) -> None: if "dask" in str(constructor) and reverse: request.applymarker(pytest.mark.xfail) + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) if PYARROW_VERSION < (13, 0, 0) and "pyarrow_table" in str(constructor): request.applymarker(pytest.mark.xfail) diff --git a/tests/expr_and_series/cum_prod_test.py b/tests/expr_and_series/cum_prod_test.py index 1d5816ff2..4dd5207dc 100644 --- a/tests/expr_and_series/cum_prod_test.py +++ b/tests/expr_and_series/cum_prod_test.py @@ -23,6 +23,8 @@ def test_cum_prod_expr( ) -> None: if "dask" in str(constructor) and reverse: request.applymarker(pytest.mark.xfail) + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) if PYARROW_VERSION < (13, 0, 0) and "pyarrow_table" in str(constructor): request.applymarker(pytest.mark.xfail) diff --git a/tests/expr_and_series/cum_sum_test.py b/tests/expr_and_series/cum_sum_test.py index 8df3396bc..5878222fb 100644 --- a/tests/expr_and_series/cum_sum_test.py +++ b/tests/expr_and_series/cum_sum_test.py @@ -18,6 +18,8 @@ def test_cum_sum_expr( request: pytest.FixtureRequest, constructor: Constructor, *, reverse: bool ) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) if "dask" in str(constructor) and reverse: request.applymarker(pytest.mark.xfail) diff --git a/tests/expr_and_series/diff_test.py b/tests/expr_and_series/diff_test.py index da433f7ad..f7730a2d4 100644 --- a/tests/expr_and_series/diff_test.py +++ b/tests/expr_and_series/diff_test.py @@ -22,6 +22,8 @@ def test_diff( if "pyarrow_table_constructor" in str(constructor) and PYARROW_VERSION < (13,): # pc.pairwisediff is available since pyarrow 13.0.0 request.applymarker(pytest.mark.xfail) + if any(x in str(constructor) for x in ("duckdb", "pyspark")): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = df.with_columns(c_diff=nw.col("c").diff()).filter(nw.col("i") > 0) expected = { diff --git a/tests/expr_and_series/dt/datetime_attributes_test.py b/tests/expr_and_series/dt/datetime_attributes_test.py index ad5f8dc3f..e1af276e4 100644 --- a/tests/expr_and_series/dt/datetime_attributes_test.py +++ b/tests/expr_and_series/dt/datetime_attributes_test.py @@ -49,6 +49,8 @@ def test_datetime_attributes( request.applymarker(pytest.mark.xfail) if attribute == "date" and "cudf" in str(constructor): request.applymarker(pytest.mark.xfail) + if "duckdb" in str(constructor) and attribute in ("date", "weekday", "ordinal_day"): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = df.select(getattr(nw.col("a").dt, attribute)()) @@ -118,6 +120,7 @@ def test_to_date(request: pytest.FixtureRequest, constructor: Constructor) -> No "pandas_nullable_constructor", "cudf", "modin_constructor", + "duckdb", ) ): request.applymarker(pytest.mark.xfail) diff --git a/tests/expr_and_series/dt/datetime_duration_test.py b/tests/expr_and_series/dt/datetime_duration_test.py index 09f227c79..bda3e4703 100644 --- a/tests/expr_and_series/dt/datetime_duration_test.py +++ b/tests/expr_and_series/dt/datetime_duration_test.py @@ -46,6 +46,8 @@ def test_duration_attributes( ) -> None: if PANDAS_VERSION < (2, 2) and "pandas_pyarrow" in str(constructor): request.applymarker(pytest.mark.xfail) + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) diff --git a/tests/expr_and_series/dt/timestamp_test.py b/tests/expr_and_series/dt/timestamp_test.py index e205d8179..b7e20519f 100644 --- a/tests/expr_and_series/dt/timestamp_test.py +++ b/tests/expr_and_series/dt/timestamp_test.py @@ -50,6 +50,8 @@ def test_timestamp_datetimes( time_unit: Literal["ns", "us", "ms"], expected: list[int | None], ) -> None: + if any(x in str(constructor) for x in ("duckdb", "pyspark")): + request.applymarker(pytest.mark.xfail) if original_time_unit == "s" and "polars" in str(constructor): request.applymarker(pytest.mark.xfail) if "pandas_pyarrow" in str(constructor) and PANDAS_VERSION < ( @@ -90,6 +92,8 @@ def test_timestamp_datetimes_tz_aware( time_unit: Literal["ns", "us", "ms"], expected: list[int | None], ) -> None: + if any(x in str(constructor) for x in ("duckdb", "pyspark")): + request.applymarker(pytest.mark.xfail) if ( (any(x in str(constructor) for x in ("pyarrow",)) and is_windows()) or ("pandas_pyarrow" in str(constructor) and PANDAS_VERSION < (2,)) @@ -136,6 +140,8 @@ def test_timestamp_dates( time_unit: Literal["ns", "us", "ms"], expected: list[int | None], ) -> None: + if any(x in str(constructor) for x in ("duckdb", "pyspark")): + request.applymarker(pytest.mark.xfail) if any( x in str(constructor) for x in ( @@ -161,6 +167,8 @@ def test_timestamp_dates( def test_timestamp_invalid_date( request: pytest.FixtureRequest, constructor: Constructor ) -> None: + if any(x in str(constructor) for x in ("duckdb", "pyspark")): + request.applymarker(pytest.mark.xfail) if "polars" in str(constructor): request.applymarker(pytest.mark.xfail) data_str = {"a": ["x", "y", None]} diff --git a/tests/expr_and_series/dt/to_string_test.py b/tests/expr_and_series/dt/to_string_test.py index 629b39806..6fa500024 100644 --- a/tests/expr_and_series/dt/to_string_test.py +++ b/tests/expr_and_series/dt/to_string_test.py @@ -59,7 +59,11 @@ def test_dt_to_string_series(constructor_eager: ConstructorEager, fmt: str) -> N ], ) @pytest.mark.skipif(is_windows(), reason="pyarrow breaking on windows") -def test_dt_to_string_expr(constructor: Constructor, fmt: str) -> None: +def test_dt_to_string_expr( + constructor: Constructor, fmt: str, request: pytest.FixtureRequest +) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) input_frame = nw.from_native(constructor(data)) expected_col = [datetime.strftime(d, fmt) for d in data["a"]] @@ -132,8 +136,13 @@ def test_dt_to_string_iso_local_datetime_series( ) @pytest.mark.skipif(is_windows(), reason="pyarrow breaking on windows") def test_dt_to_string_iso_local_datetime_expr( - constructor: Constructor, data: datetime, expected: str + constructor: Constructor, + data: datetime, + expected: str, + request: pytest.FixtureRequest, ) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) df = constructor({"a": [data]}) result = nw.from_native(df).with_columns( @@ -166,8 +175,13 @@ def test_dt_to_string_iso_local_date_series( ) @pytest.mark.skipif(is_windows(), reason="pyarrow breaking on windows") def test_dt_to_string_iso_local_date_expr( - constructor: Constructor, data: datetime, expected: str + constructor: Constructor, + data: datetime, + expected: str, + request: pytest.FixtureRequest, ) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) df = constructor({"a": [data]}) result = nw.from_native(df).with_columns( nw.col("a").dt.to_string("%Y-%m-%d").alias("b") diff --git a/tests/expr_and_series/fill_null_test.py b/tests/expr_and_series/fill_null_test.py index 57f767d4d..58ef5c890 100644 --- a/tests/expr_and_series/fill_null_test.py +++ b/tests/expr_and_series/fill_null_test.py @@ -47,7 +47,11 @@ def test_fill_null_exceptions(constructor: Constructor) -> None: df.with_columns(nw.col("a").fill_null(strategy="invalid")) # type: ignore # noqa: PGH003 -def test_fill_null_strategies_with_limit_as_none(constructor: Constructor) -> None: +def test_fill_null_strategies_with_limit_as_none( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) data_limits = { "a": [1, None, None, None, 5, 6, None, None, None, 10], "b": ["a", None, None, None, "b", "c", None, None, None, "d"], @@ -113,7 +117,11 @@ def test_fill_null_strategies_with_limit_as_none(constructor: Constructor) -> No assert_equal_data(result_backward, expected_backward) -def test_fill_null_limits(constructor: Constructor) -> None: +def test_fill_null_limits( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) context: Any = ( pytest.raises(NotImplementedError, match="The limit keyword is not supported") if "cudf" in str(constructor) diff --git a/tests/expr_and_series/is_duplicated_test.py b/tests/expr_and_series/is_duplicated_test.py index d4ce3461f..fe8b45bf1 100644 --- a/tests/expr_and_series/is_duplicated_test.py +++ b/tests/expr_and_series/is_duplicated_test.py @@ -1,12 +1,18 @@ from __future__ import annotations +import pytest + import narwhals.stable.v1 as nw from tests.utils import Constructor from tests.utils import ConstructorEager from tests.utils import assert_equal_data -def test_is_duplicated_expr(constructor: Constructor) -> None: +def test_is_duplicated_expr( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) data = {"a": [1, 1, 2], "b": [1, 2, 3], "index": [0, 1, 2]} df = nw.from_native(constructor(data)) result = df.select(nw.col("a", "b").is_duplicated(), "index").sort("index") @@ -14,7 +20,11 @@ def test_is_duplicated_expr(constructor: Constructor) -> None: assert_equal_data(result, expected) -def test_is_duplicated_w_nulls_expr(constructor: Constructor) -> None: +def test_is_duplicated_w_nulls_expr( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) data = {"a": [1, 1, None], "b": [1, None, None], "index": [0, 1, 2]} df = nw.from_native(constructor(data)) result = df.select(nw.col("a", "b").is_duplicated(), "index").sort("index") diff --git a/tests/expr_and_series/is_finite_test.py b/tests/expr_and_series/is_finite_test.py index 270ba7d52..7718ed1a7 100644 --- a/tests/expr_and_series/is_finite_test.py +++ b/tests/expr_and_series/is_finite_test.py @@ -11,7 +11,9 @@ @pytest.mark.filterwarnings("ignore:invalid value encountered in cast") -def test_is_finite_expr(constructor: Constructor) -> None: +def test_is_finite_expr(constructor: Constructor, request: pytest.FixtureRequest) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) if "polars" in str(constructor) or "pyarrow_table" in str(constructor): expected = {"a": [False, False, True, None]} elif ( diff --git a/tests/expr_and_series/is_first_distinct_test.py b/tests/expr_and_series/is_first_distinct_test.py index 7084fb3fb..786f2ade7 100644 --- a/tests/expr_and_series/is_first_distinct_test.py +++ b/tests/expr_and_series/is_first_distinct_test.py @@ -1,5 +1,7 @@ from __future__ import annotations +import pytest + import narwhals.stable.v1 as nw from tests.utils import Constructor from tests.utils import ConstructorEager @@ -11,7 +13,11 @@ } -def test_is_first_distinct_expr(constructor: Constructor) -> None: +def test_is_first_distinct_expr( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = df.select(nw.all().is_first_distinct()) expected = { diff --git a/tests/expr_and_series/is_last_distinct_test.py b/tests/expr_and_series/is_last_distinct_test.py index b91c171d3..c5d73c8d7 100644 --- a/tests/expr_and_series/is_last_distinct_test.py +++ b/tests/expr_and_series/is_last_distinct_test.py @@ -1,5 +1,7 @@ from __future__ import annotations +import pytest + import narwhals.stable.v1 as nw from tests.utils import Constructor from tests.utils import ConstructorEager @@ -11,7 +13,11 @@ } -def test_is_last_distinct_expr(constructor: Constructor) -> None: +def test_is_last_distinct_expr( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = df.select(nw.all().is_last_distinct()) expected = { diff --git a/tests/expr_and_series/is_nan_test.py b/tests/expr_and_series/is_nan_test.py index 806dc7535..7bae35a52 100644 --- a/tests/expr_and_series/is_nan_test.py +++ b/tests/expr_and_series/is_nan_test.py @@ -24,7 +24,9 @@ ] -def test_nan(constructor: Constructor) -> None: +def test_nan(constructor: Constructor, request: pytest.FixtureRequest) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) data_na = {"int": [0, 1, None]} df = nw.from_native(constructor(data_na)).with_columns( float=nw.col("int").cast(nw.Float64), float_na=nw.col("int") / nw.col("int") @@ -93,7 +95,9 @@ def test_nan_series(constructor_eager: ConstructorEager) -> None: assert_equal_data(result, expected) -def test_nan_non_float(constructor: Constructor) -> None: +def test_nan_non_float(constructor: Constructor, request: pytest.FixtureRequest) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) from polars.exceptions import InvalidOperationError as PlInvalidOperationError from pyarrow.lib import ArrowNotImplementedError diff --git a/tests/expr_and_series/is_unique_test.py b/tests/expr_and_series/is_unique_test.py index b44878886..3e9259c03 100644 --- a/tests/expr_and_series/is_unique_test.py +++ b/tests/expr_and_series/is_unique_test.py @@ -1,12 +1,16 @@ from __future__ import annotations +import pytest + import narwhals.stable.v1 as nw from tests.utils import Constructor from tests.utils import ConstructorEager from tests.utils import assert_equal_data -def test_is_unique_expr(constructor: Constructor) -> None: +def test_is_unique_expr(constructor: Constructor, request: pytest.FixtureRequest) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) data = { "a": [1, 1, 2], "b": [1, 2, 3], @@ -22,7 +26,11 @@ def test_is_unique_expr(constructor: Constructor) -> None: assert_equal_data(result, expected) -def test_is_unique_w_nulls_expr(constructor: Constructor) -> None: +def test_is_unique_w_nulls_expr( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) data = { "a": [None, 1, 2], "b": [None, 2, None], diff --git a/tests/expr_and_series/lit_test.py b/tests/expr_and_series/lit_test.py index 501bfc4bd..505d99bf8 100644 --- a/tests/expr_and_series/lit_test.py +++ b/tests/expr_and_series/lit_test.py @@ -87,6 +87,13 @@ def test_lit_operation( expected_result: list[int], request: pytest.FixtureRequest, ) -> None: + if "duckdb" in str(constructor) and col_name in ( + "left_scalar_with_agg", + "left_lit_with_agg", + "right_lit", + "right_lit_with_agg", + ): + request.applymarker(pytest.mark.xfail) if ( "dask" in str(constructor) and col_name in ("left_lit", "left_scalar") diff --git a/tests/expr_and_series/mean_horizontal_test.py b/tests/expr_and_series/mean_horizontal_test.py index 485bf1750..c1652c837 100644 --- a/tests/expr_and_series/mean_horizontal_test.py +++ b/tests/expr_and_series/mean_horizontal_test.py @@ -10,7 +10,11 @@ @pytest.mark.parametrize("col_expr", [nw.col("a"), "a"]) -def test_meanh(constructor: Constructor, col_expr: Any) -> None: +def test_meanh( + constructor: Constructor, col_expr: Any, request: pytest.FixtureRequest +) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) data = {"a": [1, 3, None, None], "b": [4, None, 6, None]} df = nw.from_native(constructor(data)) result = df.select(horizontal_mean=nw.mean_horizontal(col_expr, nw.col("b"))) @@ -18,7 +22,9 @@ def test_meanh(constructor: Constructor, col_expr: Any) -> None: assert_equal_data(result, expected) -def test_meanh_all(constructor: Constructor) -> None: +def test_meanh_all(constructor: Constructor, request: pytest.FixtureRequest) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) data = {"a": [2, 4, 6], "b": [10, 20, 30]} df = nw.from_native(constructor(data)) result = df.select(nw.mean_horizontal(nw.all())) diff --git a/tests/expr_and_series/median_test.py b/tests/expr_and_series/median_test.py index 7c50988dc..b0b6edcba 100644 --- a/tests/expr_and_series/median_test.py +++ b/tests/expr_and_series/median_test.py @@ -41,16 +41,17 @@ def test_median_series( @pytest.mark.parametrize("expr", [nw.col("s").median(), nw.median("s")]) def test_median_expr_raises_on_str( - constructor: Constructor, - expr: nw.Expr, + constructor: Constructor, expr: nw.Expr, request: pytest.FixtureRequest ) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) from polars.exceptions import InvalidOperationError as PlInvalidOperationError df = nw.from_native(constructor(data)) - if "polars_lazy" in str(constructor): + if isinstance(df, nw.LazyFrame): with pytest.raises( - PlInvalidOperationError, - match="`median` operation not supported for dtype `str`", + (InvalidOperationError, PlInvalidOperationError), + match="`median` operation not supported", ): df.select(expr).lazy().collect() else: diff --git a/tests/expr_and_series/n_unique_test.py b/tests/expr_and_series/n_unique_test.py index 90bffb04b..d8e4d9b77 100644 --- a/tests/expr_and_series/n_unique_test.py +++ b/tests/expr_and_series/n_unique_test.py @@ -1,5 +1,7 @@ from __future__ import annotations +import pytest + import narwhals.stable.v1 as nw from tests.utils import Constructor from tests.utils import ConstructorEager @@ -11,7 +13,9 @@ } -def test_n_unique(constructor: Constructor) -> None: +def test_n_unique(constructor: Constructor, request: pytest.FixtureRequest) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = df.select(nw.all().n_unique()) expected = {"a": [3], "b": [4]} diff --git a/tests/expr_and_series/name/to_uppercase_test.py b/tests/expr_and_series/name/to_uppercase_test.py index 785da4957..e6703212d 100644 --- a/tests/expr_and_series/name/to_uppercase_test.py +++ b/tests/expr_and_series/name/to_uppercase_test.py @@ -12,21 +12,31 @@ data = {"foo": [1, 2, 3], "BAR": [4, 5, 6]} -def test_to_uppercase(constructor: Constructor) -> None: +def test_to_uppercase(constructor: Constructor, request: pytest.FixtureRequest) -> None: + if any(x in str(constructor) for x in ("duckdb", "pyspark")): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = df.select((nw.col("foo", "BAR") * 2).name.to_uppercase()) expected = {k.upper(): [e * 2 for e in v] for k, v in data.items()} assert_equal_data(result, expected) -def test_to_uppercase_after_alias(constructor: Constructor) -> None: +def test_to_uppercase_after_alias( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + if any(x in str(constructor) for x in ("duckdb", "pyspark")): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = df.select((nw.col("foo")).alias("alias_for_foo").name.to_uppercase()) expected = {"FOO": data["foo"]} assert_equal_data(result, expected) -def test_to_uppercase_raise_anonymous(constructor: Constructor) -> None: +def test_to_uppercase_raise_anonymous( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + if any(x in str(constructor) for x in ("duckdb", "pyspark")): + request.applymarker(pytest.mark.xfail) df_raw = constructor(data) df = nw.from_native(df_raw) diff --git a/tests/expr_and_series/nth_test.py b/tests/expr_and_series/nth_test.py index 8179fb261..4dd453528 100644 --- a/tests/expr_and_series/nth_test.py +++ b/tests/expr_and_series/nth_test.py @@ -25,6 +25,8 @@ def test_nth( expected: dict[str, list[int]], request: pytest.FixtureRequest, ) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) if "polars" in str(constructor) and POLARS_VERSION < (1, 0, 0): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) diff --git a/tests/expr_and_series/null_count_test.py b/tests/expr_and_series/null_count_test.py index 0f2250713..d10258901 100644 --- a/tests/expr_and_series/null_count_test.py +++ b/tests/expr_and_series/null_count_test.py @@ -1,5 +1,7 @@ from __future__ import annotations +import pytest + import narwhals.stable.v1 as nw from tests.utils import Constructor from tests.utils import ConstructorEager @@ -11,7 +13,11 @@ } -def test_null_count_expr(constructor: Constructor) -> None: +def test_null_count_expr( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = df.select(nw.all().null_count()) expected = { diff --git a/tests/expr_and_series/over_test.py b/tests/expr_and_series/over_test.py index a67c7973b..f42bdca54 100644 --- a/tests/expr_and_series/over_test.py +++ b/tests/expr_and_series/over_test.py @@ -24,6 +24,8 @@ def test_over_single(request: pytest.FixtureRequest, constructor: Constructor) -> None: if "dask_lazy_p2" in str(constructor): request.applymarker(pytest.mark.xfail) + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) expected = { @@ -40,6 +42,8 @@ def test_over_single(request: pytest.FixtureRequest, constructor: Constructor) - def test_over_multiple(request: pytest.FixtureRequest, constructor: Constructor) -> None: if "dask_lazy_p2" in str(constructor): request.applymarker(pytest.mark.xfail) + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) expected = { @@ -56,6 +60,8 @@ def test_over_multiple(request: pytest.FixtureRequest, constructor: Constructor) def test_over_invalid(request: pytest.FixtureRequest, constructor: Constructor) -> None: if "polars" in str(constructor): request.applymarker(pytest.mark.xfail) + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) with pytest.raises(ValueError, match="Anonymous expressions"): @@ -67,6 +73,8 @@ def test_over_cumsum(request: pytest.FixtureRequest, constructor: Constructor) - request.applymarker(pytest.mark.xfail) if "pandas_pyarrow" in str(constructor) and PANDAS_VERSION < (2, 1): request.applymarker(pytest.mark.xfail) + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data_cum)) expected = { @@ -84,6 +92,8 @@ def test_over_cumsum(request: pytest.FixtureRequest, constructor: Constructor) - def test_over_cumcount(request: pytest.FixtureRequest, constructor: Constructor) -> None: if "pyarrow_table" in str(constructor) or "dask_lazy_p2" in str(constructor): request.applymarker(pytest.mark.xfail) + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data_cum)) expected = { @@ -101,10 +111,12 @@ def test_over_cumcount(request: pytest.FixtureRequest, constructor: Constructor) def test_over_cummax(request: pytest.FixtureRequest, constructor: Constructor) -> None: - if "pyarrow_table" in str(constructor) or "dask_lazy_p2" in str(constructor): + if any(x in str(constructor) for x in ("pyarrow_table", "dask_lazy_p2", "duckdb")): request.applymarker(pytest.mark.xfail) if "pandas_pyarrow" in str(constructor) and PANDAS_VERSION < (2, 1): request.applymarker(pytest.mark.xfail) + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data_cum)) expected = { "a": ["a", "a", "b", "b", "b"], @@ -120,9 +132,10 @@ def test_over_cummax(request: pytest.FixtureRequest, constructor: Constructor) - def test_over_cummin(request: pytest.FixtureRequest, constructor: Constructor) -> None: if "pyarrow_table" in str(constructor) or "dask_lazy_p2" in str(constructor): request.applymarker(pytest.mark.xfail) - if "pandas_pyarrow" in str(constructor) and PANDAS_VERSION < (2, 1): request.applymarker(pytest.mark.xfail) + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data_cum)) expected = { @@ -138,11 +151,12 @@ def test_over_cummin(request: pytest.FixtureRequest, constructor: Constructor) - def test_over_cumprod(request: pytest.FixtureRequest, constructor: Constructor) -> None: - if any(x in str(constructor) for x in ("pyarrow_table", "dask_lazy_p2")): + if any(x in str(constructor) for x in ("pyarrow_table", "dask_lazy_p2", "duckdb")): request.applymarker(pytest.mark.xfail) - if "pandas_pyarrow" in str(constructor) and PANDAS_VERSION < (2, 1): request.applymarker(pytest.mark.xfail) + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data_cum)) expected = { @@ -170,6 +184,8 @@ def test_over_shift(request: pytest.FixtureRequest, constructor: Constructor) -> constructor ) or "dask_lazy_p2_constructor" in str(constructor): request.applymarker(pytest.mark.xfail) + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) expected = { diff --git a/tests/expr_and_series/quantile_test.py b/tests/expr_and_series/quantile_test.py index ae707e739..d52fae16c 100644 --- a/tests/expr_and_series/quantile_test.py +++ b/tests/expr_and_series/quantile_test.py @@ -28,7 +28,10 @@ def test_quantile_expr( expected: dict[str, list[float]], request: pytest.FixtureRequest, ) -> None: - if "dask" in str(constructor) and interpolation != "linear": + if ( + any(x in str(constructor) for x in ("dask", "duckdb")) + and interpolation != "linear" + ): request.applymarker(pytest.mark.xfail) q = 0.3 diff --git a/tests/expr_and_series/reduction_test.py b/tests/expr_and_series/reduction_test.py index 3b579d9f3..4f2faa0ce 100644 --- a/tests/expr_and_series/reduction_test.py +++ b/tests/expr_and_series/reduction_test.py @@ -30,6 +30,9 @@ def test_scalar_reduction_select( constructor: Constructor, expr: list[Any], expected: dict[str, list[Any]] ) -> None: + if "duckdb" in str(constructor): + # First one passes, the others fail. + return data = {"a": [1, 2, 3], "b": [4, 5, 6]} df = nw.from_native(constructor(data)) result = df.select(*expr) @@ -54,15 +57,24 @@ def test_scalar_reduction_select( ids=range(5), ) def test_scalar_reduction_with_columns( - constructor: Constructor, expr: list[Any], expected: dict[str, list[Any]] + constructor: Constructor, + expr: list[Any], + expected: dict[str, list[Any]], + request: pytest.FixtureRequest, ) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) data = {"a": [1, 2, 3], "b": [4, 5, 6]} df = nw.from_native(constructor(data)) result = df.with_columns(*expr).select(*expected.keys()) assert_equal_data(result, expected) -def test_empty_scalar_reduction_select(constructor: Constructor) -> None: +def test_empty_scalar_reduction_select( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) data = { "str": [*"abcde"], "int": [0, 1, 2, 3, 4], @@ -91,7 +103,11 @@ def test_empty_scalar_reduction_select(constructor: Constructor) -> None: assert_equal_data(result, expected) -def test_empty_scalar_reduction_with_columns(constructor: Constructor) -> None: +def test_empty_scalar_reduction_with_columns( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) from itertools import chain data = { diff --git a/tests/expr_and_series/replace_strict_test.py b/tests/expr_and_series/replace_strict_test.py index b1449af24..07e349bc6 100644 --- a/tests/expr_and_series/replace_strict_test.py +++ b/tests/expr_and_series/replace_strict_test.py @@ -23,6 +23,8 @@ def test_replace_strict( ) -> None: if "dask" in str(constructor): request.applymarker(pytest.mark.xfail) + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor({"a": [1, 2, 3]})) result = df.select( nw.col("a").replace_strict( @@ -58,6 +60,8 @@ def test_replace_non_full( if "dask" in str(constructor): request.applymarker(pytest.mark.xfail) + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor({"a": [1, 2, 3]})) if isinstance(df, nw.LazyFrame): with pytest.raises((ValueError, PolarsError)): @@ -77,6 +81,8 @@ def test_replace_strict_mapping( ) -> None: if "dask" in str(constructor): request.applymarker(pytest.mark.xfail) + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor({"a": [1, 2, 3]})) result = df.select( diff --git a/tests/expr_and_series/replace_time_zone_test.py b/tests/expr_and_series/replace_time_zone_test.py index 94367d1e1..eed90feb1 100644 --- a/tests/expr_and_series/replace_time_zone_test.py +++ b/tests/expr_and_series/replace_time_zone_test.py @@ -26,6 +26,7 @@ def test_replace_time_zone( or ("modin_pyarrow" in str(constructor) and PANDAS_VERSION < (2,)) or ("pyarrow_table" in str(constructor) and PYARROW_VERSION < (12,)) or ("cudf" in str(constructor)) + or ("duckdb" in str(constructor)) ): request.applymarker(pytest.mark.xfail) data = { @@ -52,6 +53,8 @@ def test_replace_time_zone_none( or ("pandas_pyarrow" in str(constructor) and PANDAS_VERSION < (2,)) or ("modin_pyarrow" in str(constructor) and PANDAS_VERSION < (2,)) or ("pyarrow_table" in str(constructor) and PYARROW_VERSION < (12,)) + or ("cudf" in str(constructor)) + or ("duckdb" in str(constructor)) ): request.applymarker(pytest.mark.xfail) data = { diff --git a/tests/expr_and_series/shift_test.py b/tests/expr_and_series/shift_test.py index 379f40986..07f5d2b58 100644 --- a/tests/expr_and_series/shift_test.py +++ b/tests/expr_and_series/shift_test.py @@ -1,6 +1,7 @@ from __future__ import annotations import pyarrow as pa +import pytest import narwhals.stable.v1 as nw from tests.utils import Constructor @@ -15,7 +16,9 @@ } -def test_shift(constructor: Constructor) -> None: +def test_shift(constructor: Constructor, request: pytest.FixtureRequest) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = df.with_columns(nw.col("a", "b", "c").shift(2)).filter(nw.col("i") > 1) expected = { diff --git a/tests/expr_and_series/std_test.py b/tests/expr_and_series/std_test.py index b83100801..f2eabf4f2 100644 --- a/tests/expr_and_series/std_test.py +++ b/tests/expr_and_series/std_test.py @@ -1,5 +1,7 @@ from __future__ import annotations +from contextlib import nullcontext as does_not_raise + import pytest import narwhals.stable.v1 as nw @@ -24,10 +26,27 @@ def test_std(constructor: Constructor, input_data: dict[str, list[float | None]] result = df.select( nw.col("a").std(ddof=1).alias("a_ddof_1"), nw.col("a").std(ddof=0).alias("a_ddof_0"), - nw.col("b").std(ddof=2).alias("b_ddof_2"), nw.col("z").std(ddof=0).alias("z_ddof_0"), ) + expected_results = { + "a_ddof_1": [1.0], + "a_ddof_0": [0.816497], + "z_ddof_0": [0.816497], + } assert_equal_data(result, expected_results) + context = ( + pytest.raises(NotImplementedError) + if "duckdb" in str(constructor) + else does_not_raise() + ) + with context: + result = df.select( + nw.col("b").std(ddof=2).alias("b_ddof_2"), + ) + expected_results = { + "b_ddof_2": [1.632993], + } + assert_equal_data(result, expected_results) @pytest.mark.parametrize("input_data", [data, data_with_nulls]) diff --git a/tests/expr_and_series/str/len_chars_test.py b/tests/expr_and_series/str/len_chars_test.py index f9c63e01c..1a318801a 100644 --- a/tests/expr_and_series/str/len_chars_test.py +++ b/tests/expr_and_series/str/len_chars_test.py @@ -1,5 +1,7 @@ from __future__ import annotations +import pytest + import narwhals.stable.v1 as nw from tests.utils import Constructor from tests.utils import ConstructorEager @@ -8,7 +10,9 @@ data = {"a": ["foo", "foobar", "Café", "345", "東京"]} -def test_str_len_chars(constructor: Constructor) -> None: +def test_str_len_chars(constructor: Constructor, request: pytest.FixtureRequest) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = df.select(nw.col("a").str.len_chars()) expected = { diff --git a/tests/expr_and_series/str/replace_test.py b/tests/expr_and_series/str/replace_test.py index ffd8fce2e..7d57eeb7d 100644 --- a/tests/expr_and_series/str/replace_test.py +++ b/tests/expr_and_series/str/replace_test.py @@ -93,6 +93,7 @@ def test_str_replace_all_series( ) def test_str_replace_expr( constructor: Constructor, + request: pytest.FixtureRequest, data: dict[str, list[str]], pattern: str, value: str, @@ -100,8 +101,9 @@ def test_str_replace_expr( literal: bool, # noqa: FBT001 expected: dict[str, list[str]], ) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) - result_df = df.select( nw.col("a").str.replace(pattern=pattern, value=value, n=n, literal=literal) ) @@ -114,14 +116,16 @@ def test_str_replace_expr( ) def test_str_replace_all_expr( constructor: Constructor, + request: pytest.FixtureRequest, data: dict[str, list[str]], pattern: str, value: str, literal: bool, # noqa: FBT001 expected: dict[str, list[str]], ) -> None: + if "duckdb" in str(constructor) and literal is False: + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) - result = df.select( nw.col("a").str.replace_all(pattern=pattern, value=value, literal=literal) ) diff --git a/tests/expr_and_series/str/to_datetime_test.py b/tests/expr_and_series/str/to_datetime_test.py index 388ef23db..3f8df65a7 100644 --- a/tests/expr_and_series/str/to_datetime_test.py +++ b/tests/expr_and_series/str/to_datetime_test.py @@ -17,7 +17,9 @@ data = {"a": ["2020-01-01T12:34:56"]} -def test_to_datetime(constructor: Constructor) -> None: +def test_to_datetime(constructor: Constructor, request: pytest.FixtureRequest) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) if "cudf" in str(constructor): expected = "2020-01-01T12:34:56.000000000" else: @@ -78,6 +80,8 @@ def test_to_datetime_infer_fmt( request.applymarker(pytest.mark.xfail) if "cudf" in str(constructor): expected = expected_cudf + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) result = ( nw.from_native(constructor(data)) .lazy() @@ -126,7 +130,11 @@ def test_to_datetime_series_infer_fmt( assert str(result) == expected -def test_to_datetime_infer_fmt_from_date(constructor: Constructor) -> None: +def test_to_datetime_infer_fmt_from_date( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) data = {"z": ["2020-01-01", "2020-01-02", None]} expected = [datetime(2020, 1, 1), datetime(2020, 1, 2), None] result = ( diff --git a/tests/expr_and_series/str/to_uppercase_to_lowercase_test.py b/tests/expr_and_series/str/to_uppercase_to_lowercase_test.py index 1d0eb8834..1057b33de 100644 --- a/tests/expr_and_series/str/to_uppercase_to_lowercase_test.py +++ b/tests/expr_and_series/str/to_uppercase_to_lowercase_test.py @@ -39,6 +39,7 @@ def test_str_to_uppercase( "pandas_pyarrow_constructor", "pyarrow_table_constructor", "modin_pyarrow_constructor", + "duckdb_lazy_constructor", ) or ("dask" in str(constructor) and PYARROW_VERSION >= (12,)) ): @@ -80,6 +81,7 @@ def test_str_to_uppercase_series( "pandas_nullable_constructor", "polars_eager_constructor", "cudf_constructor", + "duckdb_lazy_constructor", "modin_constructor", ) ): diff --git a/tests/expr_and_series/sum_horizontal_test.py b/tests/expr_and_series/sum_horizontal_test.py index 21bd138c2..decb65c02 100644 --- a/tests/expr_and_series/sum_horizontal_test.py +++ b/tests/expr_and_series/sum_horizontal_test.py @@ -10,7 +10,11 @@ @pytest.mark.parametrize("col_expr", [nw.col("a"), "a"]) -def test_sumh(constructor: Constructor, col_expr: Any) -> None: +def test_sumh( + constructor: Constructor, col_expr: Any, request: pytest.FixtureRequest +) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} df = nw.from_native(constructor(data)) result = df.with_columns(horizontal_sum=nw.sum_horizontal(col_expr, nw.col("b"))) @@ -23,7 +27,9 @@ def test_sumh(constructor: Constructor, col_expr: Any) -> None: assert_equal_data(result, expected) -def test_sumh_nullable(constructor: Constructor) -> None: +def test_sumh_nullable(constructor: Constructor, request: pytest.FixtureRequest) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) data = {"a": [1, 8, 3], "b": [4, 5, None]} expected = {"hsum": [5, 13, 3]} @@ -32,7 +38,9 @@ def test_sumh_nullable(constructor: Constructor) -> None: assert_equal_data(result, expected) -def test_sumh_all(constructor: Constructor) -> None: +def test_sumh_all(constructor: Constructor, request: pytest.FixtureRequest) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) data = {"a": [1, 2, 3], "b": [10, 20, 30]} df = nw.from_native(constructor(data)) result = df.select(nw.sum_horizontal(nw.all())) diff --git a/tests/expr_and_series/unary_test.py b/tests/expr_and_series/unary_test.py index f2f9c33ff..9ee38a230 100644 --- a/tests/expr_and_series/unary_test.py +++ b/tests/expr_and_series/unary_test.py @@ -10,7 +10,9 @@ from tests.utils import assert_equal_data -def test_unary(constructor: Constructor) -> None: +def test_unary(constructor: Constructor, request: pytest.FixtureRequest) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) data = { "a": [1, 3, 2], "b": [4, 4, 6], @@ -77,7 +79,11 @@ def test_unary_series(constructor_eager: ConstructorEager) -> None: assert_equal_data(result, expected) -def test_unary_two_elements(constructor: Constructor) -> None: +def test_unary_two_elements( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) data = {"a": [1, 2], "b": [2, 10], "c": [2.0, None]} result = nw.from_native(constructor(data)).select( a_nunique=nw.col("a").n_unique(), @@ -120,7 +126,11 @@ def test_unary_two_elements_series(constructor_eager: ConstructorEager) -> None: assert_equal_data(result, expected) -def test_unary_one_element(constructor: Constructor) -> None: +def test_unary_one_element( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) data = {"a": [1], "b": [2], "c": [None]} # Dask runs into a divide by zero RuntimeWarning for 1 element skew. context = ( diff --git a/tests/expr_and_series/var_test.py b/tests/expr_and_series/var_test.py index bab97d383..2053dfe69 100644 --- a/tests/expr_and_series/var_test.py +++ b/tests/expr_and_series/var_test.py @@ -1,5 +1,7 @@ from __future__ import annotations +from contextlib import nullcontext as does_not_raise + import pytest import narwhals.stable.v1 as nw @@ -24,10 +26,27 @@ def test_var(constructor: Constructor, input_data: dict[str, list[float | None]] result = df.select( nw.col("a").var(ddof=1).alias("a_ddof_1"), nw.col("a").var(ddof=0).alias("a_ddof_0"), - nw.col("b").var(ddof=2).alias("b_ddof_2"), nw.col("z").var(ddof=0).alias("z_ddof_0"), ) + expected_results = { + "a_ddof_1": [1.0], + "a_ddof_0": [0.6666666666666666], + "z_ddof_0": [0.6666666666666666], + } assert_equal_data(result, expected_results) + context = ( + pytest.raises(NotImplementedError) + if "duckdb" in str(constructor) + else does_not_raise() + ) + with context: + result = df.select( + nw.col("b").var(ddof=2).alias("b_ddof_2"), + ) + expected_results = { + "b_ddof_2": [2.666666666666667], + } + assert_equal_data(result, expected_results) @pytest.mark.parametrize("input_data", [data, data_with_nulls]) diff --git a/tests/expr_and_series/when_test.py b/tests/expr_and_series/when_test.py index 5c60febb4..b59dda488 100644 --- a/tests/expr_and_series/when_test.py +++ b/tests/expr_and_series/when_test.py @@ -17,7 +17,9 @@ } -def test_when(constructor: Constructor) -> None: +def test_when(constructor: Constructor, request: pytest.FixtureRequest) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = df.select(nw.when(nw.col("a") == 1).then(value=3).alias("a_when")) expected = { @@ -26,7 +28,9 @@ def test_when(constructor: Constructor) -> None: assert_equal_data(result, expected) -def test_when_otherwise(constructor: Constructor) -> None: +def test_when_otherwise(constructor: Constructor, request: pytest.FixtureRequest) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = df.select(nw.when(nw.col("a") == 1).then(3).otherwise(6).alias("a_when")) expected = { @@ -35,7 +39,11 @@ def test_when_otherwise(constructor: Constructor) -> None: assert_equal_data(result, expected) -def test_multiple_conditions(constructor: Constructor) -> None: +def test_multiple_conditions( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = df.select( nw.when(nw.col("a") < 3, nw.col("c") < 5.0).then(3).alias("a_when") @@ -46,7 +54,11 @@ def test_multiple_conditions(constructor: Constructor) -> None: assert_equal_data(result, expected) -def test_no_arg_when_fail(constructor: Constructor) -> None: +def test_no_arg_when_fail( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) with pytest.raises((TypeError, ValueError)): df.select(nw.when().then(value=3).alias("a_when")) @@ -77,7 +89,11 @@ def test_value_series(constructor_eager: ConstructorEager) -> None: assert_equal_data(result, expected) -def test_value_expression(constructor: Constructor) -> None: +def test_value_expression( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = df.select(nw.when(nw.col("a") == 1).then(nw.col("a") + 9).alias("a_when")) expected = { @@ -110,7 +126,11 @@ def test_otherwise_series(constructor_eager: ConstructorEager) -> None: assert_equal_data(result, expected) -def test_otherwise_expression(constructor: Constructor) -> None: +def test_otherwise_expression( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = df.select( nw.when(nw.col("a") == 1).then(-1).otherwise(nw.col("a") + 7).alias("a_when") @@ -121,14 +141,22 @@ def test_otherwise_expression(constructor: Constructor) -> None: assert_equal_data(result, expected) -def test_when_then_otherwise_into_expr(constructor: Constructor) -> None: +def test_when_then_otherwise_into_expr( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = df.select(nw.when(nw.col("a") > 1).then("c").otherwise("e")) expected = {"c": [7, 5, 6]} assert_equal_data(result, expected) -def test_when_then_otherwise_lit_str(constructor: Constructor) -> None: +def test_when_then_otherwise_lit_str( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = df.select(nw.when(nw.col("a") > 1).then(nw.col("b")).otherwise(nw.lit("z"))) expected = {"b": ["z", "b", "c"]} diff --git a/tests/frame/add_test.py b/tests/frame/add_test.py index 27a332ed0..e04561895 100644 --- a/tests/frame/add_test.py +++ b/tests/frame/add_test.py @@ -1,11 +1,15 @@ from __future__ import annotations +import pytest + import narwhals.stable.v1 as nw from tests.utils import Constructor from tests.utils import assert_equal_data -def test_add(constructor: Constructor) -> None: +def test_add(constructor: Constructor, request: pytest.FixtureRequest) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} df = nw.from_native(constructor(data)) result = df.with_columns( diff --git a/tests/frame/clone_test.py b/tests/frame/clone_test.py index 1a02910c8..e142ed0a7 100644 --- a/tests/frame/clone_test.py +++ b/tests/frame/clone_test.py @@ -10,6 +10,8 @@ def test_clone(request: pytest.FixtureRequest, constructor: Constructor) -> None: if "dask" in str(constructor): request.applymarker(pytest.mark.xfail) + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) if "pyarrow_table" in str(constructor): request.applymarker(pytest.mark.xfail) diff --git a/tests/frame/concat_test.py b/tests/frame/concat_test.py index 26bbd2e62..4d5f3ebc9 100644 --- a/tests/frame/concat_test.py +++ b/tests/frame/concat_test.py @@ -7,7 +7,11 @@ from tests.utils import assert_equal_data -def test_concat_horizontal(constructor: Constructor) -> None: +def test_concat_horizontal( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} df_left = nw.from_native(constructor(data)).lazy() @@ -56,7 +60,11 @@ def test_concat_vertical(constructor: Constructor) -> None: nw.concat([df_left, df_left.select("d")], how="vertical").collect() -def test_concat_diagonal(constructor: Constructor) -> None: +def test_concat_diagonal( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) data_1 = {"a": [1, 3], "b": [4, 6]} data_2 = {"a": [100, 200], "z": ["x", "y"]} expected = { diff --git a/tests/frame/drop_nulls_test.py b/tests/frame/drop_nulls_test.py index bb55439eb..368ad6ba0 100644 --- a/tests/frame/drop_nulls_test.py +++ b/tests/frame/drop_nulls_test.py @@ -12,7 +12,9 @@ } -def test_drop_nulls(constructor: Constructor) -> None: +def test_drop_nulls(constructor: Constructor, request: pytest.FixtureRequest) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) result = nw.from_native(constructor(data)).drop_nulls() expected = { "a": [2.0, 4.0], @@ -30,7 +32,12 @@ def test_drop_nulls(constructor: Constructor) -> None: ], ) def test_drop_nulls_subset( - constructor: Constructor, subset: str | list[str], expected: dict[str, float] + constructor: Constructor, + subset: str | list[str], + expected: dict[str, float], + request: pytest.FixtureRequest, ) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) result = nw.from_native(constructor(data)).drop_nulls(subset=subset) assert_equal_data(result, expected) diff --git a/tests/frame/explode_test.py b/tests/frame/explode_test.py index 631da0255..b79215a18 100644 --- a/tests/frame/explode_test.py +++ b/tests/frame/explode_test.py @@ -40,7 +40,7 @@ def test_explode_single_col( ) -> None: if any( backend in str(constructor) - for backend in ("dask", "modin", "cudf", "pyarrow_table") + for backend in ("dask", "modin", "cudf", "pyarrow_table", "duckdb") ): request.applymarker(pytest.mark.xfail) @@ -89,7 +89,7 @@ def test_explode_multiple_cols( ) -> None: if any( backend in str(constructor) - for backend in ("dask", "modin", "cudf", "pyarrow_table") + for backend in ("dask", "modin", "cudf", "pyarrow_table", "duckdb") ): request.applymarker(pytest.mark.xfail) @@ -110,7 +110,7 @@ def test_explode_shape_error( ) -> None: if any( backend in str(constructor) - for backend in ("dask", "modin", "cudf", "pyarrow_table") + for backend in ("dask", "modin", "cudf", "pyarrow_table", "duckdb") ): request.applymarker(pytest.mark.xfail) @@ -133,7 +133,7 @@ def test_explode_shape_error( def test_explode_invalid_operation_error( request: pytest.FixtureRequest, constructor: Constructor ) -> None: - if "dask" in str(constructor) or "pyarrow_table" in str(constructor): + if any(x in str(constructor) for x in ("pyarrow_table", "dask", "duckdb")): request.applymarker(pytest.mark.xfail) if "polars" in str(constructor) and POLARS_VERSION < (0, 20, 6): diff --git a/tests/frame/filter_test.py b/tests/frame/filter_test.py index b55ab7767..759d175ca 100644 --- a/tests/frame/filter_test.py +++ b/tests/frame/filter_test.py @@ -17,7 +17,11 @@ def test_filter(constructor: Constructor) -> None: assert_equal_data(result, expected) -def test_filter_with_boolean_list(constructor: Constructor) -> None: +def test_filter_with_boolean_list( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} df = nw.from_native(constructor(data)) context = ( diff --git a/tests/frame/gather_every_test.py b/tests/frame/gather_every_test.py index 671737ad1..40e9291de 100644 --- a/tests/frame/gather_every_test.py +++ b/tests/frame/gather_every_test.py @@ -11,7 +11,11 @@ @pytest.mark.parametrize("n", [1, 2, 3]) @pytest.mark.parametrize("offset", [1, 2, 3]) -def test_gather_every(constructor: Constructor, n: int, offset: int) -> None: +def test_gather_every( + constructor: Constructor, n: int, offset: int, request: pytest.FixtureRequest +) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = df.gather_every(n=n, offset=offset) expected = {"a": data["a"][offset::n]} diff --git a/tests/frame/join_test.py b/tests/frame/join_test.py index faeac5b2f..4aa68e571 100644 --- a/tests/frame/join_test.py +++ b/tests/frame/join_test.py @@ -74,7 +74,9 @@ def test_inner_join_single_key(constructor: Constructor) -> None: assert_equal_data(result_on, expected) -def test_cross_join(constructor: Constructor) -> None: +def test_cross_join(constructor: Constructor, request: pytest.FixtureRequest) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) data = {"antananarivo": [1, 3, 2]} df = nw.from_native(constructor(data)) result = df.join(df, how="cross").sort("antananarivo", "antananarivo_right") # type: ignore[arg-type] @@ -112,7 +114,11 @@ def test_suffix(constructor: Constructor, how: str, suffix: str) -> None: @pytest.mark.parametrize("suffix", ["_right", "_custom_suffix"]) -def test_cross_join_suffix(constructor: Constructor, suffix: str) -> None: +def test_cross_join_suffix( + constructor: Constructor, suffix: str, request: pytest.FixtureRequest +) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) data = {"antananarivo": [1, 3, 2]} df = nw.from_native(constructor(data)) result = df.join(df, how="cross", suffix=suffix).sort( # type: ignore[arg-type] @@ -159,7 +165,10 @@ def test_anti_join( join_key: list[str], filter_expr: nw.Expr, expected: dict[str, list[Any]], + request: pytest.FixtureRequest, ) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) data = {"antananarivo": [1, 3, 2], "bob": [4, 4, 6], "zorro": [7.0, 8, 9]} df = nw.from_native(constructor(data)) other = df.filter(filter_expr) @@ -197,7 +206,10 @@ def test_semi_join( join_key: list[str], filter_expr: nw.Expr, expected: dict[str, list[Any]], + request: pytest.FixtureRequest, ) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) data = {"antananarivo": [1, 3, 2], "bob": [4, 4, 6], "zorro": [7.0, 8, 9]} df = nw.from_native(constructor(data)) other = df.filter(filter_expr) @@ -355,7 +367,7 @@ def test_joinasof_numeric( constructor: Constructor, request: pytest.FixtureRequest, ) -> None: - if "pyarrow_table" in str(constructor) or "cudf" in str(constructor): + if any(x in str(constructor) for x in ("pyarrow_table", "cudf", "duckdb")): request.applymarker(pytest.mark.xfail) if PANDAS_VERSION < (2, 1) and ( ("pandas_pyarrow" in str(constructor)) or ("pandas_nullable" in str(constructor)) @@ -414,7 +426,7 @@ def test_joinasof_time( constructor: Constructor, request: pytest.FixtureRequest, ) -> None: - if "pyarrow_table" in str(constructor) or "cudf" in str(constructor): + if any(x in str(constructor) for x in ("pyarrow_table", "cudf", "duckdb")): request.applymarker(pytest.mark.xfail) if PANDAS_VERSION < (2, 1) and ("pandas_pyarrow" in str(constructor)): request.applymarker(pytest.mark.xfail) @@ -495,7 +507,7 @@ def test_joinasof_by( constructor: Constructor, request: pytest.FixtureRequest, ) -> None: - if "pyarrow_table" in str(constructor) or "cudf" in str(constructor): + if any(x in str(constructor) for x in ("pyarrow_table", "cudf", "duckdb")): request.applymarker(pytest.mark.xfail) if PANDAS_VERSION < (2, 1) and ( ("pandas_pyarrow" in str(constructor)) or ("pandas_nullable" in str(constructor)) diff --git a/tests/frame/select_test.py b/tests/frame/select_test.py index d85697249..9d601e468 100644 --- a/tests/frame/select_test.py +++ b/tests/frame/select_test.py @@ -27,7 +27,9 @@ def test_select(constructor: Constructor) -> None: assert_equal_data(result, expected) -def test_empty_select(constructor: Constructor) -> None: +def test_empty_select(constructor: Constructor, request: pytest.FixtureRequest) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) result = nw.from_native(constructor({"a": [1, 2, 3]})).lazy().select() assert result.collect().shape == (0, 0) @@ -75,7 +77,11 @@ def test_comparison_with_list_error_message() -> None: nw.from_native(pd.Series([[1, 2, 3]]), series_only=True) == [1, 2, 3] # noqa: B015 -def test_missing_columns(constructor: Constructor) -> None: +def test_missing_columns( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} df = nw.from_native(constructor(data)) selected_columns = ["a", "e", "f"] @@ -120,6 +126,8 @@ def test_left_to_right_broadcasting( ) -> None: if "dask" in str(constructor) and DASK_VERSION < (2024, 10): request.applymarker(pytest.mark.xfail) + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor({"a": [1, 1, 2], "b": [4, 5, 6]})) result = df.select(nw.col("a") + nw.col("b").sum()) expected = {"a": [16, 16, 17]} diff --git a/tests/frame/unique_test.py b/tests/frame/unique_test.py index 96d5a8c2d..ca34d29b4 100644 --- a/tests/frame/unique_test.py +++ b/tests/frame/unique_test.py @@ -5,7 +5,10 @@ import pytest -import narwhals.stable.v1 as nw +# We use nw instead of nw.stable.v1 to ensure that DuckDBPyRelation +# becomes LazyFrame instead of DataFrame +import narwhals as nw +from narwhals.exceptions import ColumnNotFoundError from tests.utils import Constructor from tests.utils import assert_equal_data @@ -31,7 +34,10 @@ def test_unique( ) -> None: df_raw = constructor(data) df = nw.from_native(df_raw) - if isinstance(df, nw.LazyFrame) and keep in {"first", "last"}: + if isinstance(df, nw.LazyFrame) and keep in { + "first", + "last", + }: context: Any = pytest.raises(ValueError, match="row order") elif keep == "foo": context = pytest.raises(ValueError, match=": foo") @@ -43,6 +49,13 @@ def test_unique( assert_equal_data(result, expected) +def test_unique_invalid_subset(constructor: Constructor) -> None: + df_raw = constructor(data) + df = nw.from_native(df_raw) + with pytest.raises(ColumnNotFoundError): + df.lazy().unique(["fdssfad"]).collect() + + @pytest.mark.filterwarnings("ignore:.*backwards-compatibility:UserWarning") def test_unique_none(constructor: Constructor) -> None: df_raw = constructor(data) diff --git a/tests/frame/unpivot_test.py b/tests/frame/unpivot_test.py index ad7eefe5b..2867720a7 100644 --- a/tests/frame/unpivot_test.py +++ b/tests/frame/unpivot_test.py @@ -37,9 +37,7 @@ [("b", expected_b_only), (["b", "c"], expected_b_c), (None, expected_b_c)], ) def test_unpivot_on( - constructor: Constructor, - on: str | list[str] | None, - expected: dict[str, list[float]], + constructor: Constructor, on: str | list[str] | None, expected: dict[str, list[float]] ) -> None: df = nw.from_native(constructor(data)) result = df.unpivot(on=on, index=["a"]).sort("variable", "a") diff --git a/tests/frame/with_columns_test.py b/tests/frame/with_columns_test.py index c05a41646..335c53896 100644 --- a/tests/frame/with_columns_test.py +++ b/tests/frame/with_columns_test.py @@ -52,6 +52,8 @@ def test_with_columns_dtypes_single_row( ) -> None: if "pyarrow_table" in str(constructor) and PYARROW_VERSION < (15,): request.applymarker(pytest.mark.xfail) + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) data = {"a": ["foo"]} df = nw.from_native(constructor(data)).with_columns(nw.col("a").cast(nw.Categorical)) result = df.with_columns(nw.col("a")) diff --git a/tests/frame/with_row_index_test.py b/tests/frame/with_row_index_test.py index e19d3c994..bc514fa70 100644 --- a/tests/frame/with_row_index_test.py +++ b/tests/frame/with_row_index_test.py @@ -1,5 +1,7 @@ from __future__ import annotations +import pytest + import narwhals.stable.v1 as nw from tests.utils import Constructor from tests.utils import assert_equal_data @@ -10,7 +12,9 @@ } -def test_with_row_index(constructor: Constructor) -> None: +def test_with_row_index(constructor: Constructor, request: pytest.FixtureRequest) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) result = nw.from_native(constructor(data)).with_row_index() expected = {"index": [0, 1], "a": ["foo", "bars"], "ab": ["foo", "bars"]} assert_equal_data(result, expected) diff --git a/tests/group_by_test.py b/tests/group_by_test.py index 22c3b6f19..0dd6d8a10 100644 --- a/tests/group_by_test.py +++ b/tests/group_by_test.py @@ -115,6 +115,8 @@ def test_group_by_depth_1_agg( expected: dict[str, list[int | float]], request: pytest.FixtureRequest, ) -> None: + if "duckdb" in str(constructor) and attr == "n_unique": + request.applymarker(pytest.mark.xfail) if "pandas_pyarrow" in str(constructor) and attr == "var" and PANDAS_VERSION < (2, 1): # Known issue with variance calculation in pandas 2.0.x with pyarrow backend in groupby operations" request.applymarker(pytest.mark.xfail) @@ -134,10 +136,10 @@ def test_group_by_depth_1_agg( ], ) def test_group_by_depth_1_std_var( - constructor: Constructor, - attr: str, - ddof: int, + constructor: Constructor, attr: str, ddof: int, request: pytest.FixtureRequest ) -> None: + if "duckdb" in str(constructor) and ddof == 2: + request.applymarker(pytest.mark.xfail) data = {"a": [1, 1, 1, 2, 2, 2], "b": [4, 5, 6, 0, 5, 5]} _pow = 0.5 if attr == "std" else 1 expected = { @@ -164,7 +166,11 @@ def test_group_by_median(constructor: Constructor) -> None: assert_equal_data(result, expected) -def test_group_by_n_unique_w_missing(constructor: Constructor) -> None: +def test_group_by_n_unique_w_missing( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) data = {"a": [1, 1, 2], "b": [4, None, 5], "c": [None, None, 7], "d": [1, 1, 3]} result = ( nw.from_native(constructor(data)) @@ -288,8 +294,10 @@ def test_key_with_nulls( def test_key_with_nulls_ignored( - constructor: Constructor, + constructor: Constructor, request: pytest.FixtureRequest ) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) data = {"b": [4, 5, None], "a": [1, 2, 3]} result = ( nw.from_native(constructor(data)) @@ -341,6 +349,8 @@ def test_group_by_categorical( constructor: Constructor, request: pytest.FixtureRequest, ) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) if "pyarrow_table" in str(constructor) and PYARROW_VERSION < ( 15, 0, @@ -366,6 +376,8 @@ def test_group_by_categorical( def test_group_by_shift_raises( constructor: Constructor, request: pytest.FixtureRequest ) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) if "polars" in str(constructor): # Polars supports all kinds of crazy group-by aggregations, so # we don't check that it errors here. @@ -406,6 +418,8 @@ def test_all_kind_of_aggs( # and modin lol https://github.com/modin-project/modin/issues/7414 # and cudf https://github.com/rapidsai/cudf/issues/17649 request.applymarker(pytest.mark.xfail) + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) if "pandas" in str(constructor) and PANDAS_VERSION < (1, 4): # Bug in old pandas, can't do DataFrameGroupBy[['b', 'b']] request.applymarker(pytest.mark.xfail) diff --git a/tests/selectors_test.py b/tests/selectors_test.py index 86bdbac53..103ea666d 100644 --- a/tests/selectors_test.py +++ b/tests/selectors_test.py @@ -23,28 +23,36 @@ } -def test_selectors(constructor: Constructor) -> None: +def test_selectors(constructor: Constructor, request: pytest.FixtureRequest) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = df.select(by_dtype([nw.Int64, nw.Float64]) + 1) expected = {"a": [2, 2, 3], "c": [5.1, 6.0, 7.0]} assert_equal_data(result, expected) -def test_numeric(constructor: Constructor) -> None: +def test_numeric(constructor: Constructor, request: pytest.FixtureRequest) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = df.select(numeric() + 1) expected = {"a": [2, 2, 3], "c": [5.1, 6.0, 7.0]} assert_equal_data(result, expected) -def test_boolean(constructor: Constructor) -> None: +def test_boolean(constructor: Constructor, request: pytest.FixtureRequest) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = df.select(boolean()) expected = {"d": [True, False, True]} assert_equal_data(result, expected) -def test_string(constructor: Constructor) -> None: +def test_string(constructor: Constructor, request: pytest.FixtureRequest) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = df.select(string()) expected = {"b": ["a", "b", "c"]} @@ -59,6 +67,8 @@ def test_categorical( 15, ): # pragma: no cover request.applymarker(pytest.mark.xfail) + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) expected = {"b": ["a", "b", "c"]} df = nw.from_native(constructor(data)).with_columns(nw.col("b").cast(nw.Categorical)) @@ -81,15 +91,24 @@ def test_categorical( ], ) def test_set_ops( - constructor: Constructor, selector: nw.selectors.Selector, expected: list[str] + constructor: Constructor, + selector: nw.selectors.Selector, + expected: list[str], + request: pytest.FixtureRequest, ) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = df.select(selector).collect_schema().names() assert sorted(result) == expected @pytest.mark.parametrize("invalid_constructor", [pd.DataFrame, pa.table]) -def test_set_ops_invalid(invalid_constructor: Constructor) -> None: +def test_set_ops_invalid( + invalid_constructor: Constructor, request: pytest.FixtureRequest +) -> None: + if "duckdb" in str(invalid_constructor): + request.applymarker(pytest.mark.xfail) df = nw.from_native(invalid_constructor(data)) with pytest.raises((NotImplementedError, ValueError)): df.select(1 - numeric()) diff --git a/tests/stable_api_test.py b/tests/stable_api_test.py index fd08f575c..c3d028563 100644 --- a/tests/stable_api_test.py +++ b/tests/stable_api_test.py @@ -13,7 +13,11 @@ from tests.utils import assert_equal_data -def test_renamed_taxicab_norm(constructor: Constructor) -> None: +def test_renamed_taxicab_norm( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + if "duckdb" in str(constructor): + request.applymarker(pytest.mark.xfail) # Suppose we need to rename `_l1_norm` to `_taxicab_norm`. # We need `narwhals.stable.v1` to stay stable. So, we # make the change in `narwhals`, and then add the new method diff --git a/tests/utils.py b/tests/utils.py index 34f1bfa1e..005b4eee2 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -11,6 +11,7 @@ import pandas as pd +from narwhals.translate import from_native from narwhals.typing import IntoDataFrame from narwhals.typing import IntoFrame from narwhals.utils import Implementation @@ -72,7 +73,12 @@ def assert_equal_data(result: Any, expected: dict[str, Any]) -> None: hasattr(result, "_compliant_frame") and result.implementation is Implementation.PYSPARK ) - + is_duckdb = ( + hasattr(result, "_compliant_frame") + and result._compliant_frame._implementation is Implementation.DUCKDB + ) + if is_duckdb: + result = from_native(result.to_native().arrow()) if hasattr(result, "collect"): if result.implementation is Implementation.POLARS and os.environ.get( "NARWHALS_POLARS_GPU", False diff --git a/tpch/execute.py b/tpch/execute.py index fb5982c10..e19b51dfb 100644 --- a/tpch/execute.py +++ b/tpch/execute.py @@ -13,6 +13,7 @@ pd.options.mode.copy_on_write = True pd.options.future.infer_string = True +pl.Config.set_fmt_float("full") DATA_DIR = Path("data") LINEITEM_PATH = DATA_DIR / "lineitem.parquet" @@ -92,7 +93,7 @@ def execute_query(query_id: str) -> None: print(f"\nRunning {query_id} with {backend=}") # noqa: T201 result = query_module.query( *( - nw.scan_parquet(path, native_namespace=native_namespace, **kwargs) + nw.scan_parquet(str(path), native_namespace=native_namespace, **kwargs) for path in data_paths ) ) diff --git a/utils/import_check.py b/utils/import_check.py index eee35dfc4..bac54aff7 100644 --- a/utils/import_check.py +++ b/utils/import_check.py @@ -23,6 +23,7 @@ "_arrow": {"pyarrow", "pyarrow.compute", "pyarrow.parquet"}, "_dask": {"dask.dataframe", "pandas", "dask_expr"}, "_polars": {"polars"}, + "_duckdb": {"duckdb"}, } @@ -63,6 +64,7 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None: # noqa: N802 if ( node.module in BANNED_IMPORTS and "# ignore-banned-import" not in self.lines[node.lineno - 1] + and node.module not in self.allowed_imports ): print( # noqa: T201 f"{self.file_name}:{node.lineno}:{node.col_offset}: found {node.module} import"