Skip to content

Commit d5619d8

Browse files
committed
simplify dask imports
1 parent fa0d34b commit d5619d8

File tree

4 files changed

+15
-34
lines changed

4 files changed

+15
-34
lines changed

narwhals/_dask/dataframe.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from typing import Literal
88
from typing import Sequence
99

10+
import dask.dataframe as dd
11+
1012
from narwhals._dask.utils import add_row_index
1113
from narwhals._dask.utils import parse_exprs_and_named_exprs
1214
from narwhals._pandas_like.utils import native_to_narwhals_dtype
@@ -23,7 +25,6 @@
2325
if TYPE_CHECKING:
2426
from types import ModuleType
2527

26-
import dask.dataframe as dd
2728
from typing_extensions import Self
2829

2930
from narwhals._dask.expr import DaskExpr
@@ -111,8 +112,6 @@ def select(
111112
*exprs: IntoDaskExpr,
112113
**named_exprs: IntoDaskExpr,
113114
) -> Self:
114-
import dask.dataframe as dd
115-
116115
if exprs and all(isinstance(x, str) for x in exprs) and not named_exprs:
117116
# This is a simple slice => fastpath!
118117
return self._from_native_frame(

narwhals/_dask/expr_str.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from typing import TYPE_CHECKING
44

5+
import dask.dataframe as dd
6+
57
if TYPE_CHECKING:
68
from typing_extensions import Self
79

@@ -93,8 +95,6 @@ def slice(self, offset: int, length: int | None = None) -> DaskExpr:
9395
)
9496

9597
def to_datetime(self: Self, format: str | None) -> DaskExpr: # noqa: A002
96-
import dask.dataframe as dd
97-
9898
return self._compliant_expr._from_call(
9999
lambda _input, format: dd.to_datetime(_input, format=format),
100100
"to_datetime",

narwhals/_dask/group_by.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,20 @@
66
from typing import Callable
77
from typing import Sequence
88

9+
import dask.dataframe as dd
10+
11+
try:
12+
import dask.dataframe.dask_expr as dx
13+
except ModuleNotFoundError: # pragma: no cover
14+
import dask_expr as dx
15+
16+
917
from narwhals._expression_parsing import is_simple_aggregation
1018
from narwhals._expression_parsing import parse_into_exprs
1119
from narwhals.exceptions import AnonymousExprError
1220
from narwhals.utils import remove_prefix
1321

1422
if TYPE_CHECKING:
15-
import dask.dataframe as dd
16-
17-
try:
18-
import dask.dataframe.dask_expr as dx
19-
except ModuleNotFoundError:
20-
import dask_expr as dx
21-
2223
import pandas as pd
2324

2425
from narwhals._dask.dataframe import DaskLazyFrame
@@ -27,8 +28,6 @@
2728

2829

2930
def n_unique() -> dd.Aggregation:
30-
import dask.dataframe as dd
31-
3231
def chunk(s: pd.core.groupby.generic.SeriesGroupBy) -> int:
3332
return s.nunique(dropna=False) # type: ignore[no-any-return]
3433

@@ -49,11 +48,6 @@ def var(
4948
]:
5049
from functools import partial
5150

52-
try:
53-
import dask.dataframe.dask_expr as dx
54-
except ModuleNotFoundError: # pragma: no cover
55-
import dask_expr as dx
56-
5751
return partial(dx._groupby.GroupBy.var, ddof=ddof)
5852

5953

@@ -64,11 +58,6 @@ def std(
6458
]:
6559
from functools import partial
6660

67-
try:
68-
import dask.dataframe.dask_expr as dx
69-
except ModuleNotFoundError: # pragma: no cover
70-
import dask_expr as dx
71-
7261
return partial(dx._groupby.GroupBy.std, ddof=ddof)
7362

7463

narwhals/_dask/namespace.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
from typing import Sequence
99
from typing import cast
1010

11+
import dask.dataframe as dd
12+
import pandas as pd
13+
1114
from narwhals._dask.dataframe import DaskLazyFrame
1215
from narwhals._dask.expr import DaskExpr
1316
from narwhals._dask.selectors import DaskSelectorNamespace
@@ -69,9 +72,6 @@ def nth(self, *column_indices: int) -> DaskExpr:
6972
)
7073

7174
def lit(self, value: Any, dtype: DType | None) -> DaskExpr:
72-
import dask.dataframe as dd
73-
import pandas as pd
74-
7575
def func(df: DaskLazyFrame) -> list[dx.Series]:
7676
return [
7777
dd.from_pandas(
@@ -99,7 +99,6 @@ def func(df: DaskLazyFrame) -> list[dx.Series]:
9999
)
100100

101101
def len(self) -> DaskExpr:
102-
import dask.dataframe as dd
103102
import pandas as pd
104103

105104
def func(df: DaskLazyFrame) -> list[dx.Series]:
@@ -188,8 +187,6 @@ def concat(
188187
*,
189188
how: Literal["horizontal", "vertical", "diagonal"],
190189
) -> DaskLazyFrame:
191-
import dask.dataframe as dd
192-
193190
if len(list(items)) == 0:
194191
msg = "No items to concatenate" # pragma: no cover
195192
raise AssertionError(msg)
@@ -265,8 +262,6 @@ def func(df: DaskLazyFrame) -> list[dx.Series]:
265262
)
266263

267264
def min_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr:
268-
import dask.dataframe as dd
269-
270265
parsed_exprs = parse_into_exprs(*exprs, namespace=self)
271266

272267
def func(df: DaskLazyFrame) -> list[dx.Series]:
@@ -287,8 +282,6 @@ def func(df: DaskLazyFrame) -> list[dx.Series]:
287282
)
288283

289284
def max_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr:
290-
import dask.dataframe as dd
291-
292285
parsed_exprs = parse_into_exprs(*exprs, namespace=self)
293286

294287
def func(df: DaskLazyFrame) -> list[dx.Series]:

0 commit comments

Comments
 (0)