Skip to content

Commit

Permalink
feat: LazyFrame.collect with backend and **kwargs (#1734)
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi authored Feb 2, 2025
1 parent e0f37bf commit 8ca9422
Show file tree
Hide file tree
Showing 14 changed files with 629 additions and 92 deletions.
46 changes: 40 additions & 6 deletions narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from narwhals.utils import generate_temporary_column_name
from narwhals.utils import is_sequence_but_not_str
from narwhals.utils import parse_columns_to_drop
from narwhals.utils import parse_version
from narwhals.utils import scale_bytes
from narwhals.utils import validate_backend_version

Expand Down Expand Up @@ -559,12 +560,45 @@ def lazy(self: Self, *, backend: Implementation | None = None) -> CompliantLazyF
)
raise AssertionError # pragma: no cover

def collect(self: Self) -> ArrowDataFrame:
return ArrowDataFrame(
self._native_frame,
backend_version=self._backend_version,
version=self._version,
)
def collect(
self: Self,
backend: Implementation | None,
**kwargs: Any,
) -> CompliantDataFrame:
if backend is Implementation.PYARROW or backend is None:
from narwhals._arrow.dataframe import ArrowDataFrame

return ArrowDataFrame(
native_dataframe=self._native_frame,
backend_version=self._backend_version,
version=self._version,
)

if backend is Implementation.PANDAS:
import pandas as pd # ignore-banned-import

from narwhals._pandas_like.dataframe import PandasLikeDataFrame

return PandasLikeDataFrame(
native_dataframe=self._native_frame.to_pandas(),
implementation=Implementation.PANDAS,
backend_version=parse_version(pd.__version__),
version=self._version,
)

if backend is Implementation.POLARS:
import polars as pl # ignore-banned-import

from narwhals._polars.dataframe import PolarsDataFrame

return PolarsDataFrame(
df=pl.from_arrow(self._native_frame), # type: ignore[arg-type]
backend_version=parse_version(pl.__version__),
version=self._version,
)

msg = f"Unsupported `backend` value: {backend}" # pragma: no cover
raise AssertionError(msg) # pragma: no cover

def clone(self: Self) -> Self:
msg = "clone is not yet supported on PyArrow tables"
Expand Down
53 changes: 43 additions & 10 deletions narwhals/_dask/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from narwhals._dask.utils import parse_exprs_and_named_exprs
from narwhals._pandas_like.utils import native_to_narwhals_dtype
from narwhals._pandas_like.utils import select_columns_by_name
from narwhals.typing import CompliantDataFrame
from narwhals.typing import CompliantLazyFrame
from narwhals.utils import Implementation
from narwhals.utils import check_column_exists
Expand All @@ -29,7 +30,6 @@
from narwhals._dask.expr import DaskExpr
from narwhals._dask.group_by import DaskLazyGroupBy
from narwhals._dask.namespace import DaskNamespace
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
from narwhals.dtypes import DType
from narwhals.utils import Version

Expand Down Expand Up @@ -79,16 +79,49 @@ def with_columns(self: Self, *exprs: DaskExpr, **named_exprs: DaskExpr) -> Self:
df = df.assign(**new_series)
return self._from_native_frame(df)

def collect(self: Self) -> PandasLikeDataFrame:
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
def collect(
self: Self,
backend: Implementation | None,
**kwargs: Any,
) -> CompliantDataFrame:
import pandas as pd

result = self._native_frame.compute()
return PandasLikeDataFrame(
result,
implementation=Implementation.PANDAS,
backend_version=parse_version(pd.__version__),
version=self._version,
)
result = self._native_frame.compute(**kwargs)

if backend is None or backend is Implementation.PANDAS:
from narwhals._pandas_like.dataframe import PandasLikeDataFrame

return PandasLikeDataFrame(
result,
implementation=Implementation.PANDAS,
backend_version=parse_version(pd.__version__),
version=self._version,
)

if backend is Implementation.POLARS:
import polars as pl # ignore-banned-import

from narwhals._polars.dataframe import PolarsDataFrame

return PolarsDataFrame(
pl.from_pandas(result),
backend_version=parse_version(pl.__version__),
version=self._version,
)

if backend is Implementation.PYARROW:
import pyarrow as pa # ignore-banned-import

from narwhals._arrow.dataframe import ArrowDataFrame

return ArrowDataFrame(
pa.Table.from_pandas(result),
backend_version=parse_version(pa.__version__),
version=self._version,
)

msg = f"Unsupported `backend` value: {backend}" # pragma: no cover
raise ValueError(msg) # pragma: no cover

@property
def columns(self: Self) -> list[str]:
Expand Down
50 changes: 39 additions & 11 deletions narwhals/_duckdb/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from narwhals._duckdb.utils import parse_exprs_and_named_exprs
from narwhals.dependencies import get_duckdb
from narwhals.exceptions import ColumnNotFoundError
from narwhals.typing import CompliantDataFrame
from narwhals.utils import Implementation
from narwhals.utils import Version
from narwhals.utils import generate_temporary_column_name
Expand Down Expand Up @@ -79,20 +80,47 @@ def __getitem__(self: Self, item: str) -> DuckDBInterchangeSeries:
self._native_frame.select(item), version=self._version
)

def collect(self: Self) -> pa.Table:
try:
def collect(
self: Self,
backend: ModuleType | Implementation | str | None,
**kwargs: Any,
) -> CompliantDataFrame:
if backend is None or backend is Implementation.PYARROW:
import pyarrow as pa # ignore-banned-import
except ModuleNotFoundError as exc: # pragma: no cover
msg = "PyArrow>=11.0.0 is required to collect `LazyFrame` backed by DuckDcollect `LazyFrame` backed by DuckDB"
raise ModuleNotFoundError(msg) from exc

