From 8ca9422f942a60d119aae733a0ce7e9cb714f8c2 Mon Sep 17 00:00:00 2001 From: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com> Date: Sun, 2 Feb 2025 14:29:31 +0100 Subject: [PATCH] feat: `LazyFrame.collect` with backend and **kwargs (#1734) --- narwhals/_arrow/dataframe.py | 46 +++++- narwhals/_dask/dataframe.py | 53 +++++-- narwhals/_duckdb/dataframe.py | 50 ++++-- narwhals/_pandas_like/dataframe.py | 55 ++++++- narwhals/_polars/dataframe.py | 49 +++++- narwhals/_spark_like/dataframe.py | 75 +++++++-- narwhals/_spark_like/expr_str.py | 19 ++- narwhals/dataframe.py | 145 +++++++++++++++--- narwhals/stable/v1/__init__.py | 34 +++- narwhals/utils.py | 45 ++++++ pyproject.toml | 3 +- tests/expr_and_series/str/to_datetime_test.py | 30 +++- tests/frame/collect_test.py | 105 +++++++++++++ tests/utils.py | 12 +- 14 files changed, 629 insertions(+), 92 deletions(-) create mode 100644 tests/frame/collect_test.py diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index eb72fa3f0..7b0cac0a2 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -24,6 +24,7 @@ from narwhals.utils import generate_temporary_column_name from narwhals.utils import is_sequence_but_not_str from narwhals.utils import parse_columns_to_drop +from narwhals.utils import parse_version from narwhals.utils import scale_bytes from narwhals.utils import validate_backend_version @@ -559,12 +560,45 @@ def lazy(self: Self, *, backend: Implementation | None = None) -> CompliantLazyF ) raise AssertionError # pragma: no cover - def collect(self: Self) -> ArrowDataFrame: - return ArrowDataFrame( - self._native_frame, - backend_version=self._backend_version, - version=self._version, - ) + def collect( + self: Self, + backend: Implementation | None, + **kwargs: Any, + ) -> CompliantDataFrame: + if backend is Implementation.PYARROW or backend is None: + from narwhals._arrow.dataframe import ArrowDataFrame + + return ArrowDataFrame( + native_dataframe=self._native_frame, + backend_version=self._backend_version, + version=self._version, + ) + + if backend is Implementation.PANDAS: + import pandas as pd # ignore-banned-import + + from narwhals._pandas_like.dataframe import PandasLikeDataFrame + + return PandasLikeDataFrame( + native_dataframe=self._native_frame.to_pandas(), + implementation=Implementation.PANDAS, + backend_version=parse_version(pd.__version__), + version=self._version, + ) + + if backend is Implementation.POLARS: + import polars as pl # ignore-banned-import + + from narwhals._polars.dataframe import PolarsDataFrame + + return PolarsDataFrame( + df=pl.from_arrow(self._native_frame), # type: ignore[arg-type] + backend_version=parse_version(pl.__version__), + version=self._version, + ) + + msg = f"Unsupported `backend` value: {backend}" # pragma: no cover + raise AssertionError(msg) # pragma: no cover def clone(self: Self) -> Self: msg = "clone is not yet supported on PyArrow tables" diff --git a/narwhals/_dask/dataframe.py b/narwhals/_dask/dataframe.py index 3213a0eaa..66206e766 100644 --- a/narwhals/_dask/dataframe.py +++ b/narwhals/_dask/dataframe.py @@ -13,6 +13,7 @@ from narwhals._dask.utils import parse_exprs_and_named_exprs from narwhals._pandas_like.utils import native_to_narwhals_dtype from narwhals._pandas_like.utils import select_columns_by_name +from narwhals.typing import CompliantDataFrame from narwhals.typing import CompliantLazyFrame from narwhals.utils import Implementation from narwhals.utils import check_column_exists @@ -29,7 +30,6 @@ from narwhals._dask.expr import DaskExpr from narwhals._dask.group_by import DaskLazyGroupBy from narwhals._dask.namespace import DaskNamespace - from narwhals._pandas_like.dataframe import PandasLikeDataFrame from narwhals.dtypes import DType from narwhals.utils import Version @@ -79,16 +79,49 @@ def with_columns(self: Self, *exprs: DaskExpr, **named_exprs: DaskExpr) -> Self: df = df.assign(**new_series) return self._from_native_frame(df) - def collect(self: Self) -> PandasLikeDataFrame: - from narwhals._pandas_like.dataframe import PandasLikeDataFrame + def collect( + self: Self, + backend: Implementation | None, + **kwargs: Any, + ) -> CompliantDataFrame: + import pandas as pd - result = self._native_frame.compute() - return PandasLikeDataFrame( - result, - implementation=Implementation.PANDAS, - backend_version=parse_version(pd.__version__), - version=self._version, - ) + result = self._native_frame.compute(**kwargs) + + if backend is None or backend is Implementation.PANDAS: + from narwhals._pandas_like.dataframe import PandasLikeDataFrame + + return PandasLikeDataFrame( + result, + implementation=Implementation.PANDAS, + backend_version=parse_version(pd.__version__), + version=self._version, + ) + + if backend is Implementation.POLARS: + import polars as pl # ignore-banned-import + + from narwhals._polars.dataframe import PolarsDataFrame + + return PolarsDataFrame( + pl.from_pandas(result), + backend_version=parse_version(pl.__version__), + version=self._version, + ) + + if backend is Implementation.PYARROW: + import pyarrow as pa # ignore-banned-import + + from narwhals._arrow.dataframe import ArrowDataFrame + + return ArrowDataFrame( + pa.Table.from_pandas(result), + backend_version=parse_version(pa.__version__), + version=self._version, + ) + + msg = f"Unsupported `backend` value: {backend}" # pragma: no cover + raise ValueError(msg) # pragma: no cover @property def columns(self: Self) -> list[str]: diff --git a/narwhals/_duckdb/dataframe.py b/narwhals/_duckdb/dataframe.py index 6e3699b06..59bcb757f 100644 --- a/narwhals/_duckdb/dataframe.py +++ b/narwhals/_duckdb/dataframe.py @@ -14,6 +14,7 @@ from narwhals._duckdb.utils import parse_exprs_and_named_exprs from narwhals.dependencies import get_duckdb from narwhals.exceptions import ColumnNotFoundError +from narwhals.typing import CompliantDataFrame from narwhals.utils import Implementation from narwhals.utils import Version from narwhals.utils import generate_temporary_column_name @@ -79,20 +80,47 @@ def __getitem__(self: Self, item: str) -> DuckDBInterchangeSeries: self._native_frame.select(item), version=self._version ) - def collect(self: Self) -> pa.Table: - try: + def collect( + self: Self, + backend: ModuleType | Implementation | str | None, + **kwargs: Any, + ) -> CompliantDataFrame: + if backend is None or backend is Implementation.PYARROW: import pyarrow as pa # ignore-banned-import - except ModuleNotFoundError as exc: # pragma: no cover - msg = "PyArrow>=11.0.0 is required to collect `LazyFrame` backed by DuckDcollect `LazyFrame` backed by DuckDB" - raise ModuleNotFoundError(msg) from exc - from narwhals._arrow.dataframe import ArrowDataFrame + from narwhals._arrow.dataframe import ArrowDataFrame - return ArrowDataFrame( - native_dataframe=self._native_frame.arrow(), - backend_version=parse_version(pa.__version__), - version=self._version, - ) + return ArrowDataFrame( + native_dataframe=self._native_frame.arrow(), + backend_version=parse_version(pa.__version__), + version=self._version, + ) + + if backend is Implementation.PANDAS: + import pandas as pd # ignore-banned-import + + from narwhals._pandas_like.dataframe import PandasLikeDataFrame + + return PandasLikeDataFrame( + native_dataframe=self._native_frame.df(), + implementation=Implementation.PANDAS, + backend_version=parse_version(pd.__version__), + version=self._version, + ) + + if backend is Implementation.POLARS: + import polars as pl # ignore-banned-import + + from narwhals._polars.dataframe import PolarsDataFrame + + return PolarsDataFrame( + df=self._native_frame.pl(), + backend_version=parse_version(pl.__version__), + version=self._version, + ) + + msg = f"Unsupported `backend` value: {backend}" # pragma: no cover + raise ValueError(msg) # pragma: no cover def head(self: Self, n: int) -> Self: return self._from_native_frame(self._native_frame.limit(n)) diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index 8972deb70..760df69f3 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -25,6 +25,7 @@ from narwhals.utils import import_dtypes_module from narwhals.utils import is_sequence_but_not_str from narwhals.utils import parse_columns_to_drop +from narwhals.utils import parse_version from narwhals.utils import scale_bytes from narwhals.utils import validate_backend_version @@ -501,13 +502,53 @@ def sort( ) # --- convert --- - def collect(self: Self) -> PandasLikeDataFrame: - return PandasLikeDataFrame( - self._native_frame, - implementation=self._implementation, - backend_version=self._backend_version, - version=self._version, - ) + def collect( + self: Self, + backend: Implementation | None, + **kwargs: Any, + ) -> CompliantDataFrame: + if backend is None: + return PandasLikeDataFrame( + self._native_frame, + implementation=self._implementation, + backend_version=self._backend_version, + version=self._version, + ) + + if backend is Implementation.PANDAS: + import pandas as pd # ignore-banned-import + + return PandasLikeDataFrame( + self.to_pandas(), + implementation=Implementation.PANDAS, + backend_version=parse_version(pd.__version__), + version=self._version, + ) + + if backend is Implementation.PYARROW: + import pyarrow as pa # ignore-banned-import + + from narwhals._arrow.dataframe import ArrowDataFrame + + return ArrowDataFrame( + native_dataframe=self.to_arrow(), + backend_version=parse_version(pa.__version__), + version=self._version, + ) + + if backend is Implementation.POLARS: + import polars as pl # ignore-banned-import + + from narwhals._polars.dataframe import PolarsDataFrame + + return PolarsDataFrame( + df=self.to_polars(), + backend_version=parse_version(pl.__version__), + version=self._version, + ) + + msg = f"Unsupported `backend` value: {backend}" # pragma: no cover + raise ValueError(msg) # pragma: no cover # --- actions --- def group_by(self: Self, *keys: str, drop_null_keys: bool) -> PandasLikeGroupBy: diff --git a/narwhals/_polars/dataframe.py b/narwhals/_polars/dataframe.py index b01894b39..7a8cbfdce 100644 --- a/narwhals/_polars/dataframe.py +++ b/narwhals/_polars/dataframe.py @@ -16,6 +16,7 @@ from narwhals.utils import Implementation from narwhals.utils import is_sequence_but_not_str from narwhals.utils import parse_columns_to_drop +from narwhals.utils import parse_version from narwhals.utils import validate_backend_version if TYPE_CHECKING: @@ -29,6 +30,7 @@ from narwhals._polars.group_by import PolarsLazyGroupBy from narwhals._polars.series import PolarsSeries from narwhals.dtypes import DType + from narwhals.typing import CompliantDataFrame from narwhals.typing import CompliantLazyFrame from narwhals.utils import Version @@ -440,19 +442,52 @@ def collect_schema(self: Self) -> dict[str, DType]: for name, dtype in self._native_frame.collect_schema().items() } - def collect(self: Self) -> PolarsDataFrame: + def collect( + self: Self, + backend: Implementation | None, + **kwargs: Any, + ) -> CompliantDataFrame: import polars as pl try: - result = self._native_frame.collect() + result = self._native_frame.collect(**kwargs) except pl.exceptions.ColumnNotFoundError as e: raise ColumnNotFoundError(str(e)) from e - return PolarsDataFrame( - result, - backend_version=self._backend_version, - version=self._version, - ) + if backend is None or backend is Implementation.POLARS: + from narwhals._polars.dataframe import PolarsDataFrame + + return PolarsDataFrame( + result, + backend_version=self._backend_version, + version=self._version, + ) + + if backend is Implementation.PANDAS: + import pandas as pd # ignore-banned-import + + from narwhals._pandas_like.dataframe import PandasLikeDataFrame + + return PandasLikeDataFrame( + result.to_pandas(), + implementation=Implementation.PANDAS, + backend_version=parse_version(pd.__version__), + version=self._version, + ) + + if backend is Implementation.PYARROW: + import pyarrow as pa # ignore-banned-import + + from narwhals._arrow.dataframe import ArrowDataFrame + + return ArrowDataFrame( + result.to_arrow(), + backend_version=parse_version(pa.__version__), + version=self._version, + ) + + msg = f"Unsupported `backend` value: {backend}" # pragma: no cover + raise ValueError(msg) # pragma: no cover def group_by(self: Self, *by: str, drop_null_keys: bool) -> PolarsLazyGroupBy: from narwhals._polars.group_by import PolarsLazyGroupBy diff --git a/narwhals/_spark_like/dataframe.py b/narwhals/_spark_like/dataframe.py index 0661f3cfb..d4c834b86 100644 --- a/narwhals/_spark_like/dataframe.py +++ b/narwhals/_spark_like/dataframe.py @@ -10,6 +10,7 @@ from narwhals._spark_like.utils import native_to_narwhals_dtype from narwhals._spark_like.utils import parse_exprs_and_named_exprs from narwhals.exceptions import InvalidOperationError +from narwhals.typing import CompliantDataFrame from narwhals.typing import CompliantLazyFrame from narwhals.utils import Implementation from narwhals.utils import check_column_exists @@ -24,7 +25,6 @@ from pyspark.sql import DataFrame from typing_extensions import Self - from narwhals._pandas_like.dataframe import PandasLikeDataFrame from narwhals._spark_like.expr import SparkLikeExpr from narwhals._spark_like.group_by import SparkLikeLazyGroupBy from narwhals._spark_like.namespace import SparkLikeNamespace @@ -112,17 +112,72 @@ def _from_native_frame(self: Self, df: DataFrame) -> Self: def columns(self: Self) -> list[str]: return self._native_frame.columns # type: ignore[no-any-return] - def collect(self: Self) -> PandasLikeDataFrame: - import pandas as pd # ignore-banned-import() + def collect( + self: Self, + backend: ModuleType | Implementation | str | None, + **kwargs: Any, + ) -> CompliantDataFrame: + if backend is Implementation.PANDAS: + import pandas as pd # ignore-banned-import + + from narwhals._pandas_like.dataframe import PandasLikeDataFrame + + return PandasLikeDataFrame( + native_dataframe=self._native_frame.toPandas(), + implementation=Implementation.PANDAS, + backend_version=parse_version(pd.__version__), + version=self._version, + ) - from narwhals._pandas_like.dataframe import PandasLikeDataFrame + elif backend is None or backend is Implementation.PYARROW: + import pyarrow as pa # ignore-banned-import - return PandasLikeDataFrame( - native_dataframe=self._native_frame.toPandas(), - implementation=Implementation.PANDAS, - backend_version=parse_version(pd.__version__), - version=self._version, - ) + from narwhals._arrow.dataframe import ArrowDataFrame + + try: + native_pyarrow_frame = pa.Table.from_batches( + self._native_frame._collect_as_arrow() + ) + except ValueError as exc: + if "at least one RecordBatch" in str(exc): + # Empty dataframe + from narwhals._arrow.utils import narwhals_to_native_dtype + + data: dict[str, list[Any]] = {} + schema = [] + current_schema = self.collect_schema() + for key, value in current_schema.items(): + data[key] = [] + schema.append( + (key, narwhals_to_native_dtype(value, self._version)) + ) + native_pyarrow_frame = pa.Table.from_pydict( + data, schema=pa.schema(schema) + ) + else: # pragma: no cover + raise + return ArrowDataFrame( + native_pyarrow_frame, + backend_version=parse_version(pa.__version__), + version=self._version, + ) + + elif backend is Implementation.POLARS: + import polars as pl # ignore-banned-import + import pyarrow as pa # ignore-banned-import + + from narwhals._polars.dataframe import PolarsDataFrame + + return PolarsDataFrame( + df=pl.from_arrow( # type: ignore[arg-type] + pa.Table.from_batches(self._native_frame._collect_as_arrow()) + ), + backend_version=parse_version(pl.__version__), + version=self._version, + ) + + msg = f"Unsupported `backend` value: {backend}" # pragma: no cover + raise ValueError(msg) # pragma: no cover def simple_select(self: Self, *column_names: str) -> Self: return self._from_native_frame(self._native_frame.select(*column_names)) diff --git a/narwhals/_spark_like/expr_str.py b/narwhals/_spark_like/expr_str.py index 8bae6a030..2873972c3 100644 --- a/narwhals/_spark_like/expr_str.py +++ b/narwhals/_spark_like/expr_str.py @@ -127,14 +127,29 @@ def to_lowercase(self: Self) -> SparkLikeExpr: ) def to_datetime(self: Self, format: str | None) -> SparkLikeExpr: # noqa: A002 + is_naive = ( + format is not None + and "%s" not in format + and "%z" not in format + and "Z" not in format + ) + function = ( + self._compliant_expr._F.to_timestamp_ntz + if is_naive + else self._compliant_expr._F.to_timestamp + ) + pyspark_format = strptime_to_pyspark_format(format) + format = ( + self._compliant_expr._F.lit(pyspark_format) if is_naive else pyspark_format + ) return self._compliant_expr._from_call( - lambda _input: self._compliant_expr._F.to_timestamp( + lambda _input: function( self._compliant_expr._F.replace( _input, self._compliant_expr._F.lit("T"), self._compliant_expr._F.lit(" "), ), - format=strptime_to_pyspark_format(format), + format=format, ), "to_datetime", expr_kind=self._compliant_expr._expr_kind, diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 805c8cd6f..ec6d0f780 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -3808,16 +3808,48 @@ def __getitem__(self: Self, item: str | slice) -> NoReturn: msg = "Slicing is not supported on LazyFrame" raise TypeError(msg) - def collect(self: Self) -> DataFrame[Any]: + def collect( + self: Self, + backend: ModuleType | Implementation | str | None = None, + **kwargs: Any, + ) -> DataFrame[Any]: r"""Materialize this LazyFrame into a DataFrame. + As each underlying lazyframe has different arguments to set when materializing + the lazyframe into a dataframe, we allow to pass them as kwargs (see examples + below for how to generalize the specification). + + Arguments: + backend: specifies which eager backend collect to. This will be the underlying + backend for the resulting Narwhals DataFrame. If None, then the following + default conversions will be applied: + + - `polars.LazyFrame` -> `polars.DataFrame` + - `dask.DataFrame` -> `pandas.DataFrame` + - `duckdb.PyRelation` -> `pyarrow.Table` + - `pyspark.DataFrame` -> `pyarrow.Table` + + `backend` can be specified in various ways: + + - As `Implementation.` with `BACKEND` being `PANDAS`, `PYARROW` + or `POLARS`. + - As a string: `"pandas"`, `"pyarrow"` or `"polars"` + - Directly as a module `pandas`, `pyarrow` or `polars`. + kwargs: backend specific kwargs to pass along. To know more please check the + backend specific documentation: + + - [polars.LazyFrame.collect](https://docs.pola.rs/api/python/dev/reference/lazyframe/api/polars.LazyFrame.collect.html) + - [dask.dataframe.DataFrame.compute](https://docs.dask.org/en/stable/generated/dask.dataframe.DataFrame.compute.html) + Returns: DataFrame Examples: - >>> import narwhals as nw >>> import polars as pl >>> import dask.dataframe as dd + >>> import narwhals as nw + >>> from narwhals.typing import IntoDataFrame, IntoFrame + >>> >>> data = { ... "a": ["a", "b", "a", "b", "b", "c"], ... "b": [1, 2, 3, 4, 5, 6], @@ -3826,28 +3858,14 @@ def collect(self: Self) -> DataFrame[Any]: >>> lf_pl = pl.LazyFrame(data) >>> lf_dask = dd.from_dict(data, npartitions=2) - >>> lf = nw.from_native(lf_pl) - >>> lf # doctest:+ELLIPSIS + >>> nw.from_native(lf_pl) # doctest:+ELLIPSIS ┌─────────────────────────────┐ | Narwhals LazyFrame | |-----------------------------| |>> df = lf.group_by("a").agg(nw.all().sum()).collect() - >>> df.to_native().sort("a") - shape: (3, 3) - ┌─────┬─────┬─────┐ - │ a ┆ b ┆ c │ - │ --- ┆ --- ┆ --- │ - │ str ┆ i64 ┆ i64 │ - ╞═════╪═════╪═════╡ - │ a ┆ 4 ┆ 10 │ - │ b ┆ 11 ┆ 10 │ - │ c ┆ 6 ┆ 1 │ - └─────┴─────┴─────┘ - >>> lf = nw.from_native(lf_dask) - >>> lf + >>> nw.from_native(lf_dask) ┌───────────────────────────────────┐ | Narwhals LazyFrame | |-----------------------------------| @@ -3860,15 +3878,96 @@ def collect(self: Self) -> DataFrame[Any]: |Dask Name: frompandas, 1 expression| |Expr=df | └───────────────────────────────────┘ - >>> df = lf.group_by("a").agg(nw.col("b", "c").sum()).collect() - >>> df.to_native() + + Let's define a dataframe-agnostic that does some grouping computation and + finally collects to a DataFrame: + + >>> def agnostic_group_by_and_collect(lf_native: IntoFrame) -> IntoDataFrame: + ... lf = nw.from_native(lf_native) + ... return ( + ... lf.group_by("a") + ... .agg(nw.col("b", "c").sum()) + ... .sort("a") + ... .collect() + ... .to_native() + ... ) + + We can then pass any supported library such as Polars or Dask + to `agnostic_group_by_and_collect`: + + >>> agnostic_group_by_and_collect(lf_pl) + shape: (3, 3) + ┌─────┬─────┬─────┐ + │ a ┆ b ┆ c │ + │ --- ┆ --- ┆ --- │ + │ str ┆ i64 ┆ i64 │ + ╞═════╪═════╪═════╡ + │ a ┆ 4 ┆ 10 │ + │ b ┆ 11 ┆ 10 │ + │ c ┆ 6 ┆ 1 │ + └─────┴─────┴─────┘ + + >>> agnostic_group_by_and_collect(lf_dask) a b c 0 a 4 10 1 b 11 10 2 c 6 1 + + Now, let's suppose that we want to run lazily, yet without + query optimization (e.g. for debugging purpose) _and_ collect to pyarrow. + As this is achieved differently in polars and dask, to keep a unified workflow + we can specify the native kwargs for each lazy backend, and specify + backend="pyarrow" in order to collect to pyarrow instead of their default. + + >>> collect_kwargs = { + ... nw.Implementation.POLARS: {"no_optimization": True}, + ... nw.Implementation.DASK: {"optimize_graph": False}, + ... nw.Implementation.PYARROW: {}, + ... } + + >>> def agnostic_collect_no_opt(lf_native: IntoFrame) -> IntoDataFrame: + ... lf = nw.from_native(lf_native) + ... return ( + ... lf.group_by("a") + ... .agg(nw.col("b", "c").sum()) + ... .sort("a") + ... .collect( + ... backend="pyarrow", **collect_kwargs.get(lf.implementation, {}) + ... ) + ... .to_native() + ... ) + + >>> agnostic_collect_no_opt(lf_pl) + pyarrow.Table + a: large_string + b: int64 + c: int64 + ---- + a: [["a","b","c"]] + b: [[4,11,6]] + c: [[10,10,1]] + + >>> agnostic_collect_no_opt(lf_dask) + pyarrow.Table + a: large_string + b: int64 + c: int64 + ---- + a: [["a","b","c"]] + b: [[4,11,6]] + c: [[10,10,1]] """ + eager_backend = None if backend is None else Implementation.from_backend(backend) + supported_eager_backends = ( + Implementation.POLARS, + Implementation.PANDAS, + Implementation.PYARROW, + ) + if eager_backend is not None and eager_backend not in supported_eager_backends: + msg = f"Unsupported `backend` value.\nExpected one of {supported_eager_backends} or None, got: {eager_backend}." + raise ValueError(msg) return self._dataframe( - self._compliant_frame.collect(), + self._compliant_frame.collect(backend=eager_backend, **kwargs), level="full", ) @@ -5285,9 +5384,9 @@ def clone(self: Self) -> Self: return super().clone() def lazy(self: Self) -> Self: - """Lazify the DataFrame (if possible). + """Restrict available API methods to lazy-only ones. - If a library does not support lazy execution, then this is a no-op. + This is a no-op, and exists only for compatibility with `DataFrame.lazy`. Returns: A LazyFrame. diff --git a/narwhals/stable/v1/__init__.py b/narwhals/stable/v1/__init__.py index cba00013b..43ab578cd 100644 --- a/narwhals/stable/v1/__init__.py +++ b/narwhals/stable/v1/__init__.py @@ -278,13 +278,43 @@ def _extract_compliant(self: Self, arg: Any) -> Any: raise TypeError(msg) raise InvalidIntoExprError.from_invalid_type(type(arg)) - def collect(self: Self) -> DataFrame[Any]: + def collect( + self: Self, + backend: ModuleType | Implementation | str | None = None, + **kwargs: Any, + ) -> DataFrame[Any]: r"""Materialize this LazyFrame into a DataFrame. + As each underlying lazyframe has different arguments to set when materializing + the lazyframe into a dataframe, we allow to pass them as kwargs (see examples + below for how to generalize the specification). + + Arguments: + backend: specifies which eager backend collect to. This will be the underlying + backend for the resulting Narwhals DataFrame. If None, then the following + default conversions will be applied: + + - `polars.LazyFrame` -> `polars.DataFrame` + - `dask.DataFrame` -> `pandas.DataFrame` + - `duckdb.PyRelation` -> `pyarrow.Table` + - `pyspark.DataFrame` -> `pyarrow.Table` + + `backend` can be specified in various ways: + + - As `Implementation.` with `BACKEND` being `PANDAS`, `PYARROW` + or `POLARS`. + - As a string: `"pandas"`, `"pyarrow"` or `"polars"` + - Directly as a module `pandas`, `pyarrow` or `polars`. + kwargs: backend specific kwargs to pass along. To know more please check the + backend specific documentation: + + - [polars.LazyFrame.collect](https://docs.pola.rs/api/python/dev/reference/lazyframe/api/polars.LazyFrame.collect.html) + - [dask.dataframe.DataFrame.compute](https://docs.dask.org/en/stable/generated/dask.dataframe.DataFrame.compute.html) + Returns: DataFrame """ - return super().collect() # type: ignore[return-value] + return super().collect(backend=backend, **kwargs) # type: ignore[return-value] def _l1_norm(self: Self) -> Self: """Private, just used to test the stable API. diff --git a/narwhals/utils.py b/narwhals/utils.py index bbfe3eeaf..29f831fd3 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -110,6 +110,51 @@ def from_native_namespace( } return mapping.get(native_namespace, Implementation.UNKNOWN) + @classmethod + def from_string( + cls: type[Self], backend_name: str + ) -> Implementation: # pragma: no cover + """Instantiate Implementation object from a native namespace module. + + Arguments: + backend_name: Name of backend, expressed as string. + + Returns: + Implementation. + """ + mapping = { + "pandas": Implementation.PANDAS, + "modin": Implementation.MODIN, + "cudf": Implementation.CUDF, + "pyarrow": Implementation.PYARROW, + "pyspark": Implementation.PYSPARK, + "polars": Implementation.POLARS, + "dask": Implementation.DASK, + "duckdb": Implementation.DUCKDB, + "ibis": Implementation.IBIS, + } + return mapping.get(backend_name, Implementation.UNKNOWN) + + @classmethod + def from_backend( + cls: type[Self], backend: str | Implementation | ModuleType + ) -> Implementation: + """Instantiate from native namespace module, string, or Implementation. + + Arguments: + backend: Backend to instantiate Implementation from. + + Returns: + Implementation. + """ + return ( + cls.from_string(backend) + if isinstance(backend, str) + else backend + if isinstance(backend, Implementation) + else cls.from_native_namespace(backend) + ) + def to_native_namespace(self: Self) -> ModuleType: """Return the native namespace module corresponding to Implementation. diff --git a/pyproject.toml b/pyproject.toml index 9162e39b4..a3377fef1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -203,7 +203,8 @@ exclude_also = [ "if .*implementation.is_cudf", 'request.applymarker\(pytest.mark.xfail', 'backend_version <', - 'if "cudf" in str\(constructor' + 'if "cudf" in str\(constructor', + 'if "pyspark" in str\(constructor' ] [tool.mypy] diff --git a/tests/expr_and_series/str/to_datetime_test.py b/tests/expr_and_series/str/to_datetime_test.py index 99f886a12..b88432384 100644 --- a/tests/expr_and_series/str/to_datetime_test.py +++ b/tests/expr_and_series/str/to_datetime_test.py @@ -1,6 +1,7 @@ from __future__ import annotations from datetime import datetime +from datetime import timezone from typing import TYPE_CHECKING import pyarrow as pa @@ -50,22 +51,25 @@ def test_to_datetime_series(constructor_eager: ConstructorEager) -> None: @pytest.mark.parametrize( - ("data", "expected", "expected_cudf"), + ("data", "expected", "expected_cudf", "expected_pyspark"), [ ( {"a": ["2020-01-01T12:34:56"]}, "2020-01-01 12:34:56", "2020-01-01T12:34:56.000000000", + "2020-01-01 12:34:56+00:00", ), ( {"a": ["2020-01-01T12:34"]}, "2020-01-01 12:34:00", "2020-01-01T12:34:00.000000000", + "2020-01-01 12:34:00+00:00", ), ( {"a": ["20240101123456"]}, "2024-01-01 12:34:56", "2024-01-01T12:34:56.000000000", + "2024-01-01 12:34:56+00:00", ), ], ) @@ -75,15 +79,20 @@ def test_to_datetime_infer_fmt( data: dict[str, list[str]], expected: str, expected_cudf: str, + expected_pyspark: str, ) -> None: - if "polars" in str(constructor) and str(data["a"][0]).isdigit(): + if ( + ("polars" in str(constructor) and str(data["a"][0]).isdigit()) + or "duckdb" in str(constructor) + or ("pyspark" in str(constructor) and data["a"][0] == "20240101123456") + ): request.applymarker(pytest.mark.xfail) + if "cudf" in str(constructor): expected = expected_cudf - if "duckdb" in str(constructor): - request.applymarker(pytest.mark.xfail) - if "pyspark" in str(constructor) and data["a"][0] == "20240101123456": - request.applymarker(pytest.mark.xfail) + elif "pyspark" in str(constructor): + expected = expected_pyspark + result = ( nw.from_native(constructor(data)) .lazy() @@ -138,7 +147,14 @@ def test_to_datetime_infer_fmt_from_date( if "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) data = {"z": ["2020-01-01", "2020-01-02", None]} - expected = [datetime(2020, 1, 1), datetime(2020, 1, 2), None] + if "pyspark" in str(constructor): + expected = [ + datetime(2020, 1, 1, tzinfo=timezone.utc), + datetime(2020, 1, 2, tzinfo=timezone.utc), + None, + ] + else: + expected = [datetime(2020, 1, 1), datetime(2020, 1, 2), None] result = ( nw.from_native(constructor(data)).lazy().select(nw.col("z").str.to_datetime()) ) diff --git a/tests/frame/collect_test.py b/tests/frame/collect_test.py new file mode 100644 index 000000000..c357d4adf --- /dev/null +++ b/tests/frame/collect_test.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pandas as pd +import polars as pl +import pyarrow as pa +import pytest + +import narwhals as nw +import narwhals.stable.v1 as nw_v1 +from narwhals.dependencies import get_cudf +from narwhals.dependencies import get_modin +from narwhals.utils import Implementation +from tests.utils import PANDAS_VERSION +from tests.utils import Constructor +from tests.utils import assert_equal_data + +if TYPE_CHECKING: + from types import ModuleType + +if PANDAS_VERSION < (1,): # pragma: no cover + pytest.skip(allow_module_level=True) + + +data = {"a": [1, 2], "b": [3, 4]} + + +def test_collect_to_default_backend(constructor: Constructor) -> None: + df = nw.from_native(constructor(data)) + result = df.lazy().collect().to_native() + + if "polars" in str(constructor): + expected_cls = pl.DataFrame + elif any(x in str(constructor) for x in ("pandas", "dask")): + expected_cls = pd.DataFrame + elif "modin" in str(constructor): + mpd = get_modin() + expected_cls = mpd.DataFrame + elif "cudf" in str(constructor): + cudf = get_cudf() + expected_cls = cudf.DataFrame + else: # pyarrow, duckdb, and PySpark + expected_cls = pa.Table + + assert isinstance(result, expected_cls) + + +@pytest.mark.filterwarnings( + "ignore:is_sparse is deprecated and will be removed in a future version." +) +@pytest.mark.parametrize( + ("backend", "expected_cls"), + [ + ("pyarrow", pa.Table), + ("polars", pl.DataFrame), + ("pandas", pd.DataFrame), + (Implementation.PYARROW, pa.Table), + (Implementation.POLARS, pl.DataFrame), + (Implementation.PANDAS, pd.DataFrame), + (pa, pa.Table), + (pl, pl.DataFrame), + (pd, pd.DataFrame), + ], +) +def test_collect_to_valid_backend( + constructor: Constructor, + backend: ModuleType | Implementation | str | None, + expected_cls: type, +) -> None: + df = nw.from_native(constructor(data)) + result = df.lazy().collect(backend=backend).to_native() + assert isinstance(result, expected_cls) + + +@pytest.mark.parametrize( + "backend", ["foo", Implementation.DASK, Implementation.MODIN, pytest] +) +def test_collect_to_invalid_backend( + constructor: Constructor, + backend: ModuleType | Implementation | str | None, +) -> None: + df = nw.from_native(constructor(data)) + + with pytest.raises(ValueError, match="Unsupported `backend` value"): + df.lazy().collect(backend=backend).to_native() + + +def test_collect_with_kwargs(constructor: Constructor) -> None: + collect_kwargs = { + nw.Implementation.POLARS: {"no_optimization": True}, + nw.Implementation.DASK: {"optimize_graph": False}, + nw.Implementation.PYARROW: {}, + } + + df = nw_v1.from_native(constructor(data)) + + result = ( + df.lazy() + .select(nw_v1.col("a", "b").sum()) + .collect(**collect_kwargs.get(df.implementation, {})) # type: ignore[arg-type] + ) + + expected = {"a": [3], "b": [7]} + assert_equal_data(result, expected) diff --git a/tests/utils.py b/tests/utils.py index 7174fbb9e..59fe42eb3 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -88,12 +88,12 @@ def assert_equal_data(result: Any, expected: dict[str, Any]) -> None: if is_duckdb: result = from_native(result.to_native().arrow()) if hasattr(result, "collect"): - if result.implementation is Implementation.POLARS and os.environ.get( - "NARWHALS_POLARS_GPU", False - ): # pragma: no cover - result = result.to_native().collect(engine="gpu") - else: - result = result.collect() + kwargs = { + Implementation.POLARS: ( + {"engine": "gpu"} if os.environ.get("NARWHALS_POLARS_GPU", False) else {} + ) # pragma: no cover + } + result = result.collect(**kwargs.get(result.implementation, {})) if hasattr(result, "columns"): for idx, (col, key) in enumerate(zip(result.columns, expected.keys())):