Skip to content

Commit

Permalink
add ascending sort (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
tato-g authored Oct 2, 2024
1 parent 189c87b commit f23b70b
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 15 deletions.
35 changes: 26 additions & 9 deletions subframe/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,7 @@ def select(
input=self.plan.input,
common=stalg.RelCommon(
emit=stalg.RelCommon.Emit(
output_mapping=[next(mapping_counter)
for _ in combined_exprs]
output_mapping=[next(mapping_counter) for _ in combined_exprs]
)
),
expressions=[c.expression for c in combined_exprs],
Expand Down Expand Up @@ -158,8 +157,7 @@ def aggregate(self, metrics: list[AggregateValue], by: list[Value | str]):
input=self.plan.input,
groupings=[
stalg.AggregateRel.Grouping(
grouping_expressions=[
val.expression for val in combined_exprs]
grouping_expressions=[val.expression for val in combined_exprs]
)
],
measures=[
Expand All @@ -169,8 +167,7 @@ def aggregate(self, metrics: list[AggregateValue], by: list[Value | str]):
)
)

names = [c._name for c in combined_exprs] + \
[expr.name for expr in metrics]
names = [c._name for c in combined_exprs] + [expr.name for expr in metrics]

schema = [c.data_type for c in combined_exprs] + [
expr.data_type for expr in metrics
Expand Down Expand Up @@ -199,7 +196,7 @@ def limit(self, n: int | None, offset: int):
)

def union(self, table: "Table", *rest: "Table", distinct: bool = True):
tables = [table] + rest
tables = [table] + list(rest)
rel = stalg.Rel(
set=stalg.SetRel(
inputs=[self.plan.input] + [t.plan.input for t in tables],
Expand All @@ -218,7 +215,7 @@ def union(self, table: "Table", *rest: "Table", distinct: bool = True):
)

def intersect(self, table: "Table", *rest: "Table", distinct: bool = True):
tables = [table] + rest
tables = [table] + list(rest)
rel = stalg.Rel(
set=stalg.SetRel(
inputs=[self.plan.input] + [t.plan.input for t in tables],
Expand All @@ -237,7 +234,7 @@ def intersect(self, table: "Table", *rest: "Table", distinct: bool = True):
)

def difference(self, table: "Table", *rest: "Table", distinct: bool = True):
tables = [table] + rest
tables = [table] + list(rest)
rel = stalg.Rel(
set=stalg.SetRel(
inputs=[self.plan.input] + [t.plan.input for t in tables],
Expand All @@ -255,6 +252,26 @@ def difference(self, table: "Table", *rest: "Table", distinct: bool = True):
extensions=self._merged_extensions(tables),
)

def order_by(self, *by: str):
rel = stalg.Rel(
sort=stalg.SortRel(
input=self.plan.input,
sorts=[
stalg.SortField(
expr=self[e].expression,
direction=stalg.SortField.SortDirection.SORT_DIRECTION_ASC_NULLS_LAST,
)
for e in by
],
)
)

return Table(
plan=stalg.RelRoot(input=rel, names=self.plan.names),
struct=self.struct,
extensions=self.extensions,
)

def as_scalar(self):
expression = stalg.Expression(
subquery=stalg.Expression.Subquery(
Expand Down
32 changes: 26 additions & 6 deletions tests/test_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
("order_total", "float", [10.0, 32.3, 32.0, 140.0]),
]

stores_raw = [("store_id", "int64", [1, 2, 3]),
("city", "string", ["NY", "LA", "NY"])]
stores_raw = [("store_id", "int64", [1, 2, 3]), ("city", "string", ["NY", "LA", "NY"])]

customers_raw = [
("customer_id", "int64", [10, 11, 13]),
Expand Down Expand Up @@ -426,14 +425,36 @@ def transform(module):
run_parity_test(request.getfixturevalue(consumer), ibis_expr, sf_expr)


@pytest.mark.parametrize(
"consumer",
[
"acero_consumer",
"datafusion_consumer",
pytest.param(
"duckdb_consumer",
marks=[pytest.mark.xfail(Exception, reason="Unimplemented")],
),
],
)
def test_order_by(consumer, request):

def transform(module):
table = _orders(module)
return table.order_by("fk_customer_id")

ibis_expr = transform(ibis)
sf_expr = transform(subframe)

run_parity_test(request.getfixturevalue(consumer), ibis_expr, sf_expr)


@pytest.mark.parametrize(
"consumer",
[
pytest.param(
"acero_consumer",
marks=[
pytest.mark.xfail(pa.ArrowNotImplementedError,
reason="Unimplemented")
pytest.mark.xfail(pa.ArrowNotImplementedError, reason="Unimplemented")
],
),
"datafusion_consumer",
Expand All @@ -451,8 +472,7 @@ def transform(module):

return orders.select(
orders["fk_store_id"],
stores.aggregate(
by=[], metrics=[stores["store_id"].max()]).as_scalar(),
stores.aggregate(by=[], metrics=[stores["store_id"].max()]).as_scalar(),
)

ibis_expr = transform(ibis)
Expand Down

0 comments on commit f23b70b

Please sign in to comment.