From 9b9c1e5c7e493f44b0c6bce664a67e4ab5fc05cf Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Fri, 23 Aug 2024 15:48:54 +0200 Subject: [PATCH] patch: lit broadcast --- narwhals/_arrow/namespace.py | 4 ++-- narwhals/_pandas_like/namespace.py | 12 +++++++++--- tests/frame/lit_test.py | 18 +++++++++++++++++- 3 files changed, 28 insertions(+), 6 deletions(-) diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index a17a6a6265..aa6abd5c6d 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -147,9 +147,9 @@ def all(self) -> ArrowExpr: ) def lit(self, value: Any, dtype: dtypes.DType | None) -> ArrowExpr: - def _lit_arrow_series(_: ArrowDataFrame) -> ArrowSeries: + def _lit_arrow_series(df: ArrowDataFrame) -> ArrowSeries: arrow_series = ArrowSeries._from_iterable( - data=[value], + data=[value] * len(df), name="lit", backend_version=self._backend_version, ) diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index 3c378e3c90..481037a5ba 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -15,10 +15,10 @@ from narwhals._pandas_like.utils import create_native_series from narwhals._pandas_like.utils import horizontal_concat from narwhals._pandas_like.utils import vertical_concat +from narwhals.utils import Implementation if TYPE_CHECKING: from narwhals._pandas_like.typing import IntoPandasLikeExpr - from narwhals.utils import Implementation class PandasLikeNamespace: @@ -130,11 +130,17 @@ def all(self) -> PandasLikeExpr: ) def lit(self, value: Any, dtype: dtypes.DType | None) -> PandasLikeExpr: + if self._implementation is Implementation.CUDF: + import cupy as np # ignore-banned-import + else: + import numpy as np # ignore-banned-import + def _lit_pandas_series(df: PandasLikeDataFrame) -> PandasLikeSeries: + native_frame = df._native_frame pandas_series = PandasLikeSeries._from_iterable( - data=[value], + data=np.full(native_frame.shape[0], value), name="lit", - index=df._native_frame.index[0:1], + index=native_frame.index, implementation=self._implementation, backend_version=self._backend_version, ) diff --git a/tests/frame/lit_test.py b/tests/frame/lit_test.py index 328e4d8e03..98945c934a 100644 --- a/tests/frame/lit_test.py +++ b/tests/frame/lit_test.py @@ -24,7 +24,7 @@ def test_lit( request.applymarker(pytest.mark.xfail) data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} df_raw = constructor(data) - df = nw.from_native(df_raw).lazy() + df = nw.from_native(df_raw) result = df.with_columns(nw.lit(2, dtype).alias("lit")) expected = { "a": [1, 3, 2], @@ -35,6 +35,22 @@ def test_lit( compare_dicts(result, expected) +def test_lit_operation(constructor: Any) -> None: + data = {"a": [1, 3, 2]} + df_raw = constructor(data) + df = nw.from_native(df_raw) + + result = df.select( + left_lit=nw.lit(1) + nw.col("a"), + right_lit=nw.col("a") - nw.lit(1), + ) + expected = { + "left_lit": [2, 4, 3], + "right_lit": [0, 2, 1], + } + compare_dicts(result, expected) + + def test_lit_error(constructor: Any) -> None: data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} df_raw = constructor(data)