Skip to content

Commit

Permalink
ref
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Jan 23, 2025
1 parent 2ba3872 commit a2a7cf0
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 19 deletions.
2 changes: 1 addition & 1 deletion narwhals/_polars/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def concat_str(

pl_exprs: list[pl.Expr] = [
expr._native_expr # type: ignore[attr-defined]
for expr in (*parse_into_exprs(*exprs, namespace=self),)
for expr in parse_into_exprs(*exprs, namespace=self)
]

if self._backend_version < (0, 20, 6):
Expand Down
18 changes: 5 additions & 13 deletions narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _from_compliant_dataframe(self: Self, df: Any) -> Self:
df, level=self._level
)

def _flatten_parse_col_names_into_expr_and_extract(
def _flatten_and_extract(
self, *exprs: IntoExpr | Any, **named_exprs: IntoExpr | Any
) -> tuple[tuple[IntoCompliantExpr[Any]], dict[str, IntoCompliantExpr[Any]]]:
"""Process `args` and `kwargs`, extracting underlying objects as we go, interpreting strings as column names."""
Expand All @@ -91,14 +91,6 @@ def _flatten_parse_col_names_into_expr_and_extract(
}
return compliant_exprs, compliant_named_exprs

def _flatten_and_extract(
self: Self, *args: Any, **kwargs: Any
) -> tuple[tuple[IntoCompliantExpr[Any]], dict[str, IntoCompliantExpr[Any]]]:
"""Process `args` and `kwargs`, extracting underlying objects as we go."""
args = [self._extract_compliant(v) for v in flatten(args)] # type: ignore[assignment]
kwargs = {k: self._extract_compliant(v) for k, v in kwargs.items()}
return args, kwargs

@abstractmethod
def _extract_compliant(self: Self, arg: Any) -> Any:
raise NotImplementedError
Expand Down Expand Up @@ -138,8 +130,8 @@ def columns(self: Self) -> list[str]:
def with_columns(
self: Self, *exprs: IntoExpr | Iterable[IntoExpr], **named_exprs: IntoExpr
) -> Self:
compliant_exprs, compliant_named_exprs = (
self._flatten_parse_col_names_into_expr_and_extract(*exprs, **named_exprs)
compliant_exprs, compliant_named_exprs = self._flatten_and_extract(
*exprs, **named_exprs
)
return self._from_compliant_dataframe(
self._compliant_frame.with_columns(*compliant_exprs, **compliant_named_exprs),
Expand All @@ -162,8 +154,8 @@ def select(
msg = f"{e!s}\n\nHint: Did you mean one of these columns: {self.columns}?"
raise ColumnNotFoundError(msg) from e

compliant_exprs, compliant_named_exprs = (
self._flatten_parse_col_names_into_expr_and_extract(*exprs, **named_exprs)
compliant_exprs, compliant_named_exprs = self._flatten_and_extract(
*exprs, **named_exprs
)
return self._from_compliant_dataframe(
self._compliant_frame.select(*compliant_exprs, **compliant_named_exprs),
Expand Down
13 changes: 8 additions & 5 deletions narwhals/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,6 @@ def __mod__(self, other: Any) -> Self: ...
def __pow__(self, other: Any) -> Self: ...


IntoCompliantExpr: TypeAlias = (
CompliantExpr[CompliantSeriesT_co] | str | CompliantSeriesT_co
)


class CompliantNamespace(Protocol, Generic[CompliantSeriesT_co]):
def col(self, *column_names: str) -> CompliantExpr[CompliantSeriesT_co]: ...
def lit(
Expand Down Expand Up @@ -287,6 +282,14 @@ class DTypes:
Unknown: type[dtypes.Unknown]


if TYPE_CHECKING:
# This one needs to be in TYPE_CHECKING to pass on 3.9,
# and can only be defined after CompliantExpr has been defined
IntoCompliantExpr: TypeAlias = (
CompliantExpr[CompliantSeriesT_co] | str | CompliantSeriesT_co
)


__all__ = [
"CompliantDataFrame",
"CompliantLazyFrame",
Expand Down

0 comments on commit a2a7cf0

Please sign in to comment.