diff --git a/narwhals/_polars/namespace.py b/narwhals/_polars/namespace.py index 45d2440b8..32e53b372 100644 --- a/narwhals/_polars/namespace.py +++ b/narwhals/_polars/namespace.py @@ -161,7 +161,7 @@ def concat_str( pl_exprs: list[pl.Expr] = [ expr._native_expr # type: ignore[attr-defined] - for expr in (*parse_into_exprs(*exprs, namespace=self),) + for expr in parse_into_exprs(*exprs, namespace=self) ] if self._backend_version < (0, 20, 6): diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index d8c68e5b6..33af048ed 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -74,7 +74,7 @@ def _from_compliant_dataframe(self: Self, df: Any) -> Self: df, level=self._level ) - def _flatten_parse_col_names_into_expr_and_extract( + def _flatten_and_extract( self, *exprs: IntoExpr | Any, **named_exprs: IntoExpr | Any ) -> tuple[tuple[IntoCompliantExpr[Any]], dict[str, IntoCompliantExpr[Any]]]: """Process `args` and `kwargs`, extracting underlying objects as we go, interpreting strings as column names.""" @@ -91,14 +91,6 @@ def _flatten_parse_col_names_into_expr_and_extract( } return compliant_exprs, compliant_named_exprs - def _flatten_and_extract( - self: Self, *args: Any, **kwargs: Any - ) -> tuple[tuple[IntoCompliantExpr[Any]], dict[str, IntoCompliantExpr[Any]]]: - """Process `args` and `kwargs`, extracting underlying objects as we go.""" - args = [self._extract_compliant(v) for v in flatten(args)] # type: ignore[assignment] - kwargs = {k: self._extract_compliant(v) for k, v in kwargs.items()} - return args, kwargs - @abstractmethod def _extract_compliant(self: Self, arg: Any) -> Any: raise NotImplementedError @@ -138,8 +130,8 @@ def columns(self: Self) -> list[str]: def with_columns( self: Self, *exprs: IntoExpr | Iterable[IntoExpr], **named_exprs: IntoExpr ) -> Self: - compliant_exprs, compliant_named_exprs = ( - self._flatten_parse_col_names_into_expr_and_extract(*exprs, **named_exprs) + compliant_exprs, compliant_named_exprs = self._flatten_and_extract( + *exprs, **named_exprs ) return self._from_compliant_dataframe( self._compliant_frame.with_columns(*compliant_exprs, **compliant_named_exprs), @@ -162,8 +154,8 @@ def select( msg = f"{e!s}\n\nHint: Did you mean one of these columns: {self.columns}?" raise ColumnNotFoundError(msg) from e - compliant_exprs, compliant_named_exprs = ( - self._flatten_parse_col_names_into_expr_and_extract(*exprs, **named_exprs) + compliant_exprs, compliant_named_exprs = self._flatten_and_extract( + *exprs, **named_exprs ) return self._from_compliant_dataframe( self._compliant_frame.select(*compliant_exprs, **compliant_named_exprs), diff --git a/narwhals/typing.py b/narwhals/typing.py index 041a4480a..9f9e13815 100644 --- a/narwhals/typing.py +++ b/narwhals/typing.py @@ -95,11 +95,6 @@ def __mod__(self, other: Any) -> Self: ... def __pow__(self, other: Any) -> Self: ... -IntoCompliantExpr: TypeAlias = ( - CompliantExpr[CompliantSeriesT_co] | str | CompliantSeriesT_co -) - - class CompliantNamespace(Protocol, Generic[CompliantSeriesT_co]): def col(self, *column_names: str) -> CompliantExpr[CompliantSeriesT_co]: ... def lit( @@ -287,6 +282,14 @@ class DTypes: Unknown: type[dtypes.Unknown] +if TYPE_CHECKING: + # This one needs to be in TYPE_CHECKING to pass on 3.9, + # and can only be defined after CompliantExpr has been defined + IntoCompliantExpr: TypeAlias = ( + CompliantExpr[CompliantSeriesT_co] | str | CompliantSeriesT_co + ) + + __all__ = [ "CompliantDataFrame", "CompliantLazyFrame",