From f23b70b80a717dda1f64ee9c0bc315f2a194fea2 Mon Sep 17 00:00:00 2001 From: Tato Gurgenidze <75217126+tato-g@users.noreply.github.com> Date: Wed, 2 Oct 2024 17:26:02 +0400 Subject: [PATCH] add ascending sort (#8) --- subframe/table.py | 35 ++++++++++++++++++++++++++--------- tests/test_execution.py | 32 ++++++++++++++++++++++++++------ 2 files changed, 52 insertions(+), 15 deletions(-) diff --git a/subframe/table.py b/subframe/table.py index fb529d7..6d34e46 100644 --- a/subframe/table.py +++ b/subframe/table.py @@ -93,8 +93,7 @@ def select( input=self.plan.input, common=stalg.RelCommon( emit=stalg.RelCommon.Emit( - output_mapping=[next(mapping_counter) - for _ in combined_exprs] + output_mapping=[next(mapping_counter) for _ in combined_exprs] ) ), expressions=[c.expression for c in combined_exprs], @@ -158,8 +157,7 @@ def aggregate(self, metrics: list[AggregateValue], by: list[Value | str]): input=self.plan.input, groupings=[ stalg.AggregateRel.Grouping( - grouping_expressions=[ - val.expression for val in combined_exprs] + grouping_expressions=[val.expression for val in combined_exprs] ) ], measures=[ @@ -169,8 +167,7 @@ def aggregate(self, metrics: list[AggregateValue], by: list[Value | str]): ) ) - names = [c._name for c in combined_exprs] + \ - [expr.name for expr in metrics] + names = [c._name for c in combined_exprs] + [expr.name for expr in metrics] schema = [c.data_type for c in combined_exprs] + [ expr.data_type for expr in metrics @@ -199,7 +196,7 @@ def limit(self, n: int | None, offset: int): ) def union(self, table: "Table", *rest: "Table", distinct: bool = True): - tables = [table] + rest + tables = [table] + list(rest) rel = stalg.Rel( set=stalg.SetRel( inputs=[self.plan.input] + [t.plan.input for t in tables], @@ -218,7 +215,7 @@ def union(self, table: "Table", *rest: "Table", distinct: bool = True): ) def intersect(self, table: "Table", *rest: "Table", distinct: bool = True): - tables = [table] + rest + tables = [table] + list(rest) rel = stalg.Rel( set=stalg.SetRel( inputs=[self.plan.input] + [t.plan.input for t in tables], @@ -237,7 +234,7 @@ def intersect(self, table: "Table", *rest: "Table", distinct: bool = True): ) def difference(self, table: "Table", *rest: "Table", distinct: bool = True): - tables = [table] + rest + tables = [table] + list(rest) rel = stalg.Rel( set=stalg.SetRel( inputs=[self.plan.input] + [t.plan.input for t in tables], @@ -255,6 +252,26 @@ def difference(self, table: "Table", *rest: "Table", distinct: bool = True): extensions=self._merged_extensions(tables), ) + def order_by(self, *by: str): + rel = stalg.Rel( + sort=stalg.SortRel( + input=self.plan.input, + sorts=[ + stalg.SortField( + expr=self[e].expression, + direction=stalg.SortField.SortDirection.SORT_DIRECTION_ASC_NULLS_LAST, + ) + for e in by + ], + ) + ) + + return Table( + plan=stalg.RelRoot(input=rel, names=self.plan.names), + struct=self.struct, + extensions=self.extensions, + ) + def as_scalar(self): expression = stalg.Expression( subquery=stalg.Expression.Subquery( diff --git a/tests/test_execution.py b/tests/test_execution.py index f455444..f8ec002 100644 --- a/tests/test_execution.py +++ b/tests/test_execution.py @@ -20,8 +20,7 @@ ("order_total", "float", [10.0, 32.3, 32.0, 140.0]), ] -stores_raw = [("store_id", "int64", [1, 2, 3]), - ("city", "string", ["NY", "LA", "NY"])] +stores_raw = [("store_id", "int64", [1, 2, 3]), ("city", "string", ["NY", "LA", "NY"])] customers_raw = [ ("customer_id", "int64", [10, 11, 13]), @@ -426,14 +425,36 @@ def transform(module): run_parity_test(request.getfixturevalue(consumer), ibis_expr, sf_expr) +@pytest.mark.parametrize( + "consumer", + [ + "acero_consumer", + "datafusion_consumer", + pytest.param( + "duckdb_consumer", + marks=[pytest.mark.xfail(Exception, reason="Unimplemented")], + ), + ], +) +def test_order_by(consumer, request): + + def transform(module): + table = _orders(module) + return table.order_by("fk_customer_id") + + ibis_expr = transform(ibis) + sf_expr = transform(subframe) + + run_parity_test(request.getfixturevalue(consumer), ibis_expr, sf_expr) + + @pytest.mark.parametrize( "consumer", [ pytest.param( "acero_consumer", marks=[ - pytest.mark.xfail(pa.ArrowNotImplementedError, - reason="Unimplemented") + pytest.mark.xfail(pa.ArrowNotImplementedError, reason="Unimplemented") ], ), "datafusion_consumer", @@ -451,8 +472,7 @@ def transform(module): return orders.select( orders["fk_store_id"], - stores.aggregate( - by=[], metrics=[stores["store_id"].max()]).as_scalar(), + stores.aggregate(by=[], metrics=[stores["store_id"].max()]).as_scalar(), ) ibis_expr = transform(ibis)