Skip to content

Commit

Permalink
patch: stableify concat function (#869)
Browse files Browse the repository at this point in the history
* patch: stableify concat function

* @overload

* test `to_lazy`

* supposed to raise due to type mismatch
  • Loading branch information
FBruzzesi authored Aug 26, 2024
1 parent 570ca78 commit 9dfd6f5
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 5 deletions.
25 changes: 24 additions & 1 deletion narwhals/stable/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from narwhals.expr import Then as NwThen
from narwhals.expr import When as NwWhen
from narwhals.expr import when as nw_when
from narwhals.functions import concat
from narwhals.functions import show_versions
from narwhals.schema import Schema as NwSchema
from narwhals.series import Series as NwSeries
Expand Down Expand Up @@ -1338,6 +1337,30 @@ def mean_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr:
return _stableify(nw.mean_horizontal(*exprs))


@overload
def concat(
items: Iterable[DataFrame[Any]],
*,
how: Literal["horizontal", "vertical"] = "vertical",
) -> DataFrame[Any]: ...


@overload
def concat(
items: Iterable[LazyFrame[Any]],
*,
how: Literal["horizontal", "vertical"] = "vertical",
) -> LazyFrame[Any]: ...


def concat(
items: Iterable[DataFrame[Any] | LazyFrame[Any]],
*,
how: Literal["horizontal", "vertical"] = "vertical",
) -> DataFrame[Any] | LazyFrame[Any]:
return _stableify(nw.concat(items, how=how)) # type: ignore[no-any-return]


def is_ordered_categorical(series: Series) -> bool:
"""
Return whether indices of categories are semantically meaningful.
Expand Down
6 changes: 3 additions & 3 deletions tests/frame/concat_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ def test_concat_horizontal(constructor: Any, request: Any) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)
data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}
df_left = nw.from_native(constructor(data))
df_left = nw.from_native(constructor(data)).lazy()

data_right = {"c": [6, 12, -1], "d": [0, -4, 2]}
df_right = nw.from_native(constructor(data_right))
df_right = nw.from_native(constructor(data_right)).lazy()

result = nw.concat([df_left, df_right], how="horizontal")
expected = {
Expand All @@ -34,7 +34,7 @@ def test_concat_vertical(constructor: Any, request: Any) -> None:
request.applymarker(pytest.mark.xfail)
data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}
df_left = (
nw.from_native(constructor(data)).rename({"a": "c", "b": "d"}).drop("z").lazy()
nw.from_native(constructor(data)).lazy().rename({"a": "c", "b": "d"}).drop("z")
)

data_right = {"c": [6, 12, -1], "d": [0, -4, 2]}
Expand Down
2 changes: 1 addition & 1 deletion tests/frame/test_invalid.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_validate_laziness() -> None:
NotImplementedError,
match=("The items to concatenate should either all be eager, or all lazy"),
):
nw.concat([nw.from_native(df, eager_only=True), nw.from_native(df).lazy()])
nw.concat([nw.from_native(df, eager_only=True), nw.from_native(df).lazy()]) # type: ignore[list-item]


@pytest.mark.skipif(
Expand Down

0 comments on commit 9dfd6f5

Please sign in to comment.