Skip to content

Commit 1adc576

Browse files
authored
Add more supported type annotations, fix spark connect issue (#542)
* Add more supported type annotations, fix spark connect issue * update * update * update * update * update * update * update
1 parent 48b7ab6 commit 1adc576

File tree

13 files changed

+95
-32
lines changed

13 files changed

+95
-32
lines changed

.devcontainer/devcontainer.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
],
3939
"postCreateCommand": "make devenv",
4040
"features": {
41-
"ghcr.io/devcontainers/features/docker-in-docker:2": {},
41+
"ghcr.io/devcontainers/features/docker-in-docker:2.11.0": {},
4242
"ghcr.io/devcontainers/features/java:1": {
4343
"version": "11"
4444
}

.github/workflows/test_all.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ jobs:
2525
runs-on: ubuntu-latest
2626
strategy:
2727
matrix:
28-
python-version: [3.8, "3.10"] # TODO: add back 3.11 when dask-sql is compatible
28+
python-version: [3.8, "3.10", "3.11"]
2929

3030
steps:
3131
- uses: actions/checkout@v2

RELEASE.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# Release Notes
22

3+
## 0.9.1
4+
5+
- [543](https://github.com/fugue-project/fugue/issues/543) Support type hinting with standard collections
6+
- [544](https://github.com/fugue-project/fugue/issues/544) Fix Spark connect import issue on worker side
7+
38
## 0.9.0
49

510
- [482](https://github.com/fugue-project/fugue/issues/482) Move Fugue SQL dependencies into extra `[sql]` and functions to become soft dependencies

fugue/dataframe/function_wrapper.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
PositionalParam,
2121
function_wrapper,
2222
)
23+
from triad.utils.convert import compare_annotations
2324
from triad.utils.iter import EmptyAwareIterable, make_empty_aware
2425

2526
from ..constants import FUGUE_ENTRYPOINT
@@ -37,6 +38,14 @@
3738
from .pandas_dataframe import PandasDataFrame
3839

3940

41+
def _compare_iter(tp: Any) -> Any:
42+
return lambda x: compare_annotations(
43+
x, Iterable[tp] # type:ignore
44+
) or compare_annotations(
45+
x, Iterator[tp] # type:ignore
46+
)
47+
48+
4049
@function_wrapper(FUGUE_ENTRYPOINT)
4150
class DataFrameFunctionWrapper(FunctionWrapper):
4251
@property
@@ -228,10 +237,7 @@ def count(self, df: List[List[Any]]) -> int:
228237
return len(df)
229238

230239

231-
@fugue_annotated_param(
232-
Iterable[List[Any]],
233-
matcher=lambda x: x == Iterable[List[Any]] or x == Iterator[List[Any]],
234-
)
240+
@fugue_annotated_param(Iterable[List[Any]], matcher=_compare_iter(List[Any]))
235241
class _IterableListParam(_LocalNoSchemaDataFrameParam):
236242
@no_type_check
237243
def to_input_data(self, df: DataFrame, ctx: Any) -> Iterable[List[Any]]:
@@ -288,10 +294,7 @@ def count(self, df: List[Dict[str, Any]]) -> int:
288294
return len(df)
289295

290296

291-
@fugue_annotated_param(
292-
Iterable[Dict[str, Any]],
293-
matcher=lambda x: x == Iterable[Dict[str, Any]] or x == Iterator[Dict[str, Any]],
294-
)
297+
@fugue_annotated_param(Iterable[Dict[str, Any]], matcher=_compare_iter(Dict[str, Any]))
295298
class _IterableDictParam(_LocalNoSchemaDataFrameParam):
296299
@no_type_check
297300
def to_input_data(self, df: DataFrame, ctx: Any) -> Iterable[Dict[str, Any]]:
@@ -360,10 +363,7 @@ def format_hint(self) -> Optional[str]:
360363
return "pandas"
361364

362365

363-
@fugue_annotated_param(
364-
Iterable[pd.DataFrame],
365-
matcher=lambda x: x == Iterable[pd.DataFrame] or x == Iterator[pd.DataFrame],
366-
)
366+
@fugue_annotated_param(Iterable[pd.DataFrame], matcher=_compare_iter(pd.DataFrame))
367367
class _IterablePandasParam(LocalDataFrameParam):
368368
@no_type_check
369369
def to_input_data(self, df: DataFrame, ctx: Any) -> Iterable[pd.DataFrame]:
@@ -419,10 +419,7 @@ def format_hint(self) -> Optional[str]:
419419
return "pyarrow"
420420

421421

422-
@fugue_annotated_param(
423-
Iterable[pa.Table],
424-
matcher=lambda x: x == Iterable[pa.Table] or x == Iterator[pa.Table],
425-
)
422+
@fugue_annotated_param(Iterable[pa.Table], matcher=_compare_iter(pa.Table))
426423
class _IterableArrowParam(LocalDataFrameParam):
427424
@no_type_check
428425
def to_input_data(self, df: DataFrame, ctx: Any) -> Iterable[pa.Table]:

fugue_spark/_utils/misc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
try:
44
from pyspark.sql.connect.session import SparkSession as SparkConnectSession
55
from pyspark.sql.connect.dataframe import DataFrame as SparkConnectDataFrame
6-
except ImportError: # pragma: no cover
6+
except Exception: # pragma: no cover
77
SparkConnectSession = None
88
SparkConnectDataFrame = None
99
import pyspark.sql as ps

fugue_version/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.9.0"
1+
__version__ = "0.9.1"

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def get_version() -> str:
3838
keywords="distributed spark dask ray sql dsl domain specific language",
3939
url="http://github.com/fugue-project/fugue",
4040
install_requires=[
41-
"triad>=0.9.6",
41+
"triad>=0.9.7",
4242
"adagio>=0.2.4",
4343
],
4444
extras_require={

tests/fugue/dataframe/test_function_wrapper.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
from __future__ import annotations
2+
13
import copy
4+
import sys
25
from typing import Any, Dict, Iterable, Iterator, List
36

47
import pandas as pd
@@ -29,7 +32,10 @@
2932

3033

3134
def test_function_wrapper():
32-
for f in [f20, f21, f212, f22, f23, f24, f25, f26, f30, f31, f32, f35, f36]:
35+
fs = [f20, f21, f212, f22, f23, f24, f25, f26, f30, f31, f32, f35, f36]
36+
if sys.version_info >= (3, 9):
37+
fs.append(f33)
38+
for f in fs:
3339
df = ArrayDataFrame([[0]], "a:int")
3440
w = DataFrameFunctionWrapper(f, "^[ldsp][ldsp]$", "[ldspq]")
3541
res = w.run([df], dict(a=df), ignore_unknown=False, output_schema="a:int")
@@ -372,6 +378,14 @@ def f32(
372378
return ArrayDataFrame(arr, "a:int").as_dict_iterable()
373379

374380

381+
def f33(
382+
e: list[dict[str, Any]], a: Iterable[dict[str, Any]]
383+
) -> EmptyAwareIterable[Dict[str, Any]]:
384+
e += list(a)
385+
arr = [[x["a"]] for x in e]
386+
return ArrayDataFrame(arr, "a:int").as_dict_iterable()
387+
388+
375389
def f35(e: pd.DataFrame, a: LocalDataFrame) -> Iterable[pd.DataFrame]:
376390
e = PandasDataFrame(e, "a:int").as_pandas()
377391
a = ArrayDataFrame(a, "a:int").as_pandas()

tests/fugue_dask/test_execution_engine.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from fugue_dask.execution_engine import DaskExecutionEngine
2525
from fugue_test.builtin_suite import BuiltInTests
2626
from fugue_test.execution_suite import ExecutionEngineTests
27+
from fugue.column import col, all_cols
28+
import fugue.column.functions as ff
2729

2830
_CONF = {
2931
"fugue.rpc.server": "fugue.rpc.flask.FlaskRPCServer",
@@ -50,6 +52,46 @@ def test_get_parallelism(self):
5052
def test__join_outer_pandas_incompatible(self):
5153
return
5254

55+
# TODO: dask-sql 2024.5.0 has a bug, can't pass the HAVING tests
56+
def test_select(self):
57+
try:
58+
import qpd
59+
import dask_sql
60+
except ImportError:
61+
return
62+
63+
a = ArrayDataFrame(
64+
[[1, 2], [None, 2], [None, 1], [3, 4], [None, 4]], "a:double,b:int"
65+
)
66+
67+
# simple
68+
b = fa.select(a, col("b"), (col("b") + 1).alias("c").cast(str))
69+
self.df_eq(
70+
b,
71+
[[2, "3"], [2, "3"], [1, "2"], [4, "5"], [4, "5"]],
72+
"b:int,c:str",
73+
throw=True,
74+
)
75+
76+
# with distinct
77+
b = fa.select(
78+
a, col("b"), (col("b") + 1).alias("c").cast(str), distinct=True
79+
)
80+
self.df_eq(
81+
b,
82+
[[2, "3"], [1, "2"], [4, "5"]],
83+
"b:int,c:str",
84+
throw=True,
85+
)
86+
87+
# wildcard
88+
b = fa.select(a, all_cols(), where=col("a") + col("b") == 3)
89+
self.df_eq(b, [[1, 2]], "a:double,b:int", throw=True)
90+
91+
# aggregation
92+
b = fa.select(a, col("a"), ff.sum(col("b")).cast(float).alias("b"))
93+
self.df_eq(b, [[1, 2], [3, 4], [None, 7]], "a:double,b:double", throw=True)
94+
5395
def test_to_df(self):
5496
e = self.engine
5597
a = e.to_df([[1, 2], [3, 4]], "a:int,b:int")

tests/fugue_duckdb/test_dataframe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
class DuckDataFrameTests(DataFrameTests.Tests):
1616
def df(self, data: Any = None, schema: Any = None) -> DuckDataFrame:
1717
df = ArrowDataFrame(data, schema)
18-
return DuckDataFrame(duckdb.from_arrow(df.native, self.context.session))
18+
return DuckDataFrame(duckdb.from_arrow(df.native, connection=self.context.session))
1919

2020
def test_as_array_special_values(self):
2121
for func in [
@@ -69,7 +69,7 @@ def test_duck_as_local(self):
6969
class NativeDuckDataFrameTests(DataFrameTests.NativeTests):
7070
def df(self, data: Any = None, schema: Any = None) -> DuckDataFrame:
7171
df = ArrowDataFrame(data, schema)
72-
return DuckDataFrame(duckdb.from_arrow(df.native, self.context.session)).native
72+
return DuckDataFrame(duckdb.from_arrow(df.native, connection=self.context.session)).native
7373

7474
def to_native_df(self, pdf: pd.DataFrame) -> Any:
7575
return duckdb.from_df(pdf)

tests/fugue_duckdb/test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def test_type_conversion(backend_context):
4545

4646
def assert_(tp):
4747
dt = duckdb.from_arrow(
48-
pa.Table.from_pydict(dict(a=pa.nulls(2, tp))), con
48+
pa.Table.from_pydict(dict(a=pa.nulls(2, tp))), connection=con
4949
).types[0]
5050
assert to_pa_type(dt) == tp
5151
dt = to_duck_type(tp)

tests/fugue_ibis/mock/execution_engine.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -81,15 +81,17 @@ def sample(
8181
f"one and only one of n and frac should be non-negative, {n}, {frac}"
8282
),
8383
)
84-
tn = self.get_temp_table_name()
84+
idf = self.to_df(df)
85+
tn = f"({idf.native.compile()})"
86+
if seed is not None:
87+
_seed = f",{seed}"
88+
else:
89+
_seed = ""
8590
if frac is not None:
86-
sql = f"SELECT * FROM {tn} USING SAMPLE bernoulli({frac*100} PERCENT)"
91+
sql = f"SELECT * FROM {tn} USING SAMPLE {frac*100}% (bernoulli{_seed})"
8792
else:
88-
sql = f"SELECT * FROM {tn} USING SAMPLE reservoir({n} ROWS)"
89-
if seed is not None:
90-
sql += f" REPEATABLE ({seed})"
91-
idf = self.to_df(df)
92-
_res = f"WITH {tn} AS ({idf.native.compile()}) " + sql
93+
sql = f"SELECT * FROM {tn} USING SAMPLE {n} ROWS (reservoir{_seed})"
94+
_res = f"SELECT * FROM ({sql})" # ibis has a bug to inject LIMIT
9395
return self.to_df(self.backend.sql(_res))
9496

9597
def _register_df(

tests/fugue_ibis/test_execution_engine.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ def test_properties(self):
2323
assert not self.engine.map_engine.is_distributed
2424
assert not self.engine.sql_engine.is_distributed
2525

26+
assert self.engine.sql_engine.get_temp_table_name(
27+
) != self.engine.sql_engine.get_temp_table_name()
28+
2629
def test_select(self):
2730
# it can't work properly with DuckDB (hugeint is not recognized)
2831
pass

0 commit comments

Comments
 (0)