from narwhals._arrow.dataframe import ArrowDataFrame
from narwhals._arrow.dataframe import ArrowDataFrame

return ArrowDataFrame(
native_dataframe=self._native_frame.arrow(),
backend_version=parse_version(pa.__version__),
version=self._version,
)
return ArrowDataFrame(
native_dataframe=self._native_frame.arrow(),
backend_version=parse_version(pa.__version__),
version=self._version,
)

if backend is Implementation.PANDAS:
import pandas as pd # ignore-banned-import

from narwhals._pandas_like.dataframe import PandasLikeDataFrame

return PandasLikeDataFrame(
native_dataframe=self._native_frame.df(),
implementation=Implementation.PANDAS,
backend_version=parse_version(pd.__version__),
version=self._version,
)

if backend is Implementation.POLARS:
import polars as pl # ignore-banned-import

from narwhals._polars.dataframe import PolarsDataFrame

return PolarsDataFrame(
df=self._native_frame.pl(),
backend_version=parse_version(pl.__version__),
version=self._version,
)

msg = f"Unsupported `backend` value: {backend}" # pragma: no cover
raise ValueError(msg) # pragma: no cover

def head(self: Self, n: int) -> Self:
return self._from_native_frame(self._native_frame.limit(n))
Expand Down
55 changes: 48 additions & 7 deletions narwhals/_pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from narwhals.utils import import_dtypes_module
from narwhals.utils import is_sequence_but_not_str
from narwhals.utils import parse_columns_to_drop
from narwhals.utils import parse_version
from narwhals.utils import scale_bytes
from narwhals.utils import validate_backend_version

Expand Down Expand Up @@ -501,13 +502,53 @@ def sort(
)

# --- convert ---
def collect(self: Self) -> PandasLikeDataFrame:
return PandasLikeDataFrame(
self._native_frame,
implementation=self._implementation,
backend_version=self._backend_version,
version=self._version,
)
def collect(
self: Self,
backend: Implementation | None,
**kwargs: Any,
) -> CompliantDataFrame:
if backend is None:
return PandasLikeDataFrame(
self._native_frame,
implementation=self._implementation,
backend_version=self._backend_version,
version=self._version,
)

if backend is Implementation.PANDAS:
import pandas as pd # ignore-banned-import

return PandasLikeDataFrame(
self.to_pandas(),
implementation=Implementation.PANDAS,
backend_version=parse_version(pd.__version__),
version=self._version,
)

if backend is Implementation.PYARROW:
import pyarrow as pa # ignore-banned-import

from narwhals._arrow.dataframe import ArrowDataFrame

return ArrowDataFrame(
native_dataframe=self.to_arrow(),
backend_version=parse_version(pa.__version__),
version=self._version,
)

if backend is Implementation.POLARS:
import polars as pl # ignore-banned-import

from narwhals._polars.dataframe import PolarsDataFrame

return PolarsDataFrame(
df=self.to_polars(),
backend_version=parse_version(pl.__version__),
version=self._version,
)

msg = f"Unsupported `backend` value: {backend}" # pragma: no cover
raise ValueError(msg) # pragma: no cover

# --- actions ---
def group_by(self: Self, *keys: str, drop_null_keys: bool) -> PandasLikeGroupBy:
Expand Down
49 changes: 42 additions & 7 deletions narwhals/_polars/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from narwhals.utils import Implementation
from narwhals.utils import is_sequence_but_not_str
from narwhals.utils import parse_columns_to_drop
from narwhals.utils import parse_version
from narwhals.utils import validate_backend_version

if TYPE_CHECKING:
Expand All @@ -29,6 +30,7 @@
from narwhals._polars.group_by import PolarsLazyGroupBy
from narwhals._polars.series import PolarsSeries
from narwhals.dtypes import DType
from narwhals.typing import CompliantDataFrame
from narwhals.typing import CompliantLazyFrame
from narwhals.utils import Version

Expand Down Expand Up @@ -440,19 +442,52 @@ def collect_schema(self: Self) -> dict[str, DType]:
for name, dtype in self._native_frame.collect_schema().items()
}

def collect(self: Self) -> PolarsDataFrame:
def collect(
self: Self,
backend: Implementation | None,
**kwargs: Any,
) -> CompliantDataFrame:
import polars as pl

try:
result = self._native_frame.collect()
result = self._native_frame.collect(**kwargs)
except pl.exceptions.ColumnNotFoundError as e:
raise ColumnNotFoundError(str(e)) from e

return PolarsDataFrame(
result,
backend_version=self._backend_version,
version=self._version,
)
if backend is None or backend is Implementation.POLARS:
from narwhals._polars.dataframe import PolarsDataFrame

return PolarsDataFrame(
result,
backend_version=self._backend_version,
version=self._version,
)

if backend is Implementation.PANDAS:
import pandas as pd # ignore-banned-import

from narwhals._pandas_like.dataframe import PandasLikeDataFrame

return PandasLikeDataFrame(
result.to_pandas(),
implementation=Implementation.PANDAS,
backend_version=parse_version(pd.__version__),
version=self._version,
)

if backend is Implementation.PYARROW:
import pyarrow as pa # ignore-banned-import

from narwhals._arrow.dataframe import ArrowDataFrame

return ArrowDataFrame(
result.to_arrow(),
backend_version=parse_version(pa.__version__),
version=self._version,
)

msg = f"Unsupported `backend` value: {backend}" # pragma: no cover
raise ValueError(msg) # pragma: no cover

def group_by(self: Self, *by: str, drop_null_keys: bool) -> PolarsLazyGroupBy:
from narwhals._polars.group_by import PolarsLazyGroupBy
Expand Down
Loading

0 comments on commit 8ca9422

Please sign in to comment.