Skip to content

Commit 01ef676

Browse files
committed
better typing
1 parent e877de8 commit 01ef676

File tree

6 files changed

+35
-37
lines changed

6 files changed

+35
-37
lines changed

narwhals/_expression_parsing.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
from narwhals.utils import Implementation
1919

2020
if TYPE_CHECKING:
21-
from typing_extensions import TypeAlias
22-
2321
from narwhals._arrow.expr import ArrowExpr
2422
from narwhals._pandas_like.expr import PandasLikeExpr
2523
from narwhals.typing import CompliantDataFrame
@@ -28,13 +26,9 @@
2826
from narwhals.typing import CompliantNamespace
2927
from narwhals.typing import CompliantSeries
3028
from narwhals.typing import CompliantSeriesT_co
29+
from narwhals.typing import IntoCompliantExpr
3130
from narwhals.typing import IntoExpr
3231

33-
IntoCompliantExpr: TypeAlias = (
34-
CompliantExpr[CompliantSeriesT_co] | str | CompliantSeriesT_co
35-
)
36-
CompliantExprT = TypeVar("CompliantExprT", bound=CompliantExpr[Any])
37-
3832
ArrowOrPandasLikeExpr = TypeVar(
3933
"ArrowOrPandasLikeExpr", bound=Union[ArrowExpr, PandasLikeExpr]
4034
)

narwhals/_polars/dataframe.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -214,11 +214,7 @@ def __getitem__(self: Self, item: Any) -> Any:
214214
return self._from_native_object(result)
215215

216216
def simple_select(self, *column_names: str) -> Self:
217-
try:
218-
return self._from_native_frame(self._native_frame.select(*column_names))
219-
except Exception as e:
220-
msg = f"{e!s}\n\nHint: Did you mean one of these columns: {self.columns}?"
221-
raise ColumnNotFoundError(msg) from e
217+
return self._from_native_frame(self._native_frame.select(*column_names))
222218

223219
def get_column(self: Self, name: str) -> PolarsSeries:
224220
from narwhals._polars.series import PolarsSeries
@@ -484,7 +480,4 @@ def unpivot(
484480
)
485481

486482
def simple_select(self, *column_names: str) -> Self:
487-
try:
488-
return self._from_native_frame(self._native_frame.select(*column_names))
489-
except Exception as e:
490-
raise ColumnNotFoundError(str(e)) from e
483+
return self._from_native_frame(self._native_frame.select(*column_names))

narwhals/dataframe.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from narwhals.group_by import GroupBy
4545
from narwhals.group_by import LazyGroupBy
4646
from narwhals.series import Series
47+
from narwhals.typing import IntoCompliantExpr
4748
from narwhals.typing import IntoDataFrame
4849
from narwhals.typing import IntoExpr
4950
from narwhals.typing import IntoFrame
@@ -70,28 +71,29 @@ def __narwhals_namespace__(self: Self) -> Any:
7071
def _from_compliant_dataframe(self: Self, df: Any) -> Self:
7172
# construct, preserving properties
7273
return self.__class__( # type: ignore[call-arg]
73-
df,
74-
level=self._level,
74+
df, level=self._level
7575
)
7676

7777
def _flatten_parse_col_names_into_expr_and_extract(
7878
self, *exprs: IntoExpr | Any, **named_exprs: IntoExpr | Any
79-
) -> Any:
79+
) -> tuple[tuple[IntoCompliantExpr[Any]], dict[str, IntoCompliantExpr[Any]]]:
8080
"""Process `args` and `kwargs`, extracting underlying objects as we go, interpreting strings as column names."""
8181
plx = self.__narwhals_namespace__()
82-
exprs = tuple(
82+
compliant_exprs = tuple(
8383
plx.col(expr) if isinstance(expr, str) else self._extract_compliant(expr)
8484
for expr in flatten(exprs)
8585
)
86-
named_exprs = {
86+
compliant_named_exprs = {
8787
key: plx.col(value)
8888
if isinstance(value, str)
8989
else self._extract_compliant(value)
9090
for key, value in named_exprs.items()
9191
}
92-
return exprs, named_exprs
92+
return compliant_exprs, compliant_named_exprs
9393

94-
def _flatten_and_extract(self: Self, *args: Any, **kwargs: Any) -> Any:
94+
def _flatten_and_extract(
95+
self: Self, *args: Any, **kwargs: Any
96+
) -> tuple[tuple[IntoCompliantExpr[Any]], dict[str, IntoCompliantExpr[Any]]]:
9597
"""Process `args` and `kwargs`, extracting underlying objects as we go."""
9698
args = [self._extract_compliant(v) for v in flatten(args)] # type: ignore[assignment]
9799
kwargs = {k: self._extract_compliant(v) for k, v in kwargs.items()}
@@ -136,35 +138,35 @@ def columns(self: Self) -> list[str]:
136138
def with_columns(
137139
self: Self, *exprs: IntoExpr | Iterable[IntoExpr], **named_exprs: IntoExpr
138140
) -> Self:
139-
exprs, named_exprs = self._flatten_parse_col_names_into_expr_and_extract(
140-
*exprs, **named_exprs
141+
compliant_exprs, compliant_named_exprs = (
142+
self._flatten_parse_col_names_into_expr_and_extract(*exprs, **named_exprs)
141143
)
142144
return self._from_compliant_dataframe(
143-
self._compliant_frame.with_columns(*exprs, **named_exprs),
145+
self._compliant_frame.with_columns(*compliant_exprs, **compliant_named_exprs),
144146
)
145147

146148
def select(
147149
self: Self,
148150
*exprs: IntoExpr | Iterable[IntoExpr],
149151
**named_exprs: IntoExpr,
150152
) -> Self:
151-
flat_exprs = list(flatten(exprs))
152-
if flat_exprs and all(isinstance(x, str) for x in flat_exprs) and not named_exprs:
153+
exprs = tuple(flatten(exprs))
154+
if exprs and all(isinstance(x, str) for x in exprs) and not named_exprs:
153155
# fast path!
154156
try:
155157
return self._from_compliant_dataframe(
156-
self._compliant_frame.simple_select(*flat_exprs),
158+
self._compliant_frame.simple_select(*exprs),
157159
)
158160
except Exception as e:
159161
# Column not found is the only thing that can realistically be raised here.
160162
msg = f"{e!s}\n\nHint: Did you mean one of these columns: {self.columns}?"
161163
raise ColumnNotFoundError(msg) from e
162164

163-
flat_exprs, named_exprs = self._flatten_parse_col_names_into_expr_and_extract(
164-
*flat_exprs, **named_exprs
165+
compliant_exprs, compliant_named_exprs = (
166+
self._flatten_parse_col_names_into_expr_and_extract(*exprs, **named_exprs)
165167
)
166168
return self._from_compliant_dataframe(
167-
self._compliant_frame.select(*flat_exprs, **named_exprs),
169+
self._compliant_frame.select(*compliant_exprs, **compliant_named_exprs),
168170
)
169171

170172
def rename(self: Self, mapping: dict[str, str]) -> Self:

narwhals/group_by.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,11 @@ def agg(
122122
"but `df.group_by('a').agg(nw.col('b'))` is not."
123123
)
124124
raise InvalidOperationError(msg)
125-
aggs, named_aggs = self._df._flatten_and_extract(*aggs, **named_aggs)
125+
compliant_aggs, compliant_named_aggs = self._df._flatten_and_extract(
126+
*aggs, **named_aggs
127+
)
126128
return self._df._from_compliant_dataframe( # type: ignore[return-value]
127-
self._grouped.agg(*aggs, **named_aggs),
129+
self._grouped.agg(*compliant_aggs, **compliant_named_aggs),
128130
)
129131

130132
def __iter__(self: Self) -> Iterator[tuple[Any, DataFrameT]]:
@@ -218,7 +220,9 @@ def agg(
218220
"but `df.group_by('a').agg(nw.col('b'))` is not."
219221
)
220222
raise InvalidOperationError(msg)
221-
aggs, named_aggs = self._df._flatten_and_extract(*aggs, **named_aggs)
223+
compliant_aggs, compliant_named_aggs = self._df._flatten_and_extract(
224+
*aggs, **named_aggs
225+
)
222226
return self._df._from_compliant_dataframe( # type: ignore[return-value]
223-
self._grouped.agg(*aggs, **named_aggs),
227+
self._grouped.agg(*compliant_aggs, **compliant_named_aggs),
224228
)

narwhals/typing.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,11 @@ def __mod__(self, other: Any) -> Self: ...
9595
def __pow__(self, other: Any) -> Self: ...
9696

9797

98+
IntoCompliantExpr: TypeAlias = (
99+
CompliantExpr[CompliantSeriesT_co] | str | CompliantSeriesT_co
100+
)
101+
102+
98103
class CompliantNamespace(Protocol, Generic[CompliantSeriesT_co]):
99104
def col(self, *column_names: str) -> CompliantExpr[CompliantSeriesT_co]: ...
100105
def lit(

tests/expr_and_series/concat_str_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def test_concat_str(
3838
nw.concat_str(
3939
[
4040
nw.col("a") * 2,
41-
nw.col("b"),
41+
"b",
4242
nw.col("c"),
4343
],
4444
separator=" ",

0 commit comments

Comments
 (0)